Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Identity: prepublish jwt signing keys #12414

Merged
merged 15 commits into from Sep 9, 2021
Merged
143 changes: 107 additions & 36 deletions vault/identity_store_oidc.go
Expand Up @@ -10,6 +10,8 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
mathrand "math/rand"
"net/url"
"sort"
"strings"
Expand Down Expand Up @@ -50,6 +52,7 @@ type namedKey struct {
RotationPeriod time.Duration `json:"rotation_period"`
KeyRing []*expireableKey `json:"key_ring"`
SigningKey *jose.JSONWebKey `json:"signing_key"`
NextSigningKey *jose.JSONWebKey `json:"next_signing_key"`
NextRotation time.Time `json:"next_rotation"`
AllowedClientIDs []string `json:"allowed_client_ids"`
}
Expand Down Expand Up @@ -509,25 +512,48 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica
return logical.ErrorResponse("unknown signing algorithm %q", key.Algorithm), nil
}

now := time.Now()

// Update next rotation time if it is unset or now earlier than previously set.
nextRotation := time.Now().Add(key.RotationPeriod)
nextRotation := now.Add(key.RotationPeriod)
if key.NextRotation.IsZero() || nextRotation.Before(key.NextRotation) {
key.NextRotation = nextRotation
}

// generate keys if creating a new key or changing algorithms
// generate current and next keys if creating a new key or changing algorithms
if key.Algorithm != prevAlgorithm {
signingKey, err := generateKeys(key.Algorithm)
if err != nil {
return nil, err
}

i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID)
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
key.SigningKey = signingKey
key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID})
key.KeyRing = append(key.KeyRing, &expireableKey{
KeyID: signingKey.Public().KeyID,
ExpireAt: now.Add(key.RotationPeriod).Add(key.VerificationTTL),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this means that we're calculating the ExpireAt by factoring the rotation period early on instead of just appending the VerificationTTL upon rotation like it currently is. What happens if there's an update on the named key's rotation_period after the key has been created (i.e. by an update)?

})

if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil {
return nil, err
}

nextSigningKey, err := generateKeys(key.Algorithm)
if err != nil {
return nil, err
}

i.Logger().Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)
key.NextSigningKey = nextSigningKey
key.KeyRing = append(key.KeyRing, &expireableKey{
KeyID: nextSigningKey.Public().KeyID,
ExpireAt: now.Add(key.RotationPeriod).Add(key.RotationPeriod).Add(key.VerificationTTL),
})

if err := saveOIDCPublicKey(ctx, req.Storage, nextSigningKey.Public()); err != nil {
return nil, err
}

}

if err := i.oidcCache.Flush(ns); err != nil {
Expand Down Expand Up @@ -711,7 +737,7 @@ func (i *IdentityStore) pathOIDCRotateKey(ctx context.Context, req *logical.Requ
verificationTTLOverride = time.Duration(ttlRaw.(int)) * time.Second
}

if err := storedNamedKey.rotate(ctx, req.Storage, verificationTTLOverride); err != nil {
if err := storedNamedKey.rotate(ctx, i.Logger(), req.Storage, verificationTTLOverride); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1193,18 +1219,34 @@ func (i *IdentityStore) pathOIDCReadPublicKeys(ctx context.Context, req *logical
return nil, err
}
if len(keys) > 0 {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
// if maxJwksClientCache is set use that, otherwise fall back on the more conservative
// nextRun values
maxJwksClientCache, ok, err := i.oidcCache.Get(noNamespace, "maxJwksClientCache")
if err != nil {
return nil, err
}

if ok {
now := time.Now()
expireAt := v.(time.Time)
if expireAt.After(now) {
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds)
resp.Data[logical.HTTPRawCacheControl] = expireInString
maxDuration := int64(maxJwksClientCache.(time.Duration))
randDuration := mathrand.Int63n(maxDuration)
// truncate to seconds
durationInSeconds := time.Duration(randDuration).Seconds()
durationInString := fmt.Sprintf("max-age=%.0f", durationInSeconds)
resp.Data[logical.HTTPRawCacheControl] = durationInString
} else {
v, ok, err := i.oidcCache.Get(noNamespace, "nextRun")
if err != nil {
return nil, err
}

if ok {
now := time.Now()
expireAt := v.(time.Time)
if expireAt.After(now) {
expireInSeconds := expireAt.Sub(time.Now()).Seconds()
expireInString := fmt.Sprintf("max-age=%.0f", expireInSeconds)
resp.Data[logical.HTTPRawCacheControl] = expireInString
}
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand Down Expand Up @@ -1302,36 +1344,35 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req
return introspectionResp("")
}

// namedKey.rotate(overrides) performs a key rotation on a namedKey and returns the
// verification_ttl that was applied. verification_ttl can be overridden with an
// overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerificationTTL time.Duration) error {
// namedKey.rotate(overrides) performs a key rotation on a namedKey.
// verification_ttl can be overridden with an overrideVerificationTTL value >= 0
func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error {
verificationTTL := k.VerificationTTL

if overrideVerificationTTL >= 0 {
verificationTTL = overrideVerificationTTL
}

// generate new key
signingKey, err := generateKeys(k.Algorithm)
nextSigningKey, err := generateKeys(k.Algorithm)
if err != nil {
return err
}
if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil {
if err := saveOIDCPublicKey(ctx, s, nextSigningKey.Public()); err != nil {
return err
}
logger.Debug("generated OIDC public key for future use", "key_id", nextSigningKey.Public().KeyID)

now := time.Now()

// set the previous public key's expiry time
for _, key := range k.KeyRing {
if key.KeyID == k.SigningKey.KeyID {
key.ExpireAt = now.Add(verificationTTL)
break
}
}
k.SigningKey = signingKey
k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.KeyID})
k.KeyRing = append(k.KeyRing, &expireableKey{
KeyID: nextSigningKey.KeyID,
// this token won't start being used for 1 rotation period,
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
// will be used for 1 rotation period after that,
// and should be deleted verificationTTL after it is no longer used
ExpireAt: now.Add(k.RotationPeriod).Add(k.RotationPeriod).Add(verificationTTL),
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
})
k.SigningKey = k.NextSigningKey
k.NextSigningKey = nextSigningKey
k.NextRotation = now.Add(k.RotationPeriod)

// store named key (it was modified when rotate was called on it)
Expand All @@ -1343,6 +1384,7 @@ func (k *namedKey) rotate(ctx context.Context, s logical.Storage, overrideVerifi
return err
}

logger.Debug("rotated OIDC public key. now using", "key_id", k.SigningKey.Public().KeyID)
return nil
}

Expand Down Expand Up @@ -1575,24 +1617,35 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor
return nextExpiration, nil
}

func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (time.Time, error) {
func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) (nextRotation time.Time, maxJwksClientCacheDuration time.Duration, err error) {
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
// soonestRotation will be the soonest rotation time of all keys. Initialize
// here to a relatively distant time.
now := time.Now()
soonestRotation := now.Add(24 * time.Hour)

// the OIDC JWKS endpoint returns a Cache-Control HTTP header time
// between 0 and the minimum verificationTTL or minimum rotationPeriod out
// of all keys, whichever value is lower.
//
// This smooths calls from services validating JWTs to Vault, while
// ensuring that operators can assert that servers honoring the Cache-Control
// header will always have a superset of all valid keys, and not trust
// any keys longer than a jwksCacheControlMax duration after a key is rotated
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
// out of signing use
jwksClientCacheDuration := time.Duration(math.MaxInt64)

i.oidcLock.Lock()
defer i.oidcLock.Unlock()

keys, err := s.List(ctx, namedKeyConfigPath)
if err != nil {
return now, err
return now, jwksClientCacheDuration, err
}

for _, k := range keys {
entry, err := s.Get(ctx, namedKeyConfigPath+k)
if err != nil {
return now, err
return now, jwksClientCacheDuration, err
}

if entry == nil {
Expand All @@ -1601,10 +1654,18 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)

var key namedKey
if err := entry.DecodeJSON(&key); err != nil {
return now, err
return now, jwksClientCacheDuration, err
}
key.name = k

if key.VerificationTTL < jwksClientCacheDuration {
jwksClientCacheDuration = key.VerificationTTL
}

if key.RotationPeriod < jwksClientCacheDuration {
jwksClientCacheDuration = key.RotationPeriod
}

// Future key rotation that is the earliest we've seen.
if now.Before(key.NextRotation) && key.NextRotation.Before(soonestRotation) {
soonestRotation = key.NextRotation
Expand All @@ -1613,8 +1674,8 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
// Key that is due to be rotated.
if now.After(key.NextRotation) {
i.Logger().Debug("rotating OIDC key", "key", key.name)
if err := key.rotate(ctx, s, -1); err != nil {
return now, err
if err := key.rotate(ctx, i.Logger(), s, -1); err != nil {
return now, jwksClientCacheDuration, err
}

// Possibly save the new rotation time
Expand All @@ -1624,7 +1685,7 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
}
}

return soonestRotation, nil
return soonestRotation, jwksClientCacheDuration, nil
}

// oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key
Expand All @@ -1651,7 +1712,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
// Initialize to a fairly distant next run time. This will be brought in
// based on key rotation times.
nextRun = now.Add(24 * time.Hour)

minJwksClientCacheDuration := time.Duration(math.MaxInt64)
for _, ns := range i.listNamespaces() {
nsPath := ns.Path

Expand All @@ -1661,7 +1722,7 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
continue
}

nextRotation, err := i.oidcKeyRotation(ctx, s)
nextRotation, jwksClientCacheDuration, err := i.oidcKeyRotation(ctx, s)
if err != nil {
i.Logger().Warn("error rotating OIDC keys", "err", err)
}
Expand All @@ -1683,10 +1744,20 @@ func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
if nextExpiration.Before(nextRun) {
nextRun = nextExpiration
}

if jwksClientCacheDuration < minJwksClientCacheDuration {
minJwksClientCacheDuration = jwksClientCacheDuration
}
}
if err := i.oidcCache.SetDefault(noNamespace, "nextRun", nextRun); err != nil {
i.Logger().Error("error setting oidc cache", "err", err)
}
if minJwksClientCacheDuration < math.MaxInt64 {
if err := i.oidcCache.SetDefault(noNamespace, "maxJwksClientCache", minJwksClientCacheDuration); err != nil {
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
i.Logger().Error("error setting maxJwksClientCache in oidc cache", "err", err)
}
}

}
}

Expand Down
24 changes: 14 additions & 10 deletions vault/identity_store_oidc_test.go
Expand Up @@ -469,7 +469,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
Storage: storage,
})

// .well-known/keys should contain 1 public key
// .well-known/keys should contain 2 public keys
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/.well-known/keys",
Operation: logical.ReadOperation,
Expand All @@ -479,8 +479,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
// parse response
responseJWKS := &jose.JSONWebKeySet{}
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 1 {
t.Fatalf("expected 1 public key but instead got %d", len(responseJWKS.Keys))
if len(responseJWKS.Keys) != 2 {
t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys))
}

// rotate test-key a few times, each rotate should increase the length of public keys returned
Expand All @@ -498,7 +498,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
})
expectSuccess(t, resp, err)

// .well-known/keys should contain 3 public keys
// .well-known/keys should contain 4 public keys
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/.well-known/keys",
Operation: logical.ReadOperation,
Expand All @@ -507,8 +507,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
expectSuccess(t, resp, err)
// parse response
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 3 {
t.Fatalf("expected 3 public keys but instead got %d", len(responseJWKS.Keys))
if len(responseJWKS.Keys) != 4 {
t.Fatalf("expected 4 public keys but instead got %d", len(responseJWKS.Keys))
}

// create another named key
Expand All @@ -525,7 +525,7 @@ func TestOIDC_PublicKeys(t *testing.T) {
Storage: storage,
})

// .well-known/keys should contain 1 public key, all of the public keys
// .well-known/keys should contain 2 public key, all of the public keys
// from named key "test-key" should have been deleted
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/.well-known/keys",
Expand All @@ -535,8 +535,8 @@ func TestOIDC_PublicKeys(t *testing.T) {
expectSuccess(t, resp, err)
// parse response
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 1 {
t.Fatalf("expected 1 public keys but instead got %d", len(responseJWKS.Keys))
if len(responseJWKS.Keys) != 2 {
t.Fatalf("expected 2 public keys but instead got %d", len(responseJWKS.Keys))
}
}

Expand Down Expand Up @@ -699,6 +699,7 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
RotationPeriod: 1 * cyclePeriod,
KeyRing: nil,
SigningKey: jwk,
NextSigningKey: jwk,
NextRotation: time.Now(),
},
[]struct {
Expand All @@ -708,8 +709,11 @@ func TestOIDC_PeriodicFunc(t *testing.T) {
}{
{1, 1, 1},
{2, 2, 2},
{3, 2, 2},
{3, 3, 3},
{4, 2, 2},
{5, 3, 3},
{6, 2, 2},
{7, 3, 3},
},
},
}
Expand Down