From b5fa247ab3b17206c4d47982ecebb60e56311eb0 Mon Sep 17 00:00:00 2001 From: Matt Greenfield Date: Mon, 15 Mar 2021 09:10:12 -0600 Subject: [PATCH 1/8] Use the MS Graph API for atomic add/remove password operations Azure Active Directory Graph API, now deprecated, does not provide support for atomically creating/removing passwords on an application. As a result, there is a race conditions that can occur when creds are being created for roles configured with an existing service principal that is configured on multiple mounts or across multiple Vault clusters. Unfortunately, [`Azure/azure-sdk-for-go`](https://github.com/Azure/azure-sdk-for-go) does not yet offer a MS Graph API client, therefore, this PR utilizes [`Azure/go-autorest`](https://github.com/Azure/go-autorest) to construct a client the same as [`Azure/azure-sdk-for-go`](https://github.com/Azure/azure-sdk-for-go). This changeset preserves using the AAD Graph API by default but provides a mount configuration option for toggling to the new MS Graph API. This is because the two APIs require different API permissions. This allows users to upgrade to the new plugin version and then switch to the new API. Additionally, although using the MS Graph API is a net benefit, it itself has reliability issues when handling multiple requests in parallel. More details can be found in https://github.com/mdgreenfield/microsoft-graph-api-reliability and I am working with Microsoft to try to get some of these reliability issues resolved. Fixes #58 --- backend.go | 12 +- backend_test.go | 72 ++++--- client.go | 88 ++------- graph_api_client.go | 334 +++++++++++++++++++++++++++++++++ path_config.go | 18 +- path_config_test.go | 31 +-- path_service_principal.go | 2 +- path_service_principal_test.go | 12 +- provider.go | 209 ++++++++++++++++++--- 9 files changed, 613 insertions(+), 165 deletions(-) create mode 100644 graph_api_client.go diff --git a/backend.go b/backend.go index 0e647353..07bac01e 100644 --- a/backend.go +++ b/backend.go @@ -14,7 +14,7 @@ import ( type azureSecretBackend struct { *framework.Backend - getProvider func(*clientSettings) (AzureProvider, error) + getProvider func(*clientSettings, bool, passwords) (AzureProvider, error) client *client settings *clientSettings lock sync.RWMutex @@ -121,16 +121,16 @@ func (b *azureSecretBackend) getClient(ctx context.Context, s logical.Storage) ( b.settings = settings } - p, err := b.getProvider(b.settings) - if err != nil { - return nil, err - } - passwords := passwords{ policyGenerator: b.System(), policyName: config.PasswordPolicy, } + p, err := b.getProvider(b.settings, config.UseMsGraphAPI, passwords) + if err != nil { + return nil, err + } + c := &client{ provider: p, settings: b.settings, diff --git a/backend_test.go b/backend_test.go index 9c7f5453..6ecb36c1 100644 --- a/backend_test.go +++ b/backend_test.go @@ -6,6 +6,7 @@ import ( "fmt" "regexp" "strings" + "sync" "testing" "time" @@ -13,6 +14,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/to" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/sdk/helper/logging" @@ -44,7 +46,7 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical b.settings = new(clientSettings) mockProvider := newMockProvider() - b.getProvider = func(s *clientSettings) (AzureProvider, error) { + b.getProvider = func(s *clientSettings, usMsGraphApi bool, p passwords) (AzureProvider, error) { return mockProvider, nil } @@ -69,8 +71,9 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical type mockProvider struct { subscriptionID string applications map[string]bool - passwords map[string]bool + passwords map[string]passwordCredential failNextCreateApplication bool + lock sync.Mutex } // errMockProvider simulates a normal provider which fails to associate a role, @@ -88,15 +91,15 @@ func (e *errMockProvider) CreateRoleAssignment(ctx context.Context, scope string // key is found, unlike mockProvider which returns the same application object // id each time. Existing tests depend on the mockProvider behavior, which is // why errMockProvider has it's own version. -func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) { +func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (ApplicationResult, error) { for s := range e.applications { if s == applicationObjectID { - return graphrbac.Application{ + return ApplicationResult{ AppID: to.StringPtr(s), }, nil } } - return graphrbac.Application{}, errors.New("not found") + return ApplicationResult{}, errors.New("not found") } func newErrMockProvider() AzureProvider { @@ -104,7 +107,7 @@ func newErrMockProvider() AzureProvider { mockProvider: &mockProvider{ subscriptionID: generateUUID(), applications: make(map[string]bool), - passwords: make(map[string]bool), + passwords: make(map[string]passwordCredential), }, } } @@ -113,7 +116,7 @@ func newMockProvider() AzureProvider { return &mockProvider{ subscriptionID: generateUUID(), applications: make(map[string]bool), - passwords: make(map[string]bool), + passwords: make(map[string]passwordCredential), } } @@ -174,22 +177,26 @@ func (m *mockProvider) CreateServicePrincipal(ctx context.Context, parameters gr }, nil } -func (m *mockProvider) CreateApplication(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) { +func (m *mockProvider) CreateApplication(ctx context.Context, displayName string) (ApplicationResult, error) { if m.failNextCreateApplication { m.failNextCreateApplication = false - return graphrbac.Application{}, errors.New("Mock: fail to create application") + return ApplicationResult{}, errors.New("Mock: fail to create application") } appObjID := generateUUID() + + m.lock.Lock() + defer m.lock.Unlock() + m.applications[appObjID] = true - return graphrbac.Application{ - AppID: to.StringPtr(generateUUID()), - ObjectID: &appObjID, + return ApplicationResult{ + AppID: to.StringPtr(generateUUID()), + ID: &appObjID, }, nil } -func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) { - return graphrbac.Application{ +func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (ApplicationResult, error) { + return ApplicationResult{ AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"), }, nil } @@ -199,24 +206,32 @@ func (m *mockProvider) DeleteApplication(ctx context.Context, applicationObjectI return autorest.Response{}, nil } -func (m *mockProvider) UpdateApplicationPasswordCredentials(ctx context.Context, applicationObjectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (result autorest.Response, err error) { - m.passwords = make(map[string]bool) - for _, v := range *parameters.Value { - m.passwords[*v.KeyID] = true +func (m *mockProvider) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { + keyID := generateUUID() + cred := passwordCredential{ + DisplayName: to.StringPtr(displayName), + StartDate: &date.Time{Time: time.Now()}, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + SecretText: to.StringPtr(generateUUID()), } - return autorest.Response{}, nil + m.lock.Lock() + defer m.lock.Unlock() + m.passwords[keyID] = cred + + return PasswordCredentialResult{ + passwordCredential: cred, + }, nil } -func (m *mockProvider) ListApplicationPasswordCredentials(ctx context.Context, applicationObjectID string) (result graphrbac.PasswordCredentialListResult, err error) { - var creds []graphrbac.PasswordCredential - for keyID := range m.passwords { - creds = append(creds, graphrbac.PasswordCredential{KeyID: &keyID}) - } +func (m *mockProvider) RemoveApplicationPassword(background context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + m.lock.Lock() + defer m.lock.Unlock() - return graphrbac.PasswordCredentialListResult{ - Value: &creds, - }, nil + delete(m.passwords, keyID) + + return autorest.Response{}, nil } func (m *mockProvider) appExists(s string) bool { @@ -224,7 +239,8 @@ func (m *mockProvider) appExists(s string) bool { } func (m *mockProvider) passwordExists(s string) bool { - return m.passwords[s] + _, ok := m.passwords[s] + return ok } func (m *mockProvider) VMGet(ctx context.Context, resourceGroupName string, VMName string, expand compute.InstanceViewTypes) (result compute.VirtualMachine, err error) { diff --git a/client.go b/client.go index 0e0be999..56095a6b 100644 --- a/client.go +++ b/client.go @@ -44,7 +44,7 @@ func (c *client) Valid() bool { // createApp creates a new Azure application. // An Application is a needed to create service principals used by // the caller for authentication. -func (c *client) createApp(ctx context.Context) (app *graphrbac.Application, err error) { +func (c *client) createApp(ctx context.Context) (app *ApplicationResult, err error) { name, err := uuid.GenerateUUID() if err != nil { return nil, err @@ -52,14 +52,7 @@ func (c *client) createApp(ctx context.Context) (app *graphrbac.Application, err name = appNamePrefix + name - appURL := fmt.Sprintf("https://%s", name) - - result, err := c.provider.CreateApplication(ctx, graphrbac.ApplicationCreateParameters{ - AvailableToOtherTenants: to.BoolPtr(false), - DisplayName: to.StringPtr(name), - Homepage: to.StringPtr(appURL), - IdentifierUris: to.StringSlicePtr([]string{appURL}), - }) + result, err := c.provider.CreateApplication(ctx, name) return &result, err } @@ -67,7 +60,7 @@ func (c *client) createApp(ctx context.Context) (app *graphrbac.Application, err // createSP creates a new service principal. func (c *client) createSP( ctx context.Context, - app *graphrbac.Application, + app *ApplicationResult, duration time.Duration) (svcPrinc *graphrbac.ServicePrincipal, password string, err error) { // Generate a random key (which must be a UUID) and password @@ -114,85 +107,26 @@ func (c *client) createSP( } // addAppPassword adds a new password to an App's credentials list. -func (c *client) addAppPassword(ctx context.Context, appObjID string, duration time.Duration) (keyID string, password string, err error) { - keyID, err = uuid.GenerateUUID() - if err != nil { - return "", "", err - } - - // Key IDs are not secret, and they're a convenient way for an operator to identify Vault-generated - // passwords. These must be UUIDs, so the three leading bytes will be used as an indicator. - keyID = "ffffff" + keyID[6:] - - password, err = c.passwords.generate(ctx) - if err != nil { - return "", "", err - } - - now := time.Now().UTC() - cred := graphrbac.PasswordCredential{ - StartDate: &date.Time{Time: now}, - EndDate: &date.Time{Time: now.Add(duration)}, - KeyID: to.StringPtr(keyID), - Value: to.StringPtr(password), - } - - // Load current credentials - resp, err := c.provider.ListApplicationPasswordCredentials(ctx, appObjID) +func (c *client) addAppPassword(ctx context.Context, appObjID string, expiresIn time.Duration) (string, string, error) { + exp := date.Time{Time: time.Now().Add(expiresIn)} + resp, err := c.provider.AddApplicationPassword(ctx, appObjID, "vault-plugin-secrets-azure", exp) if err != nil { - return "", "", errwrap.Wrapf("error fetching credentials: {{err}}", err) - } - curCreds := *resp.Value - - // Add and save credentials - curCreds = append(curCreds, cred) - - if _, err := c.provider.UpdateApplicationPasswordCredentials(ctx, appObjID, - graphrbac.PasswordCredentialsUpdateParameters{ - Value: &curCreds, - }, - ); err != nil { if strings.Contains(err.Error(), "size of the object has exceeded its limit") { err = errors.New("maximum number of Application passwords reached") } return "", "", errwrap.Wrapf("error updating credentials: {{err}}", err) } - return keyID, password, nil + return to.String(resp.KeyID), to.String(resp.SecretText), nil } // deleteAppPassword removes a password, if present, from an App's credentials list. func (c *client) deleteAppPassword(ctx context.Context, appObjID, keyID string) error { - // Load current credentials - resp, err := c.provider.ListApplicationPasswordCredentials(ctx, appObjID) - if err != nil { - return errwrap.Wrapf("error fetching credentials: {{err}}", err) - } - curCreds := *resp.Value - - // Remove credential - found := false - for i := range curCreds { - if to.String(curCreds[i].KeyID) == keyID { - curCreds[i] = curCreds[len(curCreds)-1] - curCreds = curCreds[:len(curCreds)-1] - found = true - break + if _, err := c.provider.RemoveApplicationPassword(ctx, appObjID, keyID); err != nil { + if strings.Contains(err.Error(), "No password credential found with keyId") { + return nil } - } - - // KeyID is not present, so nothing to do - if !found { - return nil - } - - // Save new credentials list - if _, err := c.provider.UpdateApplicationPasswordCredentials(ctx, appObjID, - graphrbac.PasswordCredentialsUpdateParameters{ - Value: &curCreds, - }, - ); err != nil { - return errwrap.Wrapf("error updating credentials: {{err}}", err) + return errwrap.Wrapf("error removing credentials: {{err}}", err) } return nil diff --git a/graph_api_client.go b/graph_api_client.go new file mode 100644 index 00000000..7b74b54e --- /dev/null +++ b/graph_api_client.go @@ -0,0 +1,334 @@ +package azuresecrets + +import ( + "context" + "net/http" + + "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" +) + +const ( + // defaultGraphMicrosoftComURI is the default URI used for the service MS Graph API + defaultGraphMicrosoftComURI = "https://graph.microsoft.com" +) + +type msGraphApplicationsClient struct { + authorization.BaseClient +} + +func newMSGraphApplicationClient(subscriptionId string) msGraphApplicationsClient { + return msGraphApplicationsClient{authorization.NewWithBaseURI(defaultGraphMicrosoftComURI, subscriptionId)} +} + +func (p *msGraphApplicationsClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { + req, err := p.getApplicationPreparer(ctx, applicationObjectID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "GetApplication", nil, "Failure preparing request") + return + } + + resp, err := p.getApplicationSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "GetApplication", resp, "Failure sending request") + return + } + + result, err = p.getApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "GetApplication", resp, "Failure responding to request") + } + + return +} + +// CreateApplication create a new Azure application object. +func (p *msGraphApplicationsClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { + req, err := p.createApplicationPreparer(ctx, displayName) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", nil, "Failure preparing request") + return + } + + resp, err := p.createApplicationSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", resp, "Failure sending request") + return + } + + result, err = p.createApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", resp, "Failure responding to request") + } + + return +} + +// DeleteApplication deletes an Azure application object. +// This will in turn remove the service principal (but not the role assignments). +func (p *msGraphApplicationsClient) DeleteApplication(ctx context.Context, applicationObjectID string) (result autorest.Response, err error) { + req, err := p.deleteApplicationPreparer(ctx, applicationObjectID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", nil, "Failure preparing request") + return + } + + resp, err := p.deleteApplicationSender(req) + if err != nil { + result.Response = resp + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", resp, "Failure sending request") + return + } + + result, err = p.deleteApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", resp, "Failure responding to request") + } + + return +} + +func (p *msGraphApplicationsClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { + req, err := p.addPasswordPreparer(ctx, applicationObjectID, displayName, endDateTime) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", nil, "Failure preparing request") + return + } + + resp, err := p.addPasswordSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", resp, "Failure sending request") + return + } + + result, err = p.addPasswordResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", resp, "Failure responding to request") + } + + return +} + +func (p *msGraphApplicationsClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + req, err := p.removePasswordPreparer(ctx, applicationObjectID, keyID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", nil, "Failure preparing request") + return + } + + resp, err := p.removePasswordSender(req) + if err != nil { + result.Response = resp + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", resp, "Failure sending request") + return + } + + result, err = p.removePasswordResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", resp, "Failure responding to request") + } + + return +} + +func (client msGraphApplicationsClient) getApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsGet(), + autorest.WithBaseURL(client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}", pathParameters), + client.Authorizer.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (client msGraphApplicationsClient) getApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(client, req, sd...) +} + +func (client msGraphApplicationsClient) getApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { + err = autorest.Respond( + resp, + client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return +} + +func (client msGraphApplicationsClient) addPasswordPreparer(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + parameters := struct { + PasswordCredential *passwordCredential `json:"passwordCredential"` + }{ + PasswordCredential: &passwordCredential{ + DisplayName: to.StringPtr(displayName), + EndDate: &endDateTime, + }, + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}/addPassword", pathParameters), + autorest.WithJSON(parameters), + client.Authorizer.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (client msGraphApplicationsClient) addPasswordSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(client, req, sd...) +} + +func (client msGraphApplicationsClient) addPasswordResponder(resp *http.Response) (result PasswordCredentialResult, err error) { + err = autorest.Respond( + resp, + client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return +} + +func (client msGraphApplicationsClient) removePasswordPreparer(ctx context.Context, applicationObjectID string, keyID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + parameters := struct { + KeyID string `json:"keyId"` + }{ + KeyID: keyID, + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}/removePassword", pathParameters), + autorest.WithJSON(parameters), + client.Authorizer.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (client msGraphApplicationsClient) removePasswordSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(client, req, sd...) +} + +func (client msGraphApplicationsClient) removePasswordResponder(resp *http.Response) (result autorest.Response, err error) { + err = autorest.Respond( + resp, + client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusNoContent), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = resp + return +} + +func (client msGraphApplicationsClient) createApplicationPreparer(ctx context.Context, displayName string) (*http.Request, error) { + parameters := struct { + DisplayName *string `json:"displayName"` + }{ + DisplayName: to.StringPtr(displayName), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(client.BaseURI), + autorest.WithPath("/v1.0/applications"), + autorest.WithJSON(parameters), + client.Authorizer.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (client msGraphApplicationsClient) createApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(client, req, sd...) +} + +func (client msGraphApplicationsClient) createApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { + err = autorest.Respond( + resp, + client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusCreated), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return +} + +func (client msGraphApplicationsClient) deleteApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsDelete(), + autorest.WithBaseURL(client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}", pathParameters), + client.Authorizer.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (client msGraphApplicationsClient) deleteApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(client, req, sd...) +} + +func (client msGraphApplicationsClient) deleteApplicationResponder(resp *http.Response) (result autorest.Response, err error) { + err = autorest.Respond( + resp, + client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusNoContent), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = resp + return +} + +type passwordCredential struct { + DisplayName *string `json:"displayName"` + // StartDate - Start date. + StartDate *date.Time `json:"startDateTime,omitempty"` + // EndDate - End date. + EndDate *date.Time `json:"endDateTime,omitempty"` + // KeyID - Key ID. + KeyID *string `json:"keyId,omitempty"` + // Value - Key value. + SecretText *string `json:"secretText,omitempty"` +} + +type PasswordCredentialResult struct { + autorest.Response `json:"-"` + + passwordCredential +} + +type ApplicationResult struct { + autorest.Response `json:"-"` + + AppID *string `json:"appId,omitempty"` + ID *string `json:"id,omitempty"` + PasswordCredentials []*passwordCredential `json:"passwordCredentials,omitempty"` +} diff --git a/path_config.go b/path_config.go index 431a0681..e405aa0b 100644 --- a/path_config.go +++ b/path_config.go @@ -24,6 +24,7 @@ type azureConfig struct { ClientSecret string `json:"client_secret"` Environment string `json:"environment"` PasswordPolicy string `json:"password_policy"` + UseMsGraphAPI bool `json:"use_microsoft_graph_api"` } func pathConfig(b *azureSecretBackend) *framework.Path { @@ -59,6 +60,10 @@ func pathConfig(b *azureSecretBackend) *framework.Path { Type: framework.TypeString, Description: "Name of the password policy to use to generate passwords for dynamic credentials.", }, + "use_microsoft_graph_api": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: "Enable usage of the Microsoft Graph API over the deprecated Azure AD Graph API.", + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ logical.ReadOperation: b.pathConfigRead, @@ -112,6 +117,10 @@ func (b *azureSecretBackend) pathConfigWrite(ctx context.Context, req *logical.R config.ClientSecret = clientSecret.(string) } + if useMsGraphApi, ok := data.GetOk("use_microsoft_graph_api"); ok { + config.UseMsGraphAPI = useMsGraphApi.(bool) + } + config.PasswordPolicy = data.Get("password_policy").(string) if merr.ErrorOrNil() != nil { @@ -136,10 +145,11 @@ func (b *azureSecretBackend) pathConfigRead(ctx context.Context, req *logical.Re resp := &logical.Response{ Data: map[string]interface{}{ - "subscription_id": config.SubscriptionID, - "tenant_id": config.TenantID, - "environment": config.Environment, - "client_id": config.ClientID, + "subscription_id": config.SubscriptionID, + "tenant_id": config.TenantID, + "environment": config.Environment, + "client_id": config.ClientID, + "use_microsoft_graph_api": config.UseMsGraphAPI, }, } return resp, nil diff --git a/path_config_test.go b/path_config_test.go index bed8e539..a43c471a 100644 --- a/path_config_test.go +++ b/path_config_test.go @@ -12,11 +12,12 @@ func TestConfig(t *testing.T) { // Test valid config config := map[string]interface{}{ - "subscription_id": "a228ceec-bf1a-4411-9f95-39678d8cdb34", - "tenant_id": "7ac36e27-80fc-4209-a453-e8ad83dc18c2", - "client_id": "testClientId", - "client_secret": "testClientSecret", - "environment": "AZURECHINACLOUD", + "subscription_id": "a228ceec-bf1a-4411-9f95-39678d8cdb34", + "tenant_id": "7ac36e27-80fc-4209-a453-e8ad83dc18c2", + "client_id": "testClientId", + "client_secret": "testClientSecret", + "environment": "AZURECHINACLOUD", + "use_microsoft_graph_api": false, } testConfigCreate(t, b, s, config) @@ -54,11 +55,12 @@ func TestConfigDelete(t *testing.T) { // Test valid config config := map[string]interface{}{ - "subscription_id": "a228ceec-bf1a-4411-9f95-39678d8cdb34", - "tenant_id": "7ac36e27-80fc-4209-a453-e8ad83dc18c2", - "client_id": "testClientId", - "client_secret": "testClientSecret", - "environment": "AZURECHINACLOUD", + "subscription_id": "a228ceec-bf1a-4411-9f95-39678d8cdb34", + "tenant_id": "7ac36e27-80fc-4209-a453-e8ad83dc18c2", + "client_id": "testClientId", + "client_secret": "testClientSecret", + "environment": "AZURECHINACLOUD", + "use_microsoft_graph_api": false, } testConfigCreate(t, b, s, config) @@ -79,10 +81,11 @@ func TestConfigDelete(t *testing.T) { } config = map[string]interface{}{ - "subscription_id": "", - "tenant_id": "", - "client_id": "", - "environment": "", + "subscription_id": "", + "tenant_id": "", + "client_id": "", + "environment": "", + "use_microsoft_graph_api": false, } testConfigRead(t, b, s, config) } diff --git a/path_service_principal.go b/path_service_principal.go index 2505c8ac..251c5866 100644 --- a/path_service_principal.go +++ b/path_service_principal.go @@ -105,7 +105,7 @@ func (b *azureSecretBackend) createSPSecret(ctx context.Context, s logical.Stora return nil, err } appID := to.String(app.AppID) - appObjID := to.String(app.ObjectID) + appObjID := to.String(app.ID) // Write a WAL entry in case the SP create process doesn't complete walID, err := framework.PutWAL(ctx, s, walAppKey, &walApp{ diff --git a/path_service_principal_test.go b/path_service_principal_test.go index 9abad420..e380db8e 100644 --- a/path_service_principal_test.go +++ b/path_service_principal_test.go @@ -59,7 +59,7 @@ func TestSP_WAL_Cleanup(t *testing.T) { // overwrite the normal test backend provider with the errMockProvider errMockProvider := newErrMockProvider() - b.getProvider = func(s *clientSettings) (AzureProvider, error) { + b.getProvider = func(s *clientSettings, useMsGraphApi bool, p passwords) (AzureProvider, error) { return errMockProvider, nil } @@ -271,8 +271,8 @@ func TestStaticSPRead(t *testing.T) { assertErrorIsNil(t, err) keyID := resp.Secret.InternalData["key_id"].(string) - if !strings.HasPrefix(keyID, "ffffff") { - t.Fatalf("expected prefix 'ffffff': %s", keyID) + if len(keyID) == 0 { + t.Fatalf("expected keyId to not be empty") } client, err := b.getClient(context.Background(), s) @@ -413,8 +413,8 @@ func TestStaticSPRevoke(t *testing.T) { assertErrorIsNil(t, err) keyID := resp.Secret.InternalData["key_id"].(string) - if !strings.HasPrefix(keyID, "ffffff") { - t.Fatalf("expected prefix 'ffffff': %s", keyID) + if len(keyID) == 0 { + t.Fatalf("expected keyId to not be empty") } client, err := b.getClient(context.Background(), s) @@ -733,7 +733,7 @@ func TestCredentialInteg(t *testing.T) { for i := 0; i < 8; i++ { // New credentials are only tested during an actual operation, not provider creation. // This step should never fail. - p, err := newAzureProvider(settings) + p, err := newAzureProvider(settings, true, passwords{}) if err != nil { t.Fatal(err) } diff --git a/provider.go b/provider.go index 99498817..7aaa2b5c 100644 --- a/provider.go +++ b/provider.go @@ -2,12 +2,19 @@ package azuresecrets import ( "context" + "errors" + "fmt" "strings" + "time" "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure/auth" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/helper/useragent" "github.com/hashicorp/vault/sdk/version" ) @@ -24,14 +31,11 @@ type AzureProvider interface { } type ApplicationsClient interface { - CreateApplication(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) + GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) + CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) - GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) - UpdateApplicationPasswordCredentials( - ctx context.Context, - applicationObjectID string, - parameters graphrbac.PasswordCredentialsUpdateParameters) (result autorest.Response, err error) - ListApplicationPasswordCredentials(ctx context.Context, applicationObjectID string) (result graphrbac.PasswordCredentialListResult, err error) + AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) + RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) } type ServicePrincipalsClient interface { @@ -65,7 +69,7 @@ type RoleDefinitionsClient interface { type provider struct { settings *clientSettings - appClient *graphrbac.ApplicationsClient + appClient ApplicationsClient spClient *graphrbac.ServicePrincipalsClient groupsClient *graphrbac.GroupsClient raClient *authorization.RoleAssignmentsClient @@ -73,9 +77,9 @@ type provider struct { } // newAzureProvider creates an azureProvider, backed by Azure client objects for underlying services. -func newAzureProvider(settings *clientSettings) (AzureProvider, error) { +func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords passwords) (AzureProvider, error) { // build clients that use the GraphRBAC endpoint - authorizer, err := getAuthorizer(settings, settings.Environment.GraphEndpoint) + graphAuthorizer, err := getAuthorizer(settings, settings.Environment.GraphEndpoint) if err != nil { return nil, err } @@ -105,36 +109,52 @@ func newAzureProvider(settings *clientSettings) (AzureProvider, error) { } userAgent = strings.Replace(userAgent, ")", vaultIDString, 1) - appClient := graphrbac.NewApplicationsClient(settings.TenantID) - appClient.Authorizer = authorizer - appClient.AddToUserAgent(userAgent) - spClient := graphrbac.NewServicePrincipalsClient(settings.TenantID) - spClient.Authorizer = authorizer + spClient.Authorizer = graphAuthorizer spClient.AddToUserAgent(userAgent) groupsClient := graphrbac.NewGroupsClient(settings.TenantID) - groupsClient.Authorizer = authorizer + groupsClient.Authorizer = graphAuthorizer groupsClient.AddToUserAgent(userAgent) + var appClient ApplicationsClient + if useMsGraphApi { + graphApiAuthorizer, err := getAuthorizer(settings, defaultGraphMicrosoftComURI) + if err != nil { + return nil, err + } + + msGraphAppClient := newMSGraphApplicationClient(settings.SubscriptionID) + msGraphAppClient.Authorizer = graphApiAuthorizer + msGraphAppClient.AddToUserAgent(userAgent) + + appClient = &msGraphAppClient + } else { + aadGraphClient := graphrbac.NewApplicationsClient(settings.TenantID) + aadGraphClient.Authorizer = graphAuthorizer + aadGraphClient.AddToUserAgent(userAgent) + + appClient = &aadGraphApplicationsClient{appClient: &aadGraphClient, passwords: passwords} + } + // build clients that use the Resource Manager endpoint - authorizer, err = getAuthorizer(settings, settings.Environment.ResourceManagerEndpoint) + resourceManagerAuthorizer, err := getAuthorizer(settings, settings.Environment.ResourceManagerEndpoint) if err != nil { return nil, err } raClient := authorization.NewRoleAssignmentsClientWithBaseURI(settings.Environment.ResourceManagerEndpoint, settings.SubscriptionID) - raClient.Authorizer = authorizer + raClient.Authorizer = resourceManagerAuthorizer raClient.AddToUserAgent(userAgent) rdClient := authorization.NewRoleDefinitionsClientWithBaseURI(settings.Environment.ResourceManagerEndpoint, settings.SubscriptionID) - rdClient.Authorizer = authorizer + rdClient.Authorizer = resourceManagerAuthorizer rdClient.AddToUserAgent(userAgent) p := &provider{ settings: settings, - appClient: &appClient, + appClient: appClient, spClient: &spClient, groupsClient: &groupsClient, raClient: &raClient, @@ -169,26 +189,26 @@ func getAuthorizer(settings *clientSettings, resource string) (authorizer autore } // CreateApplication create a new Azure application object. -func (p *provider) CreateApplication(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) { - return p.appClient.Create(ctx, parameters) +func (p *provider) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { + return p.appClient.CreateApplication(ctx, displayName) } -func (p *provider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) { - return p.appClient.Get(ctx, applicationObjectID) +func (p *provider) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { + return p.appClient.GetApplication(ctx, applicationObjectID) } // DeleteApplication deletes an Azure application object. // This will in turn remove the service principal (but not the role assignments). func (p *provider) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { - return p.appClient.Delete(ctx, applicationObjectID) + return p.appClient.DeleteApplication(ctx, applicationObjectID) } -func (p *provider) UpdateApplicationPasswordCredentials(ctx context.Context, applicationObjectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (result autorest.Response, err error) { - return p.appClient.UpdatePasswordCredentials(ctx, applicationObjectID, parameters) +func (p *provider) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { + return p.appClient.AddApplicationPassword(ctx, applicationObjectID, displayName, endDateTime) } -func (p *provider) ListApplicationPasswordCredentials(ctx context.Context, applicationObjectID string) (result graphrbac.PasswordCredentialListResult, err error) { - return p.appClient.ListPasswordCredentials(ctx, applicationObjectID) +func (p *provider) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + return p.appClient.RemoveApplicationPassword(ctx, applicationObjectID, keyID) } // CreateServicePrincipal creates a new Azure service principal. @@ -265,3 +285,134 @@ func (p *provider) ListGroups(ctx context.Context, filter string) (result []grap return page.Values(), nil } + +type aadGraphApplicationsClient struct { + appClient *graphrbac.ApplicationsClient + passwords passwords +} + +func (a *aadGraphApplicationsClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { + app, err := a.appClient.Get(ctx, applicationObjectID) + if err != nil { + return ApplicationResult{}, err + } + + return ApplicationResult{ + AppID: app.AppID, + ID: app.ObjectID, + }, nil +} + +func (a *aadGraphApplicationsClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { + appURL := fmt.Sprintf("https://%s", displayName) + + app, err := a.appClient.Create(ctx, graphrbac.ApplicationCreateParameters{ + AvailableToOtherTenants: to.BoolPtr(false), + DisplayName: to.StringPtr(displayName), + Homepage: to.StringPtr(appURL), + IdentifierUris: to.StringSlicePtr([]string{appURL}), + }) + if err != nil { + return ApplicationResult{}, err + } + + return ApplicationResult{ + AppID: app.AppID, + ID: app.ObjectID, + }, nil +} + +func (a *aadGraphApplicationsClient) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { + return a.appClient.Delete(ctx, applicationObjectID) +} + +func (a *aadGraphApplicationsClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { + keyID, err := uuid.GenerateUUID() + if err != nil { + return PasswordCredentialResult{}, err + } + + // Key IDs are not secret, and they're a convenient way for an operator to identify Vault-generated + // passwords. These must be UUIDs, so the three leading bytes will be used as an indicator. + keyID = "ffffff" + keyID[6:] + + password, err := a.passwords.generate(ctx) + if err != nil { + return PasswordCredentialResult{}, err + } + + now := date.Time{Time: time.Now().UTC()} + cred := graphrbac.PasswordCredential{ + StartDate: &now, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + Value: to.StringPtr(password), + } + + // Load current credentials + resp, err := a.appClient.ListPasswordCredentials(ctx, applicationObjectID) + if err != nil { + return PasswordCredentialResult{}, errwrap.Wrapf("error fetching credentials: {{err}}", err) + } + curCreds := *resp.Value + + // Add and save credentials + curCreds = append(curCreds, cred) + + if _, err := a.appClient.UpdatePasswordCredentials(ctx, applicationObjectID, + graphrbac.PasswordCredentialsUpdateParameters{ + Value: &curCreds, + }, + ); err != nil { + if strings.Contains(err.Error(), "size of the object has exceeded its limit") { + err = errors.New("maximum number of Application passwords reached") + } + return PasswordCredentialResult{}, errwrap.Wrapf("error updating credentials: {{err}}", err) + } + + return PasswordCredentialResult{ + passwordCredential: passwordCredential{ + DisplayName: to.StringPtr(displayName), + StartDate: &now, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + SecretText: to.StringPtr(password), + }, + }, nil +} + +func (a *aadGraphApplicationsClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + // Load current credentials + resp, err := a.appClient.ListPasswordCredentials(ctx, applicationObjectID) + if err != nil { + return autorest.Response{}, errwrap.Wrapf("error fetching credentials: {{err}}", err) + } + curCreds := *resp.Value + + // Remove credential + found := false + for i := range curCreds { + if to.String(curCreds[i].KeyID) == keyID { + curCreds[i] = curCreds[len(curCreds)-1] + curCreds = curCreds[:len(curCreds)-1] + found = true + break + } + } + + // KeyID is not present, so nothing to do + if !found { + return autorest.Response{}, nil + } + + // Save new credentials list + if _, err := a.appClient.UpdatePasswordCredentials(ctx, applicationObjectID, + graphrbac.PasswordCredentialsUpdateParameters{ + Value: &curCreds, + }, + ); err != nil { + return autorest.Response{}, errwrap.Wrapf("error updating credentials: {{err}}", err) + } + + return autorest.Response{}, nil +} From adc49107b451fe2d8dacb7f1dcbe2b25e627b957 Mon Sep 17 00:00:00 2001 From: Jason O'Donnell <2160810+jasonodonnell@users.noreply.github.com> Date: Fri, 3 Sep 2021 10:59:16 -0400 Subject: [PATCH 2/8] Move to separate package --- api/aad_application.go | 147 +++++++++++++ api/api.go | 80 +++++++ .../graph_application.go | 80 +++---- api/passwords.go | 31 +++ backend.go | 9 +- backend_test.go | 39 ++-- client.go | 15 +- passwords.go | 31 --- path_roles.go | 2 +- path_service_principal_test.go | 13 +- provider.go | 202 +----------------- 11 files changed, 337 insertions(+), 312 deletions(-) create mode 100644 api/aad_application.go create mode 100644 api/api.go rename graph_api_client.go => api/graph_application.go (69%) create mode 100644 api/passwords.go delete mode 100644 passwords.go diff --git a/api/aad_application.go b/api/aad_application.go new file mode 100644 index 00000000..6820c715 --- /dev/null +++ b/api/aad_application.go @@ -0,0 +1,147 @@ +package api + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/go-uuid" +) + +type ActiveDirectoryApplicatinClient struct { + Client *graphrbac.ApplicationsClient + Passwords Passwords +} + +func (a *ActiveDirectoryApplicatinClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { + app, err := a.Client.Get(ctx, applicationObjectID) + if err != nil { + return ApplicationResult{}, err + } + + return ApplicationResult{ + AppID: app.AppID, + ID: app.ObjectID, + }, nil +} + +func (a *ActiveDirectoryApplicatinClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { + appURL := fmt.Sprintf("https://%s", displayName) + + app, err := a.Client.Create(ctx, graphrbac.ApplicationCreateParameters{ + AvailableToOtherTenants: to.BoolPtr(false), + DisplayName: to.StringPtr(displayName), + Homepage: to.StringPtr(appURL), + IdentifierUris: to.StringSlicePtr([]string{appURL}), + }) + if err != nil { + return ApplicationResult{}, err + } + + return ApplicationResult{ + AppID: app.AppID, + ID: app.ObjectID, + }, nil +} + +func (a *ActiveDirectoryApplicatinClient) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { + return a.Client.Delete(ctx, applicationObjectID) +} + +func (a *ActiveDirectoryApplicatinClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { + keyID, err := uuid.GenerateUUID() + if err != nil { + return PasswordCredentialResult{}, err + } + + // Key IDs are not secret, and they're a convenient way for an operator to identify Vault-generated + // passwords. These must be UUIDs, so the three leading bytes will be used as an indicator. + keyID = "ffffff" + keyID[6:] + + password, err := a.Passwords.Generate(ctx) + if err != nil { + return PasswordCredentialResult{}, err + } + + now := date.Time{Time: time.Now().UTC()} + cred := graphrbac.PasswordCredential{ + StartDate: &now, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + Value: to.StringPtr(password), + } + + // Load current credentials + resp, err := a.Client.ListPasswordCredentials(ctx, applicationObjectID) + if err != nil { + return PasswordCredentialResult{}, errwrap.Wrapf("error fetching credentials: {{err}}", err) + } + curCreds := *resp.Value + + // Add and save credentials + curCreds = append(curCreds, cred) + + if _, err := a.Client.UpdatePasswordCredentials(ctx, applicationObjectID, + graphrbac.PasswordCredentialsUpdateParameters{ + Value: &curCreds, + }, + ); err != nil { + if strings.Contains(err.Error(), "size of the object has exceeded its limit") { + err = errors.New("maximum number of Application passwords reached") + } + return PasswordCredentialResult{}, errwrap.Wrapf("error updating credentials: {{err}}", err) + } + + return PasswordCredentialResult{ + PasswordCredential: PasswordCredential{ + DisplayName: to.StringPtr(displayName), + StartDate: &now, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + SecretText: to.StringPtr(password), + }, + }, nil +} + +func (a *ActiveDirectoryApplicatinClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + // Load current credentials + resp, err := a.Client.ListPasswordCredentials(ctx, applicationObjectID) + if err != nil { + return autorest.Response{}, errwrap.Wrapf("error fetching credentials: {{err}}", err) + } + curCreds := *resp.Value + + // Remove credential + found := false + for i := range curCreds { + if to.String(curCreds[i].KeyID) == keyID { + curCreds[i] = curCreds[len(curCreds)-1] + curCreds = curCreds[:len(curCreds)-1] + found = true + break + } + } + + // KeyID is not present, so nothing to do + if !found { + return autorest.Response{}, nil + } + + // Save new credentials list + if _, err := a.Client.UpdatePasswordCredentials(ctx, applicationObjectID, + graphrbac.PasswordCredentialsUpdateParameters{ + Value: &curCreds, + }, + ); err != nil { + return autorest.Response{}, errwrap.Wrapf("error updating credentials: {{err}}", err) + } + + return autorest.Response{}, nil +} diff --git a/api/api.go b/api/api.go new file mode 100644 index 00000000..f4fd09a8 --- /dev/null +++ b/api/api.go @@ -0,0 +1,80 @@ +package api + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/date" +) + +// AzureProvider is an interface to access underlying Azure client objects and supporting services. +// Where practical the original function signature is preserved. client provides higher +// level operations atop AzureProvider. +type AzureProvider interface { + ApplicationsClient + ServicePrincipalsClient + ADGroupsClient + RoleAssignmentsClient + RoleDefinitionsClient +} + +type ApplicationsClient interface { + GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) + CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) + DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) + AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) + RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) +} + +type ServicePrincipalsClient interface { + CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) +} + +type ADGroupsClient interface { + AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) + RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) + GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) + ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) +} + +type RoleAssignmentsClient interface { + CreateRoleAssignment( + ctx context.Context, + scope string, + roleAssignmentName string, + parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) + DeleteRoleAssignmentByID(ctx context.Context, roleID string) (authorization.RoleAssignment, error) +} + +type RoleDefinitionsClient interface { + ListRoles(ctx context.Context, scope string, filter string) ([]authorization.RoleDefinition, error) + GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) +} + +type PasswordCredential struct { + DisplayName *string `json:"displayName"` + // StartDate - Start date. + StartDate *date.Time `json:"startDateTime,omitempty"` + // EndDate - End date. + EndDate *date.Time `json:"endDateTime,omitempty"` + // KeyID - Key ID. + KeyID *string `json:"keyId,omitempty"` + // Value - Key value. + SecretText *string `json:"secretText,omitempty"` +} + +type PasswordCredentialResult struct { + autorest.Response `json:"-"` + + PasswordCredential +} + +type ApplicationResult struct { + autorest.Response `json:"-"` + + AppID *string `json:"appId,omitempty"` + ID *string `json:"id,omitempty"` + PasswordCredentials []*PasswordCredential `json:"passwordCredentials,omitempty"` +} diff --git a/graph_api_client.go b/api/graph_application.go similarity index 69% rename from graph_api_client.go rename to api/graph_application.go index 7b74b54e..8765ab46 100644 --- a/graph_api_client.go +++ b/api/graph_application.go @@ -1,4 +1,4 @@ -package azuresecrets +package api import ( "context" @@ -13,18 +13,18 @@ import ( const ( // defaultGraphMicrosoftComURI is the default URI used for the service MS Graph API - defaultGraphMicrosoftComURI = "https://graph.microsoft.com" + DefaultGraphMicrosoftComURI = "https://graph.microsoft.com" ) -type msGraphApplicationsClient struct { +type AppClient struct { authorization.BaseClient } -func newMSGraphApplicationClient(subscriptionId string) msGraphApplicationsClient { - return msGraphApplicationsClient{authorization.NewWithBaseURI(defaultGraphMicrosoftComURI, subscriptionId)} +func NewGraphApplicationClient(subscriptionId string) AppClient { + return AppClient{authorization.NewWithBaseURI(DefaultGraphMicrosoftComURI, subscriptionId)} } -func (p *msGraphApplicationsClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { +func (p *AppClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { req, err := p.getApplicationPreparer(ctx, applicationObjectID) if err != nil { err = autorest.NewErrorWithError(err, "provider", "GetApplication", nil, "Failure preparing request") @@ -47,7 +47,7 @@ func (p *msGraphApplicationsClient) GetApplication(ctx context.Context, applicat } // CreateApplication create a new Azure application object. -func (p *msGraphApplicationsClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { +func (p *AppClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { req, err := p.createApplicationPreparer(ctx, displayName) if err != nil { err = autorest.NewErrorWithError(err, "provider", "CreateApplication", nil, "Failure preparing request") @@ -71,7 +71,7 @@ func (p *msGraphApplicationsClient) CreateApplication(ctx context.Context, displ // DeleteApplication deletes an Azure application object. // This will in turn remove the service principal (but not the role assignments). -func (p *msGraphApplicationsClient) DeleteApplication(ctx context.Context, applicationObjectID string) (result autorest.Response, err error) { +func (p *AppClient) DeleteApplication(ctx context.Context, applicationObjectID string) (result autorest.Response, err error) { req, err := p.deleteApplicationPreparer(ctx, applicationObjectID) if err != nil { err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", nil, "Failure preparing request") @@ -93,7 +93,7 @@ func (p *msGraphApplicationsClient) DeleteApplication(ctx context.Context, appli return } -func (p *msGraphApplicationsClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { +func (p *AppClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { req, err := p.addPasswordPreparer(ctx, applicationObjectID, displayName, endDateTime) if err != nil { err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", nil, "Failure preparing request") @@ -115,7 +115,7 @@ func (p *msGraphApplicationsClient) AddApplicationPassword(ctx context.Context, return } -func (p *msGraphApplicationsClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { +func (p *AppClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { req, err := p.removePasswordPreparer(ctx, applicationObjectID, keyID) if err != nil { err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", nil, "Failure preparing request") @@ -137,7 +137,7 @@ func (p *msGraphApplicationsClient) RemoveApplicationPassword(ctx context.Contex return } -func (client msGraphApplicationsClient) getApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { +func (client AppClient) getApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { pathParameters := map[string]interface{}{ "applicationObjectId": autorest.Encode("path", applicationObjectID), } @@ -151,12 +151,12 @@ func (client msGraphApplicationsClient) getApplicationPreparer(ctx context.Conte return preparer.Prepare((&http.Request{}).WithContext(ctx)) } -func (client msGraphApplicationsClient) getApplicationSender(req *http.Request) (*http.Response, error) { +func (client AppClient) getApplicationSender(req *http.Request) (*http.Response, error) { sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) return autorest.SendWithSender(client, req, sd...) } -func (client msGraphApplicationsClient) getApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { +func (client AppClient) getApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { err = autorest.Respond( resp, client.ByInspecting(), @@ -167,15 +167,15 @@ func (client msGraphApplicationsClient) getApplicationResponder(resp *http.Respo return } -func (client msGraphApplicationsClient) addPasswordPreparer(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (*http.Request, error) { +func (client AppClient) addPasswordPreparer(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (*http.Request, error) { pathParameters := map[string]interface{}{ "applicationObjectId": autorest.Encode("path", applicationObjectID), } parameters := struct { - PasswordCredential *passwordCredential `json:"passwordCredential"` + PasswordCredential *PasswordCredential `json:"passwordCredential"` }{ - PasswordCredential: &passwordCredential{ + PasswordCredential: &PasswordCredential{ DisplayName: to.StringPtr(displayName), EndDate: &endDateTime, }, @@ -191,12 +191,12 @@ func (client msGraphApplicationsClient) addPasswordPreparer(ctx context.Context, return preparer.Prepare((&http.Request{}).WithContext(ctx)) } -func (client msGraphApplicationsClient) addPasswordSender(req *http.Request) (*http.Response, error) { +func (client AppClient) addPasswordSender(req *http.Request) (*http.Response, error) { sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) return autorest.SendWithSender(client, req, sd...) } -func (client msGraphApplicationsClient) addPasswordResponder(resp *http.Response) (result PasswordCredentialResult, err error) { +func (client AppClient) addPasswordResponder(resp *http.Response) (result PasswordCredentialResult, err error) { err = autorest.Respond( resp, client.ByInspecting(), @@ -207,7 +207,7 @@ func (client msGraphApplicationsClient) addPasswordResponder(resp *http.Response return } -func (client msGraphApplicationsClient) removePasswordPreparer(ctx context.Context, applicationObjectID string, keyID string) (*http.Request, error) { +func (client AppClient) removePasswordPreparer(ctx context.Context, applicationObjectID string, keyID string) (*http.Request, error) { pathParameters := map[string]interface{}{ "applicationObjectId": autorest.Encode("path", applicationObjectID), } @@ -228,12 +228,12 @@ func (client msGraphApplicationsClient) removePasswordPreparer(ctx context.Conte return preparer.Prepare((&http.Request{}).WithContext(ctx)) } -func (client msGraphApplicationsClient) removePasswordSender(req *http.Request) (*http.Response, error) { +func (client AppClient) removePasswordSender(req *http.Request) (*http.Response, error) { sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) return autorest.SendWithSender(client, req, sd...) } -func (client msGraphApplicationsClient) removePasswordResponder(resp *http.Response) (result autorest.Response, err error) { +func (client AppClient) removePasswordResponder(resp *http.Response) (result autorest.Response, err error) { err = autorest.Respond( resp, client.ByInspecting(), @@ -244,7 +244,7 @@ func (client msGraphApplicationsClient) removePasswordResponder(resp *http.Respo return } -func (client msGraphApplicationsClient) createApplicationPreparer(ctx context.Context, displayName string) (*http.Request, error) { +func (client AppClient) createApplicationPreparer(ctx context.Context, displayName string) (*http.Request, error) { parameters := struct { DisplayName *string `json:"displayName"` }{ @@ -261,12 +261,12 @@ func (client msGraphApplicationsClient) createApplicationPreparer(ctx context.Co return preparer.Prepare((&http.Request{}).WithContext(ctx)) } -func (client msGraphApplicationsClient) createApplicationSender(req *http.Request) (*http.Response, error) { +func (client AppClient) createApplicationSender(req *http.Request) (*http.Response, error) { sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) return autorest.SendWithSender(client, req, sd...) } -func (client msGraphApplicationsClient) createApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { +func (client AppClient) createApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { err = autorest.Respond( resp, client.ByInspecting(), @@ -277,7 +277,7 @@ func (client msGraphApplicationsClient) createApplicationResponder(resp *http.Re return } -func (client msGraphApplicationsClient) deleteApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { +func (client AppClient) deleteApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { pathParameters := map[string]interface{}{ "applicationObjectId": autorest.Encode("path", applicationObjectID), } @@ -291,12 +291,12 @@ func (client msGraphApplicationsClient) deleteApplicationPreparer(ctx context.Co return preparer.Prepare((&http.Request{}).WithContext(ctx)) } -func (client msGraphApplicationsClient) deleteApplicationSender(req *http.Request) (*http.Response, error) { +func (client AppClient) deleteApplicationSender(req *http.Request) (*http.Response, error) { sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) return autorest.SendWithSender(client, req, sd...) } -func (client msGraphApplicationsClient) deleteApplicationResponder(resp *http.Response) (result autorest.Response, err error) { +func (client AppClient) deleteApplicationResponder(resp *http.Response) (result autorest.Response, err error) { err = autorest.Respond( resp, client.ByInspecting(), @@ -306,29 +306,3 @@ func (client msGraphApplicationsClient) deleteApplicationResponder(resp *http.Re result.Response = resp return } - -type passwordCredential struct { - DisplayName *string `json:"displayName"` - // StartDate - Start date. - StartDate *date.Time `json:"startDateTime,omitempty"` - // EndDate - End date. - EndDate *date.Time `json:"endDateTime,omitempty"` - // KeyID - Key ID. - KeyID *string `json:"keyId,omitempty"` - // Value - Key value. - SecretText *string `json:"secretText,omitempty"` -} - -type PasswordCredentialResult struct { - autorest.Response `json:"-"` - - passwordCredential -} - -type ApplicationResult struct { - autorest.Response `json:"-"` - - AppID *string `json:"appId,omitempty"` - ID *string `json:"id,omitempty"` - PasswordCredentials []*passwordCredential `json:"passwordCredentials,omitempty"` -} diff --git a/api/passwords.go b/api/passwords.go new file mode 100644 index 00000000..7ac8b927 --- /dev/null +++ b/api/passwords.go @@ -0,0 +1,31 @@ +package api + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-secure-stdlib/base62" +) + +const ( + PasswordLength = 36 +) + +type PasswordGenerator interface { + GeneratePasswordFromPolicy(ctx context.Context, policyName string) (password string, err error) +} + +type Passwords struct { + PolicyGenerator PasswordGenerator + PolicyName string +} + +func (p Passwords) Generate(ctx context.Context) (password string, err error) { + if p.PolicyName == "" { + return base62.Random(PasswordLength) + } + if p.PolicyGenerator == nil { + return "", fmt.Errorf("policy set, but no policy generator specified") + } + return p.PolicyGenerator.GeneratePasswordFromPolicy(ctx, p.PolicyName) +} diff --git a/backend.go b/backend.go index 07bac01e..034d5ce6 100644 --- a/backend.go +++ b/backend.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/locksutil" "github.com/hashicorp/vault/sdk/logical" @@ -14,7 +15,7 @@ import ( type azureSecretBackend struct { *framework.Backend - getProvider func(*clientSettings, bool, passwords) (AzureProvider, error) + getProvider func(*clientSettings, bool, api.Passwords) (api.AzureProvider, error) client *client settings *clientSettings lock sync.RWMutex @@ -121,9 +122,9 @@ func (b *azureSecretBackend) getClient(ctx context.Context, s logical.Storage) ( b.settings = settings } - passwords := passwords{ - policyGenerator: b.System(), - policyName: config.PasswordPolicy, + passwords := api.Passwords{ + PolicyGenerator: b.System(), + PolicyName: config.PasswordPolicy, } p, err := b.getProvider(b.settings, config.UseMsGraphAPI, passwords) diff --git a/backend_test.go b/backend_test.go index 6ecb36c1..1f964f9b 100644 --- a/backend_test.go +++ b/backend_test.go @@ -10,13 +10,14 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute" "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/to" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" ) @@ -46,7 +47,7 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical b.settings = new(clientSettings) mockProvider := newMockProvider() - b.getProvider = func(s *clientSettings, usMsGraphApi bool, p passwords) (AzureProvider, error) { + b.getProvider = func(s *clientSettings, usMsGraphApi bool, p api.Passwords) (api.AzureProvider, error) { return mockProvider, nil } @@ -71,7 +72,7 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical type mockProvider struct { subscriptionID string applications map[string]bool - passwords map[string]passwordCredential + passwords map[string]api.PasswordCredential failNextCreateApplication bool lock sync.Mutex } @@ -91,32 +92,32 @@ func (e *errMockProvider) CreateRoleAssignment(ctx context.Context, scope string // key is found, unlike mockProvider which returns the same application object // id each time. Existing tests depend on the mockProvider behavior, which is // why errMockProvider has it's own version. -func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (ApplicationResult, error) { +func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (api.ApplicationResult, error) { for s := range e.applications { if s == applicationObjectID { - return ApplicationResult{ + return api.ApplicationResult{ AppID: to.StringPtr(s), }, nil } } - return ApplicationResult{}, errors.New("not found") + return api.ApplicationResult{}, errors.New("not found") } -func newErrMockProvider() AzureProvider { +func newErrMockProvider() api.AzureProvider { return &errMockProvider{ mockProvider: &mockProvider{ subscriptionID: generateUUID(), applications: make(map[string]bool), - passwords: make(map[string]passwordCredential), + passwords: make(map[string]api.PasswordCredential), }, } } -func newMockProvider() AzureProvider { +func newMockProvider() api.AzureProvider { return &mockProvider{ subscriptionID: generateUUID(), applications: make(map[string]bool), - passwords: make(map[string]passwordCredential), + passwords: make(map[string]api.PasswordCredential), } } @@ -177,10 +178,10 @@ func (m *mockProvider) CreateServicePrincipal(ctx context.Context, parameters gr }, nil } -func (m *mockProvider) CreateApplication(ctx context.Context, displayName string) (ApplicationResult, error) { +func (m *mockProvider) CreateApplication(ctx context.Context, displayName string) (api.ApplicationResult, error) { if m.failNextCreateApplication { m.failNextCreateApplication = false - return ApplicationResult{}, errors.New("Mock: fail to create application") + return api.ApplicationResult{}, errors.New("Mock: fail to create application") } appObjID := generateUUID() @@ -189,14 +190,14 @@ func (m *mockProvider) CreateApplication(ctx context.Context, displayName string m.applications[appObjID] = true - return ApplicationResult{ + return api.ApplicationResult{ AppID: to.StringPtr(generateUUID()), ID: &appObjID, }, nil } -func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (ApplicationResult, error) { - return ApplicationResult{ +func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (api.ApplicationResult, error) { + return api.ApplicationResult{ AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"), }, nil } @@ -206,9 +207,9 @@ func (m *mockProvider) DeleteApplication(ctx context.Context, applicationObjectI return autorest.Response{}, nil } -func (m *mockProvider) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { +func (m *mockProvider) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result api.PasswordCredentialResult, err error) { keyID := generateUUID() - cred := passwordCredential{ + cred := api.PasswordCredential{ DisplayName: to.StringPtr(displayName), StartDate: &date.Time{Time: time.Now()}, EndDate: &endDateTime, @@ -220,8 +221,8 @@ func (m *mockProvider) AddApplicationPassword(ctx context.Context, applicationOb defer m.lock.Unlock() m.passwords[keyID] = cred - return PasswordCredentialResult{ - passwordCredential: cred, + return api.PasswordCredentialResult{ + PasswordCredential: cred, }, nil } diff --git a/client.go b/client.go index 56095a6b..f1047b09 100644 --- a/client.go +++ b/client.go @@ -9,14 +9,15 @@ import ( "strings" "time" + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/to" "github.com/hashicorp/errwrap" multierror "github.com/hashicorp/go-multierror" uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/logical" ) @@ -30,10 +31,10 @@ const ( // for handlers. It in turn relies on a Provider interface to access the lower level // Azure Client SDK methods. type client struct { - provider AzureProvider + provider api.AzureProvider settings *clientSettings expiration time.Time - passwords passwords + passwords api.Passwords } // Valid returns whether the client defined and not expired. @@ -44,7 +45,7 @@ func (c *client) Valid() bool { // createApp creates a new Azure application. // An Application is a needed to create service principals used by // the caller for authentication. -func (c *client) createApp(ctx context.Context) (app *ApplicationResult, err error) { +func (c *client) createApp(ctx context.Context) (app *api.ApplicationResult, err error) { name, err := uuid.GenerateUUID() if err != nil { return nil, err @@ -60,7 +61,7 @@ func (c *client) createApp(ctx context.Context) (app *ApplicationResult, err err // createSP creates a new service principal. func (c *client) createSP( ctx context.Context, - app *ApplicationResult, + app *api.ApplicationResult, duration time.Duration) (svcPrinc *graphrbac.ServicePrincipal, password string, err error) { // Generate a random key (which must be a UUID) and password @@ -69,7 +70,7 @@ func (c *client) createSP( return nil, "", err } - password, err = c.passwords.generate(ctx) + password, err = c.passwords.Generate(ctx) if err != nil { return nil, "", err } @@ -157,7 +158,7 @@ func (c *client) assignRoles(ctx context.Context, sp *graphrbac.ServicePrincipal resultRaw, err := retry(ctx, func() (interface{}, bool, error) { ra, err := c.provider.CreateRoleAssignment(ctx, role.Scope, assignmentID, authorization.RoleAssignmentCreateParameters{ - RoleAssignmentProperties: &authorization.RoleAssignmentProperties{ + Properties: &authorization.RoleAssignmentProperties{ RoleDefinitionID: to.StringPtr(role.RoleID), PrincipalID: sp.ObjectID, }, diff --git a/passwords.go b/passwords.go deleted file mode 100644 index 25b58afa..00000000 --- a/passwords.go +++ /dev/null @@ -1,31 +0,0 @@ -package azuresecrets - -import ( - "context" - "fmt" - - "github.com/hashicorp/go-secure-stdlib/base62" -) - -const ( - passwordLength = 36 -) - -type passwordGenerator interface { - GeneratePasswordFromPolicy(ctx context.Context, policyName string) (password string, err error) -} - -type passwords struct { - policyGenerator passwordGenerator - policyName string -} - -func (p passwords) generate(ctx context.Context) (password string, err error) { - if p.policyName == "" { - return base62.Random(passwordLength) - } - if p.policyGenerator == nil { - return "", fmt.Errorf("policy set, but no policy generator specified") - } - return p.policyGenerator.GeneratePasswordFromPolicy(ctx, p.policyName) -} diff --git a/path_roles.go b/path_roles.go index be292989..4b88b5ed 100644 --- a/path_roles.go +++ b/path_roles.go @@ -7,8 +7,8 @@ import ( "strings" "time" + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" "github.com/Azure/go-autorest/autorest/to" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/sdk/framework" diff --git a/path_service_principal_test.go b/path_service_principal_test.go index e380db8e..2a0de71a 100644 --- a/path_service_principal_test.go +++ b/path_service_principal_test.go @@ -11,6 +11,7 @@ import ( "github.com/Azure/go-autorest/autorest/to" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" @@ -59,7 +60,7 @@ func TestSP_WAL_Cleanup(t *testing.T) { // overwrite the normal test backend provider with the errMockProvider errMockProvider := newErrMockProvider() - b.getProvider = func(s *clientSettings, useMsGraphApi bool, p passwords) (AzureProvider, error) { + b.getProvider = func(s *clientSettings, useMsGraphApi bool, p api.Passwords) (api.AzureProvider, error) { return errMockProvider, nil } @@ -90,7 +91,7 @@ func TestSP_WAL_Cleanup(t *testing.T) { }) } -func assertEmptyWAL(t *testing.T, b *azureSecretBackend, emp AzureProvider, s logical.Storage) { +func assertEmptyWAL(t *testing.T, b *azureSecretBackend, emp api.AzureProvider, s logical.Storage) { t.Helper() wal, err := framework.ListWAL(context.Background(), s) @@ -592,7 +593,7 @@ func TestCredentialInteg(t *testing.T) { roleDefs, err := client.provider.ListRoles(context.Background(), fmt.Sprintf("subscriptions/%s", subscriptionID), "") assertErrorIsNil(t, err) - defID := *ra.RoleAssignmentPropertiesWithScope.RoleDefinitionID + defID := *ra.Properties.RoleDefinitionID found := false for _, def := range roleDefs { if *def.ID == defID && *def.RoleName == "Reader" { @@ -733,7 +734,7 @@ func TestCredentialInteg(t *testing.T) { for i := 0; i < 8; i++ { // New credentials are only tested during an actual operation, not provider creation. // This step should never fail. - p, err := newAzureProvider(settings, true, passwords{}) + p, err := newAzureProvider(settings, true, api.Passwords{}) if err != nil { t.Fatal(err) } @@ -772,7 +773,7 @@ func assertClientSecret(tb testing.TB, data map[string]interface{}) { if !ok { tb.Fatalf("client_secret is not a string") } - if len(actualPassword) != passwordLength { - tb.Fatalf("client_secret is not the correct length: expected %d but was %d", passwordLength, len(actualPassword)) + if len(actualPassword) != api.PasswordLength { + tb.Fatalf("client_secret is not the correct length: expected %d but was %d", api.PasswordLength, len(actualPassword)) } } diff --git a/provider.go b/provider.go index 7aaa2b5c..63d18fb5 100644 --- a/provider.go +++ b/provider.go @@ -2,74 +2,25 @@ package azuresecrets import ( "context" - "errors" - "fmt" "strings" - "time" + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure/auth" "github.com/Azure/go-autorest/autorest/date" - "github.com/Azure/go-autorest/autorest/to" - "github.com/hashicorp/errwrap" - "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/helper/useragent" "github.com/hashicorp/vault/sdk/version" ) -// AzureProvider is an interface to access underlying Azure client objects and supporting services. -// Where practical the original function signature is preserved. client provides higher -// level operations atop AzureProvider. -type AzureProvider interface { - ApplicationsClient - ServicePrincipalsClient - ADGroupsClient - RoleAssignmentsClient - RoleDefinitionsClient -} - -type ApplicationsClient interface { - GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) - CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) - DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) - AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) - RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) -} - -type ServicePrincipalsClient interface { - CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) -} - -type ADGroupsClient interface { - AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) - RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) - GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) - ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) -} - -type RoleAssignmentsClient interface { - CreateRoleAssignment( - ctx context.Context, - scope string, - roleAssignmentName string, - parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) - DeleteRoleAssignmentByID(ctx context.Context, roleID string) (authorization.RoleAssignment, error) -} - -type RoleDefinitionsClient interface { - ListRoles(ctx context.Context, scope string, filter string) ([]authorization.RoleDefinition, error) - GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) -} - // provider is a concrete implementation of AzureProvider. In most cases it is a simple passthrough // to the appropriate client object. But if the response requires processing that is more practical // at this layer, the response signature may different from the Azure signature. type provider struct { settings *clientSettings - appClient ApplicationsClient + appClient api.ApplicationsClient spClient *graphrbac.ServicePrincipalsClient groupsClient *graphrbac.GroupsClient raClient *authorization.RoleAssignmentsClient @@ -77,7 +28,7 @@ type provider struct { } // newAzureProvider creates an azureProvider, backed by Azure client objects for underlying services. -func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords passwords) (AzureProvider, error) { +func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords api.Passwords) (api.AzureProvider, error) { // build clients that use the GraphRBAC endpoint graphAuthorizer, err := getAuthorizer(settings, settings.Environment.GraphEndpoint) if err != nil { @@ -117,14 +68,14 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords pa groupsClient.Authorizer = graphAuthorizer groupsClient.AddToUserAgent(userAgent) - var appClient ApplicationsClient + var appClient api.ApplicationsClient if useMsGraphApi { - graphApiAuthorizer, err := getAuthorizer(settings, defaultGraphMicrosoftComURI) + graphApiAuthorizer, err := getAuthorizer(settings, api.DefaultGraphMicrosoftComURI) if err != nil { return nil, err } - msGraphAppClient := newMSGraphApplicationClient(settings.SubscriptionID) + msGraphAppClient := api.NewGraphApplicationClient(settings.SubscriptionID) msGraphAppClient.Authorizer = graphApiAuthorizer msGraphAppClient.AddToUserAgent(userAgent) @@ -134,7 +85,7 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords pa aadGraphClient.Authorizer = graphAuthorizer aadGraphClient.AddToUserAgent(userAgent) - appClient = &aadGraphApplicationsClient{appClient: &aadGraphClient, passwords: passwords} + appClient = &api.ActiveDirectoryApplicatinClient{Client: &aadGraphClient, Passwords: passwords} } // build clients that use the Resource Manager endpoint @@ -189,11 +140,11 @@ func getAuthorizer(settings *clientSettings, resource string) (authorizer autore } // CreateApplication create a new Azure application object. -func (p *provider) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { +func (p *provider) CreateApplication(ctx context.Context, displayName string) (result api.ApplicationResult, err error) { return p.appClient.CreateApplication(ctx, displayName) } -func (p *provider) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { +func (p *provider) GetApplication(ctx context.Context, applicationObjectID string) (result api.ApplicationResult, err error) { return p.appClient.GetApplication(ctx, applicationObjectID) } @@ -203,7 +154,7 @@ func (p *provider) DeleteApplication(ctx context.Context, applicationObjectID st return p.appClient.DeleteApplication(ctx, applicationObjectID) } -func (p *provider) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { +func (p *provider) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result api.PasswordCredentialResult, err error) { return p.appClient.AddApplicationPassword(ctx, applicationObjectID, displayName, endDateTime) } @@ -285,134 +236,3 @@ func (p *provider) ListGroups(ctx context.Context, filter string) (result []grap return page.Values(), nil } - -type aadGraphApplicationsClient struct { - appClient *graphrbac.ApplicationsClient - passwords passwords -} - -func (a *aadGraphApplicationsClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { - app, err := a.appClient.Get(ctx, applicationObjectID) - if err != nil { - return ApplicationResult{}, err - } - - return ApplicationResult{ - AppID: app.AppID, - ID: app.ObjectID, - }, nil -} - -func (a *aadGraphApplicationsClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { - appURL := fmt.Sprintf("https://%s", displayName) - - app, err := a.appClient.Create(ctx, graphrbac.ApplicationCreateParameters{ - AvailableToOtherTenants: to.BoolPtr(false), - DisplayName: to.StringPtr(displayName), - Homepage: to.StringPtr(appURL), - IdentifierUris: to.StringSlicePtr([]string{appURL}), - }) - if err != nil { - return ApplicationResult{}, err - } - - return ApplicationResult{ - AppID: app.AppID, - ID: app.ObjectID, - }, nil -} - -func (a *aadGraphApplicationsClient) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { - return a.appClient.Delete(ctx, applicationObjectID) -} - -func (a *aadGraphApplicationsClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { - keyID, err := uuid.GenerateUUID() - if err != nil { - return PasswordCredentialResult{}, err - } - - // Key IDs are not secret, and they're a convenient way for an operator to identify Vault-generated - // passwords. These must be UUIDs, so the three leading bytes will be used as an indicator. - keyID = "ffffff" + keyID[6:] - - password, err := a.passwords.generate(ctx) - if err != nil { - return PasswordCredentialResult{}, err - } - - now := date.Time{Time: time.Now().UTC()} - cred := graphrbac.PasswordCredential{ - StartDate: &now, - EndDate: &endDateTime, - KeyID: to.StringPtr(keyID), - Value: to.StringPtr(password), - } - - // Load current credentials - resp, err := a.appClient.ListPasswordCredentials(ctx, applicationObjectID) - if err != nil { - return PasswordCredentialResult{}, errwrap.Wrapf("error fetching credentials: {{err}}", err) - } - curCreds := *resp.Value - - // Add and save credentials - curCreds = append(curCreds, cred) - - if _, err := a.appClient.UpdatePasswordCredentials(ctx, applicationObjectID, - graphrbac.PasswordCredentialsUpdateParameters{ - Value: &curCreds, - }, - ); err != nil { - if strings.Contains(err.Error(), "size of the object has exceeded its limit") { - err = errors.New("maximum number of Application passwords reached") - } - return PasswordCredentialResult{}, errwrap.Wrapf("error updating credentials: {{err}}", err) - } - - return PasswordCredentialResult{ - passwordCredential: passwordCredential{ - DisplayName: to.StringPtr(displayName), - StartDate: &now, - EndDate: &endDateTime, - KeyID: to.StringPtr(keyID), - SecretText: to.StringPtr(password), - }, - }, nil -} - -func (a *aadGraphApplicationsClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { - // Load current credentials - resp, err := a.appClient.ListPasswordCredentials(ctx, applicationObjectID) - if err != nil { - return autorest.Response{}, errwrap.Wrapf("error fetching credentials: {{err}}", err) - } - curCreds := *resp.Value - - // Remove credential - found := false - for i := range curCreds { - if to.String(curCreds[i].KeyID) == keyID { - curCreds[i] = curCreds[len(curCreds)-1] - curCreds = curCreds[:len(curCreds)-1] - found = true - break - } - } - - // KeyID is not present, so nothing to do - if !found { - return autorest.Response{}, nil - } - - // Save new credentials list - if _, err := a.appClient.UpdatePasswordCredentials(ctx, applicationObjectID, - graphrbac.PasswordCredentialsUpdateParameters{ - Value: &curCreds, - }, - ); err != nil { - return autorest.Response{}, errwrap.Wrapf("error updating credentials: {{err}}", err) - } - - return autorest.Response{}, nil -} From 58d747bfae3fb3373b89547ee381ef2a0b8b3a4a Mon Sep 17 00:00:00 2001 From: Michael Golowka <72365+pcman312@users.noreply.github.com> Date: Tue, 14 Sep 2021 15:50:50 -0600 Subject: [PATCH 3/8] Add group API support for ms-graph --- api/api.go | 38 +- ...{aad_application.go => application_aad.go} | 12 +- api/application_msgraph.go | 524 ++++++++++++++++++ api/graph_application.go | 308 ---------- api/groups.go | 17 + api/groups_aad.go | 72 +++ backend.go | 5 +- backend_test.go | 259 --------- client.go | 26 +- client_test.go | 6 +- path_roles.go | 9 +- provider.go | 127 ++--- provider_mock_test.go | 267 +++++++++ 13 files changed, 983 insertions(+), 687 deletions(-) rename api/{aad_application.go => application_aad.go} (81%) create mode 100644 api/application_msgraph.go delete mode 100644 api/graph_application.go create mode 100644 api/groups.go create mode 100644 api/groups_aad.go create mode 100644 provider_mock_test.go diff --git a/api/api.go b/api/api.go index f4fd09a8..eb7777ea 100644 --- a/api/api.go +++ b/api/api.go @@ -9,50 +9,34 @@ import ( "github.com/Azure/go-autorest/autorest/date" ) -// AzureProvider is an interface to access underlying Azure client objects and supporting services. -// Where practical the original function signature is preserved. client provides higher +// AzureProvider is an interface to access underlying Azure Client objects and supporting services. +// Where practical the original function signature is preserved. Client provides higher // level operations atop AzureProvider. type AzureProvider interface { ApplicationsClient - ServicePrincipalsClient - ADGroupsClient - RoleAssignmentsClient - RoleDefinitionsClient -} + GroupsClient -type ApplicationsClient interface { - GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) - CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) - DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) - AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) - RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) -} - -type ServicePrincipalsClient interface { CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) -} -type ADGroupsClient interface { - AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) - RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) - GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) - ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) -} - -type RoleAssignmentsClient interface { CreateRoleAssignment( ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) DeleteRoleAssignmentByID(ctx context.Context, roleID string) (authorization.RoleAssignment, error) -} -type RoleDefinitionsClient interface { ListRoles(ctx context.Context, scope string, filter string) ([]authorization.RoleDefinition, error) GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) } +type ApplicationsClient interface { + GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) + CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) + DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) + AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) + RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) +} + type PasswordCredential struct { DisplayName *string `json:"displayName"` // StartDate - Start date. diff --git a/api/aad_application.go b/api/application_aad.go similarity index 81% rename from api/aad_application.go rename to api/application_aad.go index 6820c715..2039ceb9 100644 --- a/api/aad_application.go +++ b/api/application_aad.go @@ -15,12 +15,12 @@ import ( "github.com/hashicorp/go-uuid" ) -type ActiveDirectoryApplicatinClient struct { +type ActiveDirectoryApplicationClient struct { Client *graphrbac.ApplicationsClient Passwords Passwords } -func (a *ActiveDirectoryApplicatinClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { +func (a *ActiveDirectoryApplicationClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { app, err := a.Client.Get(ctx, applicationObjectID) if err != nil { return ApplicationResult{}, err @@ -32,7 +32,7 @@ func (a *ActiveDirectoryApplicatinClient) GetApplication(ctx context.Context, ap }, nil } -func (a *ActiveDirectoryApplicatinClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { +func (a *ActiveDirectoryApplicationClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { appURL := fmt.Sprintf("https://%s", displayName) app, err := a.Client.Create(ctx, graphrbac.ApplicationCreateParameters{ @@ -51,11 +51,11 @@ func (a *ActiveDirectoryApplicatinClient) CreateApplication(ctx context.Context, }, nil } -func (a *ActiveDirectoryApplicatinClient) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { +func (a *ActiveDirectoryApplicationClient) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { return a.Client.Delete(ctx, applicationObjectID) } -func (a *ActiveDirectoryApplicatinClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { +func (a *ActiveDirectoryApplicationClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { keyID, err := uuid.GenerateUUID() if err != nil { return PasswordCredentialResult{}, err @@ -110,7 +110,7 @@ func (a *ActiveDirectoryApplicatinClient) AddApplicationPassword(ctx context.Con }, nil } -func (a *ActiveDirectoryApplicatinClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { +func (a *ActiveDirectoryApplicationClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { // Load current credentials resp, err := a.Client.ListPasswordCredentials(ctx, applicationObjectID) if err != nil { diff --git a/api/application_msgraph.go b/api/application_msgraph.go new file mode 100644 index 00000000..7cfb5860 --- /dev/null +++ b/api/application_msgraph.go @@ -0,0 +1,524 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" +) + +const ( + // DefaultGraphMicrosoftComURI is the default URI used for the service MS Graph API + DefaultGraphMicrosoftComURI = "https://graph.microsoft.com" +) + +var _ ApplicationsClient = (*AppClient)(nil) +var _ GroupsClient = (*AppClient)(nil) + +type AppClient struct { + client authorization.BaseClient +} + +func NewMSGraphApplicationClient(subscriptionId string, userAgentExtension string, auth autorest.Authorizer) (*AppClient, error) { + client := authorization.NewWithBaseURI(DefaultGraphMicrosoftComURI, subscriptionId) + client.Authorizer = auth + + if userAgentExtension != "" { + err := client.AddToUserAgent(userAgentExtension) + if err != nil { + return nil, fmt.Errorf("failed to add extension to user agent") + } + } + + ac := &AppClient{ + client: client, + } + return ac, nil +} + +func (c *AppClient) AddToUserAgent(extension string) error { + return c.client.AddToUserAgent(extension) +} + +func (c *AppClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { + req, err := c.getApplicationPreparer(ctx, applicationObjectID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "GetApplication", nil, "Failure preparing request") + return + } + + resp, err := c.getApplicationSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "GetApplication", resp, "Failure sending request") + return + } + + result, err = c.getApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "GetApplication", resp, "Failure responding to request") + } + + return +} + +// CreateApplication create a new Azure application object. +func (c *AppClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { + req, err := c.createApplicationPreparer(ctx, displayName) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", nil, "Failure preparing request") + return + } + + resp, err := c.createApplicationSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", resp, "Failure sending request") + return + } + + result, err = c.createApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", resp, "Failure responding to request") + } + + return +} + +// DeleteApplication deletes an Azure application object. +// This will in turn remove the service principal (but not the role assignments). +func (c *AppClient) DeleteApplication(ctx context.Context, applicationObjectID string) (result autorest.Response, err error) { + req, err := c.deleteApplicationPreparer(ctx, applicationObjectID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", nil, "Failure preparing request") + return + } + + resp, err := c.deleteApplicationSender(req) + if err != nil { + result.Response = resp + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", resp, "Failure sending request") + return + } + + result, err = c.deleteApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", resp, "Failure responding to request") + } + + return +} + +func (c *AppClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { + req, err := c.addPasswordPreparer(ctx, applicationObjectID, displayName, endDateTime) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", nil, "Failure preparing request") + return + } + + resp, err := c.addPasswordSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", resp, "Failure sending request") + return + } + + result, err = c.addPasswordResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", resp, "Failure responding to request") + } + + return +} + +func (c *AppClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + req, err := c.removePasswordPreparer(ctx, applicationObjectID, keyID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", nil, "Failure preparing request") + return + } + + resp, err := c.removePasswordSender(req) + if err != nil { + result.Response = resp + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", resp, "Failure sending request") + return + } + + result, err = c.removePasswordResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", resp, "Failure responding to request") + } + + return +} + +func (c AppClient) getApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsGet(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}", pathParameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) getApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) getApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return result, err +} + +func (c AppClient) addPasswordPreparer(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + parameters := struct { + PasswordCredential *PasswordCredential `json:"passwordCredential"` + }{ + PasswordCredential: &PasswordCredential{ + DisplayName: to.StringPtr(displayName), + EndDate: &endDateTime, + }, + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}/addPassword", pathParameters), + autorest.WithJSON(parameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) addPasswordSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) addPasswordResponder(resp *http.Response) (result PasswordCredentialResult, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return +} + +func (c AppClient) removePasswordPreparer(ctx context.Context, applicationObjectID string, keyID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + parameters := struct { + KeyID string `json:"keyId"` + }{ + KeyID: keyID, + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}/removePassword", pathParameters), + autorest.WithJSON(parameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) removePasswordSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) removePasswordResponder(resp *http.Response) (result autorest.Response, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusNoContent), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = resp + return +} + +func (c AppClient) createApplicationPreparer(ctx context.Context, displayName string) (*http.Request, error) { + parameters := struct { + DisplayName *string `json:"displayName"` + }{ + DisplayName: to.StringPtr(displayName), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPath("/v1.0/applications"), + autorest.WithJSON(parameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) createApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) createApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusCreated), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return +} + +func (c AppClient) deleteApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsDelete(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}", pathParameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) deleteApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) deleteApplicationResponder(resp *http.Response) (result autorest.Response, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusNoContent), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = resp + return +} + +func (c AppClient) AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error { + if groupObjectID == "" { + return fmt.Errorf("missing groupObjectID") + } + pathParams := map[string]interface{}{ + "groupObjectID": groupObjectID, + } + body := map[string]interface{}{ + "@odata.id": fmt.Sprintf("%s/v1.0/directoryObjects/%s", DefaultGraphMicrosoftComURI, memberObjectID), + } + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/groups/{groupObjectID}/members/$ref", pathParams), + autorest.WithJSON(body), + c.client.WithAuthorization()) + req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) + if err != nil { + return err + } + + sender := autorest.GetSendDecorators(req.Context(), + autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), + ) + resp, err := autorest.SendWithSender(c.client, req, sender...) + if err != nil { + return err + } + + respBody := map[string]interface{}{} + + return autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), + autorest.ByUnmarshallingJSON(&respBody), + autorest.ByClosing(), + ) +} + +func (c AppClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error { + if groupObjectID == "" { + return fmt.Errorf("missing groupObjectID") + } + if memberObjectID == "" { + return fmt.Errorf("missing memberObjectID") + } + pathParams := map[string]interface{}{ + "groupObjectID": groupObjectID, + "memberObjectID": memberObjectID, + } + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsDelete(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/groups/{groupObjectID}/members/{memberObjectID}/$ref", pathParams), + c.client.WithAuthorization()) + req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) + if err != nil { + return err + } + + sender := autorest.GetSendDecorators(req.Context(), + autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), + ) + resp, err := autorest.SendWithSender(c.client, req, sender...) + if err != nil { + return err + } + + return autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), + autorest.ByClosing(), + ) +} + +// groupResponse is a struct representation of the data we care about coming back from +// the ms-graph API. This is not the same as ADGroup because this information is +// slightly different from the AAD implementation and there should be an abstraction +// between the ms-graph API itself and the API this package presents. +type groupResponse struct { + ID string `json:"id"` + DisplayName string `json:"displayName"` +} + +func (c AppClient) GetGroup(ctx context.Context, groupID string) (result ADGroup, err error) { + if groupID == "" { + return ADGroup{}, fmt.Errorf("missing groupID") + } + pathParams := map[string]interface{}{ + "groupID": groupID, + } + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsGet(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/groups/{groupID}", pathParams), + c.client.WithAuthorization()) + req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) + if err != nil { + return ADGroup{}, err + } + + sender := autorest.GetSendDecorators(req.Context(), + autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), + ) + resp, err := autorest.SendWithSender(c.client, req, sender...) + if err != nil { + return ADGroup{}, err + } + + groupResp := groupResponse{} + + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), + autorest.ByUnmarshallingJSON(&groupResp), + autorest.ByClosing(), + ) + if err != nil { + return ADGroup{}, err + } + + group := ADGroup{ + ID: groupResp.ID, + DisplayName: groupResp.DisplayName, + } + + return group, nil +} + +// listGroupsResponse is a struct representation of the data we care about +// coming back from the ms-graph API +type listGroupsResponse struct { + Groups []groupResponse `json:"value"` +} + +func (c AppClient) ListGroups(ctx context.Context, filter string) (result []ADGroup, err error) { + filterArgs := url.Values{} + if filter != "" { + filterArgs.Set("$filter", filter) + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsGet(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPath(fmt.Sprintf("/v1.0/groups?%s", filterArgs.Encode())), + c.client.WithAuthorization()) + req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) + if err != nil { + return nil, err + } + + sender := autorest.GetSendDecorators(req.Context(), + autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), + ) + resp, err := autorest.SendWithSender(c.client, req, sender...) + if err != nil { + return nil, err + } + + groupsResp := listGroupsResponse{} + + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), + autorest.ByUnmarshallingJSON(&groupsResp), + autorest.ByClosing(), + ) + if err != nil { + return nil, err + } + + groups := []ADGroup{} + for _, rawGroup := range groupsResp.Groups { + if rawGroup.ID == "" { + return nil, fmt.Errorf("missing group ID from response") + } + + group := ADGroup{ + ID: rawGroup.ID, + DisplayName: rawGroup.DisplayName, + } + groups = append(groups, group) + } + return groups, nil +} diff --git a/api/graph_application.go b/api/graph_application.go deleted file mode 100644 index 8765ab46..00000000 --- a/api/graph_application.go +++ /dev/null @@ -1,308 +0,0 @@ -package api - -import ( - "context" - "net/http" - - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/Azure/go-autorest/autorest/date" - "github.com/Azure/go-autorest/autorest/to" -) - -const ( - // defaultGraphMicrosoftComURI is the default URI used for the service MS Graph API - DefaultGraphMicrosoftComURI = "https://graph.microsoft.com" -) - -type AppClient struct { - authorization.BaseClient -} - -func NewGraphApplicationClient(subscriptionId string) AppClient { - return AppClient{authorization.NewWithBaseURI(DefaultGraphMicrosoftComURI, subscriptionId)} -} - -func (p *AppClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { - req, err := p.getApplicationPreparer(ctx, applicationObjectID) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "GetApplication", nil, "Failure preparing request") - return - } - - resp, err := p.getApplicationSender(req) - if err != nil { - result.Response = autorest.Response{Response: resp} - err = autorest.NewErrorWithError(err, "provider", "GetApplication", resp, "Failure sending request") - return - } - - result, err = p.getApplicationResponder(resp) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "GetApplication", resp, "Failure responding to request") - } - - return -} - -// CreateApplication create a new Azure application object. -func (p *AppClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { - req, err := p.createApplicationPreparer(ctx, displayName) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "CreateApplication", nil, "Failure preparing request") - return - } - - resp, err := p.createApplicationSender(req) - if err != nil { - result.Response = autorest.Response{Response: resp} - err = autorest.NewErrorWithError(err, "provider", "CreateApplication", resp, "Failure sending request") - return - } - - result, err = p.createApplicationResponder(resp) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "CreateApplication", resp, "Failure responding to request") - } - - return -} - -// DeleteApplication deletes an Azure application object. -// This will in turn remove the service principal (but not the role assignments). -func (p *AppClient) DeleteApplication(ctx context.Context, applicationObjectID string) (result autorest.Response, err error) { - req, err := p.deleteApplicationPreparer(ctx, applicationObjectID) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", nil, "Failure preparing request") - return - } - - resp, err := p.deleteApplicationSender(req) - if err != nil { - result.Response = resp - err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", resp, "Failure sending request") - return - } - - result, err = p.deleteApplicationResponder(resp) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", resp, "Failure responding to request") - } - - return -} - -func (p *AppClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { - req, err := p.addPasswordPreparer(ctx, applicationObjectID, displayName, endDateTime) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", nil, "Failure preparing request") - return - } - - resp, err := p.addPasswordSender(req) - if err != nil { - result.Response = autorest.Response{Response: resp} - err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", resp, "Failure sending request") - return - } - - result, err = p.addPasswordResponder(resp) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", resp, "Failure responding to request") - } - - return -} - -func (p *AppClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { - req, err := p.removePasswordPreparer(ctx, applicationObjectID, keyID) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", nil, "Failure preparing request") - return - } - - resp, err := p.removePasswordSender(req) - if err != nil { - result.Response = resp - err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", resp, "Failure sending request") - return - } - - result, err = p.removePasswordResponder(resp) - if err != nil { - err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", resp, "Failure responding to request") - } - - return -} - -func (client AppClient) getApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { - pathParameters := map[string]interface{}{ - "applicationObjectId": autorest.Encode("path", applicationObjectID), - } - - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), - autorest.AsGet(), - autorest.WithBaseURL(client.BaseURI), - autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}", pathParameters), - client.Authorizer.WithAuthorization()) - return preparer.Prepare((&http.Request{}).WithContext(ctx)) -} - -func (client AppClient) getApplicationSender(req *http.Request) (*http.Response, error) { - sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) - return autorest.SendWithSender(client, req, sd...) -} - -func (client AppClient) getApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { - err = autorest.Respond( - resp, - client.ByInspecting(), - azure.WithErrorUnlessStatusCode(http.StatusOK), - autorest.ByUnmarshallingJSON(&result), - autorest.ByClosing()) - result.Response = autorest.Response{Response: resp} - return -} - -func (client AppClient) addPasswordPreparer(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (*http.Request, error) { - pathParameters := map[string]interface{}{ - "applicationObjectId": autorest.Encode("path", applicationObjectID), - } - - parameters := struct { - PasswordCredential *PasswordCredential `json:"passwordCredential"` - }{ - PasswordCredential: &PasswordCredential{ - DisplayName: to.StringPtr(displayName), - EndDate: &endDateTime, - }, - } - - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), - autorest.AsPost(), - autorest.WithBaseURL(client.BaseURI), - autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}/addPassword", pathParameters), - autorest.WithJSON(parameters), - client.Authorizer.WithAuthorization()) - return preparer.Prepare((&http.Request{}).WithContext(ctx)) -} - -func (client AppClient) addPasswordSender(req *http.Request) (*http.Response, error) { - sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) - return autorest.SendWithSender(client, req, sd...) -} - -func (client AppClient) addPasswordResponder(resp *http.Response) (result PasswordCredentialResult, err error) { - err = autorest.Respond( - resp, - client.ByInspecting(), - azure.WithErrorUnlessStatusCode(http.StatusOK), - autorest.ByUnmarshallingJSON(&result), - autorest.ByClosing()) - result.Response = autorest.Response{Response: resp} - return -} - -func (client AppClient) removePasswordPreparer(ctx context.Context, applicationObjectID string, keyID string) (*http.Request, error) { - pathParameters := map[string]interface{}{ - "applicationObjectId": autorest.Encode("path", applicationObjectID), - } - - parameters := struct { - KeyID string `json:"keyId"` - }{ - KeyID: keyID, - } - - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), - autorest.AsPost(), - autorest.WithBaseURL(client.BaseURI), - autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}/removePassword", pathParameters), - autorest.WithJSON(parameters), - client.Authorizer.WithAuthorization()) - return preparer.Prepare((&http.Request{}).WithContext(ctx)) -} - -func (client AppClient) removePasswordSender(req *http.Request) (*http.Response, error) { - sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) - return autorest.SendWithSender(client, req, sd...) -} - -func (client AppClient) removePasswordResponder(resp *http.Response) (result autorest.Response, err error) { - err = autorest.Respond( - resp, - client.ByInspecting(), - azure.WithErrorUnlessStatusCode(http.StatusNoContent), - autorest.ByUnmarshallingJSON(&result), - autorest.ByClosing()) - result.Response = resp - return -} - -func (client AppClient) createApplicationPreparer(ctx context.Context, displayName string) (*http.Request, error) { - parameters := struct { - DisplayName *string `json:"displayName"` - }{ - DisplayName: to.StringPtr(displayName), - } - - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), - autorest.AsPost(), - autorest.WithBaseURL(client.BaseURI), - autorest.WithPath("/v1.0/applications"), - autorest.WithJSON(parameters), - client.Authorizer.WithAuthorization()) - return preparer.Prepare((&http.Request{}).WithContext(ctx)) -} - -func (client AppClient) createApplicationSender(req *http.Request) (*http.Response, error) { - sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) - return autorest.SendWithSender(client, req, sd...) -} - -func (client AppClient) createApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { - err = autorest.Respond( - resp, - client.ByInspecting(), - azure.WithErrorUnlessStatusCode(http.StatusCreated), - autorest.ByUnmarshallingJSON(&result), - autorest.ByClosing()) - result.Response = autorest.Response{Response: resp} - return -} - -func (client AppClient) deleteApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { - pathParameters := map[string]interface{}{ - "applicationObjectId": autorest.Encode("path", applicationObjectID), - } - - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), - autorest.AsDelete(), - autorest.WithBaseURL(client.BaseURI), - autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}", pathParameters), - client.Authorizer.WithAuthorization()) - return preparer.Prepare((&http.Request{}).WithContext(ctx)) -} - -func (client AppClient) deleteApplicationSender(req *http.Request) (*http.Response, error) { - sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(client.RetryAttempts, client.RetryDuration, autorest.StatusCodesForRetry...)) - return autorest.SendWithSender(client, req, sd...) -} - -func (client AppClient) deleteApplicationResponder(resp *http.Response) (result autorest.Response, err error) { - err = autorest.Respond( - resp, - client.ByInspecting(), - azure.WithErrorUnlessStatusCode(http.StatusNoContent), - autorest.ByUnmarshallingJSON(&result), - autorest.ByClosing()) - result.Response = resp - return -} diff --git a/api/groups.go b/api/groups.go new file mode 100644 index 00000000..b9b0ee10 --- /dev/null +++ b/api/groups.go @@ -0,0 +1,17 @@ +package api + +import ( + "context" +) + +type GroupsClient interface { + AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error + RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error + GetGroup(ctx context.Context, objectID string) (result ADGroup, err error) + ListGroups(ctx context.Context, filter string) (result []ADGroup, err error) +} + +type ADGroup struct { + ID string + DisplayName string +} diff --git a/api/groups_aad.go b/api/groups_aad.go new file mode 100644 index 00000000..f784fc10 --- /dev/null +++ b/api/groups_aad.go @@ -0,0 +1,72 @@ +package api + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/to" +) + +var _ GroupsClient = (*ActiveDirectoryApplicationGroupsClient)(nil) + +type aadGroupsClient interface { + AddMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) + RemoveMember(ctx context.Context, groupObjectID string, memberObjectID string) (result autorest.Response, err error) + Get(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) + List(ctx context.Context, filter string) (result graphrbac.GroupListResultPage, err error) +} + +type ActiveDirectoryApplicationGroupsClient struct { + BaseURI string + TenantID string + Client aadGroupsClient +} + +func (a ActiveDirectoryApplicationGroupsClient) AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error { + uri := fmt.Sprintf("%s/%s/directoryObjects/%s", a.BaseURI, a.TenantID, memberObjectID) + aadParams := graphrbac.GroupAddMemberParameters{ + URL: to.StringPtr(uri), + } + _, err := a.Client.AddMember(ctx, groupObjectID, aadParams) + return err +} + +func (a ActiveDirectoryApplicationGroupsClient) RemoveGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error { + _, err := a.Client.RemoveMember(ctx, groupObjectID, memberObjectID) + return err +} + +func (a ActiveDirectoryApplicationGroupsClient) GetGroup(ctx context.Context, objectID string) (result ADGroup, err error) { + resp, err := a.Client.Get(ctx, objectID) + if err != nil { + return ADGroup{}, err + } + + grp := getGroupFromRBAC(resp) + + return grp, nil +} + +func getGroupFromRBAC(resp graphrbac.ADGroup) ADGroup { + grp := ADGroup{ + ID: *resp.ObjectID, + DisplayName: *resp.DisplayName, + } + return grp +} + +func (a ActiveDirectoryApplicationGroupsClient) ListGroups(ctx context.Context, filter string) (result []ADGroup, err error) { + resp, err := a.Client.List(ctx, filter) + if err != nil { + return nil, err + } + + grps := []ADGroup{} + for _, aadGrp := range resp.Values() { + grp := getGroupFromRBAC(aadGrp) + grps = append(grps, grp) + } + return grps, nil +} diff --git a/backend.go b/backend.go index 034d5ce6..2815d039 100644 --- a/backend.go +++ b/backend.go @@ -90,16 +90,15 @@ func (b *azureSecretBackend) invalidate(ctx context.Context, key string) { func (b *azureSecretBackend) getClient(ctx context.Context, s logical.Storage) (*client, error) { b.lock.RLock() - unlockFunc := b.lock.RUnlock - defer func() { unlockFunc() }() if b.client.Valid() { + b.lock.RUnlock() return b.client, nil } b.lock.RUnlock() b.lock.Lock() - unlockFunc = b.lock.Unlock + defer b.lock.Unlock() if b.client.Valid() { return b.client, nil diff --git a/backend_test.go b/backend_test.go index 1f964f9b..c22cef3e 100644 --- a/backend_test.go +++ b/backend_test.go @@ -2,20 +2,9 @@ package azuresecrets import ( "context" - "errors" - "fmt" - "regexp" - "strings" - "sync" "testing" "time" - "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" - "github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/date" - "github.com/Azure/go-autorest/autorest/to" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/helper/logging" @@ -67,251 +56,3 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical return b, config.StorageView } - -// mockProvider is a Provider that provides stubs and simple, deterministic responses. -type mockProvider struct { - subscriptionID string - applications map[string]bool - passwords map[string]api.PasswordCredential - failNextCreateApplication bool - lock sync.Mutex -} - -// errMockProvider simulates a normal provider which fails to associate a role, -// returning an error -type errMockProvider struct { - *mockProvider -} - -// CreateRoleAssignment for the errMockProvider intentionally fails -func (e *errMockProvider) CreateRoleAssignment(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { - return authorization.RoleAssignment{}, errors.New("PrincipalNotFound") -} - -// GetApplication for the errMockProvider only returns an application if that -// key is found, unlike mockProvider which returns the same application object -// id each time. Existing tests depend on the mockProvider behavior, which is -// why errMockProvider has it's own version. -func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (api.ApplicationResult, error) { - for s := range e.applications { - if s == applicationObjectID { - return api.ApplicationResult{ - AppID: to.StringPtr(s), - }, nil - } - } - return api.ApplicationResult{}, errors.New("not found") -} - -func newErrMockProvider() api.AzureProvider { - return &errMockProvider{ - mockProvider: &mockProvider{ - subscriptionID: generateUUID(), - applications: make(map[string]bool), - passwords: make(map[string]api.PasswordCredential), - }, - } -} - -func newMockProvider() api.AzureProvider { - return &mockProvider{ - subscriptionID: generateUUID(), - applications: make(map[string]bool), - passwords: make(map[string]api.PasswordCredential), - } -} - -// ListRoles returns a single fake role based on the inbound filter -func (m *mockProvider) ListRoles(ctx context.Context, scope string, filter string) (result []authorization.RoleDefinition, err error) { - reRoleName := regexp.MustCompile("roleName eq '(.*)'") - - match := reRoleName.FindAllStringSubmatch(filter, -1) - if len(match) > 0 { - name := match[0][1] - if name == "multiple" { - return []authorization.RoleDefinition{ - { - ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s-1", name)), - RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ - RoleName: to.StringPtr(name), - }, - }, - { - ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s-2", name)), - RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ - RoleName: to.StringPtr(name), - }, - }, - }, nil - } - return []authorization.RoleDefinition{ - { - ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s", name)), - RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ - RoleName: to.StringPtr(name), - }, - }, - }, nil - } - - return []authorization.RoleDefinition{}, nil -} - -// GetRoleByID will returns a fake role definition from the povided ID -// Assumes an ID format of: .*FAKE_ROLE-{rolename} -func (m *mockProvider) GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) { - d := authorization.RoleDefinition{} - s := strings.Split(roleID, "FAKE_ROLE-") - if len(s) > 1 { - d.ID = to.StringPtr(roleID) - d.RoleDefinitionProperties = &authorization.RoleDefinitionProperties{ - RoleName: to.StringPtr(s[1]), - } - } - - return d, nil -} - -func (m *mockProvider) CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) { - return graphrbac.ServicePrincipal{ - ObjectID: to.StringPtr(generateUUID()), - }, nil -} - -func (m *mockProvider) CreateApplication(ctx context.Context, displayName string) (api.ApplicationResult, error) { - if m.failNextCreateApplication { - m.failNextCreateApplication = false - return api.ApplicationResult{}, errors.New("Mock: fail to create application") - } - appObjID := generateUUID() - - m.lock.Lock() - defer m.lock.Unlock() - - m.applications[appObjID] = true - - return api.ApplicationResult{ - AppID: to.StringPtr(generateUUID()), - ID: &appObjID, - }, nil -} - -func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (api.ApplicationResult, error) { - return api.ApplicationResult{ - AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"), - }, nil -} - -func (m *mockProvider) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { - delete(m.applications, applicationObjectID) - return autorest.Response{}, nil -} - -func (m *mockProvider) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result api.PasswordCredentialResult, err error) { - keyID := generateUUID() - cred := api.PasswordCredential{ - DisplayName: to.StringPtr(displayName), - StartDate: &date.Time{Time: time.Now()}, - EndDate: &endDateTime, - KeyID: to.StringPtr(keyID), - SecretText: to.StringPtr(generateUUID()), - } - - m.lock.Lock() - defer m.lock.Unlock() - m.passwords[keyID] = cred - - return api.PasswordCredentialResult{ - PasswordCredential: cred, - }, nil -} - -func (m *mockProvider) RemoveApplicationPassword(background context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { - m.lock.Lock() - defer m.lock.Unlock() - - delete(m.passwords, keyID) - - return autorest.Response{}, nil -} - -func (m *mockProvider) appExists(s string) bool { - return m.applications[s] -} - -func (m *mockProvider) passwordExists(s string) bool { - _, ok := m.passwords[s] - return ok -} - -func (m *mockProvider) VMGet(ctx context.Context, resourceGroupName string, VMName string, expand compute.InstanceViewTypes) (result compute.VirtualMachine, err error) { - return compute.VirtualMachine{}, nil -} - -func (m *mockProvider) VMUpdate(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachineUpdate) (result compute.VirtualMachinesUpdateFuture, err error) { - return compute.VirtualMachinesUpdateFuture{}, nil -} - -func (m *mockProvider) CreateRoleAssignment(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { - return authorization.RoleAssignment{ - ID: to.StringPtr(generateUUID()), - }, nil -} - -func (m *mockProvider) DeleteRoleAssignmentByID(ctx context.Context, roleID string) (result authorization.RoleAssignment, err error) { - return authorization.RoleAssignment{}, nil -} - -// AddGroupMember adds a member to a AAD Group. -func (m *mockProvider) AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) { - return autorest.Response{}, nil -} - -// RemoveGroupMember removes a member from a AAD Group. -func (m *mockProvider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) { - return autorest.Response{}, nil -} - -// GetGroup gets group information from the directory. -func (m *mockProvider) GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) { - g := graphrbac.ADGroup{ - ObjectID: to.StringPtr(objectID), - } - s := strings.Split(objectID, "FAKE_GROUP-") - if len(s) > 1 { - g.DisplayName = to.StringPtr(s[1]) - } - - return g, nil -} - -// ListGroups gets list of groups for the current tenant. -func (m *mockProvider) ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) { - reGroupName := regexp.MustCompile("displayName eq '(.*)'") - - match := reGroupName.FindAllStringSubmatch(filter, -1) - if len(match) > 0 { - name := match[0][1] - if name == "multiple" { - return []graphrbac.ADGroup{ - { - ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name)), - DisplayName: to.StringPtr(name), - }, - { - ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-2", name)), - DisplayName: to.StringPtr(name), - }, - }, nil - } - - return []graphrbac.ADGroup{ - { - ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name)), - DisplayName: to.StringPtr(name), - }, - }, nil - } - - return []graphrbac.ADGroup{}, nil -} diff --git a/client.go b/client.go index f1047b09..18ecb482 100644 --- a/client.go +++ b/client.go @@ -202,16 +202,7 @@ func (c *client) unassignRoles(ctx context.Context, roleIDs []string) error { func (c *client) addGroupMemberships(ctx context.Context, sp *graphrbac.ServicePrincipal, groups []*AzureGroup) error { for _, group := range groups { _, err := retry(ctx, func() (interface{}, bool, error) { - _, err := c.provider.AddGroupMember(ctx, group.ObjectID, - graphrbac.GroupAddMemberParameters{ - URL: to.StringPtr( - fmt.Sprintf("%s%s/directoryObjects/%s", - c.settings.Environment.GraphEndpoint, - c.settings.TenantID, - *sp.ObjectID, - ), - ), - }) + err := c.provider.AddGroupMember(ctx, group.ObjectID, *sp.ObjectID) // Propagation delays within Azure can cause this error occasionally, so don't quit on it. if err != nil && strings.Contains(err.Error(), "Request_ResourceNotFound") { @@ -237,7 +228,7 @@ func (c *client) removeGroupMemberships(ctx context.Context, servicePrincipalObj var merr *multierror.Error for _, id := range groupIDs { - if _, err := c.provider.RemoveGroupMember(ctx, servicePrincipalObjectID, id); err != nil { + if err := c.provider.RemoveGroupMember(ctx, servicePrincipalObjectID, id); err != nil { merr = multierror.Append(merr, errwrap.Wrapf("error removing group membership: {{err}}", err)) } } @@ -263,7 +254,7 @@ func (c *client) findRoles(ctx context.Context, roleName string) ([]authorizatio // findGroups is used to find a group by name. It returns all groups matching // the passsed name. -func (c *client) findGroups(ctx context.Context, groupName string) ([]graphrbac.ADGroup, error) { +func (c *client) findGroups(ctx context.Context, groupName string) ([]api.ADGroup, error) { return c.provider.ListGroups(ctx, fmt.Sprintf("displayName eq '%s'", groupName)) } @@ -337,10 +328,13 @@ func retry(ctx context.Context, f func() (interface{}, bool, error)) (interface{ } rng := rand.New(rand.NewSource(time.Now().UnixNano())) + var lastErr error for { - if result, done, err := f(); done { + result, done, err := f() + if done { return result, err } + lastErr = err delay := time.Duration(2000+rng.Intn(6000)) * time.Millisecond delayTimer.Reset(delay) @@ -349,7 +343,11 @@ func retry(ctx context.Context, f func() (interface{}, bool, error)) (interface{ case <-delayTimer.C: // Retry loop case <-ctx.Done(): - return nil, fmt.Errorf("retry failed: %w", ctx.Err()) + err := lastErr + if err == nil { + err = ctx.Err() + } + return nil, fmt.Errorf("retry failed: %w", err) } } } diff --git a/client_test.go b/client_test.go index 0382fc09..4c6c2887 100644 --- a/client_test.go +++ b/client_test.go @@ -83,7 +83,7 @@ func TestRetry(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { - time.Sleep(7 * time.Second) + time.Sleep(1 * time.Second) cancel() }() @@ -92,8 +92,8 @@ func TestRetry(t *testing.T) { return nil, false, nil }) elapsed := time.Now().Sub(start).Seconds() - if elapsed < 6 || elapsed > 8 { - t.Fatalf("expected time of ~7 seconds, got: %f", elapsed) + if elapsed < 0 || elapsed > 2 { + t.Fatalf("expected time of ~1 second, got: %f", elapsed) } if err == nil { diff --git a/path_roles.go b/path_roles.go index 4b88b5ed..9160749b 100644 --- a/path_roles.go +++ b/path_roles.go @@ -8,9 +8,9 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" "github.com/Azure/go-autorest/autorest/to" "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" @@ -238,7 +238,7 @@ func (b *azureSecretBackend) pathRoleUpdate(ctx context.Context, req *logical.Re // update and verify Azure groups, including looking up each group by ID or name. groupSet := make(map[string]bool) for _, r := range role.AzureGroups { - var groupDef graphrbac.ADGroup + var groupDef api.ADGroup if r.ObjectID != "" { groupDef, err = client.provider.GetGroup(ctx, r.ObjectID) if err != nil { @@ -260,9 +260,8 @@ func (b *azureSecretBackend) pathRoleUpdate(ctx context.Context, req *logical.Re groupDef = defs[0] } - groupDefID := to.String(groupDef.ObjectID) - groupDefName := to.String(groupDef.DisplayName) - r.GroupName, r.ObjectID = groupDefName, groupDefID + r.ObjectID = groupDef.ID + r.GroupName = groupDef.DisplayName if groupSet[r.ObjectID] { return logical.ErrorResponse("duplicate object_id '%s'", r.ObjectID), nil diff --git a/provider.go b/provider.go index 63d18fb5..9f071c2e 100644 --- a/provider.go +++ b/provider.go @@ -14,6 +14,8 @@ import ( "github.com/hashicorp/vault/sdk/version" ) +var _ api.AzureProvider = (*provider)(nil) + // provider is a concrete implementation of AzureProvider. In most cases it is a simple passthrough // to the appropriate client object. But if the response requires processing that is more practical // at this layer, the response signature may different from the Azure signature. @@ -22,7 +24,7 @@ type provider struct { appClient api.ApplicationsClient spClient *graphrbac.ServicePrincipalsClient - groupsClient *graphrbac.GroupsClient + groupsClient api.GroupsClient raClient *authorization.RoleAssignmentsClient rdClient *authorization.RoleDefinitionsClient } @@ -35,57 +37,43 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords ap return nil, err } - var userAgent string - if settings.PluginEnv != nil { - userAgent = useragent.PluginString(settings.PluginEnv, "azure-secrets") - } else { - userAgent = useragent.String() - } - - // Sets a unique ID in the user-agent - // Normal user-agent looks like this: - // - // Vault/1.6.0 (+https://www.vaultproject.io/; azure-secrets; go1.15.7) - // - // Here we append a unique code if it's an enterprise version, where - // VersionMetadata will contain a non-empty string like "ent" or "prem". - // Otherwise use the default identifier for OSS Vault. The end result looks - // like so: - // - // Vault/1.6.0 (+https://www.vaultproject.io/; azure-secrets; go1.15.7; b2c13ec1-60e8-4733-9a76-88dbb2ce2471) - vaultIDString := "; 15cd22ce-24af-43a4-aa83-4c1a36a4b177)" - ver := version.GetVersion() - if ver.VersionMetadata != "" { - vaultIDString = "; b2c13ec1-60e8-4733-9a76-88dbb2ce2471)" - } - userAgent = strings.Replace(userAgent, ")", vaultIDString, 1) + userAgent := getUserAgent(settings) spClient := graphrbac.NewServicePrincipalsClient(settings.TenantID) spClient.Authorizer = graphAuthorizer spClient.AddToUserAgent(userAgent) - groupsClient := graphrbac.NewGroupsClient(settings.TenantID) - groupsClient.Authorizer = graphAuthorizer - groupsClient.AddToUserAgent(userAgent) - var appClient api.ApplicationsClient + var groupsClient api.GroupsClient if useMsGraphApi { graphApiAuthorizer, err := getAuthorizer(settings, api.DefaultGraphMicrosoftComURI) if err != nil { return nil, err } - msGraphAppClient := api.NewGraphApplicationClient(settings.SubscriptionID) - msGraphAppClient.Authorizer = graphApiAuthorizer - msGraphAppClient.AddToUserAgent(userAgent) + msGraphAppClient, err := api.NewMSGraphApplicationClient(settings.SubscriptionID, userAgent, graphApiAuthorizer) + if err != nil { + return nil, err + } - appClient = &msGraphAppClient + appClient = msGraphAppClient + groupsClient = msGraphAppClient } else { aadGraphClient := graphrbac.NewApplicationsClient(settings.TenantID) aadGraphClient.Authorizer = graphAuthorizer aadGraphClient.AddToUserAgent(userAgent) - appClient = &api.ActiveDirectoryApplicatinClient{Client: &aadGraphClient, Passwords: passwords} + appClient = &api.ActiveDirectoryApplicationClient{Client: &aadGraphClient, Passwords: passwords} + + aadGroupsClient := graphrbac.NewGroupsClient(settings.TenantID) + aadGroupsClient.Authorizer = graphAuthorizer + aadGroupsClient.AddToUserAgent(userAgent) + + groupsClient = api.ActiveDirectoryApplicationGroupsClient{ + BaseURI: aadGroupsClient.BaseURI, + TenantID: aadGroupsClient.TenantID, + Client: aadGroupsClient, + } } // build clients that use the Resource Manager endpoint @@ -107,7 +95,7 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords ap appClient: appClient, spClient: &spClient, - groupsClient: &groupsClient, + groupsClient: groupsClient, raClient: &raClient, rdClient: &rdClient, } @@ -115,28 +103,48 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords ap return p, nil } +func getUserAgent(settings *clientSettings) string { + var userAgent string + if settings.PluginEnv != nil { + userAgent = useragent.PluginString(settings.PluginEnv, "azure-secrets") + } else { + userAgent = useragent.String() + } + + // Sets a unique ID in the user-agent + // Normal user-agent looks like this: + // + // Vault/1.6.0 (+https://www.vaultproject.io/; azure-secrets; go1.15.7) + // + // Here we append a unique code if it's an enterprise version, where + // VersionMetadata will contain a non-empty string like "ent" or "prem". + // Otherwise use the default identifier for OSS Vault. The end result looks + // like so: + // + // Vault/1.6.0 (+https://www.vaultproject.io/; azure-secrets; go1.15.7; b2c13ec1-60e8-4733-9a76-88dbb2ce2471) + vaultIDString := "; 15cd22ce-24af-43a4-aa83-4c1a36a4b177)" + ver := version.GetVersion() + if ver.VersionMetadata != "" { + vaultIDString = "; b2c13ec1-60e8-4733-9a76-88dbb2ce2471)" + } + userAgent = strings.Replace(userAgent, ")", vaultIDString, 1) + + return userAgent +} + // getAuthorizer attempts to create an authorizer, preferring ClientID/Secret if present, // and falling back to MSI if not. -func getAuthorizer(settings *clientSettings, resource string) (authorizer autorest.Authorizer, err error) { - +func getAuthorizer(settings *clientSettings, resource string) (autorest.Authorizer, error) { if settings.ClientID != "" && settings.ClientSecret != "" && settings.TenantID != "" { config := auth.NewClientCredentialsConfig(settings.ClientID, settings.ClientSecret, settings.TenantID) config.AADEndpoint = settings.Environment.ActiveDirectoryEndpoint config.Resource = resource - authorizer, err = config.Authorizer() - if err != nil { - return nil, err - } - } else { - config := auth.NewMSIConfig() - config.Resource = resource - authorizer, err = config.Authorizer() - if err != nil { - return nil, err - } + return config.Authorizer() } - return authorizer, nil + config := auth.NewMSIConfig() + config.Resource = resource + return config.Authorizer() } // CreateApplication create a new Azure application object. @@ -213,26 +221,21 @@ func (p *provider) ListRoleAssignments(ctx context.Context, filter string) ([]au } // AddGroupMember adds a member to a AAD Group. -func (p *provider) AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) { - return p.groupsClient.AddMember(ctx, groupObjectID, parameters) +func (p *provider) AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) (err error) { + return p.groupsClient.AddGroupMember(ctx, groupObjectID, memberObjectID) } // RemoveGroupMember removes a member from a AAD Group. -func (p *provider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) { - return p.groupsClient.RemoveMember(ctx, groupObjectID, memberObjectID) +func (p *provider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (err error) { + return p.groupsClient.RemoveGroupMember(ctx, groupObjectID, memberObjectID) } // GetGroup gets group information from the directory. -func (p *provider) GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) { - return p.groupsClient.Get(ctx, objectID) +func (p *provider) GetGroup(ctx context.Context, objectID string) (result api.ADGroup, err error) { + return p.groupsClient.GetGroup(ctx, objectID) } // ListGroups gets list of groups for the current tenant. -func (p *provider) ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) { - page, err := p.groupsClient.List(ctx, filter) - if err != nil { - return nil, err - } - - return page.Values(), nil +func (p *provider) ListGroups(ctx context.Context, filter string) (result []api.ADGroup, err error) { + return p.groupsClient.ListGroups(ctx, filter) } diff --git a/provider_mock_test.go b/provider_mock_test.go new file mode 100644 index 00000000..dbfd9b0b --- /dev/null +++ b/provider_mock_test.go @@ -0,0 +1,267 @@ +package azuresecrets + +import ( + "context" + "errors" + "fmt" + "regexp" + "strings" + "sync" + "time" + + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" + "github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute" + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/vault-plugin-secrets-azure/api" +) + +// mockProvider is a Provider that provides stubs and simple, deterministic responses. +type mockProvider struct { + subscriptionID string + applications map[string]bool + passwords map[string]api.PasswordCredential + failNextCreateApplication bool + lock sync.Mutex +} + +func newMockProvider() api.AzureProvider { + return &mockProvider{ + subscriptionID: generateUUID(), + applications: make(map[string]bool), + passwords: make(map[string]api.PasswordCredential), + } +} + +// ListRoles returns a single fake role based on the inbound filter +func (m *mockProvider) ListRoles(_ context.Context, _ string, filter string) (result []authorization.RoleDefinition, err error) { + reRoleName := regexp.MustCompile("roleName eq '(.*)'") + + match := reRoleName.FindAllStringSubmatch(filter, -1) + if len(match) > 0 { + name := match[0][1] + if name == "multiple" { + return []authorization.RoleDefinition{ + { + ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s-1", name)), + RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ + RoleName: to.StringPtr(name), + }, + }, + { + ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s-2", name)), + RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ + RoleName: to.StringPtr(name), + }, + }, + }, nil + } + return []authorization.RoleDefinition{ + { + ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s", name)), + RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ + RoleName: to.StringPtr(name), + }, + }, + }, nil + } + + return []authorization.RoleDefinition{}, nil +} + +// GetRoleByID will returns a fake role definition from the povided ID +// Assumes an ID format of: .*FAKE_ROLE-{rolename} +func (m *mockProvider) GetRoleByID(_ context.Context, roleID string) (result authorization.RoleDefinition, err error) { + d := authorization.RoleDefinition{} + s := strings.Split(roleID, "FAKE_ROLE-") + if len(s) > 1 { + d.ID = to.StringPtr(roleID) + d.RoleDefinitionProperties = &authorization.RoleDefinitionProperties{ + RoleName: to.StringPtr(s[1]), + } + } + + return d, nil +} + +func (m *mockProvider) CreateServicePrincipal(_ context.Context, _ graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) { + return graphrbac.ServicePrincipal{ + ObjectID: to.StringPtr(generateUUID()), + }, nil +} + +func (m *mockProvider) CreateApplication(_ context.Context, _ string) (api.ApplicationResult, error) { + if m.failNextCreateApplication { + m.failNextCreateApplication = false + return api.ApplicationResult{}, errors.New("Mock: fail to create application") + } + appObjID := generateUUID() + + m.lock.Lock() + defer m.lock.Unlock() + + m.applications[appObjID] = true + + return api.ApplicationResult{ + AppID: to.StringPtr(generateUUID()), + ID: &appObjID, + }, nil +} + +func (m *mockProvider) GetApplication(_ context.Context, _ string) (api.ApplicationResult, error) { + return api.ApplicationResult{ + AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"), + }, nil +} + +func (m *mockProvider) DeleteApplication(_ context.Context, applicationObjectID string) (autorest.Response, error) { + delete(m.applications, applicationObjectID) + return autorest.Response{}, nil +} + +func (m *mockProvider) AddApplicationPassword(_ context.Context, _ string, displayName string, endDateTime date.Time) (result api.PasswordCredentialResult, err error) { + keyID := generateUUID() + cred := api.PasswordCredential{ + DisplayName: to.StringPtr(displayName), + StartDate: &date.Time{Time: time.Now()}, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + SecretText: to.StringPtr(generateUUID()), + } + + m.lock.Lock() + defer m.lock.Unlock() + m.passwords[keyID] = cred + + return api.PasswordCredentialResult{ + PasswordCredential: cred, + }, nil +} + +func (m *mockProvider) RemoveApplicationPassword(_ context.Context, _ string, keyID string) (result autorest.Response, err error) { + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.passwords, keyID) + + return autorest.Response{}, nil +} + +func (m *mockProvider) appExists(s string) bool { + return m.applications[s] +} + +func (m *mockProvider) passwordExists(s string) bool { + _, ok := m.passwords[s] + return ok +} + +func (m *mockProvider) VMGet(_ context.Context, _ string, _ string, _ compute.InstanceViewTypes) (result compute.VirtualMachine, err error) { + return compute.VirtualMachine{}, nil +} + +func (m *mockProvider) VMUpdate(_ context.Context, _ string, _ string, _ compute.VirtualMachineUpdate) (result compute.VirtualMachinesUpdateFuture, err error) { + return compute.VirtualMachinesUpdateFuture{}, nil +} + +func (m *mockProvider) CreateRoleAssignment(_ context.Context, _ string, _ string, _ authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { + return authorization.RoleAssignment{ + ID: to.StringPtr(generateUUID()), + }, nil +} + +func (m *mockProvider) DeleteRoleAssignmentByID(_ context.Context, _ string) (result authorization.RoleAssignment, err error) { + return authorization.RoleAssignment{}, nil +} + +// AddGroupMember adds a member to a AAD Group. +func (m *mockProvider) AddGroupMember(_ context.Context, _ string, _ string) (err error) { + return nil +} + +// RemoveGroupMember removes a member from a AAD Group. +func (m *mockProvider) RemoveGroupMember(_ context.Context, _ string, _ string) (err error) { + return nil +} + +// GetGroup gets group information from the directory. +func (m *mockProvider) GetGroup(_ context.Context, objectID string) (api.ADGroup, error) { + g := api.ADGroup{ + ID: objectID, + } + s := strings.Split(objectID, "FAKE_GROUP-") + if len(s) > 1 { + g.DisplayName = s[1] + } + + return g, nil +} + +// ListGroups gets list of groups for the current tenant. +func (m *mockProvider) ListGroups(_ context.Context, filter string) (result []api.ADGroup, err error) { + reGroupName := regexp.MustCompile("displayName eq '(.*)'") + + match := reGroupName.FindAllStringSubmatch(filter, -1) + if len(match) > 0 { + name := match[0][1] + if name == "multiple" { + return []api.ADGroup{ + { + ID: fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name), + DisplayName: name, + }, + { + ID: fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-2", name), + DisplayName: name, + }, + }, nil + } + + return []api.ADGroup{ + { + ID: fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name), + DisplayName: name, + }, + }, nil + } + + return []api.ADGroup{}, nil +} + +// errMockProvider simulates a normal provider which fails to associate a role, +// returning an error +type errMockProvider struct { + *mockProvider +} + +func newErrMockProvider() api.AzureProvider { + return &errMockProvider{ + mockProvider: &mockProvider{ + subscriptionID: generateUUID(), + applications: make(map[string]bool), + passwords: make(map[string]api.PasswordCredential), + }, + } +} + +// CreateRoleAssignment for the errMockProvider intentionally fails +func (e *errMockProvider) CreateRoleAssignment(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { + return authorization.RoleAssignment{}, errors.New("PrincipalNotFound") +} + +// GetApplication for the errMockProvider only returns an application if that +// key is found, unlike mockProvider which returns the same application object +// id each time. Existing tests depend on the mockProvider behavior, which is +// why errMockProvider has it's own version. +func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (api.ApplicationResult, error) { + for s := range e.applications { + if s == applicationObjectID { + return api.ApplicationResult{ + AppID: to.StringPtr(s), + }, nil + } + } + return api.ApplicationResult{}, errors.New("not found") +} From fec9a43b9388300699658d918281e33bc6054aee Mon Sep 17 00:00:00 2001 From: Michael Golowka <72365+pcman312@users.noreply.github.com> Date: Wed, 15 Sep 2021 15:26:45 -0600 Subject: [PATCH 4/8] Updated elapsed time assertions --- client_test.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/client_test.go b/client_test.go index 4c6c2887..d28d9c40 100644 --- a/client_test.go +++ b/client_test.go @@ -35,10 +35,8 @@ func TestRetry(t *testing.T) { equal(t, count, 3) // each sleep can last from 2 to 8 seconds - elapsed := time.Now().Sub(start).Seconds() - if elapsed < 4 || elapsed > 16 { - t.Fatalf("expected time of 4-16 seconds, got: %f", elapsed) - } + elapsed := time.Now().Sub(start) + assertDuration(t, elapsed, 5*time.Second, 3*time.Second) assertErrorIsNil(t, err) }) @@ -75,7 +73,7 @@ func TestRetry(t *testing.T) { if called == 0 { t.Fatalf("retryable function was never called") } - assertDuration(t, elapsed, timeout, 100*time.Millisecond) + assertDuration(t, elapsed, timeout, 250*time.Millisecond) }) t.Run("Cancellation", func(t *testing.T) { @@ -91,10 +89,8 @@ func TestRetry(t *testing.T) { _, err := retry(ctx, func() (interface{}, bool, error) { return nil, false, nil }) - elapsed := time.Now().Sub(start).Seconds() - if elapsed < 0 || elapsed > 2 { - t.Fatalf("expected time of ~1 second, got: %f", elapsed) - } + elapsed := time.Now().Sub(start) + assertDuration(t, elapsed, 1*time.Second, 250*time.Millisecond) if err == nil { t.Fatalf("expected err: got nil") From 0c52a466bfc75f35624a9bc97879511940a99a9e Mon Sep 17 00:00:00 2001 From: Michael Golowka <72365+pcman312@users.noreply.github.com> Date: Wed, 15 Sep 2021 15:48:43 -0600 Subject: [PATCH 5/8] Adjust tests for retry logic --- client.go | 18 ++++++++---------- client_test.go | 4 ---- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/client.go b/client.go index 18ecb482..7d3ee577 100644 --- a/client.go +++ b/client.go @@ -330,18 +330,16 @@ func retry(ctx context.Context, f func() (interface{}, bool, error)) (interface{ rng := rand.New(rand.NewSource(time.Now().UnixNano())) var lastErr error for { - result, done, err := f() - if done { - return result, err - } - lastErr = err - - delay := time.Duration(2000+rng.Intn(6000)) * time.Millisecond - delayTimer.Reset(delay) - select { case <-delayTimer.C: - // Retry loop + result, done, err := f() + if done { + return result, err + } + lastErr = err + + delay := time.Duration(2+rng.Intn(6)) * time.Second + delayTimer.Reset(delay) case <-ctx.Done(): err := lastErr if err == nil { diff --git a/client_test.go b/client_test.go index d28d9c40..19d5b293 100644 --- a/client_test.go +++ b/client_test.go @@ -22,7 +22,6 @@ func TestRetry(t *testing.T) { t.Run("Three retries", func(t *testing.T) { t.Parallel() - start := time.Now() count := 0 _, err := retry(context.Background(), func() (interface{}, bool, error) { @@ -34,9 +33,6 @@ func TestRetry(t *testing.T) { }) equal(t, count, 3) - // each sleep can last from 2 to 8 seconds - elapsed := time.Now().Sub(start) - assertDuration(t, elapsed, 5*time.Second, 3*time.Second) assertErrorIsNil(t, err) }) From b2a290a36ef51b9397d3ca0877d0f74b622fb498 Mon Sep 17 00:00:00 2001 From: Michael Golowka <72365+pcman312@users.noreply.github.com> Date: Mon, 20 Sep 2021 10:56:21 -0600 Subject: [PATCH 6/8] ADGroup -> Group --- api/application_msgraph.go | 20 ++++++++++---------- api/groups.go | 6 +++--- api/groups_aad.go | 12 ++++++------ client.go | 2 +- path_roles.go | 2 +- provider.go | 4 ++-- provider_mock_test.go | 12 ++++++------ 7 files changed, 29 insertions(+), 29 deletions(-) diff --git a/api/application_msgraph.go b/api/application_msgraph.go index 7cfb5860..18f3e849 100644 --- a/api/application_msgraph.go +++ b/api/application_msgraph.go @@ -409,7 +409,7 @@ func (c AppClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberO } // groupResponse is a struct representation of the data we care about coming back from -// the ms-graph API. This is not the same as ADGroup because this information is +// the ms-graph API. This is not the same as `Group` because this information is // slightly different from the AAD implementation and there should be an abstraction // between the ms-graph API itself and the API this package presents. type groupResponse struct { @@ -417,9 +417,9 @@ type groupResponse struct { DisplayName string `json:"displayName"` } -func (c AppClient) GetGroup(ctx context.Context, groupID string) (result ADGroup, err error) { +func (c AppClient) GetGroup(ctx context.Context, groupID string) (result Group, err error) { if groupID == "" { - return ADGroup{}, fmt.Errorf("missing groupID") + return Group{}, fmt.Errorf("missing groupID") } pathParams := map[string]interface{}{ "groupID": groupID, @@ -432,7 +432,7 @@ func (c AppClient) GetGroup(ctx context.Context, groupID string) (result ADGroup c.client.WithAuthorization()) req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) if err != nil { - return ADGroup{}, err + return Group{}, err } sender := autorest.GetSendDecorators(req.Context(), @@ -440,7 +440,7 @@ func (c AppClient) GetGroup(ctx context.Context, groupID string) (result ADGroup ) resp, err := autorest.SendWithSender(c.client, req, sender...) if err != nil { - return ADGroup{}, err + return Group{}, err } groupResp := groupResponse{} @@ -453,10 +453,10 @@ func (c AppClient) GetGroup(ctx context.Context, groupID string) (result ADGroup autorest.ByClosing(), ) if err != nil { - return ADGroup{}, err + return Group{}, err } - group := ADGroup{ + group := Group{ ID: groupResp.ID, DisplayName: groupResp.DisplayName, } @@ -470,7 +470,7 @@ type listGroupsResponse struct { Groups []groupResponse `json:"value"` } -func (c AppClient) ListGroups(ctx context.Context, filter string) (result []ADGroup, err error) { +func (c AppClient) ListGroups(ctx context.Context, filter string) (result []Group, err error) { filterArgs := url.Values{} if filter != "" { filterArgs.Set("$filter", filter) @@ -508,13 +508,13 @@ func (c AppClient) ListGroups(ctx context.Context, filter string) (result []ADGr return nil, err } - groups := []ADGroup{} + groups := []Group{} for _, rawGroup := range groupsResp.Groups { if rawGroup.ID == "" { return nil, fmt.Errorf("missing group ID from response") } - group := ADGroup{ + group := Group{ ID: rawGroup.ID, DisplayName: rawGroup.DisplayName, } diff --git a/api/groups.go b/api/groups.go index b9b0ee10..24e1b571 100644 --- a/api/groups.go +++ b/api/groups.go @@ -7,11 +7,11 @@ import ( type GroupsClient interface { AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error - GetGroup(ctx context.Context, objectID string) (result ADGroup, err error) - ListGroups(ctx context.Context, filter string) (result []ADGroup, err error) + GetGroup(ctx context.Context, objectID string) (result Group, err error) + ListGroups(ctx context.Context, filter string) (result []Group, err error) } -type ADGroup struct { +type Group struct { ID string DisplayName string } diff --git a/api/groups_aad.go b/api/groups_aad.go index f784fc10..b4d007c0 100644 --- a/api/groups_aad.go +++ b/api/groups_aad.go @@ -38,10 +38,10 @@ func (a ActiveDirectoryApplicationGroupsClient) RemoveGroupMember(ctx context.Co return err } -func (a ActiveDirectoryApplicationGroupsClient) GetGroup(ctx context.Context, objectID string) (result ADGroup, err error) { +func (a ActiveDirectoryApplicationGroupsClient) GetGroup(ctx context.Context, objectID string) (result Group, err error) { resp, err := a.Client.Get(ctx, objectID) if err != nil { - return ADGroup{}, err + return Group{}, err } grp := getGroupFromRBAC(resp) @@ -49,21 +49,21 @@ func (a ActiveDirectoryApplicationGroupsClient) GetGroup(ctx context.Context, ob return grp, nil } -func getGroupFromRBAC(resp graphrbac.ADGroup) ADGroup { - grp := ADGroup{ +func getGroupFromRBAC(resp graphrbac.ADGroup) Group { + grp := Group{ ID: *resp.ObjectID, DisplayName: *resp.DisplayName, } return grp } -func (a ActiveDirectoryApplicationGroupsClient) ListGroups(ctx context.Context, filter string) (result []ADGroup, err error) { +func (a ActiveDirectoryApplicationGroupsClient) ListGroups(ctx context.Context, filter string) (result []Group, err error) { resp, err := a.Client.List(ctx, filter) if err != nil { return nil, err } - grps := []ADGroup{} + grps := []Group{} for _, aadGrp := range resp.Values() { grp := getGroupFromRBAC(aadGrp) grps = append(grps, grp) diff --git a/client.go b/client.go index 7d3ee577..fcb6ab84 100644 --- a/client.go +++ b/client.go @@ -254,7 +254,7 @@ func (c *client) findRoles(ctx context.Context, roleName string) ([]authorizatio // findGroups is used to find a group by name. It returns all groups matching // the passsed name. -func (c *client) findGroups(ctx context.Context, groupName string) ([]api.ADGroup, error) { +func (c *client) findGroups(ctx context.Context, groupName string) ([]api.Group, error) { return c.provider.ListGroups(ctx, fmt.Sprintf("displayName eq '%s'", groupName)) } diff --git a/path_roles.go b/path_roles.go index 9160749b..0494870a 100644 --- a/path_roles.go +++ b/path_roles.go @@ -238,7 +238,7 @@ func (b *azureSecretBackend) pathRoleUpdate(ctx context.Context, req *logical.Re // update and verify Azure groups, including looking up each group by ID or name. groupSet := make(map[string]bool) for _, r := range role.AzureGroups { - var groupDef api.ADGroup + var groupDef api.Group if r.ObjectID != "" { groupDef, err = client.provider.GetGroup(ctx, r.ObjectID) if err != nil { diff --git a/provider.go b/provider.go index 9f071c2e..9ab5013a 100644 --- a/provider.go +++ b/provider.go @@ -231,11 +231,11 @@ func (p *provider) RemoveGroupMember(ctx context.Context, groupObjectID, memberO } // GetGroup gets group information from the directory. -func (p *provider) GetGroup(ctx context.Context, objectID string) (result api.ADGroup, err error) { +func (p *provider) GetGroup(ctx context.Context, objectID string) (result api.Group, err error) { return p.groupsClient.GetGroup(ctx, objectID) } // ListGroups gets list of groups for the current tenant. -func (p *provider) ListGroups(ctx context.Context, filter string) (result []api.ADGroup, err error) { +func (p *provider) ListGroups(ctx context.Context, filter string) (result []api.Group, err error) { return p.groupsClient.ListGroups(ctx, filter) } diff --git a/provider_mock_test.go b/provider_mock_test.go index dbfd9b0b..9b8b1121 100644 --- a/provider_mock_test.go +++ b/provider_mock_test.go @@ -187,8 +187,8 @@ func (m *mockProvider) RemoveGroupMember(_ context.Context, _ string, _ string) } // GetGroup gets group information from the directory. -func (m *mockProvider) GetGroup(_ context.Context, objectID string) (api.ADGroup, error) { - g := api.ADGroup{ +func (m *mockProvider) GetGroup(_ context.Context, objectID string) (api.Group, error) { + g := api.Group{ ID: objectID, } s := strings.Split(objectID, "FAKE_GROUP-") @@ -200,14 +200,14 @@ func (m *mockProvider) GetGroup(_ context.Context, objectID string) (api.ADGroup } // ListGroups gets list of groups for the current tenant. -func (m *mockProvider) ListGroups(_ context.Context, filter string) (result []api.ADGroup, err error) { +func (m *mockProvider) ListGroups(_ context.Context, filter string) (result []api.Group, err error) { reGroupName := regexp.MustCompile("displayName eq '(.*)'") match := reGroupName.FindAllStringSubmatch(filter, -1) if len(match) > 0 { name := match[0][1] if name == "multiple" { - return []api.ADGroup{ + return []api.Group{ { ID: fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name), DisplayName: name, @@ -219,7 +219,7 @@ func (m *mockProvider) ListGroups(_ context.Context, filter string) (result []ap }, nil } - return []api.ADGroup{ + return []api.Group{ { ID: fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name), DisplayName: name, @@ -227,7 +227,7 @@ func (m *mockProvider) ListGroups(_ context.Context, filter string) (result []ap }, nil } - return []api.ADGroup{}, nil + return []api.Group{}, nil } // errMockProvider simulates a normal provider which fails to associate a role, From 2b072688d3cdcc4903666ce907d10fdc2753b430 Mon Sep 17 00:00:00 2001 From: Michael Golowka <72365+pcman312@users.noreply.github.com> Date: Wed, 22 Sep 2021 16:57:06 -0600 Subject: [PATCH 7/8] Use msgraph for Service Principals --- api/api.go | 8 +- api/application_msgraph.go | 218 +++++++++++++++++++-------------- api/service_principals.go | 16 +++ api/service_principals_aad.go | 48 ++++++++ client.go | 55 ++++----- path_roles.go | 2 +- path_service_principal.go | 8 +- path_service_principal_test.go | 91 +++++++++----- provider.go | 28 +++-- provider_mock_test.go | 13 +- 10 files changed, 304 insertions(+), 183 deletions(-) create mode 100644 api/service_principals.go create mode 100644 api/service_principals_aad.go diff --git a/api/api.go b/api/api.go index eb7777ea..53492423 100644 --- a/api/api.go +++ b/api/api.go @@ -4,7 +4,6 @@ import ( "context" "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/date" ) @@ -15,8 +14,7 @@ import ( type AzureProvider interface { ApplicationsClient GroupsClient - - CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) + ServicePrincipalClient CreateRoleAssignment( ctx context.Context, @@ -25,8 +23,8 @@ type AzureProvider interface { parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) DeleteRoleAssignmentByID(ctx context.Context, roleID string) (authorization.RoleAssignment, error) - ListRoles(ctx context.Context, scope string, filter string) ([]authorization.RoleDefinition, error) - GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) + ListRoleDefinitions(ctx context.Context, scope string, filter string) ([]authorization.RoleDefinition, error) + GetRoleDefinitionByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) } type ApplicationsClient interface { diff --git a/api/application_msgraph.go b/api/application_msgraph.go index 18f3e849..024bd803 100644 --- a/api/application_msgraph.go +++ b/api/application_msgraph.go @@ -5,12 +5,14 @@ import ( "fmt" "net/http" "net/url" + "time" "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/go-multierror" ) const ( @@ -20,6 +22,7 @@ const ( var _ ApplicationsClient = (*AppClient)(nil) var _ GroupsClient = (*AppClient)(nil) +var _ ServicePrincipalClient = (*AppClient)(nil) type AppClient struct { client authorization.BaseClient @@ -339,35 +342,12 @@ func (c AppClient) AddGroupMember(ctx context.Context, groupObjectID string, mem body := map[string]interface{}{ "@odata.id": fmt.Sprintf("%s/v1.0/directoryObjects/%s", DefaultGraphMicrosoftComURI, memberObjectID), } - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), + preparer := c.getPreparer( autorest.AsPost(), - autorest.WithBaseURL(c.client.BaseURI), autorest.WithPathParameters("/v1.0/groups/{groupObjectID}/members/$ref", pathParams), autorest.WithJSON(body), - c.client.WithAuthorization()) - req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) - if err != nil { - return err - } - - sender := autorest.GetSendDecorators(req.Context(), - autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), - ) - resp, err := autorest.SendWithSender(c.client, req, sender...) - if err != nil { - return err - } - - respBody := map[string]interface{}{} - - return autorest.Respond( - resp, - c.client.ByInspecting(), - azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), - autorest.ByUnmarshallingJSON(&respBody), - autorest.ByClosing(), ) + return c.sendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) } func (c AppClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error { @@ -381,31 +361,12 @@ func (c AppClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberO "groupObjectID": groupObjectID, "memberObjectID": memberObjectID, } - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), + + preparer := c.getPreparer( autorest.AsDelete(), - autorest.WithBaseURL(c.client.BaseURI), autorest.WithPathParameters("/v1.0/groups/{groupObjectID}/members/{memberObjectID}/$ref", pathParams), - c.client.WithAuthorization()) - req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) - if err != nil { - return err - } - - sender := autorest.GetSendDecorators(req.Context(), - autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), - ) - resp, err := autorest.SendWithSender(c.client, req, sender...) - if err != nil { - return err - } - - return autorest.Respond( - resp, - c.client.ByInspecting(), - azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), - autorest.ByClosing(), ) + return c.sendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) } // groupResponse is a struct representation of the data we care about coming back from @@ -424,33 +385,16 @@ func (c AppClient) GetGroup(ctx context.Context, groupID string) (result Group, pathParams := map[string]interface{}{ "groupID": groupID, } - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), + + preparer := c.getPreparer( autorest.AsGet(), - autorest.WithBaseURL(c.client.BaseURI), autorest.WithPathParameters("/v1.0/groups/{groupID}", pathParams), - c.client.WithAuthorization()) - req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) - if err != nil { - return Group{}, err - } - - sender := autorest.GetSendDecorators(req.Context(), - autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), ) - resp, err := autorest.SendWithSender(c.client, req, sender...) - if err != nil { - return Group{}, err - } groupResp := groupResponse{} - - err = autorest.Respond( - resp, - c.client.ByInspecting(), + err = c.sendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), autorest.ByUnmarshallingJSON(&groupResp), - autorest.ByClosing(), ) if err != nil { return Group{}, err @@ -476,40 +420,22 @@ func (c AppClient) ListGroups(ctx context.Context, filter string) (result []Grou filterArgs.Set("$filter", filter) } - preparer := autorest.CreatePreparer( - autorest.AsContentType("application/json; charset=utf-8"), + preparer := c.getPreparer( autorest.AsGet(), - autorest.WithBaseURL(c.client.BaseURI), autorest.WithPath(fmt.Sprintf("/v1.0/groups?%s", filterArgs.Encode())), - c.client.WithAuthorization()) - req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) - if err != nil { - return nil, err - } - - sender := autorest.GetSendDecorators(req.Context(), - autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), ) - resp, err := autorest.SendWithSender(c.client, req, sender...) - if err != nil { - return nil, err - } - groupsResp := listGroupsResponse{} - - err = autorest.Respond( - resp, - c.client.ByInspecting(), + respBody := listGroupsResponse{} + err = c.sendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), - autorest.ByUnmarshallingJSON(&groupsResp), - autorest.ByClosing(), + autorest.ByUnmarshallingJSON(&respBody), ) if err != nil { return nil, err } groups := []Group{} - for _, rawGroup := range groupsResp.Groups { + for _, rawGroup := range respBody.Groups { if rawGroup.ID == "" { return nil, fmt.Errorf("missing group ID from response") } @@ -522,3 +448,115 @@ func (c AppClient) ListGroups(ctx context.Context, filter string) (result []Grou } return groups, nil } + +func (c *AppClient) CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (spID string, password string, err error) { + spID, err = c.createServicePrincipal(ctx, appID) + if err != nil { + return "", "", err + } + password, err = c.setPasswordForServicePrincipal(ctx, spID, startDate, endDate) + if err != nil { + dErr := c.deleteServicePrincipal(ctx, spID) + merr := multierror.Append(err, dErr) + return "", "", merr.ErrorOrNil() + } + return spID, password, nil +} + +func (c *AppClient) createServicePrincipal(ctx context.Context, appID string) (id string, err error) { + body := map[string]interface{}{ + "appId": appID, + "accountEnabled": true, + } + preparer := c.getPreparer( + autorest.AsPost(), + autorest.WithPath("/v1.0/servicePrincipals"), + autorest.WithJSON(body), + ) + + respBody := createServicePrincipalResponse{} + err = c.sendRequest(ctx, preparer, + autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusCreated), + autorest.ByUnmarshallingJSON(&respBody), + ) + if err != nil { + return "", err + } + + return respBody.ID, nil +} + +func (c *AppClient) setPasswordForServicePrincipal(ctx context.Context, spID string, startDate time.Time, endDate time.Time) (password string, err error) { + pathParams := map[string]interface{}{ + "id": spID, + } + reqBody := map[string]interface{}{ + "startDateTime": startDate.UTC().Format("2006-01-02T15:04:05Z"), + "endDateTime": startDate.UTC().Format("2006-01-02T15:04:05Z"), + } + + preparer := c.getPreparer( + autorest.AsPost(), + autorest.WithPathParameters("/v1.0/servicePrincipals/{id}/addPassword", pathParams), + autorest.WithJSON(reqBody), + ) + + respBody := PasswordCredential{} + err = c.sendRequest(ctx, preparer, + autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), + autorest.ByUnmarshallingJSON(&respBody), + ) + if err != nil { + return "", err + } + return *respBody.SecretText, nil +} + +type createServicePrincipalResponse struct { + ID string `json:"id"` +} + +func (c *AppClient) deleteServicePrincipal(ctx context.Context, spID string) error { + pathParams := map[string]interface{}{ + "id": spID, + } + + preparer := c.getPreparer( + autorest.AsDelete(), + autorest.WithPathParameters("/v1.0/servicePrincipals/{id}", pathParams), + ) + + return c.sendRequest(ctx, preparer, autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) +} + +func (c *AppClient) getPreparer(prepareDecorators ...autorest.PrepareDecorator) autorest.Preparer { + decs := []autorest.PrepareDecorator{ + autorest.AsContentType("application/json; charset=utf-8"), + autorest.WithBaseURL(c.client.BaseURI), + c.client.WithAuthorization(), + } + decs = append(decs, prepareDecorators...) + preparer := autorest.CreatePreparer(decs...) + return preparer +} + +func (c *AppClient) sendRequest(ctx context.Context, preparer autorest.Preparer, respDecs ...autorest.RespondDecorator) error { + req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) + if err != nil { + return err + } + + sender := autorest.GetSendDecorators(req.Context(), + autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), + ) + resp, err := autorest.SendWithSender(c.client, req, sender...) + if err != nil { + return err + } + + // Put ByInspecting() before any provided decorators + respDecs = append([]autorest.RespondDecorator{c.client.ByInspecting()}, respDecs...) + respDecs = append(respDecs, autorest.ByClosing()) + + return autorest.Respond(resp, respDecs...) +} diff --git a/api/service_principals.go b/api/service_principals.go new file mode 100644 index 00000000..7dbd2432 --- /dev/null +++ b/api/service_principals.go @@ -0,0 +1,16 @@ +package api + +import ( + "context" + "time" +) + +type ServicePrincipalClient interface { + // CreateServicePrincipal in Azure. The password returned is the actual password that the appID was created with + CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (id string, password string, err error) +} + +type ServicePrincipal struct { + ObjectID string + AppID string +} diff --git a/api/service_principals_aad.go b/api/service_principals_aad.go new file mode 100644 index 00000000..784f935b --- /dev/null +++ b/api/service_principals_aad.go @@ -0,0 +1,48 @@ +package api + +import ( + "context" + "time" + + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/go-uuid" +) + +var _ ServicePrincipalClient = (*AADServicePrincipalsClient)(nil) + +type AADServicePrincipalsClient struct { + Client graphrbac.ServicePrincipalsClient + Passwords Passwords +} + +func (c AADServicePrincipalsClient) CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (id string, password string, err error) { + keyID, err := uuid.GenerateUUID() + if err != nil { + return "", "", err + } + + password, err = c.Passwords.Generate(ctx) + if err != nil { + return "", "", err + } + + clientParams := graphrbac.ServicePrincipalCreateParameters{ + AppID: to.StringPtr(appID), + AccountEnabled: to.BoolPtr(true), + PasswordCredentials: &[]graphrbac.PasswordCredential{ + graphrbac.PasswordCredential{ + StartDate: &date.Time{startDate}, + EndDate: &date.Time{endDate}, + KeyID: &keyID, + Value: &password, + }, + }, + } + sp, err := c.Client.Create(ctx, clientParams) + if err != nil { + return "", "", err + } + return *sp.ObjectID, password, nil +} diff --git a/client.go b/client.go index fcb6ab84..37a88d37 100644 --- a/client.go +++ b/client.go @@ -10,7 +10,6 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/to" @@ -62,49 +61,37 @@ func (c *client) createApp(ctx context.Context) (app *api.ApplicationResult, err func (c *client) createSP( ctx context.Context, app *api.ApplicationResult, - duration time.Duration) (svcPrinc *graphrbac.ServicePrincipal, password string, err error) { + duration time.Duration) (spID string, password string, err error) { - // Generate a random key (which must be a UUID) and password - keyID, err := uuid.GenerateUUID() - if err != nil { - return nil, "", err - } - - password, err = c.passwords.Generate(ctx) - if err != nil { - return nil, "", err + type idPass struct { + ID string + Password string } resultRaw, err := retry(ctx, func() (interface{}, bool, error) { - now := time.Now().UTC() - result, err := c.provider.CreateServicePrincipal(ctx, graphrbac.ServicePrincipalCreateParameters{ - AppID: app.AppID, - AccountEnabled: to.BoolPtr(true), - PasswordCredentials: &[]graphrbac.PasswordCredential{ - graphrbac.PasswordCredential{ - StartDate: &date.Time{Time: now}, - EndDate: &date.Time{Time: now.Add(duration)}, - KeyID: to.StringPtr(keyID), - Value: to.StringPtr(password), - }, - }, - }) + now := time.Now() + spID, password, err := c.provider.CreateServicePrincipal(ctx, *app.AppID, now, now.Add(duration)) // Propagation delays within Azure can cause this error occasionally, so don't quit on it. if err != nil && strings.Contains(err.Error(), "does not reference a valid application object") { return nil, false, nil } + result := idPass{ + ID: spID, + Password: password, + } + return result, true, err }) if err != nil { - return nil, "", errwrap.Wrapf("error creating service principal: {{err}}", err) + return "", "", errwrap.Wrapf("error creating service principal: {{err}}", err) } - result := resultRaw.(graphrbac.ServicePrincipal) + result := resultRaw.(idPass) - return &result, password, nil + return result.ID, result.Password, nil } // addAppPassword adds a new password to an App's credentials list. @@ -146,7 +133,7 @@ func (c *client) deleteApp(ctx context.Context, appObjectID string) error { } // assignRoles assigns Azure roles to a service principal. -func (c *client) assignRoles(ctx context.Context, sp *graphrbac.ServicePrincipal, roles []*AzureRole) ([]string, error) { +func (c *client) assignRoles(ctx context.Context, spID string, roles []*AzureRole) ([]string, error) { var ids []string for _, role := range roles { @@ -159,8 +146,8 @@ func (c *client) assignRoles(ctx context.Context, sp *graphrbac.ServicePrincipal ra, err := c.provider.CreateRoleAssignment(ctx, role.Scope, assignmentID, authorization.RoleAssignmentCreateParameters{ Properties: &authorization.RoleAssignmentProperties{ - RoleDefinitionID: to.StringPtr(role.RoleID), - PrincipalID: sp.ObjectID, + RoleDefinitionID: &role.RoleID, + PrincipalID: &spID, }, }) @@ -199,10 +186,10 @@ func (c *client) unassignRoles(ctx context.Context, roleIDs []string) error { } // addGroupMemberships adds the service principal to the Azure groups. -func (c *client) addGroupMemberships(ctx context.Context, sp *graphrbac.ServicePrincipal, groups []*AzureGroup) error { +func (c *client) addGroupMemberships(ctx context.Context, spID string, groups []*AzureGroup) error { for _, group := range groups { _, err := retry(ctx, func() (interface{}, bool, error) { - err := c.provider.AddGroupMember(ctx, group.ObjectID, *sp.ObjectID) + err := c.provider.AddGroupMember(ctx, group.ObjectID, spID) // Propagation delays within Azure can cause this error occasionally, so don't quit on it. if err != nil && strings.Contains(err.Error(), "Request_ResourceNotFound") { @@ -249,11 +236,11 @@ func groupObjectIDs(groups []*AzureGroup) []string { // search for roles by name func (c *client) findRoles(ctx context.Context, roleName string) ([]authorization.RoleDefinition, error) { - return c.provider.ListRoles(ctx, fmt.Sprintf("subscriptions/%s", c.settings.SubscriptionID), fmt.Sprintf("roleName eq '%s'", roleName)) + return c.provider.ListRoleDefinitions(ctx, fmt.Sprintf("subscriptions/%s", c.settings.SubscriptionID), fmt.Sprintf("roleName eq '%s'", roleName)) } // findGroups is used to find a group by name. It returns all groups matching -// the passsed name. +// the provided name. func (c *client) findGroups(ctx context.Context, groupName string) ([]api.Group, error) { return c.provider.ListGroups(ctx, fmt.Sprintf("displayName eq '%s'", groupName)) } diff --git a/path_roles.go b/path_roles.go index 0494870a..244f5452 100644 --- a/path_roles.go +++ b/path_roles.go @@ -203,7 +203,7 @@ func (b *azureSecretBackend) pathRoleUpdate(ctx context.Context, req *logical.Re for _, r := range role.AzureRoles { var roleDef authorization.RoleDefinition if r.RoleID != "" { - roleDef, err = client.provider.GetRoleByID(ctx, r.RoleID) + roleDef, err = client.provider.GetRoleDefinitionByID(ctx, r.RoleID) if err != nil { if strings.Contains(err.Error(), "RoleDefinitionDoesNotExist") { return logical.ErrorResponse("no role found for role_id: '%s'", r.RoleID), nil diff --git a/path_service_principal.go b/path_service_principal.go index 251c5866..ab3c3eb1 100644 --- a/path_service_principal.go +++ b/path_service_principal.go @@ -118,19 +118,19 @@ func (b *azureSecretBackend) createSPSecret(ctx context.Context, s logical.Stora } // Create a service principal associated with the new App - sp, password, err := c.createSP(ctx, app, spExpiration) + spID, password, err := c.createSP(ctx, app, spExpiration) if err != nil { return nil, err } // Assign Azure roles to the new SP - raIDs, err := c.assignRoles(ctx, sp, role.AzureRoles) + raIDs, err := c.assignRoles(ctx, spID, role.AzureRoles) if err != nil { return nil, err } // Assign Azure group memberships to the new SP - if err := c.addGroupMemberships(ctx, sp, role.AzureGroups); err != nil { + if err := c.addGroupMemberships(ctx, spID, role.AzureGroups); err != nil { return nil, err } @@ -145,7 +145,7 @@ func (b *azureSecretBackend) createSPSecret(ctx context.Context, s logical.Stora } internalData := map[string]interface{}{ "app_object_id": appObjID, - "sp_object_id": sp.ObjectID, + "sp_object_id": spID, "role_assignment_ids": raIDs, "group_membership_ids": groupObjectIDs(role.AzureGroups), "role": roleName, diff --git a/path_service_principal_test.go b/path_service_principal_test.go index 2a0de71a..5c22714e 100644 --- a/path_service_principal_test.go +++ b/path_service_principal_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/Azure/go-autorest/autorest/to" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault-plugin-secrets-azure/api" @@ -492,7 +491,7 @@ func TestCredentialInteg(t *testing.T) { t.Skip("Azure Secrets: Azure environment variables not set. Skipping.") } - t.Run("SP", func(t *testing.T) { + t.Run("service principals", func(t *testing.T) { t.Parallel() b := backend() @@ -557,30 +556,9 @@ func TestCredentialInteg(t *testing.T) { client, err := b.getClient(context.Background(), s) assertErrorIsNil(t, err) provider := client.provider.(*provider) + spObjID := findServicePrincipalID(t, provider.spClient, appID) - // recover the SP Object ID, which is not used by the application but - // is helpful for verification testing - spList, err := provider.spClient.List(context.Background(), "") - assertErrorIsNil(t, err) - - var spObjID string - for spList.NotDone() { - for _, v := range spList.Values() { - if to.String(v.AppID) == appID { - spObjID = to.String(v.ObjectID) - goto FOUND - } - } - spList.Next() - } - t.Fatal("Couldn't find SP Object ID") - - FOUND: - // verify the new SP can be accessed - _, err = provider.spClient.Get(context.Background(), spObjID) - if err != nil { - t.Fatalf("Expected nil error on GET of new SP, got: %#v", err) - } + assertServicePrincipalExists(t, provider.spClient, spObjID) // Verify that the role assignments were created. Get the assignment // info from Azure and verify it matches the Reader role. @@ -590,7 +568,7 @@ func TestCredentialInteg(t *testing.T) { ra, err := provider.raClient.GetByID(context.Background(), raIDs[0]) assertErrorIsNil(t, err) - roleDefs, err := client.provider.ListRoles(context.Background(), fmt.Sprintf("subscriptions/%s", subscriptionID), "") + roleDefs, err := provider.ListRoleDefinitions(context.Background(), fmt.Sprintf("subscriptions/%s", subscriptionID), "") assertErrorIsNil(t, err) defID := *ra.Properties.RoleDefinitionID @@ -620,14 +598,10 @@ func TestCredentialInteg(t *testing.T) { // Verify that SP get is an error after delete. Expected there // to be a delay and that this step would take some time/retries, // but that seems not to be the case. - _, err = provider.spClient.Get(context.Background(), spObjID) - - if err == nil { - t.Fatal("Expected error reading deleted SP") - } + assertServicePrincipalDoesNotExist(t, provider.spClient, spObjID) }) - t.Run("Static SP", func(t *testing.T) { + t.Run("Static service principals", func(t *testing.T) { t.Parallel() b := backend() @@ -777,3 +751,56 @@ func assertClientSecret(tb testing.TB, data map[string]interface{}) { tb.Fatalf("client_secret is not the correct length: expected %d but was %d", api.PasswordLength, len(actualPassword)) } } + +func findServicePrincipalID(t *testing.T, client api.ServicePrincipalClient, appID string) (spID string) { + t.Helper() + + switch spClient := client.(type) { + case api.AADServicePrincipalsClient: + spList, err := spClient.Client.List(context.Background(), "") + assertErrorIsNil(t, err) + for spList.NotDone() { + for _, sp := range spList.Values() { + if *sp.AppID == appID { + return *sp.ObjectID + } + } + err = spList.NextWithContext(context.Background()) + assertErrorIsNil(t, err) + } + // TODO: Add MSGraph + default: + t.Fatalf("Unrecognized service principal client type: %T", spClient) + } + + t.Fatalf("Failed to find service principal with application ID") + return "" // Because compilers +} + +func assertServicePrincipalExists(t *testing.T, client api.ServicePrincipalClient, spID string) { + t.Helper() + + switch spClient := client.(type) { + case api.AADServicePrincipalsClient: + _, err := spClient.Client.Get(context.Background(), spID) + if err != nil { + t.Fatalf("Expected nil error on GET of new SP, got: %#v", err) + } + default: + t.Fatalf("Unrecognized service principal client type: %T", spClient) + } +} + +func assertServicePrincipalDoesNotExist(t *testing.T, client api.ServicePrincipalClient, spID string) { + t.Helper() + + switch spClient := client.(type) { + case api.AADServicePrincipalsClient: + _, err := spClient.Client.Get(context.Background(), spID) + if err == nil { + t.Fatalf("Expected error on GET of new SP") + } + default: + t.Fatalf("Unrecognized service principal client type: %T", spClient) + } +} diff --git a/provider.go b/provider.go index 9ab5013a..b6f8ab09 100644 --- a/provider.go +++ b/provider.go @@ -3,6 +3,7 @@ package azuresecrets import ( "context" "strings" + "time" "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" @@ -23,7 +24,7 @@ type provider struct { settings *clientSettings appClient api.ApplicationsClient - spClient *graphrbac.ServicePrincipalsClient + spClient api.ServicePrincipalClient groupsClient api.GroupsClient raClient *authorization.RoleAssignmentsClient rdClient *authorization.RoleDefinitionsClient @@ -39,12 +40,9 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords ap userAgent := getUserAgent(settings) - spClient := graphrbac.NewServicePrincipalsClient(settings.TenantID) - spClient.Authorizer = graphAuthorizer - spClient.AddToUserAgent(userAgent) - var appClient api.ApplicationsClient var groupsClient api.GroupsClient + var spClient api.ServicePrincipalClient if useMsGraphApi { graphApiAuthorizer, err := getAuthorizer(settings, api.DefaultGraphMicrosoftComURI) if err != nil { @@ -58,6 +56,7 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords ap appClient = msGraphAppClient groupsClient = msGraphAppClient + spClient = msGraphAppClient } else { aadGraphClient := graphrbac.NewApplicationsClient(settings.TenantID) aadGraphClient.Authorizer = graphAuthorizer @@ -74,6 +73,15 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords ap TenantID: aadGroupsClient.TenantID, Client: aadGroupsClient, } + + servicePrincipalClient := graphrbac.NewServicePrincipalsClient(settings.TenantID) + servicePrincipalClient.Authorizer = graphAuthorizer + servicePrincipalClient.AddToUserAgent(userAgent) + + spClient = api.AADServicePrincipalsClient{ + Client: servicePrincipalClient, + Passwords: passwords, + } } // build clients that use the Resource Manager endpoint @@ -94,7 +102,7 @@ func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords ap settings: settings, appClient: appClient, - spClient: &spClient, + spClient: spClient, groupsClient: groupsClient, raClient: &raClient, rdClient: &rdClient, @@ -172,12 +180,12 @@ func (p *provider) RemoveApplicationPassword(ctx context.Context, applicationObj // CreateServicePrincipal creates a new Azure service principal. // An Application must be created prior to calling this and pass in parameters. -func (p *provider) CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) { - return p.spClient.Create(ctx, parameters) +func (p *provider) CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (id string, password string, err error) { + return p.spClient.CreateServicePrincipal(ctx, appID, startDate, endDate) } // ListRoles like all Azure roles with a scope (often subscription). -func (p *provider) ListRoles(ctx context.Context, scope string, filter string) (result []authorization.RoleDefinition, err error) { +func (p *provider) ListRoleDefinitions(ctx context.Context, scope string, filter string) (result []authorization.RoleDefinition, err error) { page, err := p.rdClient.List(ctx, scope, filter) if err != nil { @@ -188,7 +196,7 @@ func (p *provider) ListRoles(ctx context.Context, scope string, filter string) ( } // GetRoleByID fetches the full role definition given a roleID. -func (p *provider) GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) { +func (p *provider) GetRoleDefinitionByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) { return p.rdClient.GetByID(ctx, roleID) } diff --git a/provider_mock_test.go b/provider_mock_test.go index 9b8b1121..471e8abf 100644 --- a/provider_mock_test.go +++ b/provider_mock_test.go @@ -11,7 +11,6 @@ import ( "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/to" @@ -36,7 +35,7 @@ func newMockProvider() api.AzureProvider { } // ListRoles returns a single fake role based on the inbound filter -func (m *mockProvider) ListRoles(_ context.Context, _ string, filter string) (result []authorization.RoleDefinition, err error) { +func (m *mockProvider) ListRoleDefinitions(_ context.Context, _ string, filter string) (result []authorization.RoleDefinition, err error) { reRoleName := regexp.MustCompile("roleName eq '(.*)'") match := reRoleName.FindAllStringSubmatch(filter, -1) @@ -73,7 +72,7 @@ func (m *mockProvider) ListRoles(_ context.Context, _ string, filter string) (re // GetRoleByID will returns a fake role definition from the povided ID // Assumes an ID format of: .*FAKE_ROLE-{rolename} -func (m *mockProvider) GetRoleByID(_ context.Context, roleID string) (result authorization.RoleDefinition, err error) { +func (m *mockProvider) GetRoleDefinitionByID(_ context.Context, roleID string) (result authorization.RoleDefinition, err error) { d := authorization.RoleDefinition{} s := strings.Split(roleID, "FAKE_ROLE-") if len(s) > 1 { @@ -86,10 +85,10 @@ func (m *mockProvider) GetRoleByID(_ context.Context, roleID string) (result aut return d, nil } -func (m *mockProvider) CreateServicePrincipal(_ context.Context, _ graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) { - return graphrbac.ServicePrincipal{ - ObjectID: to.StringPtr(generateUUID()), - }, nil +func (m *mockProvider) CreateServicePrincipal(_ context.Context, _ string, _ time.Time, _ time.Time) (string, string, error) { + id := generateUUID() + pass := generateUUID() + return id, pass, nil } func (m *mockProvider) CreateApplication(_ context.Context, _ string) (api.ApplicationResult, error) { From 3f4f563c55896aaf64df4bee71e9a8082c04b4c4 Mon Sep 17 00:00:00 2001 From: Michael Golowka <72365+pcman312@users.noreply.github.com> Date: Thu, 23 Sep 2021 16:55:24 -0600 Subject: [PATCH 8/8] Add integration test for msgraph --- api/application_msgraph.go | 32 ++-- path_service_principal_test.go | 283 ++++++++++++++++++++++++++++----- 2 files changed, 263 insertions(+), 52 deletions(-) diff --git a/api/application_msgraph.go b/api/application_msgraph.go index 024bd803..f3448cbf 100644 --- a/api/application_msgraph.go +++ b/api/application_msgraph.go @@ -342,12 +342,12 @@ func (c AppClient) AddGroupMember(ctx context.Context, groupObjectID string, mem body := map[string]interface{}{ "@odata.id": fmt.Sprintf("%s/v1.0/directoryObjects/%s", DefaultGraphMicrosoftComURI, memberObjectID), } - preparer := c.getPreparer( + preparer := c.GetPreparer( autorest.AsPost(), autorest.WithPathParameters("/v1.0/groups/{groupObjectID}/members/$ref", pathParams), autorest.WithJSON(body), ) - return c.sendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) + return c.SendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) } func (c AppClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error { @@ -362,11 +362,11 @@ func (c AppClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberO "memberObjectID": memberObjectID, } - preparer := c.getPreparer( + preparer := c.GetPreparer( autorest.AsDelete(), autorest.WithPathParameters("/v1.0/groups/{groupObjectID}/members/{memberObjectID}/$ref", pathParams), ) - return c.sendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) + return c.SendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) } // groupResponse is a struct representation of the data we care about coming back from @@ -386,13 +386,13 @@ func (c AppClient) GetGroup(ctx context.Context, groupID string) (result Group, "groupID": groupID, } - preparer := c.getPreparer( + preparer := c.GetPreparer( autorest.AsGet(), autorest.WithPathParameters("/v1.0/groups/{groupID}", pathParams), ) groupResp := groupResponse{} - err = c.sendRequest(ctx, preparer, + err = c.SendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), autorest.ByUnmarshallingJSON(&groupResp), ) @@ -420,13 +420,13 @@ func (c AppClient) ListGroups(ctx context.Context, filter string) (result []Grou filterArgs.Set("$filter", filter) } - preparer := c.getPreparer( + preparer := c.GetPreparer( autorest.AsGet(), autorest.WithPath(fmt.Sprintf("/v1.0/groups?%s", filterArgs.Encode())), ) respBody := listGroupsResponse{} - err = c.sendRequest(ctx, preparer, + err = c.SendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), autorest.ByUnmarshallingJSON(&respBody), ) @@ -468,14 +468,14 @@ func (c *AppClient) createServicePrincipal(ctx context.Context, appID string) (i "appId": appID, "accountEnabled": true, } - preparer := c.getPreparer( + preparer := c.GetPreparer( autorest.AsPost(), autorest.WithPath("/v1.0/servicePrincipals"), autorest.WithJSON(body), ) respBody := createServicePrincipalResponse{} - err = c.sendRequest(ctx, preparer, + err = c.SendRequest(ctx, preparer, autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusCreated), autorest.ByUnmarshallingJSON(&respBody), ) @@ -495,14 +495,14 @@ func (c *AppClient) setPasswordForServicePrincipal(ctx context.Context, spID str "endDateTime": startDate.UTC().Format("2006-01-02T15:04:05Z"), } - preparer := c.getPreparer( + preparer := c.GetPreparer( autorest.AsPost(), autorest.WithPathParameters("/v1.0/servicePrincipals/{id}/addPassword", pathParams), autorest.WithJSON(reqBody), ) respBody := PasswordCredential{} - err = c.sendRequest(ctx, preparer, + err = c.SendRequest(ctx, preparer, autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), autorest.ByUnmarshallingJSON(&respBody), ) @@ -521,15 +521,15 @@ func (c *AppClient) deleteServicePrincipal(ctx context.Context, spID string) err "id": spID, } - preparer := c.getPreparer( + preparer := c.GetPreparer( autorest.AsDelete(), autorest.WithPathParameters("/v1.0/servicePrincipals/{id}", pathParams), ) - return c.sendRequest(ctx, preparer, autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) + return c.SendRequest(ctx, preparer, autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) } -func (c *AppClient) getPreparer(prepareDecorators ...autorest.PrepareDecorator) autorest.Preparer { +func (c *AppClient) GetPreparer(prepareDecorators ...autorest.PrepareDecorator) autorest.Preparer { decs := []autorest.PrepareDecorator{ autorest.AsContentType("application/json; charset=utf-8"), autorest.WithBaseURL(c.client.BaseURI), @@ -540,7 +540,7 @@ func (c *AppClient) getPreparer(prepareDecorators ...autorest.PrepareDecorator) return preparer } -func (c *AppClient) sendRequest(ctx context.Context, preparer autorest.Preparer, respDecs ...autorest.RespondDecorator) error { +func (c *AppClient) SendRequest(ctx context.Context, preparer autorest.Preparer, respDecs ...autorest.RespondDecorator) error { req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) if err != nil { return err diff --git a/path_service_principal_test.go b/path_service_principal_test.go index 5c22714e..8e553297 100644 --- a/path_service_principal_test.go +++ b/path_service_principal_test.go @@ -3,11 +3,14 @@ package azuresecrets import ( "context" "fmt" + "net/http" + "net/url" "os" "strings" "testing" "time" + "github.com/Azure/go-autorest/autorest" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault-plugin-secrets-azure/api" @@ -482,7 +485,7 @@ func TestCredentialReadProviderError(t *testing.T) { // TestCredentialInteg is an integration test against the live Azure service. It requires // valid, sufficiently-privileged Azure credentials in env variables. -func TestCredentialInteg(t *testing.T) { +func TestCredentialInteg_aad(t *testing.T) { if os.Getenv("VAULT_ACC") != "1" { t.SkipNow() } @@ -531,24 +534,15 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) // Request credentials resp, err = b.HandleRequest(context.Background(), &logical.Request{ Operation: logical.ReadOperation, Path: fmt.Sprintf("creds/%s", rolename), - Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) appID := resp.Data["client_id"].(string) @@ -601,7 +595,7 @@ func TestCredentialInteg(t *testing.T) { assertServicePrincipalDoesNotExist(t, provider.spClient, spObjID) }) - t.Run("Static service principals", func(t *testing.T) { + t.Run("static service principals", func(t *testing.T) { t.Parallel() b := backend() @@ -621,7 +615,7 @@ func TestCredentialInteg(t *testing.T) { // Add a Vault role that will provide creds with Azure "Reader" permissions subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") - rolename := "test_role" + rolename := "static_test_role" role := map[string]interface{}{ "azure_roles": fmt.Sprintf(`[{ "role_name": "Reader", @@ -634,11 +628,7 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) // Request credentials resp, err = b.HandleRequest(context.Background(), &logical.Request{ @@ -647,11 +637,7 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) origResp := resp @@ -671,11 +657,7 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) // Request credentials resp, err = b.HandleRequest(context.Background(), &logical.Request{ @@ -684,11 +666,7 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) // Test the added password by creating a new Azure provider with these // creds and attempting an operation with it. @@ -741,6 +719,153 @@ func TestCredentialInteg(t *testing.T) { }) } +// Similar to TestCredentialInteg, this is an integration test against the live Azure service. It requires +// valid, sufficiently-privileged Azure credentials in env variables. +// The credentials provided to this must include permissions to use MS Graph and not AAD +// Unfortunately this means that this test cannot be run within the same test execution as TestCredentialInteg +func TestCredentialInteg_msgraph(t *testing.T) { + if os.Getenv("VAULT_ACC") != "1" { + t.SkipNow() + } + + if os.Getenv("AZURE_CLIENT_SECRET") == "" { + t.Skip("Azure Secrets: Azure environment variables not set. Skipping.") + } + + t.Run("service principals", func(t *testing.T) { + t.Parallel() + + skipIfMissingEnvVars(t, + "AZURE_SUBSCRIPTION_ID", + "AZURE_CLIENT_ID", + "AZURE_CLIENT_SECRET", + "AZURE_TENANT_ID", + ) + + b := backend() + s := new(logical.InmemStorage) + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + clientID := os.Getenv("AZURE_CLIENT_ID") + clientSecret := os.Getenv("AZURE_CLIENT_SECRET") + tenantID := os.Getenv("AZURE_TENANT_ID") + + config := &logical.BackendConfig{ + Logger: logging.NewVaultLogger(log.Trace), + System: &logical.StaticSystemView{ + DefaultLeaseTTLVal: defaultLeaseTTLHr, + MaxLeaseTTLVal: maxLeaseTTLHr, + }, + StorageView: s, + } + err := b.Setup(context.Background(), config) + assertErrorIsNil(t, err) + + configData := map[string]interface{}{ + "subscription_id": subscriptionID, + "client_id": clientID, + "client_secret": clientSecret, + "tenant_id": tenantID, + "use_microsoft_graph_api": true, + } + + configResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.CreateOperation, + Path: "config", + Data: configData, + Storage: s, + }) + assertRespNoError(t, configResp, err) + + roleName := "test_role_msgraph" + + roleData := map[string]interface{}{ + "azure_roles": fmt.Sprintf(`[ + { + "role_name": "Reader", + "scope": "/subscriptions/%s/resourceGroups/vault-azure-secrets-test1" + }, + { + "role_name": "Reader", + "scope": "/subscriptions/%s/resourceGroups/vault-azure-secrets-test2" + }]`, subscriptionID, subscriptionID), + } + + roleResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.CreateOperation, + Path: fmt.Sprintf("roles/%s", roleName), + Data: roleData, + Storage: s, + }) + assertRespNoError(t, roleResp, err) + + credsResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.ReadOperation, + Path: fmt.Sprintf("creds/%s", roleName), + Storage: s, + }) + assertRespNoError(t, credsResp, err) + + appID := credsResp.Data["client_id"].(string) + + // Use the underlying provider to access clients directly for testing + client, err := b.getClient(context.Background(), s) + assertErrorIsNil(t, err) + provider := client.provider.(*provider) + spObjID := findServicePrincipalID(t, provider.spClient, appID) + + assertServicePrincipalExists(t, provider.spClient, spObjID) + + // Verify that the role assignments were created. Get the assignment + // info from Azure and verify it matches the Reader role. + raIDs := credsResp.Secret.InternalData["role_assignment_ids"].([]string) + equal(t, 2, len(raIDs)) + + ra, err := provider.raClient.GetByID(context.Background(), raIDs[0]) + assertErrorIsNil(t, err) + + roleDefs, err := provider.ListRoleDefinitions(context.Background(), fmt.Sprintf("subscriptions/%s", subscriptionID), "") + assertErrorIsNil(t, err) + + defID := *ra.Properties.RoleDefinitionID + found := false + for _, def := range roleDefs { + if *def.ID == defID && *def.RoleName == "Reader" { + found = true + break + } + } + + if !found { + t.Fatal("'Reader' role assignment not found") + } + + // Serialize and deserialize the secret to remove typing, as will really happen. + fakeSaveLoad(credsResp.Secret) + + // Revoke the Service Principal by sending back the secret we just received + req := &logical.Request{ + Secret: credsResp.Secret, + Storage: s, + } + + b.spRevoke(context.Background(), req, nil) + + // Verify that SP get is an error after delete. Expected there + // to be a delay and that this step would take some time/retries, + // but that seems not to be the case. + assertServicePrincipalDoesNotExist(t, provider.spClient, spObjID) + }) +} + +func skipIfMissingEnvVars(t *testing.T, envVars ...string) { + t.Helper() + for _, envVar := range envVars { + if os.Getenv(envVar) == "" { + t.Skipf("Missing env variable: [%s] - skipping test", envVar) + } + } +} + func assertClientSecret(tb testing.TB, data map[string]interface{}) { assertKeyExists(tb, data, "client_secret") actualPassword, ok := data["client_secret"].(string) @@ -752,6 +877,11 @@ func assertClientSecret(tb testing.TB, data map[string]interface{}) { } } +type servicePrincipalResp struct { + AppID string `json:"appId"` + ID string `json:"id"` +} + func findServicePrincipalID(t *testing.T, client api.ServicePrincipalClient, appID string) (spID string) { t.Helper() @@ -768,12 +898,41 @@ func findServicePrincipalID(t *testing.T, client api.ServicePrincipalClient, app err = spList.NextWithContext(context.Background()) assertErrorIsNil(t, err) } - // TODO: Add MSGraph + case *api.AppClient: + pathVals := &url.Values{} + pathVals.Set("$filter", fmt.Sprintf("appId eq '%s'", appID)) + + prep := spClient.GetPreparer( + autorest.AsGet(), + autorest.WithPath(fmt.Sprintf("/v1.0/servicePrincipals?%s", pathVals.Encode())), + ) + + type listSPsResponse struct { + ServicePrincipals []servicePrincipalResp `json:"value"` + } + + respBody := listSPsResponse{} + + err := spClient.SendRequest(context.Background(), prep, + autorest.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&respBody), + ) + assertErrorIsNil(t, err) + + if len(respBody.ServicePrincipals) == 0 { + t.Fatalf("Failed to find service principals from application ID") + } + + for _, sp := range respBody.ServicePrincipals { + if sp.AppID == appID { + return sp.ID + } + } default: t.Fatalf("Unrecognized service principal client type: %T", spClient) } - t.Fatalf("Failed to find service principal with application ID") + t.Fatalf("Failed to find service principal with application ID: %s", appID) return "" // Because compilers } @@ -786,6 +945,27 @@ func assertServicePrincipalExists(t *testing.T, client api.ServicePrincipalClien if err != nil { t.Fatalf("Expected nil error on GET of new SP, got: %#v", err) } + case *api.AppClient: + pathParams := map[string]interface{}{ + "id": spID, + } + + prep := spClient.GetPreparer( + autorest.AsGet(), + autorest.WithPathParameters("/v1.0/servicePrincipals/{id}", pathParams), + ) + + respBody := servicePrincipalResp{} + + err := spClient.SendRequest(context.Background(), prep, + autorest.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&respBody), + ) + assertErrorIsNil(t, err) + + if respBody.ID == "" { + t.Fatalf("Failed to find service principal") + } default: t.Fatalf("Unrecognized service principal client type: %T", spClient) } @@ -800,7 +980,38 @@ func assertServicePrincipalDoesNotExist(t *testing.T, client api.ServicePrincipa if err == nil { t.Fatalf("Expected error on GET of new SP") } + case *api.AppClient: + pathParams := map[string]interface{}{ + "id": spID, + } + + prep := spClient.GetPreparer( + autorest.AsGet(), + autorest.WithPathParameters("/v1.0/servicePrincipals/{id}", pathParams), + ) + + respBody := servicePrincipalResp{} + + err := spClient.SendRequest(context.Background(), prep, + autorest.WithErrorUnlessStatusCode(http.StatusNotFound), + autorest.ByUnmarshallingJSON(&respBody), + ) + assertErrorIsNil(t, err) + + if respBody.ID != "" { + t.Fatalf("Found service principal when it shouldn't exist") + } default: t.Fatalf("Unrecognized service principal client type: %T", spClient) } } + +func assertRespNoError(t *testing.T, resp *logical.Response, err error) { + t.Helper() + + assertErrorIsNil(t, err) + + if resp != nil && resp.IsError() { + t.Fatal(resp.Error()) + } +}