diff --git a/core/auth/oauth/oidc.go b/core/auth/oauth/oidc.go index b03375cc..32fdb9ed 100644 --- a/core/auth/oauth/oidc.go +++ b/core/auth/oauth/oidc.go @@ -88,6 +88,12 @@ func init() { var ( _ OpenIDIdentity = new(oidcIdentity) + _ auth.RequestIdentifier = new(openIDIdentifier) + _ auth.WebAuthenticater = new(openIDIdentifier) + _ auth.WebCallbacker = new(openIDIdentifier) + _ auth.WebIdentityRefresher = new(openIDIdentifier) + _ auth.WebLogoutWithRedirect = new(openIDIdentifier) + // OpenIDTypeChecker checks the Identity for OpenID Identity OpenIDTypeChecker = func(identity auth.Identity) bool { _, ok := identity.(OpenIDIdentity) @@ -186,15 +192,9 @@ func (i *openIDIdentifier) Inject( func (i *openIDIdentifier) Identify(ctx context.Context, request *web.Request) (auth.Identity, error) { sessionCode := i.sessionCode("sessiondata") - data, ok := request.Session().Load(sessionCode) - if !ok { - return nil, errors.New("no sessiondata") - } - - sessiondata, ok := data.(sessionData) - if !ok { - request.Session().Delete(sessionCode) - return nil, errors.New("broken sessiondata") + sessiondata, err := i.sessionData(request, sessionCode) + if err != nil { + return nil, err } verifierConfig := &oidc.Config{ClientID: i.oauth2Config.ClientID} @@ -229,6 +229,37 @@ func (i *openIDIdentifier) Identify(ctx context.Context, request *web.Request) ( return identity, nil } +func (i *openIDIdentifier) sessionData(request *web.Request, sessionCode string) (sessionData, error) { + data, ok := request.Session().Load(sessionCode) + if !ok { + return sessionData{}, errors.New("no sessiondata") + } + + sessiondata, ok := data.(sessionData) + if !ok { + request.Session().Delete(sessionCode) + return sessionData{}, errors.New("broken sessiondata") + } + + return sessiondata, nil +} + +// RefreshIdentity by invalidating the access token from the token stored in the session data +// which will cause a refresh request the next time an identity is requested +func (i *openIDIdentifier) RefreshIdentity(_ context.Context, request *web.Request) error { + sessionCode := i.sessionCode("sessiondata") + + sessiondata, err := i.sessionData(request, sessionCode) + if err != nil { + return err + } + + sessiondata.Token.AccessToken = "" + request.Session().Store(sessionCode, sessiondata) + + return nil +} + func (i *oidcIdentity) tokens(ctx context.Context) (*oauth2.Token, *oidc.IDToken, error) { token, err := i.token.tokenSource.Token() if err != nil { diff --git a/core/auth/oauth/oidc_test.go b/core/auth/oauth/oidc_test.go index 427ef4d0..0d42e72f 100644 --- a/core/auth/oauth/oidc_test.go +++ b/core/auth/oauth/oidc_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "flamingo.me/flamingo/v3/core/auth" "flamingo.me/flamingo/v3/framework/flamingo" "flamingo.me/flamingo/v3/framework/web" "github.com/coreos/go-oidc/v3/oidc" @@ -170,3 +171,33 @@ func TestOidcCallback(t *testing.T) { assert.Equal(t, "at-claim-1-value", testClaims.Claim1) assert.Equal(t, "legacy-token-response-claim-value", testClaims.LegacyClaim) } + +func Test_openIDIdentifier_RefreshIdentity(t *testing.T) { + var identifier auth.RequestIdentifier = &openIDIdentifier{broker: "broker"} + session := web.EmptySession() + session.Store("core.auth.oidc.broker.sessiondata", sessionData{ + Token: &oauth2.Token{ + AccessToken: "access-token", + RefreshToken: "refresh-token", + Expiry: time.Now().Add(time.Minute * 5), + }, + }) + ctx := web.ContextWithSession(context.Background(), session) + + req := web.CreateRequest(nil, session) + ctx = web.ContextWithRequest(ctx, req) + + refresher, ok := identifier.(auth.WebIdentityRefresher) + assert.True(t, ok) + + err := refresher.RefreshIdentity(ctx, req) + assert.NoError(t, err) + + data, found := session.Load("core.auth.oidc.broker.sessiondata") + assert.True(t, found) + + sessiondata, ok := data.(sessionData) + assert.True(t, ok) + assert.Empty(t, sessiondata.Token.AccessToken) + assert.Equal(t, "refresh-token", sessiondata.Token.RefreshToken) +} diff --git a/core/auth/webidentity.go b/core/auth/webidentity.go index 0f1db4b7..3c927a25 100644 --- a/core/auth/webidentity.go +++ b/core/auth/webidentity.go @@ -41,6 +41,11 @@ type ( Logout(ctx context.Context, request *web.Request) *url.URL } + // WebIdentityRefresher refreshs an existing identity, e.g. by invalidating cached session data + WebIdentityRefresher interface { + RefreshIdentity(ctx context.Context, request *web.Request) error + } + // WebIdentityService calls one or more identifier to get all possible identities of a user WebIdentityService struct { identityProviders []RequestIdentifier diff --git a/core/auth/webidentity_test.go b/core/auth/webidentity_test.go index 4939a145..8d89156f 100644 --- a/core/auth/webidentity_test.go +++ b/core/auth/webidentity_test.go @@ -19,6 +19,10 @@ func (*testIdentifier) Identify(context.Context, *web.Request) (Identity, error) return &testIdentity{}, nil } +func (*testIdentifier) RefreshIdentity(ctx context.Context, request *web.Request) error { + return nil +} + type testIdentity struct{} func (*testIdentity) Subject() string { @@ -68,4 +72,16 @@ func Test_WebIdentityServiceIdentifyAs(t *testing.T) { t.Log(err) assert.Nil(t, identity) }) + + t.Run("refresh identity", func(t *testing.T) { + identity := s.Identify(context.Background(), nil) + assert.NotNil(t, identity) + + identifier := s.RequestIdentifier(identity.Broker()) + assert.NotNil(t, identifier) + + refresher, ok := identifier.(WebIdentityRefresher) + assert.True(t, ok) + assert.NoError(t, refresher.RefreshIdentity(context.Background(), nil)) + }) }