Skip to content

Commit

Permalink
Identity: prepublish jwt signing keys (#12414)
Browse files Browse the repository at this point in the history
* pre-publish new signing keys for `rotation_period` of time before using

* Work In Progress: Prepublish JWKS and even cache control

* remove comments

* use math/rand instead of math/big

* update tests

* remove debug comment

* refactor cache control logic into func

* don't set expiry when create/update key

* update cachecontrol name in oidccache for test

* fix bug in periodicfunc test case

* add changelog

* remove confusing comment

* add logging and comments

* update change log from bug to improvement

Co-authored-by: Ian Ferguson <ian.ferguson@datadoghq.com>
  • Loading branch information
fairclothjm and ianferguson committed Sep 9, 2021
1 parent 22ea738 commit 7b49d57
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 47 deletions.
3 changes: 3 additions & 0 deletions changelog/12414.txt
@@ -0,0 +1,3 @@
```release-note:improvement
identity: fix issue where Cache-Control header causes stampede of requests for JWKS keys
```
148 changes: 116 additions & 32 deletions vault/identity_store_oidc.go
Expand Up @@ -10,6 +10,8 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
mathrand "math/rand"
"net/url"
"sort"
"strings"
Expand Down Expand Up @@ -50,6 +52,7 @@ type namedKey struct {
RotationPeriod time.Duration `json:"rotation_period"`
KeyRing []*expireableKey `json:"key_ring"`
SigningKey *jose.JSONWebKey `json:"signing_key"`
NextSigningKey *jose.JSONWebKey `json:"next_signing_key"`
NextRotation time.Time `json:"next_rotation"`
AllowedClientIDs []string `json:"allowed_client_ids"`
}
Expand Down Expand Up @@ -510,13 +513,15 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
return logical.ErrorResponse("unknown signing algorithm %q", key.Algorithm), nil
}

now := time.Now()

// Update next rotation time if it is unset or now earlier than previously set.
nextRotation := time.Now().Add(key.RotationPeriod)
nextRotation := now.Add(key.RotationPeriod)
if key.NextRotation.IsZero() || nextRotation.Before(key.NextRotation) {
key.NextRotation = nextRotation
}

// generate keys if creating a new key or changing algorithms
// generate current and next keys if creating a new key or changing algorithms
if key.Algorithm != prevAlgorithm {
signingKey, err := generateKeys(key.Algorithm)
if err != nil {
Expand All @@ -529,6 +534,20 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)

nextSigningKey, err := generateKeys(key.Algorithm)
if err != nil {
return nil, err
}

key.NextSigningKey = nextSigningKey
key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: nextSigningKey.Public().KeyID})

if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil {
return nil, err
}
i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)
}

if err := i.oidcCache.Flush(ns); err != nil {
Expand Down Expand Up @@ -727,7 +746,7 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
verificationTTLOverride = time.Duration(ttlRaw.(int)) * time.Second
}

if err := storedNamedKey.rotate(ctx, req.Storage, verificationTTLOverride); err != nil {
if err := storedNamedKey.rotate(ctx, i.Logger(), req.Storage, verificationTTLOverride); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1168,6 +1187,40 @@ func (i *IdentityStore) pathOIDCDiscovery(ctx context.Context, req *logical.Requ
return resp, nil
}

// getKeysCacheControlHeader returns the cache control header for all public
// keys at the .well-known/keys endpoint
func (i *IdentityStore) getKeysCacheControlHeader() (string, error) {
// if jwksCacheControlMaxAge is set use that, otherwise fall back on the
// more conservative nextRun values
jwksCacheControlMaxAge, ok, err := i.oidcCache.Get(noNamespace, "jwksCacheControlMaxAge")
if err != nil {
return "", err
}

if ok {
maxDuration := int64(jwksCacheControlMaxAge.(time.Duration))
randDuration := mathrand.Int63n(maxDuration)
durationInSeconds := time.Duration(randDuration).Seconds()
return fmt.Sprintf("max-age=%.0f", durationInSeconds), nil
}

nextRun, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
return "", err
}

if ok {
now := time.Now()
expireAt := nextRun.(time.Time)
if expireAt.After(now) {
i.Logger().Debug("use nextRun value for Cache Control header", "nextRun", nextRun)
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
return fmt.Sprintf("max-age=%.0f", expireInSeconds), nil
}
}
return "", nil
}

// pathOIDCReadPublicKeys is used to retrieve all public keys so that clients can
// verify the validity of a signed OIDC token.
func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
Expand Down Expand Up @@ -1209,27 +1262,19 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
},
}

// set a Cache-Control header only if there are keys, if there aren't keys
// then nextRun should not be used to set Cache-Control header because it chooses
// a time in the future that isn't based on key rotation/expiration values
// set a Cache-Control header only if there are keys
keys, err := listOIDCPublicKeys(ctx, req.Storage)
if err != nil {
return nil, err
}
if len(keys) > 0 {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
header, err := i.getKeysCacheControlHeader()
if err != nil {
return nil, err
}

if ok {
now := time.Now()
expireAt := v.(time.Time)
if expireAt.After(now) {
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds)
resp.Data[logical.HTTPRawCacheControl] = expireInString
}
if header != "" {
resp.Data[logical.HTTPRawCacheControl] = header
}
}

Expand Down Expand Up @@ -1326,36 +1371,37 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
return introspectionResp("")
}

// namedKey.rotate(overrides) performs a key rotation on a namedKey and returns the
// verification_ttl that was applied. verification_ttl can be overridden with an
// overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerificationTTL time.Duration) error {
// namedKey.rotate(overrides) performs a key rotation on a namedKey.
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
verificationTTL := k.VerificationTTL

if overrideVerificationTTL >= 0 {
verificationTTL = overrideVerificationTTL
}

// generate new key
signingKey, err := generateKeys(k.Algorithm)
nextSigningKey, err := generateKeys(k.Algorithm)
if err != nil {
return err
}
if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil {
if err := saveOIDCPublicKey(ctx, s, nextSigningKey.Public()); err != nil {
return err
}
logger.Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)

now := time.Now()

// set the previous public key's expiry time
for _, key := range k.KeyRing {
if key.KeyID == k.SigningKey.KeyID {
key.ExpireAt = now.Add(verificationTTL)
break
}
}
k.SigningKey = signingKey
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.KeyID})

k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: nextSigningKey.KeyID})
k.SigningKey = k.NextSigningKey
k.NextSigningKey = nextSigningKey
k.NextRotation = now.Add(k.RotationPeriod)

// store named key (it was modified when rotate was called on it)
Expand All @@ -1367,6 +1413,7 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
return err
}

logger.Debug("rotated OIDC public key, now using", "key_id", k.SigningKey.Public().KeyID)
return nil
}

Expand Down Expand Up @@ -1599,24 +1646,30 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
return nextExpiration, nil
}

func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, error) {
// oidcKeyRotation will rotate any keys that are due to be rotated.
//
// It will return the time of the soonest rotation and the minimum
// verificationTTL or minimum rotationPeriod out of all the current keys.
func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, time.Duration, error) {
// soonestRotation will be the soonest rotation time of all keys. Initialize
// here to a relatively distant time.
now := time.Now()
soonestRotation := now.Add(24 * time.Hour)

jwksClientCacheDuration := time.Duration(math.MaxInt64)

i.oidcLock.Lock()
defer i.oidcLock.Unlock()

keys, err := s.List(ctx, namedKeyConfigPath)
if err != nil {
return now, err
return now, jwksClientCacheDuration, err
}

for _, k := range keys {
entry, err := s.Get(ctx, namedKeyConfigPath+k)
if err != nil {
return now, err
return now, jwksClientCacheDuration, err
}

if entry == nil {
Expand All @@ -1625,10 +1678,18 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)

var key namedKey
if err := entry.DecodeJSON(&key); err != nil {
return now, err
return now, jwksClientCacheDuration, err
}
key.name = k

if key.VerificationTTL < jwksClientCacheDuration {
jwksClientCacheDuration = key.VerificationTTL
}

if key.RotationPeriod < jwksClientCacheDuration {
jwksClientCacheDuration = key.RotationPeriod
}

// Future key rotation that is the earliest we've seen.
if now.Before(key.NextRotation) && key.NextRotation.Before(soonestRotation) {
soonestRotation = key.NextRotation
Expand All @@ -1637,8 +1698,8 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
// Key that is due to be rotated.
if now.After(key.NextRotation) {
i.Logger().Debug("rotating OIDC key", "key", key.name)
if err := key.rotate(ctx, s, -1); err != nil {
return now, err
if err := key.rotate(ctx, i.Logger(), s, -1); err != nil {
return now, jwksClientCacheDuration, err
}

// Possibly save the new rotation time
Expand All @@ -1648,12 +1709,13 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
}
}

return soonestRotation, nil
return soonestRotation, jwksClientCacheDuration, nil
}

// oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key
// rotations and expiration actions.
func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
i.Logger().Debug("begin oidcPeriodicFunc")
var nextRun time.Time
now := time.Now()

Expand All @@ -1675,6 +1737,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
// Initialize to a fairly distant next run time. This will be brought in
// based on key rotation times.
nextRun = now.Add(24 * time.Hour)
minJwksClientCacheDuration := time.Duration(math.MaxInt64)

for _, ns := range i.namespacer.ListNamespaces() {
nsPath := ns.Path
Expand All @@ -1685,7 +1748,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
continue
}

nextRotation, err := i.oidcKeyRotation(ctx, s)
nextRotation, jwksClientCacheDuration, err := i.oidcKeyRotation(ctx, s)
if err != nil {
i.Logger().Warn("error rotating OIDC keys", "err", err)
}
Expand All @@ -1707,10 +1770,31 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
if nextExpiration.Before(nextRun) {
nextRun = nextExpiration
}

if jwksClientCacheDuration < minJwksClientCacheDuration {
minJwksClientCacheDuration = jwksClientCacheDuration
}
}

if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
i.Logger().Error("error setting oidc cache", "err", err)
}

if minJwksClientCacheDuration < math.MaxInt64 {
// the OIDC JWKS endpoint returns a Cache-Control HTTP header time between
// 0 and the minimum verificationTTL or minimum rotationPeriod out of all
// keys, whichever value is lower.
//
// This smooths calls from services validating JWTs to Vault, while
// ensuring that operators can assert that servers honoring the
// Cache-Control header will always have a superset of all valid keys, and
// not trust any keys longer than a jwksCacheControlMaxAge duration after a
// key is rotated out of signing use
if err := i.oidcCache.SetDefault(noNamespace, "jwksCacheControlMaxAge", minJwksClientCacheDuration); err != nil {
i.Logger().Error("error setting jwksCacheControlMaxAge in oidc cache", "err", err)
}
}

}
}

Expand Down

0 comments on commit 7b49d57

Please sign in to comment.