diff --git a/auth/device_flow.go b/auth/device_flow.go index 7934c1df1..3769dd9ff 100644 --- a/auth/device_flow.go +++ b/auth/device_flow.go @@ -46,7 +46,7 @@ type RegistrationConfig struct { const deviceBasePath = "api/private/unauth/account/device" // RequestCode initiates the authorization flow by requesting a code. -func (c Config) RequestCode(ctx context.Context) (*DeviceCode, *atlas.Response, error) { +func (c *Config) RequestCode(ctx context.Context) (*DeviceCode, *atlas.Response, error) { req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/authorize", url.Values{ "client_id": {c.ClientID}, @@ -62,7 +62,7 @@ func (c Config) RequestCode(ctx context.Context) (*DeviceCode, *atlas.Response, } // GetToken gets a device token. -func (c Config) GetToken(ctx context.Context, deviceCode string) (*Token, *atlas.Response, error) { +func (c *Config) GetToken(ctx context.Context, deviceCode string) (*Token, *atlas.Response, error) { req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token", url.Values{ "client_id": {c.ClientID}, @@ -85,7 +85,7 @@ func (c Config) GetToken(ctx context.Context, deviceCode string) (*Token, *atlas var ErrTimeout = errors.New("authentication timed out") // PollToken polls the server until an access token is granted or denied. -func (c Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *atlas.Response, error) { +func (c *Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *atlas.Response, error) { timeNow := code.timeNow if timeNow == nil { timeNow = time.Now @@ -117,7 +117,7 @@ func (c Config) PollToken(ctx context.Context, code *DeviceCode) (*Token, *atlas } // RefreshToken takes a refresh token and gets a new access token. -func (c Config) RefreshToken(ctx context.Context, token string) (*Token, *atlas.Response, error) { +func (c *Config) RefreshToken(ctx context.Context, token string) (*Token, *atlas.Response, error) { req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/token", url.Values{ "client_id": {c.ClientID}, @@ -138,7 +138,7 @@ func (c Config) RefreshToken(ctx context.Context, token string) (*Token, *atlas. } // RevokeToken takes an access or refresh token and revokes it. -func (c Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (*atlas.Response, error) { +func (c *Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (*atlas.Response, error) { req, err := c.NewRequest(ctx, http.MethodPost, deviceBasePath+"/revoke", url.Values{ "client_id": {c.ClientID}, @@ -154,7 +154,7 @@ func (c Config) RevokeToken(ctx context.Context, token, tokenTypeHint string) (* } // RegistrationConfig retrieves the config used for registration. -func (c Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *atlas.Response, error) { +func (c *Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *atlas.Response, error) { req, err := c.NewRequest(ctx, http.MethodGet, deviceBasePath+"/registration", url.Values{}) if err != nil { return nil, nil, err @@ -167,6 +167,7 @@ func (c Config) RegistrationConfig(ctx context.Context) (*RegistrationConfig, *a return rc, resp, err } +// IsTimeoutErr checks if the given error is for the case where the device flow has expired. func IsTimeoutErr(err error) bool { var target *atlas.ErrorResponse return errors.Is(err, ErrTimeout) || (errors.As(err, &target) && target.ErrorCode == authExpiredError) diff --git a/auth/device_flow_test.go b/auth/device_flow_test.go index 6e75a22c8..9d48f2ab2 100644 --- a/auth/device_flow_test.go +++ b/auth/device_flow_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/go-test/deep" + atlas "go.mongodb.org/atlas/mongodbatlas" ) func TestConfig_RequestCode(t *testing.T) { @@ -135,7 +136,7 @@ func TestConfig_PollToken(t *testing.T) { mux.HandleFunc("/api/private/unauth/account/device/token", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r) - fmt.Fprint(w, `{ + _, _ = fmt.Fprint(w, `{ "access_token": "secret1", "refresh_token": "secret2", "scope": "openid", @@ -209,3 +210,12 @@ func TestConfig_RegistrationConfig(t *testing.T) { t.Error(diff) } } + +func TestIsTimeoutErr(t *testing.T) { + err := &atlas.ErrorResponse{ + ErrorCode: "DEVICE_AUTHORIZATION_EXPIRED", + } + if !IsTimeoutErr(err) { + t.Error("expected to be a timeout error") + } +}