Skip to content

Commit

Permalink
Package JWK: if missing kid in jwks, then refresh JWKS (#411)
Browse files Browse the repository at this point in the history
* if missing kid in jwks, then refresh JWKS

* updated comments and removed unused const

* made const lower case / private

* moved webgatewaykid const to test file

* simplified decoder to split out jwks lifecycle management to seperate class

* chore(deps): update module gopkg.in/datadog/dd-trace-go.v1 to v1.63.1 (#413) (#414)

Co-authored-by: cultureamp-renovate[bot] <89962466+cultureamp-renovate[bot]@users.noreply.github.com>
Co-authored-by: Self-hosted Renovate Bot <135776+cultureamp-renovate[bot]@users.noreply.github.com>

* Package LOG: Updates from Go meeting (#412)

* removed error on Debug() Info() etc.

* removed unused func GetEnvBool

* now can create logger with start up properties and use extensions

* removed FromContext wip

* removed nolint errcheck on ld client

* added Child to Logger interface

* example using Child to inherit parent values

* added context helpers - unit tests still to come

* started tests for context helpers

* cleaned up tests and examples and made default fields appears inside properties

* put default/global properties in a 'default_properties' sub-doc so they don't overwrite normal properties

* updated README with logging examples

* Upgraded Go to 1.22.3 to close CVE-2024-24787 and CVE-2024-24788 (#416)

* updated to go 1.22.3 to close CVE-2024-24787 and CVE-2024-24788

* small update to the README to force snyk to pass the build

* Trigger Build

* minor updates to linters (#415)

* minor updates to linters

* enabled all new linters by default and exclude only those we don't (yet) support

* added internal revive linter fixes - no breaking changes

* turned on ireturn linter

* reemoved ex exclude rule in favour of //nolint

* fixed some magic number lint errors

* fixed conflict and golint ireturns

* fixed all magic number lint warnings

* made const lower case / private

* simplified decoder to split out jwks lifecycle management to seperate class

* jwks can rotate by default in 30 seconds

* minor updates to linters (#415)

* minor updates to linters

* enabled all new linters by default and exclude only those we don't (yet) support

* added internal revive linter fixes - no breaking changes

* turned on ireturn linter

* reemoved ex exclude rule in favour of //nolint

* fixed some magic number lint errors

* fixed conflict and golint ireturns

* fixed all magic number lint warnings

* fixed linter warnings

* added jwks tests and cleaned up class

* try and fix ireturn linter warnings

* fixed refresh logic

* fixed test and coverage pipeline steps

* ignore example and gen files from coverage stats

* ignore kafaktest dir in coverage stats

* fixed kafka typo

---------

Co-authored-by: cultureamp-renovate[bot] <89962466+cultureamp-renovate[bot]@users.noreply.github.com>
Co-authored-by: Self-hosted Renovate Bot <135776+cultureamp-renovate[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed May 16, 2024
1 parent 5c9c9ee commit 193a498
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 95 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,16 @@ jobs:
go install github.com/boumenot/gocover-cobertura@v1.2.0
go install github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@v2.5.0
- name: Run all tests
- name: Run all tests with 'race'
run: |
go test ./...
go test -race ./...
- name: Run coverage tests
- name: Run test coverage
run: |
go test -race -json -v -coverprofile=coverage.json -covermode atomic ./... 2>&1 | tee gotest.log | gotestfmt
go test -json -v -coverprofile=coverage.json -covermode atomic ./... 2>&1 | tee gotest.log | gotestfmt
- name: Convert go coverage to corbetura format
run: gocover-cobertura -ignore-files test\*.go < coverage.json > coverage.xml
run: gocover-cobertura -ignore-dirs '(example|kafkatest)' -ignore-files 'test\*.go' -ignore-gen-files < coverage.json > coverage.xml

- name: Generate code coverage report
uses: irongut/CodeCoverageSummary@v1.3.0
Expand Down
8 changes: 8 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ issues:
- lll
- tagliatelle

- path: jwt/decoder.go
linters:
- ireturn

- path: jwt/decoder_jwks.go
linters:
- ireturn

- path: sentry/*
linters:
- lll
51 changes: 51 additions & 0 deletions jwt/decoddr_jwks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package jwt

import (
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestJwkSet(t *testing.T) {
b, err := os.ReadFile(filepath.Clean(testAuthJwks))
require.Nil(t, err)

count := 0
dispatcher := func() string {
count++
return string(b)
}

expiresIn := 200 * time.Millisecond
rotatesIn := 100 * time.Millisecond

// 1. test constructor
jwk := newJWKSet(dispatcher, expiresIn, rotatesIn)
assert.NotNil(t, jwk)

// 2. test get works ok
set, err := jwk.Get()
assert.Nil(t, err)
assert.NotNil(t, set)
assert.Equal(t, 1, count)

// 3. check refresh returns the current set
set, err = jwk.Refresh()
assert.NotNil(t, err)
assert.ErrorContains(t, err, "failed to refresh jwks as just recently updated")
assert.NotNil(t, set)
assert.Equal(t, 1, count)

// Now wait so that the refresh window is reached
time.Sleep(100 * time.Millisecond)

// 4. check refresh returns new set
set, err = jwk.Refresh()
assert.Nil(t, err)
assert.NotNil(t, set)
assert.Equal(t, 2, count)
}
142 changes: 61 additions & 81 deletions jwt/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,22 @@ package jwt
import (
"crypto/ecdsa"
"crypto/rsa"
"sync"
"time"

"github.com/go-errors/errors"
"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/patrickmn/go-cache"
)

const (
kidHeaderKey = "kid"
algorithmHeaderKey = "alg"
webGatewayKid = "web-gateway"
accountIDClaim = "accountId"
realUserIDClaim = "realUserId"
effectiveUserIDClaim = "effectiveUserId"
jwksCacheKey = "decoder_jwks_key"
defaultDecoderExpiration = 60 * time.Minute
defaultDecoderCleanupInterval = 1 * time.Minute
defaultDecoderLeeway = 10 * time.Second
kidHeaderKey = "kid"
algorithmHeaderKey = "alg"
accountIDClaim = "accountId"
realUserIDClaim = "realUserId"
effectiveUserIDClaim = "effectiveUserId"
defaultDecoderExpiration = 60 * time.Minute
defaultDecoderRotationDuration = 30 * time.Second
defaultDecoderLeeway = 10 * time.Second
)

type publicKey interface{} // Only ECDSA (perferred) and RSA public keys allowed
Expand All @@ -32,30 +28,30 @@ type DecoderJwksRetriever func() string

// JwtDecoder can decode a jwt token string.
type JwtDecoder struct {
fetchJwkKeys DecoderJwksRetriever // func provided by clients of this library to supply a refreshed JWKS
mu sync.Mutex // mutex to protect cache.Get/Set race condition
cache *cache.Cache // memory cache holding the jwk.Set
defaultExpiration time.Duration // default is 60 minutes
cleanupInterval time.Duration // default is every 1 minute
dispatcher DecoderJwksRetriever // func provided by clients of this library to supply the current JWKS
expiresWithin time.Duration // default is 60 minutes
rotationWindow time.Duration // default is 30 seconds
jwks *jwkFetcher // manages the life cycle of a JWK Set
}

// NewJwtDecoder creates a new JwtDecoder with the set ECDSA and RSA public keys in the JWK string.
func NewJwtDecoder(fetchJWKS DecoderJwksRetriever, options ...JwtDecoderOption) (*JwtDecoder, error) {
decoder := &JwtDecoder{
fetchJwkKeys: fetchJWKS,
defaultExpiration: defaultDecoderExpiration,
cleanupInterval: defaultDecoderCleanupInterval,
dispatcher: fetchJWKS,
jwks: nil,
expiresWithin: defaultDecoderExpiration,
rotationWindow: defaultDecoderRotationDuration,
}

// Loop through our Decoder options and apply them
for _, option := range options {
option(decoder)
}

decoder.cache = cache.New(decoder.defaultExpiration, decoder.cleanupInterval)
decoder.jwks = newJWKSet(fetchJWKS, decoder.expiresWithin, decoder.rotationWindow)

// call the getJWKS func to make sure its valid and we can parse the JWKS
_, err := decoder.loadJWKSet()
// call the get to make sure its valid and we can parse the JWKS
_, err := decoder.jwks.Get()
if err != nil {
return nil, errors.Errorf("failed to load jwks: %w", err)
}
Expand Down Expand Up @@ -106,7 +102,7 @@ func (d *JwtDecoder) DecodeWithCustomClaims(tokenString string, customClaims jwt
return nil
}

func (d *JwtDecoder) useCorrectPublicKey(token *jwt.Token) (publicKey, error) { //nolint:ireturn
func (d *JwtDecoder) useCorrectPublicKey(token *jwt.Token) (publicKey, error) {
if token == nil {
return nil, errors.Errorf("failed to decode: missing token")
}
Expand All @@ -131,82 +127,66 @@ func (d *JwtDecoder) useCorrectPublicKey(token *jwt.Token) (publicKey, error) {
return nil, errors.Errorf("failed to decode: invalid key_id (kid) header")
}

// check cache and possibly fetch new JWKS
jwkSet, err := d.loadJWKSet()
// check if kid exists in the JWK Set
return d.lookupKeyID(kid)
}

// lookupKeyID returns the public key in the JWKS that matches the "kid".
func (d *JwtDecoder) lookupKeyID(kid string) (publicKey, error) {
// check cache and possibly fetch new JWKS if cache has expired
jwkSet, err := d.jwks.Get()
if err != nil {
return nil, errors.Errorf("failed to load jwks: %w", err)
}

// set if the kid exists in the set
key, found := jwkSet.LookupKeyID(kid)
if found {
// Found a match, so use this key
var rawkey interface{}
err := key.Raw(&rawkey)
if err != nil {
return nil, errors.Errorf("failed to decode: bad public key in jwks")
}

// If the JWKS contains the full key (Private AND Public) then check for that for both ECDSA & RSA
// NOTE: this should never happen in PRPD - but does in the unit tests
if ecdsa, ok := rawkey.(*ecdsa.PrivateKey); ok {
return &ecdsa.PublicKey, nil
}
if rsa, ok := rawkey.(*rsa.PrivateKey); ok {
return &rsa.PublicKey, nil
}

return rawkey, err
// Found a match, so use this key!
return d.getPublicKey(key)
}

// Didn't find a matching kid
return nil, errors.Errorf("failed to decode: no matching key_id (kid) header for: %s", kid)
return d.tryRefreshedLookupKeyID(kid)
}

func (d *JwtDecoder) loadJWKSet() (jwk.Set, error) { //nolint:ireturn
// First check cache, if its there then great, use it!
if jwks, ok := d.getCachedJWKSet(); ok {
return jwks, nil
func (d *JwtDecoder) tryRefreshedLookupKeyID(kid string) (publicKey, error) {
// If the jwks aren't "fresh" and we are being asked for a kid we don't have
// then get a new jwks and try again. This can occur when a new key has been
// added or rotated and we haven't got the latest copy.
// The "canRefresh" check is important here, as for bad kid's we don't want
// blast the client (which in turn might blast Secrets Manager or FushionAuth)
// with a huge number of requests over and over again.
jwkSet, err := d.jwks.Refresh()
if err != nil {
// we didn't refresh, or we did but we failed to parse it
return nil, errors.Errorf("failed to decode: no matching key_id (kid) header for: %s. err: %w", kid, err)
}

// Only allow one thread to fetch, parse and update the cache
d.mu.Lock()
defer d.mu.Unlock()

// check the cache again in case another go routine just updated it
if jwks, ok := d.getCachedJWKSet(); ok {
return jwks, nil
key, found := jwkSet.LookupKeyID(kid)
if found {
// Found a match, so use this key
return d.getPublicKey(key)
}

// Call client retriever func
jwkKeys := d.fetchJwkKeys()
return nil, errors.Errorf("failed to decode: no matching key_id (kid) header for: %s", kid)
}

// Parse all new JWKs JSON keys and make sure its valid
jwkSet, err := d.parseJWKs(jwkKeys)
func (d *JwtDecoder) getPublicKey(key jwk.Key) (publicKey, error) {
var rawkey interface{}
err := key.Raw(&rawkey)
if err != nil {
return nil, err
return nil, errors.Errorf("failed to decode: bad public key in jwks")
}

// Add back into the cache
err = d.cache.Add(jwksCacheKey, jwkSet, cache.DefaultExpiration)
return jwkSet, err
}

func (d *JwtDecoder) getCachedJWKSet() (jwk.Set, bool) { //nolint:ireturn
obj, found := d.cache.Get(jwksCacheKey)
if !found {
return nil, false
// If the JWKS contains the full key (Private AND Public) then only return the public one
// ECDSA & RSA keys only.
// NOTE: this should never happen in production - but does in the unit tests
if ecdsa, ok := rawkey.(*ecdsa.PrivateKey); ok {
return &ecdsa.PublicKey, nil
}

jwks, ok := obj.(jwk.Set)
return jwks, ok
}

func (d *JwtDecoder) parseJWKs(jwks string) (jwk.Set, error) { //nolint:ireturn
if jwks == "" {
// If no jwks json, then returm empty map
return nil, errors.Errorf("missing jwks")
if rsa, ok := rawkey.(*rsa.PrivateKey); ok {
return &rsa.PublicKey, nil
}

// 1. Parse the jwks JSON string to an iterable set
return jwk.ParseString(jwks)
return rawkey, err
}

0 comments on commit 193a498

Please sign in to comment.