Skip to content

Commit

Permalink
feature: OIDC keys endpoint (#12525)
Browse files Browse the repository at this point in the history
* add keys path and initial handler

* read provider public keys

* add test cases

* remove some debug logs

* update tests after merging main

* refactor list all clients

* refactor logic to collect Key IDs
  • Loading branch information
fairclothjm committed Sep 14, 2021
1 parent 0d8d454 commit b86d300
Show file tree
Hide file tree
Showing 3 changed files with 303 additions and 2 deletions.
1 change: 0 additions & 1 deletion vault/identity_store_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1715,7 +1715,6 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage)
// oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key
// rotations and expiration actions.
func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) {
i.Logger().Debug("begin oidcPeriodicFunc")
var nextRun time.Time
now := time.Now()

Expand Down
159 changes: 159 additions & 0 deletions vault/identity_store_oidc_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/identitytpl"
"github.com/hashicorp/vault/sdk/logical"
"gopkg.in/square/go-jose.v2"
)

type assignment struct {
Expand Down Expand Up @@ -277,7 +278,165 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path {
HelpSynopsis: "Query OIDC configurations",
HelpDescription: "Query this path to retrieve the configured OIDC Issuer and Keys endpoints, response types, subject types, and signing algorithms used by the OIDC backend.",
},
{
Pattern: "oidc/provider/" + framework.GenericNameRegex("name") + "/.well-known/keys",
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Description: "Name of the provider",
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: i.pathOIDCReadProviderPublicKeys,
},
HelpSynopsis: "Retrieve public keys",
HelpDescription: "Returns the public portion of keys for a named OIDC provider. Clients can use them to validate the authenticity of an ID token.",
},
}
}

func (i *IdentityStore) listClients(ctx context.Context, s logical.Storage) ([]*client, error) {
clientNames, err := s.List(ctx, clientPath)
if err != nil {
return nil, err
}

var clients []*client
for _, name := range clientNames {
entry, err := s.Get(ctx, clientPath+name)
if err != nil {
return nil, err
}
if entry == nil {
continue
}

var client client
if err := entry.DecodeJSON(&client); err != nil {
return nil, err
}
clients = append(clients, &client)
}

return clients, nil
}

// TODO: load clients into memory (go-memdb) to look this up
func (i *IdentityStore) clientByID(ctx context.Context, s logical.Storage, id string) (*client, error) {
clients, err := i.listClients(ctx, s)
if err != nil {
return nil, err
}

for _, client := range clients {
if client.ClientID == id {
return client, nil
}
}

return nil, nil
}

// keyIDsReferencedByTargetClientIDs returns a slice of key IDs that are
// referenced by the clients' targetIDs.
// If targetIDs contains "*" then the IDs for all public keys are returned.
func (i *IdentityStore) keyIDsReferencedByTargetClientIDs(ctx context.Context, s logical.Storage, targetIDs []string) ([]string, error) {
keyNames := make(map[string]bool)

// Get all key names referenced by clients if wildcard "*" in target client IDs
if strutil.StrListContains(targetIDs, "*") {
clients, err := i.listClients(ctx, s)
if err != nil {
return nil, err
}

for _, client := range clients {
keyNames[client.Key] = true
}
}

// Otherwise, get the key names referenced by each target client ID
if len(keyNames) == 0 {
for _, clientID := range targetIDs {
client, err := i.clientByID(ctx, s, clientID)
if err != nil {
return nil, err
}

if client != nil {
keyNames[client.Key] = true
}
}
}

// Collect the key IDs
var keyIDs []string
for name, _ := range keyNames {
entry, err := s.Get(ctx, namedKeyConfigPath+name)
if err != nil {
return nil, err
}

var key namedKey
if err := entry.DecodeJSON(&key); err != nil {
return nil, err
}
for _, expirableKey := range key.KeyRing {
keyIDs = append(keyIDs, expirableKey.KeyID)
}
}
return keyIDs, nil
}

// pathOIDCReadProviderPublicKeys is used to retrieve all public keys for a
// named provider so that clients can verify the validity of a signed OIDC token.
func (i *IdentityStore) pathOIDCReadProviderPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
providerName := d.Get("name").(string)

var provider provider

providerEntry, err := req.Storage.Get(ctx, providerPath+providerName)
if err != nil {
return nil, err
}
if providerEntry == nil {
return nil, nil
}
if err := providerEntry.DecodeJSON(&provider); err != nil {
return nil, err
}

keyIDs, err := i.keyIDsReferencedByTargetClientIDs(ctx, req.Storage, provider.AllowedClientIDs)
if err != nil {
return nil, err
}

jwks := &jose.JSONWebKeySet{
Keys: make([]jose.JSONWebKey, 0, len(keyIDs)),
}

for _, keyID := range keyIDs {
key, err := loadOIDCPublicKey(ctx, req.Storage, keyID)
if err != nil {
return nil, err
}
jwks.Keys = append(jwks.Keys, *key)
}

data, err := json.Marshal(jwks)
if err != nil {
return nil, err
}

resp := &logical.Response{
Data: map[string]interface{}{
logical.HTTPStatusCode: 200,
logical.HTTPRawBody: data,
logical.HTTPContentType: "application/json",
},
}

return resp, nil
}

func (i *IdentityStore) pathOIDCProviderDiscovery(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
Expand Down
145 changes: 144 additions & 1 deletion vault/identity_store_oidc_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,151 @@ import (
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"gopkg.in/square/go-jose.v2"
)

// TestOIDC_Path_OIDC_ProviderReadPublicKey_ProviderDoesNotExist tests that the
// path can handle the read operation when the provider does not exist
func TestOIDC_Path_OIDC_ProviderReadPublicKey_ProviderDoesNotExist(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
ctx := namespace.RootContext(nil)
storage := &logical.InmemStorage{}

// Read "test-provider" .well-known keys
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/provider/test-provider/.well-known/keys",
Operation: logical.ReadOperation,
Storage: storage,
})
expectedResp := &logical.Response{}
if resp != expectedResp && err != nil {
t.Fatalf("expected empty response but got success; error:\n%v\nresp: %#v", err, resp)
}
}

// TestOIDC_Path_OIDC_ProviderReadPublicKey tests the provider .well-known
// keys endpoint read operations
func TestOIDC_Path_OIDC_ProviderReadPublicKey(t *testing.T) {
c, _, _ := TestCoreUnsealed(t)
ctx := namespace.RootContext(nil)
storage := &logical.InmemStorage{}

// Create a test key "test-key-1"
c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/key/test-key-1",
Operation: logical.CreateOperation,
Data: map[string]interface{}{
"verification_ttl": "2m",
"rotation_period": "2m",
},
Storage: storage,
})

// Create a test client "test-client-1"
c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/client/test-client-1",
Operation: logical.CreateOperation,
Storage: storage,
Data: map[string]interface{}{
"key": "test-key-1",
},
})

// get the clientID
resp, _ := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/client/test-client-1",
Operation: logical.ReadOperation,
Storage: storage,
})
clientID := resp.Data["client_id"].(string)

// Create a test provider "test-provider" and allow all client IDs -- should succeed
resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/provider/test-provider",
Operation: logical.CreateOperation,
Storage: storage,
Data: map[string]interface{}{
"issuer": "https://example.com:8200",
"allowed_client_ids": []string{"*"},
},
})
expectSuccess(t, resp, err)

// Read "test-provider" .well-known keys
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/provider/test-provider/.well-known/keys",
Operation: logical.ReadOperation,
Storage: storage,
})
expectSuccess(t, resp, err)

responseJWKS := &jose.JSONWebKeySet{}
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 2 {
t.Fatalf("expected 2 public key but instead got %d", len(responseJWKS.Keys))
}

// Create a test key "test-key-2"
c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/key/test-key-2",
Operation: logical.CreateOperation,
Data: map[string]interface{}{
"verification_ttl": "2m",
"rotation_period": "2m",
},
Storage: storage,
})

// Create a test client "test-client-2"
c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/client/test-client-2",
Operation: logical.CreateOperation,
Storage: storage,
Data: map[string]interface{}{
"key": "test-key-2",
},
})

// Read "test-provider" .well-known keys
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/provider/test-provider/.well-known/keys",
Operation: logical.ReadOperation,
Storage: storage,
})
expectSuccess(t, resp, err)

responseJWKS = &jose.JSONWebKeySet{}
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 4 {
t.Fatalf("expected 4 public key but instead got %d", len(responseJWKS.Keys))
}

// Update the test provider "test-provider" to only allow test-client-1 -- should succeed
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/provider/test-provider",
Operation: logical.UpdateOperation,
Storage: storage,
Data: map[string]interface{}{
"allowed_client_ids": []string{clientID},
},
})
expectSuccess(t, resp, err)

// Read "test-provider" .well-known keys
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/provider/test-provider/.well-known/keys",
Operation: logical.ReadOperation,
Storage: storage,
})
expectSuccess(t, resp, err)

responseJWKS = &jose.JSONWebKeySet{}
json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS)
if len(responseJWKS.Keys) != 2 {
t.Fatalf("expected 2 public key but instead got %d", len(responseJWKS.Keys))
}
}

// TestOIDC_Path_OIDC_ProviderClient_NoKeyParameter tests that a client cannot
// be created without a key parameter
func TestOIDC_Path_OIDC_ProviderClient_NoKeyParameter(t *testing.T) {
Expand Down Expand Up @@ -97,7 +240,7 @@ func TestOIDC_Path_OIDC_ProviderClient_UpdateKey(t *testing.T) {
})
expectSuccess(t, resp, err)

// Create a test client "test-client" -- should fail
// Update the test client "test-client" -- should fail
resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{
Path: "oidc/client/test-client",
Operation: logical.UpdateOperation,
Expand Down

0 comments on commit b86d300

Please sign in to comment.