From 1f69a4b0d43b381a5dc7c4f9f590ba1c174efc96 Mon Sep 17 00:00:00 2001 From: John-Michael Faircloth Date: Thu, 9 Sep 2021 13:47:42 -0500 Subject: [PATCH] Identity: prepublish jwt signing keys (#12414) * pre-publish new signing keys for `rotation_period` of time before using * Work In Progress: Prepublish JWKS and even cache control * remove comments * use math/rand instead of math/big * update tests * remove debug comment * refactor cache control logic into func * don't set expiry when create/update key * update cachecontrol name in oidccache for test * fix bug in periodicfunc test case * add changelog * remove confusing comment * add logging and comments * update change log from bug to improvement Co-authored-by: Ian Ferguson --- changelog/12414.txt | 3 + vault/identity_store_oidc.go | 148 +++++++++++++++++++++++------- vault/identity_store_oidc_test.go | 100 +++++++++++++++++--- 3 files changed, 204 insertions(+), 47 deletions(-) create mode 100644 changelog/12414.txt diff --git a/changelog/12414.txt b/changelog/12414.txt new file mode 100644 index 0000000000000..5f3cfdd7324e9 --- /dev/null +++ b/changelog/12414.txt @@ -0,0 +1,3 @@ +```release-note:improvement +identity: fix issue where Cache-Control header causes stampede of requests for JWKS keys +``` diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index cb2a73745f6d4..2f3e8b3ad19dc 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -10,6 +10,8 @@ import ( "encoding/json" "errors" "fmt" + "math" + mathrand "math/rand" "net/url" "sort" "strings" @@ -50,6 +52,7 @@ type namedKey struct { RotationPeriod time.Duration `json:"rotation_period"` KeyRing []*expireableKey `json:"key_ring"` SigningKey *jose.JSONWebKey `json:"signing_key"` + NextSigningKey *jose.JSONWebKey `json:"next_signing_key"` NextRotation time.Time `json:"next_rotation"` AllowedClientIDs []string `json:"allowed_client_ids"` } @@ -510,13 +513,15 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica return logical.ErrorResponse("unknown signing algorithm %q", key.Algorithm), nil } + now := time.Now() + // Update next rotation time if it is unset or now earlier than previously set. - nextRotation := time.Now().Add(key.RotationPeriod) + nextRotation := now.Add(key.RotationPeriod) if key.NextRotation.IsZero() || nextRotation.Before(key.NextRotation) { key.NextRotation = nextRotation } - // generate keys if creating a new key or changing algorithms + // generate current and next keys if creating a new key or changing algorithms if key.Algorithm != prevAlgorithm { signingKey, err := generateKeys(key.Algorithm) if err != nil { @@ -529,6 +534,20 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil { return nil, err } + i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID) + + nextSigningKey, err := generateKeys(key.Algorithm) + if err != nil { + return nil, err + } + + key.NextSigningKey = nextSigningKey + key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: nextSigningKey.Public().KeyID}) + + if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil { + return nil, err + } + i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID) } if err := i.oidcCache.Flush(ns); err != nil { @@ -727,7 +746,7 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ verificationTTLOverride = time.Duration(ttlRaw.(int)) * time.Second } - if err := storedNamedKey.rotate(ctx, req.Storage, verificationTTLOverride); err != nil { + if err := storedNamedKey.rotate(ctx, i.Logger(), req.Storage, verificationTTLOverride); err != nil { return nil, err } @@ -1168,6 +1187,40 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ return resp, nil } +// getKeysCacheControlHeader returns the cache control header for all public +// keys at the .well-known/keys endpoint +func (i *IdentityStore) getKeysCacheControlHeader() (string, error) { + // if jwksCacheControlMaxAge is set use that, otherwise fall back on the + // more conservative nextRun values + jwksCacheControlMaxAge, ok, err := i.oidcCache.Get(noNamespace, "jwksCacheControlMaxAge") + if err != nil { + return "", err + } + + if ok { + maxDuration := int64(jwksCacheControlMaxAge.(time.Duration)) + randDuration := mathrand.Int63n(maxDuration) + durationInSeconds := time.Duration(randDuration).Seconds() + return fmt.Sprintf("max-age=%.0f", durationInSeconds), nil + } + + nextRun, ok, err := i.oidcCache.Get(noNamespace, "nextRun") + if err != nil { + return "", err + } + + if ok { + now := time.Now() + expireAt := nextRun.(time.Time) + if expireAt.After(now) { + i.Logger().Debug("use nextRun value for Cache Control header", "nextRun", nextRun) + expireInSeconds := expireAt.Sub(time.Now()).Seconds() + return fmt.Sprintf("max-age=%.0f", expireInSeconds), nil + } + } + return "", nil +} + // pathOIDCReadPublicKeys is used to retrieve all public keys so that clients can // verify the validity of a signed OIDC token. func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { @@ -1209,27 +1262,19 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical }, } - // set a Cache-Control header only if there are keys, if there aren't keys - // then nextRun should not be used to set Cache-Control header because it chooses - // a time in the future that isn't based on key rotation/expiration values + // set a Cache-Control header only if there are keys keys, err := listOIDCPublicKeys(ctx, req.Storage) if err != nil { return nil, err } if len(keys) > 0 { - v, ok, err := i.oidcCache.Get(noNamespace, "nextRun") + header, err := i.getKeysCacheControlHeader() if err != nil { return nil, err } - if ok { - now := time.Now() - expireAt := v.(time.Time) - if expireAt.After(now) { - expireInSeconds := expireAt.Sub(time.Now()).Seconds() - expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds) - resp.Data[logical.HTTPRawCacheControl] = expireInString - } + if header != "" { + resp.Data[logical.HTTPRawCacheControl] = header } } @@ -1326,10 +1371,9 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req return introspectionResp("") } -// namedKey.rotate(overrides) performs a key rotation on a namedKey and returns the -// verification_ttl that was applied. verification_ttl can be overridden with an -// overrideVerificationTTL value >= 0 -func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerificationTTL time.Duration) error { +// namedKey.rotate(overrides) performs a key rotation on a namedKey. +// verification_ttl can be overridden with an overrideVerificationTTL value >= 0 +func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error { verificationTTL := k.VerificationTTL if overrideVerificationTTL >= 0 { @@ -1337,16 +1381,16 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi } // generate new key - signingKey, err := generateKeys(k.Algorithm) + nextSigningKey, err := generateKeys(k.Algorithm) if err != nil { return err } - if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil { + if err := saveOIDCPublicKey(ctx, s, nextSigningKey.Public()); err != nil { return err } + logger.Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID) now := time.Now() - // set the previous public key's expiry time for _, key := range k.KeyRing { if key.KeyID == k.SigningKey.KeyID { @@ -1354,8 +1398,10 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi break } } - k.SigningKey = signingKey - k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.KeyID}) + + k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: nextSigningKey.KeyID}) + k.SigningKey = k.NextSigningKey + k.NextSigningKey = nextSigningKey k.NextRotation = now.Add(k.RotationPeriod) // store named key (it was modified when rotate was called on it) @@ -1367,6 +1413,7 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi return err } + logger.Debug("rotated OIDC public key, now using", "key_id", k.SigningKey.Public().KeyID) return nil } @@ -1599,24 +1646,30 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor return nextExpiration, nil } -func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, error) { +// oidcKeyRotation will rotate any keys that are due to be rotated. +// +// It will return the time of the soonest rotation and the minimum +// verificationTTL or minimum rotationPeriod out of all the current keys. +func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, time.Duration, error) { // soonestRotation will be the soonest rotation time of all keys. Initialize // here to a relatively distant time. now := time.Now() soonestRotation := now.Add(24 * time.Hour) + jwksClientCacheDuration := time.Duration(math.MaxInt64) + i.oidcLock.Lock() defer i.oidcLock.Unlock() keys, err := s.List(ctx, namedKeyConfigPath) if err != nil { - return now, err + return now, jwksClientCacheDuration, err } for _, k := range keys { entry, err := s.Get(ctx, namedKeyConfigPath+k) if err != nil { - return now, err + return now, jwksClientCacheDuration, err } if entry == nil { @@ -1625,10 +1678,18 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) var key namedKey if err := entry.DecodeJSON(&key); err != nil { - return now, err + return now, jwksClientCacheDuration, err } key.name = k + if key.VerificationTTL < jwksClientCacheDuration { + jwksClientCacheDuration = key.VerificationTTL + } + + if key.RotationPeriod < jwksClientCacheDuration { + jwksClientCacheDuration = key.RotationPeriod + } + // Future key rotation that is the earliest we've seen. if now.Before(key.NextRotation) && key.NextRotation.Before(soonestRotation) { soonestRotation = key.NextRotation @@ -1637,8 +1698,8 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) // Key that is due to be rotated. if now.After(key.NextRotation) { i.Logger().Debug("rotating OIDC key", "key", key.name) - if err := key.rotate(ctx, s, -1); err != nil { - return now, err + if err := key.rotate(ctx, i.Logger(), s, -1); err != nil { + return now, jwksClientCacheDuration, err } // Possibly save the new rotation time @@ -1648,12 +1709,13 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) } } - return soonestRotation, nil + return soonestRotation, jwksClientCacheDuration, nil } // oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key // rotations and expiration actions. func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { + i.Logger().Debug("begin oidcPeriodicFunc") var nextRun time.Time now := time.Now() @@ -1675,6 +1737,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { // Initialize to a fairly distant next run time. This will be brought in // based on key rotation times. nextRun = now.Add(24 * time.Hour) + minJwksClientCacheDuration := time.Duration(math.MaxInt64) for _, ns := range i.namespacer.ListNamespaces() { nsPath := ns.Path @@ -1685,7 +1748,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { continue } - nextRotation, err := i.oidcKeyRotation(ctx, s) + nextRotation, jwksClientCacheDuration, err := i.oidcKeyRotation(ctx, s) if err != nil { i.Logger().Warn("error rotating OIDC keys", "err", err) } @@ -1707,10 +1770,31 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { if nextExpiration.Before(nextRun) { nextRun = nextExpiration } + + if jwksClientCacheDuration < minJwksClientCacheDuration { + minJwksClientCacheDuration = jwksClientCacheDuration + } } + if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil { i.Logger().Error("error setting oidc cache", "err", err) } + + if minJwksClientCacheDuration < math.MaxInt64 { + // the OIDC JWKS endpoint returns a Cache-Control HTTP header time between + // 0 and the minimum verificationTTL or minimum rotationPeriod out of all + // keys, whichever value is lower. + // + // This smooths calls from services validating JWTs to Vault, while + // ensuring that operators can assert that servers honoring the + // Cache-Control header will always have a superset of all valid keys, and + // not trust any keys longer than a jwksCacheControlMaxAge duration after a + // key is rotated out of signing use + if err := i.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", minJwksClientCacheDuration); err != nil { + i.Logger().Error("error setting jwksCacheControlMaxAge in oidc cache", "err", err) + } + } + } } diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index f916cd97ecefb..1e4088fff6269 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -4,6 +4,8 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" + "strconv" + "strings" "testing" "time" @@ -626,7 +628,7 @@ func TestOIDC_PublicKeys(t *testing.T) { Storage: storage, }) - // .well-known/keys should contain 1 public key + // .well-known/keys should contain 2 public keys resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/.well-known/keys", Operation: logical.ReadOperation, @@ -636,8 +638,8 @@ func TestOIDC_PublicKeys(t *testing.T) { // parse response responseJWKS := &jose.JSONWebKeySet{} json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 1 { - t.Fatalf("expected 1 public key but instead got %d", len(responseJWKS.Keys)) + if len(responseJWKS.Keys) != 2 { + t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys)) } // rotate test-key a few times, each rotate should increase the length of public keys returned @@ -655,7 +657,7 @@ func TestOIDC_PublicKeys(t *testing.T) { }) expectSuccess(t, resp, err) - // .well-known/keys should contain 3 public keys + // .well-known/keys should contain 4 public keys resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/.well-known/keys", Operation: logical.ReadOperation, @@ -664,8 +666,8 @@ func TestOIDC_PublicKeys(t *testing.T) { expectSuccess(t, resp, err) // parse response json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 3 { - t.Fatalf("expected 3 public keys but instead got %d", len(responseJWKS.Keys)) + if len(responseJWKS.Keys) != 4 { + t.Fatalf("expected 4 public keys but instead got %d", len(responseJWKS.Keys)) } // create another named key @@ -682,7 +684,7 @@ func TestOIDC_PublicKeys(t *testing.T) { Storage: storage, }) - // .well-known/keys should contain 1 public key, all of the public keys + // .well-known/keys should contain 2 public key, all of the public keys // from named key "test-key" should have been deleted resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/.well-known/keys", @@ -692,8 +694,8 @@ func TestOIDC_PublicKeys(t *testing.T) { expectSuccess(t, resp, err) // parse response json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - if len(responseJWKS.Keys) != 1 { - t.Fatalf("expected 1 public keys but instead got %d", len(responseJWKS.Keys)) + if len(responseJWKS.Keys) != 2 { + t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys)) } } @@ -814,10 +816,18 @@ func TestOIDC_SignIDToken(t *testing.T) { responseJWKS := &jose.JSONWebKeySet{} json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) - // Validate the signature - claims := &jwt.Claims{} - if err := parsedToken.Claims(responseJWKS.Keys[0], claims); err != nil { - t.Fatalf("unable to validate signed token, err:\n%#v", err) + keyCount := len(responseJWKS.Keys) + errorCount := 0 + for _, key := range responseJWKS.Keys { + // Validate the signature + claims := &jwt.Claims{} + if err := parsedToken.Claims(key, claims); err != nil { + t.Logf("unable to validate signed token, err:\n%#v", err) + errorCount += 1 + } + } + if errorCount == keyCount { + t.Fatalf("unable to validate signed token with any of the .well-known keys") } } @@ -856,6 +866,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) { RotationPeriod: 1 * cyclePeriod, KeyRing: nil, SigningKey: jwk, + NextSigningKey: jwk, NextRotation: time.Now(), }, []struct { @@ -865,8 +876,11 @@ func TestOIDC_PeriodicFunc(t *testing.T) { }{ {1, 1, 1}, {2, 2, 2}, - {3, 2, 2}, - {4, 2, 2}, + {3, 3, 3}, + {4, 3, 3}, + {5, 3, 3}, + {6, 3, 3}, + {7, 3, 3}, }, }, } @@ -1368,6 +1382,62 @@ func TestOIDC_CacheNamespaceNilCheck(t *testing.T) { } } +func TestOIDC_GetKeysCacheControlHeader(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + + // get default value + header, err := c.identityStore.getKeysCacheControlHeader() + if err != nil { + t.Fatalf("expected success, got error:\n%v", err) + } + + expectedHeader := "" + if header != expectedHeader { + t.Fatalf("expected %s, got %s", expectedHeader, header) + } + + // set nextRun + nextRun := time.Now().Add(24 * time.Hour) + if err = c.identityStore.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil { + t.Fatal(err) + } + + header, err = c.identityStore.getKeysCacheControlHeader() + if err != nil { + t.Fatalf("expected success, got error:\n%v", err) + } + + expectedNextRun := "max-age=86400" + if header != expectedNextRun { + t.Fatalf("expected %s, got %s", expectedNextRun, header) + } + + // set jwksCacheControlMaxAge + jwksCacheControlMaxAge := time.Duration(60) * time.Second + if err = c.identityStore.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", jwksCacheControlMaxAge); err != nil { + t.Fatal(err) + } + + header, err = c.identityStore.getKeysCacheControlHeader() + if err != nil { + t.Fatalf("expected success, got error:\n%v", err) + } + + if header == "" { + t.Fatalf("expected header to be set, got %s", header) + } + + maxAgeValue := strings.Split(header, "=")[1] + headerVal, err := strconv.Atoi(maxAgeValue) + if err != nil { + t.Fatal(err) + } + // headerVal will be a random value between 0 and jwksCacheControlMaxAge + if headerVal > int(jwksCacheControlMaxAge) { + t.Fatalf("unexpected header value, got %d", headerVal) + } +} + // some helpers func expectSuccess(t *testing.T, resp *logical.Response, err error) { t.Helper()