Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Identity: prepublish jwt signing keys #12414

Merged
merged 15 commits into from Sep 9, 2021
Merged
3 changes: 3 additions & 0 deletions changelog/12414.txt
@@ -0,0 +1,3 @@
```release-note:improvement
identity: fix issue where Cache-Control header causes stampede of requests for JWKS keys
```
148 changes: 116 additions & 32 deletions vault/identity_store_oidc.go
Expand Up @@ -10,6 +10,8 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
mathrand "math/rand"
"net/url"
"sort"
"strings"
Expand Down Expand Up @@ -50,6 +52,7 @@ type namedKey struct {
RotationPeriod time.Duration `json:"rotation_period"`
KeyRing []*expireableKey `json:"key_ring"`
SigningKey *jose.JSONWebKey `json:"signing_key"`
NextSigningKey *jose.JSONWebKey `json:"next_signing_key"`
NextRotation time.Time `json:"next_rotation"`
AllowedClientIDs []string `json:"allowed_client_ids"`
}
Expand Down Expand Up @@ -509,13 +512,15 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
return logical.ErrorResponse("unknown signing algorithm %q", key.Algorithm), nil
}

now := time.Now()

// Update next rotation time if it is unset or now earlier than previously set.
nextRotation := time.Now().Add(key.RotationPeriod)
nextRotation := now.Add(key.RotationPeriod)
if key.NextRotation.IsZero() || nextRotation.Before(key.NextRotation) {
key.NextRotation = nextRotation
}

// generate keys if creating a new key or changing algorithms
// generate current and next keys if creating a new key or changing algorithms
if key.Algorithm != prevAlgorithm {
signingKey, err := generateKeys(key.Algorithm)
if err != nil {
Expand All @@ -528,6 +533,20 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)

nextSigningKey, err := generateKeys(key.Algorithm)
if err != nil {
return nil, err
}

key.NextSigningKey = nextSigningKey
key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: nextSigningKey.Public().KeyID})

if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)
}

if err := i.oidcCache.Flush(ns); err != nil {
Expand Down Expand Up @@ -726,7 +745,7 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
verificationTTLOverride = time.Duration(ttlRaw.(int)) * time.Second
}

if err := storedNamedKey.rotate(ctx, req.Storage, verificationTTLOverride); err != nil {
if err := storedNamedKey.rotate(ctx, i.Logger(), req.Storage, verificationTTLOverride); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1159,6 +1178,40 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return resp, nil
}

// getKeysCacheControlHeader returns the cache control header for all public
// keys at the .well-known/keys endpoint
func (i *IdentityStore) getKeysCacheControlHeader() (string, error) {
// if jwksCacheControlMaxAge is set use that, otherwise fall back on the
// more conservative nextRun values
jwksCacheControlMaxAge, ok, err := i.oidcCache.Get(noNamespace, "jwksCacheControlMaxAge")
if err != nil {
return "", err
}

if ok {
maxDuration := int64(jwksCacheControlMaxAge.(time.Duration))
randDuration := mathrand.Int63n(maxDuration)
durationInSeconds := time.Duration(randDuration).Seconds()
return fmt.Sprintf("max-age=%.0f", durationInSeconds), nil
}

nextRun, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
return "", err
}

if ok {
now := time.Now()
expireAt := nextRun.(time.Time)
if expireAt.After(now) {
i.Logger().Debug("use nextRun value for Cache Control header", "nextRun", nextRun)
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
return fmt.Sprintf("max-age=%.0f", expireInSeconds), nil
}
}
return "", nil
}

// pathOIDCReadPublicKeys is used to retrieve all public keys so that clients can
// verify the validity of a signed OIDC token.
func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
Expand Down Expand Up @@ -1200,27 +1253,19 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
},
}

// set a Cache-Control header only if there are keys, if there aren't keys
// then nextRun should not be used to set Cache-Control header because it chooses
// a time in the future that isn't based on key rotation/expiration values
// set a Cache-Control header only if there are keys
keys, err := listOIDCPublicKeys(ctx, req.Storage)
if err != nil {
return nil, err
}
if len(keys) > 0 {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
header, err := i.getKeysCacheControlHeader()
if err != nil {
return nil, err
}

if ok {
now := time.Now()
expireAt := v.(time.Time)
if expireAt.After(now) {
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds)
resp.Data[logical.HTTPRawCacheControl] = expireInString
}
if header != "" {
resp.Data[logical.HTTPRawCacheControl] = header
}
}

Expand Down Expand Up @@ -1317,36 +1362,37 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
return introspectionResp("")
}

// namedKey.rotate(overrides) performs a key rotation on a namedKey and returns the
// verification_ttl that was applied. verification_ttl can be overridden with an
// overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerificationTTL time.Duration) error {
// namedKey.rotate(overrides) performs a key rotation on a namedKey.
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
verificationTTL := k.VerificationTTL

if overrideVerificationTTL >= 0 {
verificationTTL = overrideVerificationTTL
}

// generate new key
signingKey, err := generateKeys(k.Algorithm)
nextSigningKey, err := generateKeys(k.Algorithm)
if err != nil {
return err
}
if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil {
if err := saveOIDCPublicKey(ctx, s, nextSigningKey.Public()); err != nil {
return err
}
logger.Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)

now := time.Now()

// set the previous public key's expiry time
for _, key := range k.KeyRing {
if key.KeyID == k.SigningKey.KeyID {
key.ExpireAt = now.Add(verificationTTL)
break
}
}
k.SigningKey = signingKey
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.KeyID})

k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: nextSigningKey.KeyID})
k.SigningKey = k.NextSigningKey
k.NextSigningKey = nextSigningKey
k.NextRotation = now.Add(k.RotationPeriod)

// store named key (it was modified when rotate was called on it)
Expand All @@ -1358,6 +1404,7 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
return err
}

logger.Debug("rotated OIDC public key, now using", "key_id", k.SigningKey.Public().KeyID)
return nil
}

Expand Down Expand Up @@ -1590,24 +1637,30 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
return nextExpiration, nil
}

func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, error) {
// oidcKeyRotation will rotate any keys that are due to be rotated.
//
// It will return the time of the soonest rotation and the minimum
// verificationTTL or minimum rotationPeriod out of all the current keys.
func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, time.Duration, error) {
// soonestRotation will be the soonest rotation time of all keys. Initialize
// here to a relatively distant time.
now := time.Now()
soonestRotation := now.Add(24 * time.Hour)

jwksClientCacheDuration := time.Duration(math.MaxInt64)

i.oidcLock.Lock()
defer i.oidcLock.Unlock()

keys, err := s.List(ctx, namedKeyConfigPath)
if err != nil {
return now, err
return now, jwksClientCacheDuration, err
}

for _, k := range keys {
entry, err := s.Get(ctx, namedKeyConfigPath+k)
if err != nil {
return now, err
return now, jwksClientCacheDuration, err
}

if entry == nil {
Expand All @@ -1616,10 +1669,18 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)

var key namedKey
if err := entry.DecodeJSON(&key); err != nil {
return now, err
return now, jwksClientCacheDuration, err
}
key.name = k

if key.VerificationTTL < jwksClientCacheDuration {
jwksClientCacheDuration = key.VerificationTTL
}

if key.RotationPeriod < jwksClientCacheDuration {
jwksClientCacheDuration = key.RotationPeriod
}

// Future key rotation that is the earliest we've seen.
if now.Before(key.NextRotation) && key.NextRotation.Before(soonestRotation) {
soonestRotation = key.NextRotation
Expand All @@ -1628,8 +1689,8 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
// Key that is due to be rotated.
if now.After(key.NextRotation) {
i.Logger().Debug("rotating OIDC key", "key", key.name)
if err := key.rotate(ctx, s, -1); err != nil {
return now, err
if err := key.rotate(ctx, i.Logger(), s, -1); err != nil {
return now, jwksClientCacheDuration, err
}

// Possibly save the new rotation time
Expand All @@ -1639,12 +1700,13 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
}
}

return soonestRotation, nil
return soonestRotation, jwksClientCacheDuration, nil
}

// oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key
// rotations and expiration actions.
func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
i.Logger().Debug("begin oidcPeriodicFunc")
var nextRun time.Time
now := time.Now()

Expand All @@ -1666,6 +1728,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
// Initialize to a fairly distant next run time. This will be brought in
// based on key rotation times.
nextRun = now.Add(24 * time.Hour)
minJwksClientCacheDuration := time.Duration(math.MaxInt64)

for _, ns := range i.namespacer.ListNamespaces() {
nsPath := ns.Path
Expand All @@ -1676,7 +1739,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
continue
}

nextRotation, err := i.oidcKeyRotation(ctx, s)
nextRotation, jwksClientCacheDuration, err := i.oidcKeyRotation(ctx, s)
if err != nil {
i.Logger().Warn("error rotating OIDC keys", "err", err)
}
Expand All @@ -1698,10 +1761,31 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
if nextExpiration.Before(nextRun) {
nextRun = nextExpiration
}

if jwksClientCacheDuration < minJwksClientCacheDuration {
minJwksClientCacheDuration = jwksClientCacheDuration
}
}

if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
i.Logger().Error("error setting oidc cache", "err", err)
}

if minJwksClientCacheDuration < math.MaxInt64 {
// the OIDC JWKS endpoint returns a Cache-Control HTTP header time between
// 0 and the minimum verificationTTL or minimum rotationPeriod out of all
// keys, whichever value is lower.
//
// This smooths calls from services validating JWTs to Vault, while
// ensuring that operators can assert that servers honoring the
// Cache-Control header will always have a superset of all valid keys, and
// not trust any keys longer than a jwksCacheControlMaxAge duration after a
// key is rotated out of signing use
if err := i.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", minJwksClientCacheDuration); err != nil {
i.Logger().Error("error setting jwksCacheControlMaxAge in oidc cache", "err", err)
}
}

}
}

Expand Down