Skip to content

Commit

Permalink
feat(auth): add session refresh interface (#269)
Browse files Browse the repository at this point in the history
Co-authored-by: Christopher Schwab <christopher.schwab@aoe.com>
Co-authored-by: Carsten Dietrich <carsten.dietrich@omnevo.net>
  • Loading branch information
3 people committed Jul 4, 2022
1 parent fb27119 commit 3eba347
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 9 deletions.
49 changes: 40 additions & 9 deletions core/auth/oauth/oidc.go
Expand Up @@ -88,6 +88,12 @@ func init() {
var (
_ OpenIDIdentity = new(oidcIdentity)

_ auth.RequestIdentifier = new(openIDIdentifier)
_ auth.WebAuthenticater = new(openIDIdentifier)
_ auth.WebCallbacker = new(openIDIdentifier)
_ auth.WebIdentityRefresher = new(openIDIdentifier)
_ auth.WebLogoutWithRedirect = new(openIDIdentifier)

// OpenIDTypeChecker checks the Identity for OpenID Identity
OpenIDTypeChecker = func(identity auth.Identity) bool {
_, ok := identity.(OpenIDIdentity)
Expand Down Expand Up @@ -186,15 +192,9 @@ func (i *openIDIdentifier) Inject(
func (i *openIDIdentifier) Identify(ctx context.Context, request *web.Request) (auth.Identity, error) {
sessionCode := i.sessionCode("sessiondata")

data, ok := request.Session().Load(sessionCode)
if !ok {
return nil, errors.New("no sessiondata")
}

sessiondata, ok := data.(sessionData)
if !ok {
request.Session().Delete(sessionCode)
return nil, errors.New("broken sessiondata")
sessiondata, err := i.sessionData(request, sessionCode)
if err != nil {
return nil, err
}

verifierConfig := &oidc.Config{ClientID: i.oauth2Config.ClientID}
Expand Down Expand Up @@ -229,6 +229,37 @@ func (i *openIDIdentifier) Identify(ctx context.Context, request *web.Request) (
return identity, nil
}

func (i *openIDIdentifier) sessionData(request *web.Request, sessionCode string) (sessionData, error) {
data, ok := request.Session().Load(sessionCode)
if !ok {
return sessionData{}, errors.New("no sessiondata")
}

sessiondata, ok := data.(sessionData)
if !ok {
request.Session().Delete(sessionCode)
return sessionData{}, errors.New("broken sessiondata")
}

return sessiondata, nil
}

// RefreshIdentity by invalidating the access token from the token stored in the session data
// which will cause a refresh request the next time an identity is requested
func (i *openIDIdentifier) RefreshIdentity(_ context.Context, request *web.Request) error {
sessionCode := i.sessionCode("sessiondata")

sessiondata, err := i.sessionData(request, sessionCode)
if err != nil {
return err
}

sessiondata.Token.AccessToken = ""
request.Session().Store(sessionCode, sessiondata)

return nil
}

func (i *oidcIdentity) tokens(ctx context.Context) (*oauth2.Token, *oidc.IDToken, error) {
token, err := i.token.tokenSource.Token()
if err != nil {
Expand Down
31 changes: 31 additions & 0 deletions core/auth/oauth/oidc_test.go
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"time"

"flamingo.me/flamingo/v3/core/auth"
"flamingo.me/flamingo/v3/framework/flamingo"
"flamingo.me/flamingo/v3/framework/web"
"github.com/coreos/go-oidc/v3/oidc"
Expand Down Expand Up @@ -170,3 +171,33 @@ func TestOidcCallback(t *testing.T) {
assert.Equal(t, "at-claim-1-value", testClaims.Claim1)
assert.Equal(t, "legacy-token-response-claim-value", testClaims.LegacyClaim)
}

func Test_openIDIdentifier_RefreshIdentity(t *testing.T) {
var identifier auth.RequestIdentifier = &openIDIdentifier{broker: "broker"}
session := web.EmptySession()
session.Store("core.auth.oidc.broker.sessiondata", sessionData{
Token: &oauth2.Token{
AccessToken: "access-token",
RefreshToken: "refresh-token",
Expiry: time.Now().Add(time.Minute * 5),
},
})
ctx := web.ContextWithSession(context.Background(), session)

req := web.CreateRequest(nil, session)
ctx = web.ContextWithRequest(ctx, req)

refresher, ok := identifier.(auth.WebIdentityRefresher)
assert.True(t, ok)

err := refresher.RefreshIdentity(ctx, req)
assert.NoError(t, err)

data, found := session.Load("core.auth.oidc.broker.sessiondata")
assert.True(t, found)

sessiondata, ok := data.(sessionData)
assert.True(t, ok)
assert.Empty(t, sessiondata.Token.AccessToken)
assert.Equal(t, "refresh-token", sessiondata.Token.RefreshToken)
}
5 changes: 5 additions & 0 deletions core/auth/webidentity.go
Expand Up @@ -41,6 +41,11 @@ type (
Logout(ctx context.Context, request *web.Request) *url.URL
}

// WebIdentityRefresher refreshs an existing identity, e.g. by invalidating cached session data
WebIdentityRefresher interface {
RefreshIdentity(ctx context.Context, request *web.Request) error
}

// WebIdentityService calls one or more identifier to get all possible identities of a user
WebIdentityService struct {
identityProviders []RequestIdentifier
Expand Down
16 changes: 16 additions & 0 deletions core/auth/webidentity_test.go
Expand Up @@ -19,6 +19,10 @@ func (*testIdentifier) Identify(context.Context, *web.Request) (Identity, error)
return &testIdentity{}, nil
}

func (*testIdentifier) RefreshIdentity(ctx context.Context, request *web.Request) error {
return nil
}

type testIdentity struct{}

func (*testIdentity) Subject() string {
Expand Down Expand Up @@ -68,4 +72,16 @@ func Test_WebIdentityServiceIdentifyAs(t *testing.T) {
t.Log(err)
assert.Nil(t, identity)
})

t.Run("refresh identity", func(t *testing.T) {
identity := s.Identify(context.Background(), nil)
assert.NotNil(t, identity)

identifier := s.RequestIdentifier(identity.Broker())
assert.NotNil(t, identifier)

refresher, ok := identifier.(WebIdentityRefresher)
assert.True(t, ok)
assert.NoError(t, refresher.RefreshIdentity(context.Background(), nil))
})
}

0 comments on commit 3eba347

Please sign in to comment.