Skip to content

Commit

Permalink
Package JWT: allow clients to optionally enforce matching of aud, iss…
Browse files Browse the repository at this point in the history
… and sub claims (#419)

* allow clients to optional enforce matching of aud iss and sub claims

* cleaned up comments to better explain nbf claim

* moved ireturn linter to config file

* updated README for MustMatch methods

* updated README and added specific MustMatch tests
  • Loading branch information
techmanmike committed May 16, 2024
1 parent 193a498 commit 5fc8487
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 39 deletions.
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ issues:
- lll
- tagliatelle

- path: jwt/decoder.go
linters:
- ireturn

- path: launchdarkly/*
linters:
- err113
Expand Down
27 changes: 17 additions & 10 deletions jwt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ If you are managing Encoders and Decoders yourself, then you can provide a func
- type EncoderKeyRetriever func() (string, string) // return your private PEM key + key_id
- type DecoderJwksRetriever func() string // return your JSON JWKs


## Managing Encoders and Decoders Yourself

While we recommend using the package level methods for their ease of use, you may desire to create and manage encoders or decoers yourself, which you can do by calling:
Expand Down Expand Up @@ -54,7 +53,6 @@ decoder, err := NewJwtDecoder(jwksRetriever)

## Claims


You MUST set the `Issuer`, `Subject`, and `Audience` claims along with the standard authentication claim `AccountId`. If the JWT is for authenticaton other than to the Public API, it MUST also include the `RealUserId`, and `EffectiveUserId` claims.

- [Issuer](https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.1) `iss` claim.
Expand Down Expand Up @@ -104,6 +102,15 @@ type Claims interface {
}
```

### Enforcing Audience, Subject and Issuer

When Decoding you should enforce that the jwt matches the expect Audience (`aud` claim), Subject (`sub` claim) and Issuer (`iss` claim).
To do this you can pass `MustMatch<Type>` to the `Decode` and `DecodeWithCustomClaims` methods:

- func MustMatchAudience(aud string)
- func MustMatchIssuer(iss string)
- func MustMatchSubject(sub string)

## Examples
```
package cago
Expand All @@ -128,8 +135,8 @@ func BasicExamples() {
token, err := jwt.Encode(claims)
fmt.Printf("The encoded token is '%s' (err='%v')\n", token, err)
// Decode it back again using the key that matches the kid header using the default JWKS JSON keys
sc, err := jwt.Decode(token)
// Decode it back again using the key that matches the kid header using the default JWKS JSON keys and matching on aud, sub and iss.
sc, err := jwt.Decode(token, MustMatchAudience("who-i-am"), MustMatchIssuer("webgateway"), MustMatchSubject("user-auth"))
fmt.Printf("The decode token is '%v' (err='%+v')\n", sc, err)
}
```
Expand All @@ -141,8 +148,8 @@ the `Encoder` or `Decoder` interface.

- Encode(claims *StandardClaims) (string, error)
- EncodeWithCustomClaims(customClaims jwt.Claims) (string, error)
- Decode(tokenString string) (*StandardClaims, error)
- DecodeWithCustomClaims(tokenString string, customClaims jwt.Claims) error
- Decode(tokenString string, options ...DecoderParserOption) (*StandardClaims, error)
- DecodeWithCustomClaims(tokenString string, customClaims jwt.Claims, options ...DecoderParserOption) error

```
import (
Expand Down Expand Up @@ -214,14 +221,14 @@ func (m *mockedEncoderDecoder) Encode(claims *jwt.StandardClaims) (string, error
}
// Decrypt on the test runner just returns the "encryptedStr" as the decrypted plainstr.
func (m *mockedEncoderDecoder) Decode(tokenString string) (*jwt.StandardClaims, error) {
args := m.Called(tokenString)
func (m *mockedEncoderDecoder) Decode(tokenString string, options ...DecoderParserOption) (*jwt.StandardClaims, error) {
args := m.Called(tokenString, options)
output, _ := args.Get(0).(*jwt.StandardClaims)
return output, args.Error(1)
}
func (m *mockedEncoderDecoder) DecodeWithCustomClaims(tokenString string, customClaims gojwt.Claims) error {
args := m.Called(tokenString, customClaims)
func (m *mockedEncoderDecoder) DecodeWithCustomClaims(tokenString string, customClaims gojwt.Claims, options ...DecoderParserOption) error {
args := m.Called(tokenString, customClaims, options)
return args.Error(0)
}
```
62 changes: 45 additions & 17 deletions jwt/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,40 +60,27 @@ func NewJwtDecoder(fetchJWKS DecoderJwksRetriever, options ...JwtDecoderOption)
}

// Decode a jwt token string and return the Standard Culture Amp Claims.
func (d *JwtDecoder) Decode(tokenString string) (*StandardClaims, error) {
func (d *JwtDecoder) Decode(tokenString string, options ...DecoderParserOption) (*StandardClaims, error) {
claims := jwt.MapClaims{}
err := d.DecodeWithCustomClaims(tokenString, claims)
err := d.DecodeWithCustomClaims(tokenString, claims, options...)
if err != nil {
return nil, err
}
return newStandardClaims(claims), nil
}

// DecodeWithCustomClaims takes a jwt token string and populate the customClaims.
func (d *JwtDecoder) DecodeWithCustomClaims(tokenString string, customClaims jwt.Claims) error {
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
validAlgs := []string{"RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}

func (d *JwtDecoder) DecodeWithCustomClaims(tokenString string, customClaims jwt.Claims, options ...DecoderParserOption) error {
// sample token string in the form "header.payload.signature"
// eg. "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJuYmYiOjE0NDQ0Nzg0MDB9.u1riaD1rW97opCoAuRCTy4w58Br-Zk-bh7vLiRIsrpU"

// Eng Std: https://cultureamp.atlassian.net/wiki/spaces/TV/pages/3253240053/JWT+Authentication

// Exp
// Expiry claim is currently MANDATORY, but until all producing services are reliably setting the Expiry claim,
// we MAY still accept verified JWTs with no Expiry claim.
// Nbf
// NotBefore claim is currently MANDATORY, but until all producing services are reliably settings the NotBEfore claim,
// we MAY still accept verificed JWT's with no NotBefore claim.
token, err := jwt.ParseWithClaims(
tokenString,
customClaims,
func(token *jwt.Token) (interface{}, error) {
return d.useCorrectPublicKey(token)
},
jwt.WithValidMethods(validAlgs), // only keys with these "alg's" will be considered
jwt.WithLeeway(defaultDecoderLeeway), // as per the JWT eng std: clock skew set to 10 seconds
// jwt.WithExpirationRequired(), // add this if we want to enforce that tokens MUST have an expiry
d.enforceParsingOptions(options...)...,
)
if err != nil || !token.Valid {
return err
Expand All @@ -102,6 +89,47 @@ func (d *JwtDecoder) DecodeWithCustomClaims(tokenString string, customClaims jwt
return nil
}

func (d *JwtDecoder) enforceParsingOptions(options ...DecoderParserOption) []jwt.ParserOption {
// Eng Std: https://cultureamp.atlassian.net/wiki/spaces/TV/pages/3253240053/JWT+Authentication
var opts []jwt.ParserOption

// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
validAlgs := []string{"RS256", "RS384", "RS512", "ES256", "ES384", "ES512"}
opts = append(opts,
jwt.WithValidMethods(validAlgs), // only keys with these "alg's" will be considered
jwt.WithLeeway(defaultDecoderLeeway), // as per the JWT eng std: clock skew set to 10 seconds

// Exp
// Expiry claim is currently MANDATORY, but until all producing services are reliably setting the Expiry claim,
// we MAY still accept verified JWTs with no Expiry claim.
// jwt.WithExpirationRequired(),

// Nbf
// If the NotBefore claim is set it will automatically be enforced.
// Note: There is no parsing option for this.
)

// Loop through any client provided parsing options and apply them
parserOptions := newDecoderParser()
for _, option := range options {
option(parserOptions)
}

if parserOptions.expectedAud != "" {
opts = append(opts, jwt.WithAudience(parserOptions.expectedAud))
}

if parserOptions.expectedIss != "" {
opts = append(opts, jwt.WithIssuer(parserOptions.expectedIss))
}

if parserOptions.expectedSub != "" {
opts = append(opts, jwt.WithSubject(parserOptions.expectedSub))
}

return opts
}

func (d *JwtDecoder) useCorrectPublicKey(token *jwt.Token) (publicKey, error) {
if token == nil {
return nil, errors.Errorf("failed to decode: missing token")
Expand Down
42 changes: 41 additions & 1 deletion jwt/decoder_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"time"
)

// JwtDecoderOption function signature for added JWT Decoder options.
// JwtDecoderOption function signature for adding JWT Decoder options.
type JwtDecoderOption func(*JwtDecoder)

// WithDecoderJwksExpiry sets the JwtDecoder JWKs expiry time.Duration
Expand All @@ -21,3 +21,43 @@ func WithDecoderRotateWindow(rotate time.Duration) JwtDecoderOption {
decoder.rotationWindow = rotate
}
}

// DecoderParserOption function signature for adding JWT Decoder Parsing options.
type DecoderParserOption func(*decoderParser)

type decoderParser struct {
expectedAud string
expectedIss string
expectedSub string
}

func newDecoderParser() *decoderParser {
return &decoderParser{}
}

// MustMatchAudience configures the jwt parser to require the specified audience in
// the `aud` claim. Validation will fail if the audience is not listed in the
// token or the `aud` claim is missing.
func MustMatchAudience(aud string) DecoderParserOption {
return func(p *decoderParser) {
p.expectedAud = aud
}
}

// MustMatchIssuer configures the jwt parser to require the specified issuer in the
// `iss` claim. Validation will fail if a different issuer is specified in the
// token or the `iss` claim is missing.
func MustMatchIssuer(iss string) DecoderParserOption {
return func(p *decoderParser) {
p.expectedIss = iss
}
}

// MustMatchSubject configures the jwt parser to require the specified subject in the
// `sub` claim. Validation will fail if a different subject is specified in the
// token or the `sub` claim is missing.
func MustMatchSubject(sub string) DecoderParserOption {
return func(p *decoderParser) {
p.expectedSub = sub
}
}
51 changes: 51 additions & 0 deletions jwt/decoder_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,54 @@ func TestDecoderOptions(t *testing.T) {

assert.Greater(t, i, 1)
}

func TestDecoderParsingOptions(t *testing.T) {
token := "eyJhbGciOiJSUzUxMiIsImtpZCI6InJzYS01MTIiLCJ0eXAiOiJKV1QifQ.eyJhY2NvdW50SWQiOiJhYmMxMjMiLCJlZmZlY3RpdmVVc2VySWQiOiJ4eXozNDUiLCJyZWFsVXNlcklkIjoieHl6MjM0IiwiaXNzIjoiZW5jb2Rlci1uYW1lIiwic3ViIjoidGVzdCIsImF1ZCI6WyJkZWNvZGVyLW5hbWUiXSwiZXhwIjoyMjExNzk3NTMyLCJuYmYiOjE1ODA2MDg5MjIsImlhdCI6MTU4MDYwODkyMn0.hQLuOe8qZHUgstYe0A0n4-Pww7ovlReyKiDR1Y02ltUnUlgbm9qpp-Ef6YNFuIKdHmS-ynQbDx5pbI36szsggzi80apNpI48cwSXshx82TwuU-_Z4wNBXu7MdPvbA5FdjhxCvRqaqhglsGJ6NofC1bP9awVyyy4j9LGfkVuVEXJQrVpdvEs8Ks-LxlWz7_9Cr7BrZcLuBJnujhe4CbdSudkrfeFl19EY3i1wH9OatGjfjwOSJVqv-ZLnn3QkaZmrQ1xwXTm3MlMUH3KSQjBn8h6vbqosIB5iHDFtqR11mLCgYExGHBpzFjM1d5NEmcTNLV9MtZ_qDZwG0wkgv9O4rXVQ0JfdXypMwhchED2Z45_mc2OiLidtKtDmeoE5g0Daq8YpM0ZpVRbXUFeYIZ1doQKUNsbWNdITmrjVOC3Zn8BecYPu1pC4Hk1y-ViArDzxlCMHA7Bua64BfzVuaJ8pBTEmbqMiZ9VujWcimCOtJ5yfCks_RPAhFYOErcqy3B56fmyYdIN__mKl7VvRDtBSiiPGCq07BUjGywaMoZIULbyXYSV4zs3hX_R4_o4asGiVWCZgn7k4pZzCJo_y2e-Mf85nYoRlyr1MXx7IM4srFQCgO-KTjDWL_TXqpMJU5zDzKyelrMFkc6EaMQ2KP_yBhOrh4UW-Pm7ghusox_-bV1U"

b, err := os.ReadFile(filepath.Clean(testAuthJwks))
assert.Nil(t, err)
validJwks := string(b)

jwks := func() string {
return validJwks
}

decoder, err := NewJwtDecoder(jwks)
assert.Nil(t, err)
assert.NotNil(t, decoder)

// Decode with no MustMatch parsing options
claim, err := decoder.Decode(token)
assert.Nil(t, err)
assert.NotNil(t, claim)

// Mis-Matched aud claim returns error on Decode
claim, err = decoder.Decode(token, MustMatchAudience("abc"))
assert.NotNil(t, err)
assert.ErrorContains(t, err, "token has invalid audience")

// Matched aud claim returns success
claim, err = decoder.Decode(token, MustMatchAudience("decoder-name"))
assert.Nil(t, err)
assert.NotNil(t, claim)

// Mis-Matched iss claim returns error on Decode
claim, err = decoder.Decode(token, MustMatchIssuer("abc"))
assert.NotNil(t, err)
assert.ErrorContains(t, err, "token has invalid issuer")

// Matched iss claim returns success
claim, err = decoder.Decode(token, MustMatchIssuer("encoder-name"))
assert.Nil(t, err)
assert.NotNil(t, claim)

// Mis-Matched sub claim returns error on Decode
claim, err = decoder.Decode(token, MustMatchSubject("abc"))
assert.NotNil(t, err)
assert.ErrorContains(t, err, "token has invalid subject")

// Matched sub claim returns success
claim, err = decoder.Decode(token, MustMatchSubject("test"))
assert.Nil(t, err)
assert.NotNil(t, claim)
}
10 changes: 5 additions & 5 deletions jwt/jwt_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func ExampleDefaultJwtDecoder() {
}

mockEncDec := newMockedEncoderDecoder()
mockEncDec.On("Decode", mock.Anything).Return(claims, nil)
mockEncDec.On("Decode", mock.Anything, mock.Anything).Return(claims, nil)

// Overwrite the Default package level encoder and decoder
oldDecoder := jwt.DefaultJwtDecoder
Expand Down Expand Up @@ -156,13 +156,13 @@ func (m *mockedEncoderDecoder) EncodeWithCustomClaims(customClaims gojwt.Claims)
return output, args.Error(1)
}

func (m *mockedEncoderDecoder) Decode(tokenString string) (*jwt.StandardClaims, error) {
args := m.Called(tokenString)
func (m *mockedEncoderDecoder) Decode(tokenString string, options ...jwt.DecoderParserOption) (*jwt.StandardClaims, error) {
args := m.Called(tokenString, options)
output, _ := args.Get(0).(*jwt.StandardClaims)
return output, args.Error(1)
}

func (m *mockedEncoderDecoder) DecodeWithCustomClaims(tokenString string, customClaims gojwt.Claims) error {
args := m.Called(tokenString, customClaims)
func (m *mockedEncoderDecoder) DecodeWithCustomClaims(tokenString string, customClaims gojwt.Claims, options ...jwt.DecoderParserOption) error {
args := m.Called(tokenString, customClaims, options)
return args.Error(0)
}
12 changes: 6 additions & 6 deletions jwt/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ type Encoder interface {

// Decoder interface allows for mocking of the Decoder.
type Decoder interface {
Decode(tokenString string) (*StandardClaims, error)
DecodeWithCustomClaims(tokenString string, customClaims jwt.Claims) error
Decode(tokenString string, options ...DecoderParserOption) (*StandardClaims, error)
DecodeWithCustomClaims(tokenString string, customClaims jwt.Claims, options ...DecoderParserOption) error
}

var (
Expand All @@ -29,21 +29,21 @@ var (
)

// Decode a jwt token string and return the Standard Culture Amp Claims.
func Decode(tokenString string) (*StandardClaims, error) {
func Decode(tokenString string, options ...DecoderParserOption) (*StandardClaims, error) {
err := mustHaveDefaultJwtDecoder()
if err != nil {
return nil, err
}
return DefaultJwtDecoder.Decode(tokenString)
return DefaultJwtDecoder.Decode(tokenString, options...)
}

// DecodeWithCustomClaims takes a jwt token string and populate the customClaims.
func DecodeWithCustomClaims(tokenString string, customClaims jwt.Claims) error {
func DecodeWithCustomClaims(tokenString string, customClaims jwt.Claims, options ...DecoderParserOption) error {
err := mustHaveDefaultJwtDecoder()
if err != nil {
return err
}
return DefaultJwtDecoder.DecodeWithCustomClaims(tokenString, customClaims)
return DefaultJwtDecoder.DecodeWithCustomClaims(tokenString, customClaims, options...)
}

// Encode the Standard Culture Amp Claims in a jwt token string.
Expand Down
24 changes: 24 additions & 0 deletions jwt/package_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@ func TestPackageEncodeDecode(t *testing.T) {
assert.Equal(t, "abc123", sc.AccountId)
assert.Equal(t, "xyz234", sc.RealUserId)
assert.Equal(t, "xyz345", sc.EffectiveUserId)

// Decode it back again, checking aud, iss and sub all match
sc, err = Decode(token, MustMatchAudience("decoder-name"), MustMatchIssuer("encoder-name"), MustMatchSubject("test"))
assert.Nil(t, err)

// check it matches
assert.Equal(t, "abc123", sc.AccountId)
assert.Equal(t, "xyz234", sc.RealUserId)
assert.Equal(t, "xyz345", sc.EffectiveUserId)

// Decode it back again, checking aud should fail
sc, err = Decode(token, MustMatchAudience("incorrect-aud"))
assert.NotNil(t, err)
assert.ErrorContains(t, err, "token has invalid audience")

// Decode it back again, checking iss should fail
sc, err = Decode(token, MustMatchIssuer("incorrect-iss"))
assert.NotNil(t, err)
assert.ErrorContains(t, err, "token has invalid issuer")

// Decode it back again, checking sub should fail
sc, err = Decode(token, MustMatchSubject("incorrect-sub"))
assert.NotNil(t, err)
assert.ErrorContains(t, err, "token has invalid subject")
}

func TestPackageEncodeDecodeNotBeforeExpiryChecks(t *testing.T) {
Expand Down

0 comments on commit 5fc8487

Please sign in to comment.