From 3bdd22a342c6dc759c50b73a6a893ef6407ec743 Mon Sep 17 00:00:00 2001 From: Ian Ferguson Date: Thu, 22 Apr 2021 17:40:15 -0400 Subject: [PATCH 01/14] pre-publish new signing keys for `rotation_period` of time before using --- vault/identity_store_oidc.go | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 9a521994dbed9..9e0b92da81d0b 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -50,6 +50,7 @@ type namedKey struct { RotationPeriod time.Duration `json:"rotation_period"` KeyRing []*expireableKey `json:"key_ring"` SigningKey *jose.JSONWebKey `json:"signing_key"` + NextSigningKey *jose.JSONWebKey `json:"next_signing_key"` NextRotation time.Time `json:"next_rotation"` AllowedClientIDs []string `json:"allowed_client_ids"` } @@ -515,7 +516,7 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica key.NextRotation = nextRotation } - // generate keys if creating a new key or changing algorithms + // generate current and next keys if creating a new key or changing algorithms if key.Algorithm != prevAlgorithm { signingKey, err := generateKeys(key.Algorithm) if err != nil { @@ -528,6 +529,19 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil { return nil, err } + + nextSigningKey, err := generateKeys(key.Algorithm) + if err != nil { + return nil, err + } + + key.NextSigningKey = nextSigningKey + key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: nextSigningKey.Public().KeyID}) + + if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil { + return nil, err + } + } if err := i.oidcCache.Flush(ns); err != nil { @@ -1313,11 +1327,11 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi } // generate new key - signingKey, err := generateKeys(k.Algorithm) + nextSigningKey, err := generateKeys(k.Algorithm) if err != nil { return err } - if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil { + if err := saveOIDCPublicKey(ctx, s, nextSigningKey.Public()); err != nil { return err } @@ -1330,8 +1344,9 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi break } } - k.SigningKey = signingKey - k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.KeyID}) + k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: nextSigningKey.KeyID}) + k.SigningKey = k.NextSigningKey + k.NextSigningKey = nextSigningKey k.NextRotation = now.Add(k.RotationPeriod) // store named key (it was modified when rotate was called on it) From 36c6b2b50afd8cf740277e1f007b020d1fb2de99 Mon Sep 17 00:00:00 2001 From: Ian Ferguson Date: Thu, 22 Apr 2021 20:00:53 -0400 Subject: [PATCH 02/14] Work In Progress: Prepublish JWKS and even cache control --- vault/identity_store_oidc.go | 132 ++++++++++++++++++++++++++--------- 1 file changed, 99 insertions(+), 33 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 9e0b92da81d0b..7e153ccc2c955 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -10,6 +10,8 @@ import ( "encoding/json" "errors" "fmt" + "math" + "math/big" "net/url" "sort" "strings" @@ -510,8 +512,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica return logical.ErrorResponse("unknown signing algorithm %q", key.Algorithm), nil } + now := time.Now() + // Update next rotation time if it is unset or now earlier than previously set. - nextRotation := time.Now().Add(key.RotationPeriod) + nextRotation := now.Add(key.RotationPeriod) if key.NextRotation.IsZero() || nextRotation.Before(key.NextRotation) { key.NextRotation = nextRotation } @@ -523,8 +527,12 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica return nil, err } + i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID) key.SigningKey = signingKey - key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID}) + key.KeyRing = append(key.KeyRing, &expireableKey{ + KeyID: signingKey.Public().KeyID, + ExpireAt: now.Add(key.RotationPeriod).Add(key.VerificationTTL), + }) if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil { return nil, err @@ -535,8 +543,12 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica return nil, err } + i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID) key.NextSigningKey = nextSigningKey - key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: nextSigningKey.Public().KeyID}) + key.KeyRing = append(key.KeyRing, &expireableKey{ + KeyID: nextSigningKey.Public().KeyID, + ExpireAt: now.Add(key.RotationPeriod).Add(key.RotationPeriod).Add(key.VerificationTTL), + }) if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil { return nil, err @@ -544,6 +556,10 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica } + // TKTK storing the cache control max value in the i.oidcCache, + // and populating it only in the periodicFunc background worker means + // up to 60 seconds of `Cache-control: none` at Vault startup and + // anytime a key configuration isn't changed, which is a problem if err := i.oidcCache.Flush(ns); err != nil { return nil, err } @@ -725,7 +741,7 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ verificationTTLOverride = time.Duration(ttlRaw.(int)) * time.Second } - if err := storedNamedKey.rotate(ctx, req.Storage, verificationTTLOverride); err != nil { + if err := storedNamedKey.rotate(ctx, i.Logger(), req.Storage, verificationTTLOverride); err != nil { return nil, err } @@ -1207,18 +1223,37 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical return nil, err } if len(keys) > 0 { - v, ok, err := i.oidcCache.Get(noNamespace, "nextRun") + // if maxJwksClientCache is set use that, otherwise fall back on the more conservative + // nextRun values + maxJwksClientCache, ok, err := i.oidcCache.Get(noNamespace, "maxJwksClientCache") if err != nil { return nil, err } if ok { - now := time.Now() - expireAt := v.(time.Time) - if expireAt.After(now) { - expireInSeconds := expireAt.Sub(time.Now()).Seconds() - expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds) - resp.Data[logical.HTTPRawCacheControl] = expireInString + maxDuration := int64(maxJwksClientCache.(time.Duration)) + randDuration, err := rand.Int(rand.Reader, big.NewInt(maxDuration)) + if err != nil { + return nil, err + } + // truncate to seconds + durationInSeconds := time.Duration(randDuration.Int64()).Seconds() + durationInString := fmt.Sprintf("max-age=%.0f", durationInSeconds) + resp.Data[logical.HTTPRawCacheControl] = durationInString + } else { + v, ok, err := i.oidcCache.Get(noNamespace, "nextRun") + if err != nil { + return nil, err + } + + if ok { + now := time.Now() + expireAt := v.(time.Time) + if expireAt.After(now) { + expireInSeconds := expireAt.Sub(time.Now()).Seconds() + expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds) + resp.Data[logical.HTTPRawCacheControl] = expireInString + } } } } @@ -1316,10 +1351,9 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req return introspectionResp("") } -// namedKey.rotate(overrides) performs a key rotation on a namedKey and returns the -// verification_ttl that was applied. verification_ttl can be overridden with an -// overrideVerificationTTL value >= 0 -func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerificationTTL time.Duration) error { +// namedKey.rotate(overrides) performs a key rotation on a namedKey. +// verification_ttl can be overridden with an overrideVerificationTTL value >= 0 +func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error { verificationTTL := k.VerificationTTL if overrideVerificationTTL >= 0 { @@ -1334,17 +1368,16 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi if err := saveOIDCPublicKey(ctx, s, nextSigningKey.Public()); err != nil { return err } + logger.Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID) now := time.Now() - - // set the previous public key's expiry time - for _, key := range k.KeyRing { - if key.KeyID == k.SigningKey.KeyID { - key.ExpireAt = now.Add(verificationTTL) - break - } - } - k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: nextSigningKey.KeyID}) + k.KeyRing = append(k.KeyRing, &expireableKey{ + KeyID: nextSigningKey.KeyID, + // this token won't start being used for 1 rotation period, + // will be used for 1 rotation period after that, + // and should be deleted verificationTTL after it is no longer used + ExpireAt: now.Add(k.RotationPeriod).Add(k.RotationPeriod).Add(verificationTTL), + }) k.SigningKey = k.NextSigningKey k.NextSigningKey = nextSigningKey k.NextRotation = now.Add(k.RotationPeriod) @@ -1358,6 +1391,7 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi return err } + logger.Debug("rotated OIDC public key. now using", "key_id", k.SigningKey.Public().KeyID) return nil } @@ -1590,24 +1624,35 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor return nextExpiration, nil } -func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, error) { +func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (nextRotation time.Time, maxJwksClientCacheDuration time.Duration, err error) { // soonestRotation will be the soonest rotation time of all keys. Initialize // here to a relatively distant time. now := time.Now() soonestRotation := now.Add(24 * time.Hour) + // the OIDC JWKS endpoint returns a Cache-Control HTTP header time + // between 0 and the minimum verificationTTL or minimum rotationPeriod out + // of all keys, whichever value is lower. + // + // This smooths calls from services validating JWTs to Vault, while + // ensuring that operators can assert that servers honoring the Cache-Control + // header will always have a superset of all valid keys, and not trust + // any keys longer than a jwksCacheControlMax duration after a key is rotated + // out of signing use + jwksClientCacheDuration := time.Duration(math.MaxInt64) + i.oidcLock.Lock() defer i.oidcLock.Unlock() keys, err := s.List(ctx, namedKeyConfigPath) if err != nil { - return now, err + return now, jwksClientCacheDuration, err } for _, k := range keys { entry, err := s.Get(ctx, namedKeyConfigPath+k) if err != nil { - return now, err + return now, jwksClientCacheDuration, err } if entry == nil { @@ -1616,10 +1661,18 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) var key namedKey if err := entry.DecodeJSON(&key); err != nil { - return now, err + return now, jwksClientCacheDuration, err } key.name = k + if key.VerificationTTL < jwksClientCacheDuration { + jwksClientCacheDuration = key.VerificationTTL + } + + if key.RotationPeriod < jwksClientCacheDuration { + jwksClientCacheDuration = key.RotationPeriod + } + // Future key rotation that is the earliest we've seen. if now.Before(key.NextRotation) && key.NextRotation.Before(soonestRotation) { soonestRotation = key.NextRotation @@ -1628,8 +1681,8 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) // Key that is due to be rotated. if now.After(key.NextRotation) { i.Logger().Debug("rotating OIDC key", "key", key.name) - if err := key.rotate(ctx, s, -1); err != nil { - return now, err + if err := key.rotate(ctx, i.Logger(), s, -1); err != nil { + return now, jwksClientCacheDuration, err } // Possibly save the new rotation time @@ -1639,11 +1692,14 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) } } - return soonestRotation, nil + return soonestRotation, jwksClientCacheDuration, nil } // oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key // rotations and expiration actions. +// TKTK consider offsetting nextRun to be 60 seconds early and then +// TKTK sleepingthe remaining upto 60 seconds until the actual rotation time +// TKTK but only if this function is called in a non blocking way func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { var nextRun time.Time now := time.Now() @@ -1666,7 +1722,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { // Initialize to a fairly distant next run time. This will be brought in // based on key rotation times. nextRun = now.Add(24 * time.Hour) - + minJwksClientCacheDuration := time.Duration(math.MaxInt64) for _, ns := range i.listNamespaces() { nsPath := ns.Path @@ -1676,7 +1732,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { continue } - nextRotation, err := i.oidcKeyRotation(ctx, s) + nextRotation, jwksClientCacheDuration, err := i.oidcKeyRotation(ctx, s) if err != nil { i.Logger().Warn("error rotating OIDC keys", "err", err) } @@ -1698,10 +1754,20 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { if nextExpiration.Before(nextRun) { nextRun = nextExpiration } + + if jwksClientCacheDuration < minJwksClientCacheDuration { + minJwksClientCacheDuration = jwksClientCacheDuration + } } if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil { i.Logger().Error("error setting oidc cache", "err", err) } + if minJwksClientCacheDuration < math.MaxInt64 { + if err := i.oidcCache.SetDefault(noNamespace, "maxJwksClientCache", minJwksClientCacheDuration); err != nil { + i.Logger().Error("error setting maxJwksClientCache in oidc cache", "err", err) + } + } + } } From 333046b6bc04b40ecfbcff12b92c00bb917cc77b Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Mon, 23 Aug 2021 10:10:23 -0500 Subject: [PATCH 03/14] remove comments --- vault/identity_store_oidc.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 7e153ccc2c955..b9eda9dc6b256 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -556,10 +556,6 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica } - // TKTK storing the cache control max value in the i.oidcCache, - // and populating it only in the periodicFunc background worker means - // up to 60 seconds of `Cache-control: none` at Vault startup and - // anytime a key configuration isn't changed, which is a problem if err := i.oidcCache.Flush(ns); err != nil { return nil, err } @@ -1697,9 +1693,6 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) // oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key // rotations and expiration actions. -// TKTK consider offsetting nextRun to be 60 seconds early and then -// TKTK sleepingthe remaining upto 60 seconds until the actual rotation time -// TKTK but only if this function is called in a non blocking way func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { var nextRun time.Time now := time.Now() From 1f0a386f38178427d12eda1d0d9668be22ceb001 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Mon, 23 Aug 2021 14:49:03 -0500 Subject: [PATCH 04/14] use math/rand instead of math/big --- vault/identity_store_oidc.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index b9eda9dc6b256..6d79b9c5553ac 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -11,7 +11,7 @@ import ( "errors" "fmt" "math" - "math/big" + mathrand "math/rand" "net/url" "sort" "strings" @@ -1228,12 +1228,9 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical if ok { maxDuration := int64(maxJwksClientCache.(time.Duration)) - randDuration, err := rand.Int(rand.Reader, big.NewInt(maxDuration)) - if err != nil { - return nil, err - } + randDuration := mathrand.Int63n(maxDuration) // truncate to seconds - durationInSeconds := time.Duration(randDuration.Int64()).Seconds() + durationInSeconds := time.Duration(randDuration).Seconds() durationInString := fmt.Sprintf("max-age=%.0f", durationInSeconds) resp.Data[logical.HTTPRawCacheControl] = durationInString } else { From 0b2549d34f1607537814381f825d4a7381f7155e Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Mon, 23 Aug 2021 15:50:21 -0500 Subject: [PATCH 05/14] update tests --- vault/identity_store_oidc.go | 1 + vault/identity_store_oidc_test.go | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 6d79b9c5553ac..5cc6f70989825 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1384,6 +1384,7 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St return err } + logger.Debug("jmf here") logger.Debug("rotated OIDC public key. now using", "key_id", k.SigningKey.Public().KeyID) return nil } diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 09cbd2cdf51e5..cec68b4fdc905 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -469,7 +469,7 @@ func TestOIDC_PublicKeys(t *testing.T) { Storage: storage, }) - // .well-known/keys should contain 1 public key + // .well-known/keys should contain 2 public keys resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/.well-known/keys", Operation: logical.ReadOperation, @@ -479,8 +479,8 @@ func TestOIDC_PublicKeys(t *testing.T) { // parse response responseJWKS := &jose.JSONWebKeySet{} json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 1 { - t.Fatalf("expected 1 public key but instead got %d", len(responseJWKS.Keys)) + if len(responseJWKS.Keys) != 2 { + t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys)) } // rotate test-key a few times, each rotate should increase the length of public keys returned @@ -498,7 +498,7 @@ func TestOIDC_PublicKeys(t *testing.T) { }) expectSuccess(t, resp, err) - // .well-known/keys should contain 3 public keys + // .well-known/keys should contain 4 public keys resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/.well-known/keys", Operation: logical.ReadOperation, @@ -507,8 +507,8 @@ func TestOIDC_PublicKeys(t *testing.T) { expectSuccess(t, resp, err) // parse response json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 3 { - t.Fatalf("expected 3 public keys but instead got %d", len(responseJWKS.Keys)) + if len(responseJWKS.Keys) != 4 { + t.Fatalf("expected 4 public keys but instead got %d", len(responseJWKS.Keys)) } // create another named key @@ -525,7 +525,7 @@ func TestOIDC_PublicKeys(t *testing.T) { Storage: storage, }) - // .well-known/keys should contain 1 public key, all of the public keys + // .well-known/keys should contain 2 public key, all of the public keys // from named key "test-key" should have been deleted resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/.well-known/keys", @@ -535,8 +535,8 @@ func TestOIDC_PublicKeys(t *testing.T) { expectSuccess(t, resp, err) // parse response json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 1 { - t.Fatalf("expected 1 public keys but instead got %d", len(responseJWKS.Keys)) + if len(responseJWKS.Keys) != 2 { + t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys)) } } @@ -699,6 +699,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) { RotationPeriod: 1 * cyclePeriod, KeyRing: nil, SigningKey: jwk, + NextSigningKey: jwk, NextRotation: time.Now(), }, []struct { @@ -708,8 +709,11 @@ func TestOIDC_PeriodicFunc(t *testing.T) { }{ {1, 1, 1}, {2, 2, 2}, - {3, 2, 2}, + {3, 3, 3}, {4, 2, 2}, + {5, 3, 3}, + {6, 2, 2}, + {7, 3, 3}, }, }, } From 45077ce8f1f67f71c165146aa210cb940ce3e299 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 24 Aug 2021 09:27:59 -0500 Subject: [PATCH 06/14] remove debug comment --- vault/identity_store_oidc.go | 1 - 1 file changed, 1 deletion(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 5cc6f70989825..6d79b9c5553ac 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1384,7 +1384,6 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St return err } - logger.Debug("jmf here") logger.Debug("rotated OIDC public key. now using", "key_id", k.SigningKey.Public().KeyID) return nil } From 8a24df4a8e346c98952c9b0be14d4af78c132fbe Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 31 Aug 2021 14:16:57 -0500 Subject: [PATCH 07/14] refactor cache control logic into func --- vault/identity_store_oidc.go | 66 ++++++++++++++++++------------- vault/identity_store_oidc_test.go | 58 +++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 28 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 25a37850fa656..e14fdc50b3856 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -527,7 +527,6 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica return nil, err } - i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID) key.SigningKey = signingKey key.KeyRing = append(key.KeyRing, &expireableKey{ KeyID: signingKey.Public().KeyID, @@ -537,13 +536,13 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil { return nil, err } + i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID) nextSigningKey, err := generateKeys(key.Algorithm) if err != nil { return nil, err } - i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID) key.NextSigningKey = nextSigningKey key.KeyRing = append(key.KeyRing, &expireableKey{ KeyID: nextSigningKey.Public().KeyID, @@ -553,7 +552,7 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil { return nil, err } - + i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID) } if err := i.oidcCache.Flush(ns); err != nil { @@ -1185,6 +1184,39 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ return resp, nil } +// getKeysCacheControlHeader returns the cache control header for all public +// keys at the .well-known/keys endpoint +func (i *IdentityStore) getKeysCacheControlHeader() (string, error) { + // if maxJwksClientCache is set use that, otherwise fall back on the + // more conservative nextRun values + maxJwksClientCache, ok, err := i.oidcCache.Get(noNamespace, "maxJwksClientCache") + if err != nil { + return "", err + } + + if ok { + maxDuration := int64(maxJwksClientCache.(time.Duration)) + randDuration := mathrand.Int63n(maxDuration) + durationInSeconds := time.Duration(randDuration).Seconds() + return fmt.Sprintf("max-age=%.0f", durationInSeconds), nil + } + + nextRun, ok, err := i.oidcCache.Get(noNamespace, "nextRun") + if err != nil { + return "", err + } + + if ok { + now := time.Now() + expireAt := nextRun.(time.Time) + if expireAt.After(now) { + expireInSeconds := expireAt.Sub(time.Now()).Seconds() + return fmt.Sprintf("max-age=%.0f", expireInSeconds), nil + } + } + return "", nil +} + // pathOIDCReadPublicKeys is used to retrieve all public keys so that clients can // verify the validity of a signed OIDC token. func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { @@ -1234,35 +1266,13 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical return nil, err } if len(keys) > 0 { - // if maxJwksClientCache is set use that, otherwise fall back on the more conservative - // nextRun values - maxJwksClientCache, ok, err := i.oidcCache.Get(noNamespace, "maxJwksClientCache") + header, err := i.getKeysCacheControlHeader() if err != nil { return nil, err } - if ok { - maxDuration := int64(maxJwksClientCache.(time.Duration)) - randDuration := mathrand.Int63n(maxDuration) - // truncate to seconds - durationInSeconds := time.Duration(randDuration).Seconds() - durationInString := fmt.Sprintf("max-age=%.0f", durationInSeconds) - resp.Data[logical.HTTPRawCacheControl] = durationInString - } else { - v, ok, err := i.oidcCache.Get(noNamespace, "nextRun") - if err != nil { - return nil, err - } - - if ok { - now := time.Now() - expireAt := v.(time.Time) - if expireAt.After(now) { - expireInSeconds := expireAt.Sub(time.Now()).Seconds() - expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds) - resp.Data[logical.HTTPRawCacheControl] = expireInString - } - } + if header != "" { + resp.Data[logical.HTTPRawCacheControl] = header } } diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 8eb7b7a3d7f4a..f6467aa4fec8c 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -4,6 +4,8 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" + "strconv" + "strings" "testing" "time" @@ -1249,6 +1251,62 @@ func TestOIDC_CacheNamespaceNilCheck(t *testing.T) { } } +func TestOIDC_GetKeysCacheControlHeader(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + // get default value + header, err := c.identityStore.getKeysCacheControlHeader() + if err != nil { + t.Fatalf("expected success, got error:\n%v", err) + } + + expectedHeader := "" + if header != expectedHeader { + t.Fatalf("expected %s, got %s", expectedHeader, header) + } + + // set nextRun + nextRun := time.Now().Add(24 * time.Hour) + if err = c.identityStore.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil { + t.Fatal(err) + } + + header, err = c.identityStore.getKeysCacheControlHeader() + if err != nil { + t.Fatalf("expected success, got error:\n%v", err) + } + + expectedNextRun := "max-age=86400" + if header != expectedNextRun { + t.Fatalf("expected %s, got %s", expectedNextRun, header) + } + + // set maxJwksClientCache + maxJwksClientCache := time.Duration(60) * time.Second + if err = c.identityStore.oidcCache.SetDefault(noNamespace, "maxJwksClientCache", maxJwksClientCache); err != nil { + t.Fatal(err) + } + + header, err = c.identityStore.getKeysCacheControlHeader() + if err != nil { + t.Fatalf("expected success, got error:\n%v", err) + } + + if header == "" { + t.Fatalf("expected header to be set, got %s", header) + } + + maxAgeValue := strings.Split(header, "=")[1] + headerVal, err := strconv.Atoi(maxAgeValue) + if err != nil { + t.Fatal(err) + } + // headerVal will be a random value between 0 and maxJwksClientCache + if headerVal > int(maxJwksClientCache) { + t.Fatalf("unexpected header value, got %d", headerVal) + } +} + // some helpers func expectSuccess(t *testing.T, resp *logical.Response, err error) { t.Helper() From d12cb045968df31bf56f56fdfaf4857b4c13dcaa Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Wed, 1 Sep 2021 10:50:49 -0500 Subject: [PATCH 08/14] don't set expiry when create/update key --- vault/identity_store_oidc.go | 54 ++++++++++++++----------------- vault/identity_store_oidc_test.go | 16 ++++++--- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index e14fdc50b3856..539a73222fdd6 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -528,10 +528,7 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica } key.SigningKey = signingKey - key.KeyRing = append(key.KeyRing, &expireableKey{ - KeyID: signingKey.Public().KeyID, - ExpireAt: now.Add(key.RotationPeriod).Add(key.VerificationTTL), - }) + key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID}) if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil { return nil, err @@ -544,10 +541,7 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica } key.NextSigningKey = nextSigningKey - key.KeyRing = append(key.KeyRing, &expireableKey{ - KeyID: nextSigningKey.Public().KeyID, - ExpireAt: now.Add(key.RotationPeriod).Add(key.RotationPeriod).Add(key.VerificationTTL), - }) + key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: nextSigningKey.Public().KeyID}) if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil { return nil, err @@ -1187,15 +1181,15 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ // getKeysCacheControlHeader returns the cache control header for all public // keys at the .well-known/keys endpoint func (i *IdentityStore) getKeysCacheControlHeader() (string, error) { - // if maxJwksClientCache is set use that, otherwise fall back on the + // if jwksCacheControlMaxAge is set use that, otherwise fall back on the // more conservative nextRun values - maxJwksClientCache, ok, err := i.oidcCache.Get(noNamespace, "maxJwksClientCache") + jwksCacheControlMaxAge, ok, err := i.oidcCache.Get(noNamespace, "jwksCacheControlMaxAge") if err != nil { return "", err } if ok { - maxDuration := int64(maxJwksClientCache.(time.Duration)) + maxDuration := int64(jwksCacheControlMaxAge.(time.Duration)) randDuration := mathrand.Int63n(maxDuration) durationInSeconds := time.Duration(randDuration).Seconds() return fmt.Sprintf("max-age=%.0f", durationInSeconds), nil @@ -1389,13 +1383,15 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St logger.Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID) now := time.Now() - k.KeyRing = append(k.KeyRing, &expireableKey{ - KeyID: nextSigningKey.KeyID, - // this token won't start being used for 1 rotation period, - // will be used for 1 rotation period after that, - // and should be deleted verificationTTL after it is no longer used - ExpireAt: now.Add(k.RotationPeriod).Add(k.RotationPeriod).Add(verificationTTL), - }) + // set the previous public key's expiry time + for _, key := range k.KeyRing { + if key.KeyID == k.SigningKey.KeyID { + key.ExpireAt = now.Add(verificationTTL) + break + } + } + + k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: nextSigningKey.KeyID}) k.SigningKey = k.NextSigningKey k.NextSigningKey = nextSigningKey k.NextRotation = now.Add(k.RotationPeriod) @@ -1409,7 +1405,7 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St return err } - logger.Debug("rotated OIDC public key. now using", "key_id", k.SigningKey.Public().KeyID) + logger.Debug("rotated OIDC public key, now using", "key_id", k.SigningKey.Public().KeyID) return nil } @@ -1642,21 +1638,21 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor return nextExpiration, nil } -func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (nextRotation time.Time, maxJwksClientCacheDuration time.Duration, err error) { +func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, time.Duration, error) { // soonestRotation will be the soonest rotation time of all keys. Initialize // here to a relatively distant time. now := time.Now() soonestRotation := now.Add(24 * time.Hour) - // the OIDC JWKS endpoint returns a Cache-Control HTTP header time - // between 0 and the minimum verificationTTL or minimum rotationPeriod out - // of all keys, whichever value is lower. + // the OIDC JWKS endpoint returns a Cache-Control HTTP header time between + // 0 and the minimum verificationTTL or minimum rotationPeriod out of all + // keys, whichever value is lower. // // This smooths calls from services validating JWTs to Vault, while - // ensuring that operators can assert that servers honoring the Cache-Control - // header will always have a superset of all valid keys, and not trust - // any keys longer than a jwksCacheControlMax duration after a key is rotated - // out of signing use + // ensuring that operators can assert that servers honoring the + // Cache-Control header will always have a superset of all valid keys, and + // not trust any keys longer than a jwksCacheControlMaxAge duration after a + // key is rotated out of signing use jwksClientCacheDuration := time.Duration(math.MaxInt64) i.oidcLock.Lock() @@ -1779,8 +1775,8 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { i.Logger().Error("error setting oidc cache", "err", err) } if minJwksClientCacheDuration < math.MaxInt64 { - if err := i.oidcCache.SetDefault(noNamespace, "maxJwksClientCache", minJwksClientCacheDuration); err != nil { - i.Logger().Error("error setting maxJwksClientCache in oidc cache", "err", err) + if err := i.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", minJwksClientCacheDuration); err != nil { + i.Logger().Error("error setting jwksCacheControlMaxAge in oidc cache", "err", err) } } diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index f6467aa4fec8c..ee466e859904c 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -693,10 +693,18 @@ func TestOIDC_SignIDToken(t *testing.T) { responseJWKS := &jose.JSONWebKeySet{} json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - // Validate the signature - claims := &jwt.Claims{} - if err := parsedToken.Claims(responseJWKS.Keys[0], claims); err != nil { - t.Fatalf("unable to validate signed token, err:\n%#v", err) + keyCount := len(responseJWKS.Keys) + errorCount := 0 + for _, key := range responseJWKS.Keys { + // Validate the signature + claims := &jwt.Claims{} + if err := parsedToken.Claims(key, claims); err != nil { + t.Logf("unable to validate signed token, err:\n%#v", err) + errorCount += 1 + } + } + if errorCount == keyCount { + t.Fatalf("unable to validate signed token with any of the .well-known keys") } } From 7fa6beb514010da62cce017003c2cd2f317360f3 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Wed, 1 Sep 2021 14:32:09 -0500 Subject: [PATCH 09/14] update cachecontrol name in oidccache for test --- vault/identity_store_oidc_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index ee466e859904c..3b5e66e8c6d72 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -1289,9 +1289,9 @@ func TestOIDC_GetKeysCacheControlHeader(t *testing.T) { t.Fatalf("expected %s, got %s", expectedNextRun, header) } - // set maxJwksClientCache - maxJwksClientCache := time.Duration(60) * time.Second - if err = c.identityStore.oidcCache.SetDefault(noNamespace, "maxJwksClientCache", maxJwksClientCache); err != nil { + // set jwksCacheControlMaxAge + jwksCacheControlMaxAge := time.Duration(60) * time.Second + if err = c.identityStore.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", jwksCacheControlMaxAge); err != nil { t.Fatal(err) } @@ -1309,8 +1309,8 @@ func TestOIDC_GetKeysCacheControlHeader(t *testing.T) { if err != nil { t.Fatal(err) } - // headerVal will be a random value between 0 and maxJwksClientCache - if headerVal > int(maxJwksClientCache) { + // headerVal will be a random value between 0 and jwksCacheControlMaxAge + if headerVal > int(jwksCacheControlMaxAge) { t.Fatalf("unexpected header value, got %d", headerVal) } } From 7e3d250564ca0263f2673ced65adcfc191336a15 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Wed, 1 Sep 2021 16:13:44 -0500 Subject: [PATCH 10/14] fix bug in periodicfunc test case --- vault/identity_store_oidc_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 3b5e66e8c6d72..e3f70384c8ce5 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -754,9 +754,9 @@ func TestOIDC_PeriodicFunc(t *testing.T) { {1, 1, 1}, {2, 2, 2}, {3, 3, 3}, - {4, 2, 2}, + {4, 3, 3}, {5, 3, 3}, - {6, 2, 2}, + {6, 3, 3}, {7, 3, 3}, }, }, From c9fef0dc9118491caab87bde4eadfc1b04e048bd Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Thu, 2 Sep 2021 09:37:21 -0500 Subject: [PATCH 11/14] add changelog --- changelog/12414.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changelog/12414.txt diff --git a/changelog/12414.txt b/changelog/12414.txt new file mode 100644 index 0000000000000..31b3070b7f349 --- /dev/null +++ b/changelog/12414.txt @@ -0,0 +1,3 @@ +```release-note:bug +identity: fix issue where Cache-Control header causes stampede of requests for JWKS keys +``` From 9cadd12fb0e0e07f408555e22f4887b22cda4eee Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Fri, 3 Sep 2021 10:40:58 -0500 Subject: [PATCH 12/14] remove confusing comment --- vault/identity_store_oidc.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 539a73222fdd6..49b5488e7e602 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1252,9 +1252,7 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical }, } - // set a Cache-Control header only if there are keys, if there aren't keys - // then nextRun should not be used to set Cache-Control header because it chooses - // a time in the future that isn't based on key rotation/expiration values + // set a Cache-Control header only if there are keys keys, err := listOIDCPublicKeys(ctx, req.Storage) if err != nil { return nil, err From a39d7dfaf8a61a95fcaa734fff0a85262de72836 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Tue, 7 Sep 2021 11:15:30 -0500 Subject: [PATCH 13/14] add logging and comments --- vault/identity_store_oidc.go | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 49b5488e7e602..ea917dfbd929f 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1204,6 +1204,7 @@ func (i *IdentityStore) getKeysCacheControlHeader() (string, error) { now := time.Now() expireAt := nextRun.(time.Time) if expireAt.After(now) { + i.Logger().Debug("use nextRun value for Cache Control header", "nextRun", nextRun) expireInSeconds := expireAt.Sub(time.Now()).Seconds() return fmt.Sprintf("max-age=%.0f", expireInSeconds), nil } @@ -1636,21 +1637,16 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor return nextExpiration, nil } +// oidcKeyRotation will rotate any keys that are due to be rotated. +// +// It will return the time of the soonest rotation and the minimum +// verificationTTL or minimum rotationPeriod out of all the current keys. func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, time.Duration, error) { // soonestRotation will be the soonest rotation time of all keys. Initialize // here to a relatively distant time. now := time.Now() soonestRotation := now.Add(24 * time.Hour) - // the OIDC JWKS endpoint returns a Cache-Control HTTP header time between - // 0 and the minimum verificationTTL or minimum rotationPeriod out of all - // keys, whichever value is lower. - // - // This smooths calls from services validating JWTs to Vault, while - // ensuring that operators can assert that servers honoring the - // Cache-Control header will always have a superset of all valid keys, and - // not trust any keys longer than a jwksCacheControlMaxAge duration after a - // key is rotated out of signing use jwksClientCacheDuration := time.Duration(math.MaxInt64) i.oidcLock.Lock() @@ -1710,6 +1706,7 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) // oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key // rotations and expiration actions. func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { + i.Logger().Debug("begin oidcPeriodicFunc") var nextRun time.Time now := time.Now() @@ -1769,10 +1766,21 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { minJwksClientCacheDuration = jwksClientCacheDuration } } + if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil { i.Logger().Error("error setting oidc cache", "err", err) } + if minJwksClientCacheDuration < math.MaxInt64 { + // the OIDC JWKS endpoint returns a Cache-Control HTTP header time between + // 0 and the minimum verificationTTL or minimum rotationPeriod out of all + // keys, whichever value is lower. + // + // This smooths calls from services validating JWTs to Vault, while + // ensuring that operators can assert that servers honoring the + // Cache-Control header will always have a superset of all valid keys, and + // not trust any keys longer than a jwksCacheControlMaxAge duration after a + // key is rotated out of signing use if err := i.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", minJwksClientCacheDuration); err != nil { i.Logger().Error("error setting jwksCacheControlMaxAge in oidc cache", "err", err) } From f86b83ce6ad09f03fc95f1c811294966301e6d93 Mon Sep 17 00:00:00 2001 From: JM Faircloth Date: Thu, 9 Sep 2021 08:54:37 -0500 Subject: [PATCH 14/14] update change log from bug to improvement --- changelog/12414.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog/12414.txt b/changelog/12414.txt index 31b3070b7f349..5f3cfdd7324e9 100644 --- a/changelog/12414.txt +++ b/changelog/12414.txt @@ -1,3 +1,3 @@ -```release-note:bug +```release-note:improvement identity: fix issue where Cache-Control header causes stampede of requests for JWKS keys ```