Skip to content

Commit

Permalink
test: add test for IsTimeoutErr (#530)
Browse files Browse the repository at this point in the history
  • Loading branch information
gssbzn committed Mar 4, 2024
1 parent 9acce38 commit cc26263
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
13 changes: 7 additions & 6 deletions auth/device_flow.go
Expand Up @@ -46,7 +46,7 @@ type RegistrationConfig struct {
const deviceBasePath = "api/private/unauth/account/device"

// RequestCode initiates the authorization flow by requesting a code.
func (c Config) RequestCode(ctx context.Context) (*DeviceCode, *atlas.Response, error) {
func (c *Config) RequestCode(ctx context.Context) (*DeviceCode, *atlas.Response, error) {
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/authorize",
url.Values{
"client_id": {c.ClientID},
Expand All @@ -62,7 +62,7 @@ func (c Config) RequestCode(ctx context.Context) (*DeviceCode, *atlas.Response,
}

// GetToken gets a device token.
func (c Config) GetToken(ctx context.Context, deviceCode string) (*Token, *atlas.Response, error) {
func (c *Config) GetToken(ctx context.Context, deviceCode string) (*Token, *atlas.Response, error) {
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token",
url.Values{
"client_id": {c.ClientID},
Expand All @@ -85,7 +85,7 @@ func (c Config) GetToken(ctx context.Context, deviceCode string) (*Token, *atlas
var ErrTimeout = errors.New("authentication timed out")

// PollToken polls the server until an access token is granted or denied.
func (c Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *atlas.Response, error) {
func (c *Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *atlas.Response, error) {
timeNow := code.timeNow
if timeNow == nil {
timeNow = time.Now
Expand Down Expand Up @@ -117,7 +117,7 @@ func (c Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *atlas
}

// RefreshToken takes a refresh token and gets a new access token.
func (c Config) RefreshToken(ctx context.Context, token string) (*Token, *atlas.Response, error) {
func (c *Config) RefreshToken(ctx context.Context, token string) (*Token, *atlas.Response, error) {
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token",
url.Values{
"client_id": {c.ClientID},
Expand All @@ -138,7 +138,7 @@ func (c Config) RefreshToken(ctx context.Context, token string) (*Token, *atlas.
}

// RevokeToken takes an access or refresh token and revokes it.
func (c Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (*atlas.Response, error) {
func (c *Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (*atlas.Response, error) {
req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/revoke",
url.Values{
"client_id": {c.ClientID},
Expand All @@ -154,7 +154,7 @@ func (c Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (*
}

// RegistrationConfig retrieves the config used for registration.
func (c Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *atlas.Response, error) {
func (c *Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *atlas.Response, error) {
req, err := c.NewRequest(ctx, http.MethodGet, deviceBasePath+"/registration", url.Values{})
if err != nil {
return nil, nil, err
Expand All @@ -167,6 +167,7 @@ func (c Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *a
return rc, resp, err
}

// IsTimeoutErr checks if the given error is for the case where the device flow has expired.
func IsTimeoutErr(err error) bool {
var target *atlas.ErrorResponse
return errors.Is(err, ErrTimeout) || (errors.As(err, &target) && target.ErrorCode == authExpiredError)
Expand Down
12 changes: 11 additions & 1 deletion auth/device_flow_test.go
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/go-test/deep"
atlas "go.mongodb.org/atlas/mongodbatlas"
)

func TestConfig_RequestCode(t *testing.T) {
Expand Down Expand Up @@ -135,7 +136,7 @@ func TestConfig_PollToken(t *testing.T) {

mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r)
fmt.Fprint(w, `{
_, _ = fmt.Fprint(w, `{
"access_token": "secret1",
"refresh_token": "secret2",
"scope": "openid",
Expand Down Expand Up @@ -209,3 +210,12 @@ func TestConfig_RegistrationConfig(t *testing.T) {
t.Error(diff)
}
}

func TestIsTimeoutErr(t *testing.T) {
err := &atlas.ErrorResponse{
ErrorCode: "DEVICE_AUTHORIZATION_EXPIRED",
}
if !IsTimeoutErr(err) {
t.Error("expected to be a timeout error")
}
}

0 comments on commit cc26263

Please sign in to comment.