diff --git a/api/api.go b/api/api.go new file mode 100644 index 0000000..5349242 --- /dev/null +++ b/api/api.go @@ -0,0 +1,62 @@ +package api + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/date" +) + +// AzureProvider is an interface to access underlying Azure Client objects and supporting services. +// Where practical the original function signature is preserved. Client provides higher +// level operations atop AzureProvider. +type AzureProvider interface { + ApplicationsClient + GroupsClient + ServicePrincipalClient + + CreateRoleAssignment( + ctx context.Context, + scope string, + roleAssignmentName string, + parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) + DeleteRoleAssignmentByID(ctx context.Context, roleID string) (authorization.RoleAssignment, error) + + ListRoleDefinitions(ctx context.Context, scope string, filter string) ([]authorization.RoleDefinition, error) + GetRoleDefinitionByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) +} + +type ApplicationsClient interface { + GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) + CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) + DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) + AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) + RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) +} + +type PasswordCredential struct { + DisplayName *string `json:"displayName"` + // StartDate - Start date. + StartDate *date.Time `json:"startDateTime,omitempty"` + // EndDate - End date. + EndDate *date.Time `json:"endDateTime,omitempty"` + // KeyID - Key ID. + KeyID *string `json:"keyId,omitempty"` + // Value - Key value. + SecretText *string `json:"secretText,omitempty"` +} + +type PasswordCredentialResult struct { + autorest.Response `json:"-"` + + PasswordCredential +} + +type ApplicationResult struct { + autorest.Response `json:"-"` + + AppID *string `json:"appId,omitempty"` + ID *string `json:"id,omitempty"` + PasswordCredentials []*PasswordCredential `json:"passwordCredentials,omitempty"` +} diff --git a/api/application_aad.go b/api/application_aad.go new file mode 100644 index 0000000..2039ceb --- /dev/null +++ b/api/application_aad.go @@ -0,0 +1,147 @@ +package api + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/errwrap" + "github.com/hashicorp/go-uuid" +) + +type ActiveDirectoryApplicationClient struct { + Client *graphrbac.ApplicationsClient + Passwords Passwords +} + +func (a *ActiveDirectoryApplicationClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { + app, err := a.Client.Get(ctx, applicationObjectID) + if err != nil { + return ApplicationResult{}, err + } + + return ApplicationResult{ + AppID: app.AppID, + ID: app.ObjectID, + }, nil +} + +func (a *ActiveDirectoryApplicationClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { + appURL := fmt.Sprintf("https://%s", displayName) + + app, err := a.Client.Create(ctx, graphrbac.ApplicationCreateParameters{ + AvailableToOtherTenants: to.BoolPtr(false), + DisplayName: to.StringPtr(displayName), + Homepage: to.StringPtr(appURL), + IdentifierUris: to.StringSlicePtr([]string{appURL}), + }) + if err != nil { + return ApplicationResult{}, err + } + + return ApplicationResult{ + AppID: app.AppID, + ID: app.ObjectID, + }, nil +} + +func (a *ActiveDirectoryApplicationClient) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { + return a.Client.Delete(ctx, applicationObjectID) +} + +func (a *ActiveDirectoryApplicationClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { + keyID, err := uuid.GenerateUUID() + if err != nil { + return PasswordCredentialResult{}, err + } + + // Key IDs are not secret, and they're a convenient way for an operator to identify Vault-generated + // passwords. These must be UUIDs, so the three leading bytes will be used as an indicator. + keyID = "ffffff" + keyID[6:] + + password, err := a.Passwords.Generate(ctx) + if err != nil { + return PasswordCredentialResult{}, err + } + + now := date.Time{Time: time.Now().UTC()} + cred := graphrbac.PasswordCredential{ + StartDate: &now, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + Value: to.StringPtr(password), + } + + // Load current credentials + resp, err := a.Client.ListPasswordCredentials(ctx, applicationObjectID) + if err != nil { + return PasswordCredentialResult{}, errwrap.Wrapf("error fetching credentials: {{err}}", err) + } + curCreds := *resp.Value + + // Add and save credentials + curCreds = append(curCreds, cred) + + if _, err := a.Client.UpdatePasswordCredentials(ctx, applicationObjectID, + graphrbac.PasswordCredentialsUpdateParameters{ + Value: &curCreds, + }, + ); err != nil { + if strings.Contains(err.Error(), "size of the object has exceeded its limit") { + err = errors.New("maximum number of Application passwords reached") + } + return PasswordCredentialResult{}, errwrap.Wrapf("error updating credentials: {{err}}", err) + } + + return PasswordCredentialResult{ + PasswordCredential: PasswordCredential{ + DisplayName: to.StringPtr(displayName), + StartDate: &now, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + SecretText: to.StringPtr(password), + }, + }, nil +} + +func (a *ActiveDirectoryApplicationClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + // Load current credentials + resp, err := a.Client.ListPasswordCredentials(ctx, applicationObjectID) + if err != nil { + return autorest.Response{}, errwrap.Wrapf("error fetching credentials: {{err}}", err) + } + curCreds := *resp.Value + + // Remove credential + found := false + for i := range curCreds { + if to.String(curCreds[i].KeyID) == keyID { + curCreds[i] = curCreds[len(curCreds)-1] + curCreds = curCreds[:len(curCreds)-1] + found = true + break + } + } + + // KeyID is not present, so nothing to do + if !found { + return autorest.Response{}, nil + } + + // Save new credentials list + if _, err := a.Client.UpdatePasswordCredentials(ctx, applicationObjectID, + graphrbac.PasswordCredentialsUpdateParameters{ + Value: &curCreds, + }, + ); err != nil { + return autorest.Response{}, errwrap.Wrapf("error updating credentials: {{err}}", err) + } + + return autorest.Response{}, nil +} diff --git a/api/application_msgraph.go b/api/application_msgraph.go new file mode 100644 index 0000000..f3448cb --- /dev/null +++ b/api/application_msgraph.go @@ -0,0 +1,562 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/go-multierror" +) + +const ( + // DefaultGraphMicrosoftComURI is the default URI used for the service MS Graph API + DefaultGraphMicrosoftComURI = "https://graph.microsoft.com" +) + +var _ ApplicationsClient = (*AppClient)(nil) +var _ GroupsClient = (*AppClient)(nil) +var _ ServicePrincipalClient = (*AppClient)(nil) + +type AppClient struct { + client authorization.BaseClient +} + +func NewMSGraphApplicationClient(subscriptionId string, userAgentExtension string, auth autorest.Authorizer) (*AppClient, error) { + client := authorization.NewWithBaseURI(DefaultGraphMicrosoftComURI, subscriptionId) + client.Authorizer = auth + + if userAgentExtension != "" { + err := client.AddToUserAgent(userAgentExtension) + if err != nil { + return nil, fmt.Errorf("failed to add extension to user agent") + } + } + + ac := &AppClient{ + client: client, + } + return ac, nil +} + +func (c *AppClient) AddToUserAgent(extension string) error { + return c.client.AddToUserAgent(extension) +} + +func (c *AppClient) GetApplication(ctx context.Context, applicationObjectID string) (result ApplicationResult, err error) { + req, err := c.getApplicationPreparer(ctx, applicationObjectID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "GetApplication", nil, "Failure preparing request") + return + } + + resp, err := c.getApplicationSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "GetApplication", resp, "Failure sending request") + return + } + + result, err = c.getApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "GetApplication", resp, "Failure responding to request") + } + + return +} + +// CreateApplication create a new Azure application object. +func (c *AppClient) CreateApplication(ctx context.Context, displayName string) (result ApplicationResult, err error) { + req, err := c.createApplicationPreparer(ctx, displayName) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", nil, "Failure preparing request") + return + } + + resp, err := c.createApplicationSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", resp, "Failure sending request") + return + } + + result, err = c.createApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "CreateApplication", resp, "Failure responding to request") + } + + return +} + +// DeleteApplication deletes an Azure application object. +// This will in turn remove the service principal (but not the role assignments). +func (c *AppClient) DeleteApplication(ctx context.Context, applicationObjectID string) (result autorest.Response, err error) { + req, err := c.deleteApplicationPreparer(ctx, applicationObjectID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", nil, "Failure preparing request") + return + } + + resp, err := c.deleteApplicationSender(req) + if err != nil { + result.Response = resp + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", resp, "Failure sending request") + return + } + + result, err = c.deleteApplicationResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "DeleteApplication", resp, "Failure responding to request") + } + + return +} + +func (c *AppClient) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result PasswordCredentialResult, err error) { + req, err := c.addPasswordPreparer(ctx, applicationObjectID, displayName, endDateTime) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", nil, "Failure preparing request") + return + } + + resp, err := c.addPasswordSender(req) + if err != nil { + result.Response = autorest.Response{Response: resp} + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", resp, "Failure sending request") + return + } + + result, err = c.addPasswordResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "AddApplicationPassword", resp, "Failure responding to request") + } + + return +} + +func (c *AppClient) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + req, err := c.removePasswordPreparer(ctx, applicationObjectID, keyID) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", nil, "Failure preparing request") + return + } + + resp, err := c.removePasswordSender(req) + if err != nil { + result.Response = resp + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", resp, "Failure sending request") + return + } + + result, err = c.removePasswordResponder(resp) + if err != nil { + err = autorest.NewErrorWithError(err, "provider", "RemoveApplicationPassword", resp, "Failure responding to request") + } + + return +} + +func (c AppClient) getApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsGet(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}", pathParameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) getApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) getApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return result, err +} + +func (c AppClient) addPasswordPreparer(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + parameters := struct { + PasswordCredential *PasswordCredential `json:"passwordCredential"` + }{ + PasswordCredential: &PasswordCredential{ + DisplayName: to.StringPtr(displayName), + EndDate: &endDateTime, + }, + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}/addPassword", pathParameters), + autorest.WithJSON(parameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) addPasswordSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) addPasswordResponder(resp *http.Response) (result PasswordCredentialResult, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return +} + +func (c AppClient) removePasswordPreparer(ctx context.Context, applicationObjectID string, keyID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + parameters := struct { + KeyID string `json:"keyId"` + }{ + KeyID: keyID, + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}/removePassword", pathParameters), + autorest.WithJSON(parameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) removePasswordSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) removePasswordResponder(resp *http.Response) (result autorest.Response, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusNoContent), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = resp + return +} + +func (c AppClient) createApplicationPreparer(ctx context.Context, displayName string) (*http.Request, error) { + parameters := struct { + DisplayName *string `json:"displayName"` + }{ + DisplayName: to.StringPtr(displayName), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsPost(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPath("/v1.0/applications"), + autorest.WithJSON(parameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) createApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) createApplicationResponder(resp *http.Response) (result ApplicationResult, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusCreated), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = autorest.Response{Response: resp} + return +} + +func (c AppClient) deleteApplicationPreparer(ctx context.Context, applicationObjectID string) (*http.Request, error) { + pathParameters := map[string]interface{}{ + "applicationObjectId": autorest.Encode("path", applicationObjectID), + } + + preparer := autorest.CreatePreparer( + autorest.AsContentType("application/json; charset=utf-8"), + autorest.AsDelete(), + autorest.WithBaseURL(c.client.BaseURI), + autorest.WithPathParameters("/v1.0/applications/{applicationObjectId}", pathParameters), + c.client.WithAuthorization()) + return preparer.Prepare((&http.Request{}).WithContext(ctx)) +} + +func (c AppClient) deleteApplicationSender(req *http.Request) (*http.Response, error) { + sd := autorest.GetSendDecorators(req.Context(), autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...)) + return autorest.SendWithSender(c.client, req, sd...) +} + +func (c AppClient) deleteApplicationResponder(resp *http.Response) (result autorest.Response, err error) { + err = autorest.Respond( + resp, + c.client.ByInspecting(), + azure.WithErrorUnlessStatusCode(http.StatusNoContent), + autorest.ByUnmarshallingJSON(&result), + autorest.ByClosing()) + result.Response = resp + return +} + +func (c AppClient) AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error { + if groupObjectID == "" { + return fmt.Errorf("missing groupObjectID") + } + pathParams := map[string]interface{}{ + "groupObjectID": groupObjectID, + } + body := map[string]interface{}{ + "@odata.id": fmt.Sprintf("%s/v1.0/directoryObjects/%s", DefaultGraphMicrosoftComURI, memberObjectID), + } + preparer := c.GetPreparer( + autorest.AsPost(), + autorest.WithPathParameters("/v1.0/groups/{groupObjectID}/members/$ref", pathParams), + autorest.WithJSON(body), + ) + return c.SendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) +} + +func (c AppClient) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error { + if groupObjectID == "" { + return fmt.Errorf("missing groupObjectID") + } + if memberObjectID == "" { + return fmt.Errorf("missing memberObjectID") + } + pathParams := map[string]interface{}{ + "groupObjectID": groupObjectID, + "memberObjectID": memberObjectID, + } + + preparer := c.GetPreparer( + autorest.AsDelete(), + autorest.WithPathParameters("/v1.0/groups/{groupObjectID}/members/{memberObjectID}/$ref", pathParams), + ) + return c.SendRequest(ctx, preparer, azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) +} + +// groupResponse is a struct representation of the data we care about coming back from +// the ms-graph API. This is not the same as `Group` because this information is +// slightly different from the AAD implementation and there should be an abstraction +// between the ms-graph API itself and the API this package presents. +type groupResponse struct { + ID string `json:"id"` + DisplayName string `json:"displayName"` +} + +func (c AppClient) GetGroup(ctx context.Context, groupID string) (result Group, err error) { + if groupID == "" { + return Group{}, fmt.Errorf("missing groupID") + } + pathParams := map[string]interface{}{ + "groupID": groupID, + } + + preparer := c.GetPreparer( + autorest.AsGet(), + autorest.WithPathParameters("/v1.0/groups/{groupID}", pathParams), + ) + + groupResp := groupResponse{} + err = c.SendRequest(ctx, preparer, + azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), + autorest.ByUnmarshallingJSON(&groupResp), + ) + if err != nil { + return Group{}, err + } + + group := Group{ + ID: groupResp.ID, + DisplayName: groupResp.DisplayName, + } + + return group, nil +} + +// listGroupsResponse is a struct representation of the data we care about +// coming back from the ms-graph API +type listGroupsResponse struct { + Groups []groupResponse `json:"value"` +} + +func (c AppClient) ListGroups(ctx context.Context, filter string) (result []Group, err error) { + filterArgs := url.Values{} + if filter != "" { + filterArgs.Set("$filter", filter) + } + + preparer := c.GetPreparer( + autorest.AsGet(), + autorest.WithPath(fmt.Sprintf("/v1.0/groups?%s", filterArgs.Encode())), + ) + + respBody := listGroupsResponse{} + err = c.SendRequest(ctx, preparer, + azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), + autorest.ByUnmarshallingJSON(&respBody), + ) + if err != nil { + return nil, err + } + + groups := []Group{} + for _, rawGroup := range respBody.Groups { + if rawGroup.ID == "" { + return nil, fmt.Errorf("missing group ID from response") + } + + group := Group{ + ID: rawGroup.ID, + DisplayName: rawGroup.DisplayName, + } + groups = append(groups, group) + } + return groups, nil +} + +func (c *AppClient) CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (spID string, password string, err error) { + spID, err = c.createServicePrincipal(ctx, appID) + if err != nil { + return "", "", err + } + password, err = c.setPasswordForServicePrincipal(ctx, spID, startDate, endDate) + if err != nil { + dErr := c.deleteServicePrincipal(ctx, spID) + merr := multierror.Append(err, dErr) + return "", "", merr.ErrorOrNil() + } + return spID, password, nil +} + +func (c *AppClient) createServicePrincipal(ctx context.Context, appID string) (id string, err error) { + body := map[string]interface{}{ + "appId": appID, + "accountEnabled": true, + } + preparer := c.GetPreparer( + autorest.AsPost(), + autorest.WithPath("/v1.0/servicePrincipals"), + autorest.WithJSON(body), + ) + + respBody := createServicePrincipalResponse{} + err = c.SendRequest(ctx, preparer, + autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusCreated), + autorest.ByUnmarshallingJSON(&respBody), + ) + if err != nil { + return "", err + } + + return respBody.ID, nil +} + +func (c *AppClient) setPasswordForServicePrincipal(ctx context.Context, spID string, startDate time.Time, endDate time.Time) (password string, err error) { + pathParams := map[string]interface{}{ + "id": spID, + } + reqBody := map[string]interface{}{ + "startDateTime": startDate.UTC().Format("2006-01-02T15:04:05Z"), + "endDateTime": startDate.UTC().Format("2006-01-02T15:04:05Z"), + } + + preparer := c.GetPreparer( + autorest.AsPost(), + autorest.WithPathParameters("/v1.0/servicePrincipals/{id}/addPassword", pathParams), + autorest.WithJSON(reqBody), + ) + + respBody := PasswordCredential{} + err = c.SendRequest(ctx, preparer, + autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent), + autorest.ByUnmarshallingJSON(&respBody), + ) + if err != nil { + return "", err + } + return *respBody.SecretText, nil +} + +type createServicePrincipalResponse struct { + ID string `json:"id"` +} + +func (c *AppClient) deleteServicePrincipal(ctx context.Context, spID string) error { + pathParams := map[string]interface{}{ + "id": spID, + } + + preparer := c.GetPreparer( + autorest.AsDelete(), + autorest.WithPathParameters("/v1.0/servicePrincipals/{id}", pathParams), + ) + + return c.SendRequest(ctx, preparer, autorest.WithErrorUnlessStatusCode(http.StatusOK, http.StatusNoContent)) +} + +func (c *AppClient) GetPreparer(prepareDecorators ...autorest.PrepareDecorator) autorest.Preparer { + decs := []autorest.PrepareDecorator{ + autorest.AsContentType("application/json; charset=utf-8"), + autorest.WithBaseURL(c.client.BaseURI), + c.client.WithAuthorization(), + } + decs = append(decs, prepareDecorators...) + preparer := autorest.CreatePreparer(decs...) + return preparer +} + +func (c *AppClient) SendRequest(ctx context.Context, preparer autorest.Preparer, respDecs ...autorest.RespondDecorator) error { + req, err := preparer.Prepare((&http.Request{}).WithContext(ctx)) + if err != nil { + return err + } + + sender := autorest.GetSendDecorators(req.Context(), + autorest.DoRetryForStatusCodes(c.client.RetryAttempts, c.client.RetryDuration, autorest.StatusCodesForRetry...), + ) + resp, err := autorest.SendWithSender(c.client, req, sender...) + if err != nil { + return err + } + + // Put ByInspecting() before any provided decorators + respDecs = append([]autorest.RespondDecorator{c.client.ByInspecting()}, respDecs...) + respDecs = append(respDecs, autorest.ByClosing()) + + return autorest.Respond(resp, respDecs...) +} diff --git a/api/groups.go b/api/groups.go new file mode 100644 index 0000000..24e1b57 --- /dev/null +++ b/api/groups.go @@ -0,0 +1,17 @@ +package api + +import ( + "context" +) + +type GroupsClient interface { + AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error + RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) error + GetGroup(ctx context.Context, objectID string) (result Group, err error) + ListGroups(ctx context.Context, filter string) (result []Group, err error) +} + +type Group struct { + ID string + DisplayName string +} diff --git a/api/groups_aad.go b/api/groups_aad.go new file mode 100644 index 0000000..b4d007c --- /dev/null +++ b/api/groups_aad.go @@ -0,0 +1,72 @@ +package api + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/to" +) + +var _ GroupsClient = (*ActiveDirectoryApplicationGroupsClient)(nil) + +type aadGroupsClient interface { + AddMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) + RemoveMember(ctx context.Context, groupObjectID string, memberObjectID string) (result autorest.Response, err error) + Get(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) + List(ctx context.Context, filter string) (result graphrbac.GroupListResultPage, err error) +} + +type ActiveDirectoryApplicationGroupsClient struct { + BaseURI string + TenantID string + Client aadGroupsClient +} + +func (a ActiveDirectoryApplicationGroupsClient) AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error { + uri := fmt.Sprintf("%s/%s/directoryObjects/%s", a.BaseURI, a.TenantID, memberObjectID) + aadParams := graphrbac.GroupAddMemberParameters{ + URL: to.StringPtr(uri), + } + _, err := a.Client.AddMember(ctx, groupObjectID, aadParams) + return err +} + +func (a ActiveDirectoryApplicationGroupsClient) RemoveGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) error { + _, err := a.Client.RemoveMember(ctx, groupObjectID, memberObjectID) + return err +} + +func (a ActiveDirectoryApplicationGroupsClient) GetGroup(ctx context.Context, objectID string) (result Group, err error) { + resp, err := a.Client.Get(ctx, objectID) + if err != nil { + return Group{}, err + } + + grp := getGroupFromRBAC(resp) + + return grp, nil +} + +func getGroupFromRBAC(resp graphrbac.ADGroup) Group { + grp := Group{ + ID: *resp.ObjectID, + DisplayName: *resp.DisplayName, + } + return grp +} + +func (a ActiveDirectoryApplicationGroupsClient) ListGroups(ctx context.Context, filter string) (result []Group, err error) { + resp, err := a.Client.List(ctx, filter) + if err != nil { + return nil, err + } + + grps := []Group{} + for _, aadGrp := range resp.Values() { + grp := getGroupFromRBAC(aadGrp) + grps = append(grps, grp) + } + return grps, nil +} diff --git a/api/passwords.go b/api/passwords.go new file mode 100644 index 0000000..7ac8b92 --- /dev/null +++ b/api/passwords.go @@ -0,0 +1,31 @@ +package api + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-secure-stdlib/base62" +) + +const ( + PasswordLength = 36 +) + +type PasswordGenerator interface { + GeneratePasswordFromPolicy(ctx context.Context, policyName string) (password string, err error) +} + +type Passwords struct { + PolicyGenerator PasswordGenerator + PolicyName string +} + +func (p Passwords) Generate(ctx context.Context) (password string, err error) { + if p.PolicyName == "" { + return base62.Random(PasswordLength) + } + if p.PolicyGenerator == nil { + return "", fmt.Errorf("policy set, but no policy generator specified") + } + return p.PolicyGenerator.GeneratePasswordFromPolicy(ctx, p.PolicyName) +} diff --git a/api/service_principals.go b/api/service_principals.go new file mode 100644 index 0000000..7dbd243 --- /dev/null +++ b/api/service_principals.go @@ -0,0 +1,16 @@ +package api + +import ( + "context" + "time" +) + +type ServicePrincipalClient interface { + // CreateServicePrincipal in Azure. The password returned is the actual password that the appID was created with + CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (id string, password string, err error) +} + +type ServicePrincipal struct { + ObjectID string + AppID string +} diff --git a/api/service_principals_aad.go b/api/service_principals_aad.go new file mode 100644 index 0000000..784f935 --- /dev/null +++ b/api/service_principals_aad.go @@ -0,0 +1,48 @@ +package api + +import ( + "context" + "time" + + "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/go-uuid" +) + +var _ ServicePrincipalClient = (*AADServicePrincipalsClient)(nil) + +type AADServicePrincipalsClient struct { + Client graphrbac.ServicePrincipalsClient + Passwords Passwords +} + +func (c AADServicePrincipalsClient) CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (id string, password string, err error) { + keyID, err := uuid.GenerateUUID() + if err != nil { + return "", "", err + } + + password, err = c.Passwords.Generate(ctx) + if err != nil { + return "", "", err + } + + clientParams := graphrbac.ServicePrincipalCreateParameters{ + AppID: to.StringPtr(appID), + AccountEnabled: to.BoolPtr(true), + PasswordCredentials: &[]graphrbac.PasswordCredential{ + graphrbac.PasswordCredential{ + StartDate: &date.Time{startDate}, + EndDate: &date.Time{endDate}, + KeyID: &keyID, + Value: &password, + }, + }, + } + sp, err := c.Client.Create(ctx, clientParams) + if err != nil { + return "", "", err + } + return *sp.ObjectID, password, nil +} diff --git a/backend.go b/backend.go index 0e64735..2815d03 100644 --- a/backend.go +++ b/backend.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/locksutil" "github.com/hashicorp/vault/sdk/logical" @@ -14,7 +15,7 @@ import ( type azureSecretBackend struct { *framework.Backend - getProvider func(*clientSettings) (AzureProvider, error) + getProvider func(*clientSettings, bool, api.Passwords) (api.AzureProvider, error) client *client settings *clientSettings lock sync.RWMutex @@ -89,16 +90,15 @@ func (b *azureSecretBackend) invalidate(ctx context.Context, key string) { func (b *azureSecretBackend) getClient(ctx context.Context, s logical.Storage) (*client, error) { b.lock.RLock() - unlockFunc := b.lock.RUnlock - defer func() { unlockFunc() }() if b.client.Valid() { + b.lock.RUnlock() return b.client, nil } b.lock.RUnlock() b.lock.Lock() - unlockFunc = b.lock.Unlock + defer b.lock.Unlock() if b.client.Valid() { return b.client, nil @@ -121,14 +121,14 @@ func (b *azureSecretBackend) getClient(ctx context.Context, s logical.Storage) ( b.settings = settings } - p, err := b.getProvider(b.settings) - if err != nil { - return nil, err + passwords := api.Passwords{ + PolicyGenerator: b.System(), + PolicyName: config.PasswordPolicy, } - passwords := passwords{ - policyGenerator: b.System(), - policyName: config.PasswordPolicy, + p, err := b.getProvider(b.settings, config.UseMsGraphAPI, passwords) + if err != nil { + return nil, err } c := &client{ diff --git a/backend_test.go b/backend_test.go index 9c7f545..c22cef3 100644 --- a/backend_test.go +++ b/backend_test.go @@ -2,19 +2,11 @@ package azuresecrets import ( "context" - "errors" - "fmt" - "regexp" - "strings" "testing" "time" - "github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/to" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" ) @@ -44,7 +36,7 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical b.settings = new(clientSettings) mockProvider := newMockProvider() - b.getProvider = func(s *clientSettings) (AzureProvider, error) { + b.getProvider = func(s *clientSettings, usMsGraphApi bool, p api.Passwords) (api.AzureProvider, error) { return mockProvider, nil } @@ -64,237 +56,3 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical return b, config.StorageView } - -// mockProvider is a Provider that provides stubs and simple, deterministic responses. -type mockProvider struct { - subscriptionID string - applications map[string]bool - passwords map[string]bool - failNextCreateApplication bool -} - -// errMockProvider simulates a normal provider which fails to associate a role, -// returning an error -type errMockProvider struct { - *mockProvider -} - -// CreateRoleAssignment for the errMockProvider intentionally fails -func (e *errMockProvider) CreateRoleAssignment(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { - return authorization.RoleAssignment{}, errors.New("PrincipalNotFound") -} - -// GetApplication for the errMockProvider only returns an application if that -// key is found, unlike mockProvider which returns the same application object -// id each time. Existing tests depend on the mockProvider behavior, which is -// why errMockProvider has it's own version. -func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) { - for s := range e.applications { - if s == applicationObjectID { - return graphrbac.Application{ - AppID: to.StringPtr(s), - }, nil - } - } - return graphrbac.Application{}, errors.New("not found") -} - -func newErrMockProvider() AzureProvider { - return &errMockProvider{ - mockProvider: &mockProvider{ - subscriptionID: generateUUID(), - applications: make(map[string]bool), - passwords: make(map[string]bool), - }, - } -} - -func newMockProvider() AzureProvider { - return &mockProvider{ - subscriptionID: generateUUID(), - applications: make(map[string]bool), - passwords: make(map[string]bool), - } -} - -// ListRoles returns a single fake role based on the inbound filter -func (m *mockProvider) ListRoles(ctx context.Context, scope string, filter string) (result []authorization.RoleDefinition, err error) { - reRoleName := regexp.MustCompile("roleName eq '(.*)'") - - match := reRoleName.FindAllStringSubmatch(filter, -1) - if len(match) > 0 { - name := match[0][1] - if name == "multiple" { - return []authorization.RoleDefinition{ - { - ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s-1", name)), - RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ - RoleName: to.StringPtr(name), - }, - }, - { - ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s-2", name)), - RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ - RoleName: to.StringPtr(name), - }, - }, - }, nil - } - return []authorization.RoleDefinition{ - { - ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s", name)), - RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ - RoleName: to.StringPtr(name), - }, - }, - }, nil - } - - return []authorization.RoleDefinition{}, nil -} - -// GetRoleByID will returns a fake role definition from the povided ID -// Assumes an ID format of: .*FAKE_ROLE-{rolename} -func (m *mockProvider) GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) { - d := authorization.RoleDefinition{} - s := strings.Split(roleID, "FAKE_ROLE-") - if len(s) > 1 { - d.ID = to.StringPtr(roleID) - d.RoleDefinitionProperties = &authorization.RoleDefinitionProperties{ - RoleName: to.StringPtr(s[1]), - } - } - - return d, nil -} - -func (m *mockProvider) CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) { - return graphrbac.ServicePrincipal{ - ObjectID: to.StringPtr(generateUUID()), - }, nil -} - -func (m *mockProvider) CreateApplication(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) { - if m.failNextCreateApplication { - m.failNextCreateApplication = false - return graphrbac.Application{}, errors.New("Mock: fail to create application") - } - appObjID := generateUUID() - m.applications[appObjID] = true - - return graphrbac.Application{ - AppID: to.StringPtr(generateUUID()), - ObjectID: &appObjID, - }, nil -} - -func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) { - return graphrbac.Application{ - AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"), - }, nil -} - -func (m *mockProvider) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { - delete(m.applications, applicationObjectID) - return autorest.Response{}, nil -} - -func (m *mockProvider) UpdateApplicationPasswordCredentials(ctx context.Context, applicationObjectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (result autorest.Response, err error) { - m.passwords = make(map[string]bool) - for _, v := range *parameters.Value { - m.passwords[*v.KeyID] = true - } - - return autorest.Response{}, nil -} - -func (m *mockProvider) ListApplicationPasswordCredentials(ctx context.Context, applicationObjectID string) (result graphrbac.PasswordCredentialListResult, err error) { - var creds []graphrbac.PasswordCredential - for keyID := range m.passwords { - creds = append(creds, graphrbac.PasswordCredential{KeyID: &keyID}) - } - - return graphrbac.PasswordCredentialListResult{ - Value: &creds, - }, nil -} - -func (m *mockProvider) appExists(s string) bool { - return m.applications[s] -} - -func (m *mockProvider) passwordExists(s string) bool { - return m.passwords[s] -} - -func (m *mockProvider) VMGet(ctx context.Context, resourceGroupName string, VMName string, expand compute.InstanceViewTypes) (result compute.VirtualMachine, err error) { - return compute.VirtualMachine{}, nil -} - -func (m *mockProvider) VMUpdate(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachineUpdate) (result compute.VirtualMachinesUpdateFuture, err error) { - return compute.VirtualMachinesUpdateFuture{}, nil -} - -func (m *mockProvider) CreateRoleAssignment(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { - return authorization.RoleAssignment{ - ID: to.StringPtr(generateUUID()), - }, nil -} - -func (m *mockProvider) DeleteRoleAssignmentByID(ctx context.Context, roleID string) (result authorization.RoleAssignment, err error) { - return authorization.RoleAssignment{}, nil -} - -// AddGroupMember adds a member to a AAD Group. -func (m *mockProvider) AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) { - return autorest.Response{}, nil -} - -// RemoveGroupMember removes a member from a AAD Group. -func (m *mockProvider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) { - return autorest.Response{}, nil -} - -// GetGroup gets group information from the directory. -func (m *mockProvider) GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) { - g := graphrbac.ADGroup{ - ObjectID: to.StringPtr(objectID), - } - s := strings.Split(objectID, "FAKE_GROUP-") - if len(s) > 1 { - g.DisplayName = to.StringPtr(s[1]) - } - - return g, nil -} - -// ListGroups gets list of groups for the current tenant. -func (m *mockProvider) ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) { - reGroupName := regexp.MustCompile("displayName eq '(.*)'") - - match := reGroupName.FindAllStringSubmatch(filter, -1) - if len(match) > 0 { - name := match[0][1] - if name == "multiple" { - return []graphrbac.ADGroup{ - { - ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name)), - DisplayName: to.StringPtr(name), - }, - { - ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-2", name)), - DisplayName: to.StringPtr(name), - }, - }, nil - } - - return []graphrbac.ADGroup{ - { - ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name)), - DisplayName: to.StringPtr(name), - }, - }, nil - } - - return []graphrbac.ADGroup{}, nil -} diff --git a/client.go b/client.go index 0e0be99..37a88d3 100644 --- a/client.go +++ b/client.go @@ -9,14 +9,14 @@ import ( "strings" "time" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/to" "github.com/hashicorp/errwrap" multierror "github.com/hashicorp/go-multierror" uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/logical" ) @@ -30,10 +30,10 @@ const ( // for handlers. It in turn relies on a Provider interface to access the lower level // Azure Client SDK methods. type client struct { - provider AzureProvider + provider api.AzureProvider settings *clientSettings expiration time.Time - passwords passwords + passwords api.Passwords } // Valid returns whether the client defined and not expired. @@ -44,7 +44,7 @@ func (c *client) Valid() bool { // createApp creates a new Azure application. // An Application is a needed to create service principals used by // the caller for authentication. -func (c *client) createApp(ctx context.Context) (app *graphrbac.Application, err error) { +func (c *client) createApp(ctx context.Context) (app *api.ApplicationResult, err error) { name, err := uuid.GenerateUUID() if err != nil { return nil, err @@ -52,14 +52,7 @@ func (c *client) createApp(ctx context.Context) (app *graphrbac.Application, err name = appNamePrefix + name - appURL := fmt.Sprintf("https://%s", name) - - result, err := c.provider.CreateApplication(ctx, graphrbac.ApplicationCreateParameters{ - AvailableToOtherTenants: to.BoolPtr(false), - DisplayName: to.StringPtr(name), - Homepage: to.StringPtr(appURL), - IdentifierUris: to.StringSlicePtr([]string{appURL}), - }) + result, err := c.provider.CreateApplication(ctx, name) return &result, err } @@ -67,132 +60,61 @@ func (c *client) createApp(ctx context.Context) (app *graphrbac.Application, err // createSP creates a new service principal. func (c *client) createSP( ctx context.Context, - app *graphrbac.Application, - duration time.Duration) (svcPrinc *graphrbac.ServicePrincipal, password string, err error) { + app *api.ApplicationResult, + duration time.Duration) (spID string, password string, err error) { - // Generate a random key (which must be a UUID) and password - keyID, err := uuid.GenerateUUID() - if err != nil { - return nil, "", err - } - - password, err = c.passwords.generate(ctx) - if err != nil { - return nil, "", err + type idPass struct { + ID string + Password string } resultRaw, err := retry(ctx, func() (interface{}, bool, error) { - now := time.Now().UTC() - result, err := c.provider.CreateServicePrincipal(ctx, graphrbac.ServicePrincipalCreateParameters{ - AppID: app.AppID, - AccountEnabled: to.BoolPtr(true), - PasswordCredentials: &[]graphrbac.PasswordCredential{ - graphrbac.PasswordCredential{ - StartDate: &date.Time{Time: now}, - EndDate: &date.Time{Time: now.Add(duration)}, - KeyID: to.StringPtr(keyID), - Value: to.StringPtr(password), - }, - }, - }) + now := time.Now() + spID, password, err := c.provider.CreateServicePrincipal(ctx, *app.AppID, now, now.Add(duration)) // Propagation delays within Azure can cause this error occasionally, so don't quit on it. if err != nil && strings.Contains(err.Error(), "does not reference a valid application object") { return nil, false, nil } + result := idPass{ + ID: spID, + Password: password, + } + return result, true, err }) if err != nil { - return nil, "", errwrap.Wrapf("error creating service principal: {{err}}", err) + return "", "", errwrap.Wrapf("error creating service principal: {{err}}", err) } - result := resultRaw.(graphrbac.ServicePrincipal) + result := resultRaw.(idPass) - return &result, password, nil + return result.ID, result.Password, nil } // addAppPassword adds a new password to an App's credentials list. -func (c *client) addAppPassword(ctx context.Context, appObjID string, duration time.Duration) (keyID string, password string, err error) { - keyID, err = uuid.GenerateUUID() - if err != nil { - return "", "", err - } - - // Key IDs are not secret, and they're a convenient way for an operator to identify Vault-generated - // passwords. These must be UUIDs, so the three leading bytes will be used as an indicator. - keyID = "ffffff" + keyID[6:] - - password, err = c.passwords.generate(ctx) - if err != nil { - return "", "", err - } - - now := time.Now().UTC() - cred := graphrbac.PasswordCredential{ - StartDate: &date.Time{Time: now}, - EndDate: &date.Time{Time: now.Add(duration)}, - KeyID: to.StringPtr(keyID), - Value: to.StringPtr(password), - } - - // Load current credentials - resp, err := c.provider.ListApplicationPasswordCredentials(ctx, appObjID) +func (c *client) addAppPassword(ctx context.Context, appObjID string, expiresIn time.Duration) (string, string, error) { + exp := date.Time{Time: time.Now().Add(expiresIn)} + resp, err := c.provider.AddApplicationPassword(ctx, appObjID, "vault-plugin-secrets-azure", exp) if err != nil { - return "", "", errwrap.Wrapf("error fetching credentials: {{err}}", err) - } - curCreds := *resp.Value - - // Add and save credentials - curCreds = append(curCreds, cred) - - if _, err := c.provider.UpdateApplicationPasswordCredentials(ctx, appObjID, - graphrbac.PasswordCredentialsUpdateParameters{ - Value: &curCreds, - }, - ); err != nil { if strings.Contains(err.Error(), "size of the object has exceeded its limit") { err = errors.New("maximum number of Application passwords reached") } return "", "", errwrap.Wrapf("error updating credentials: {{err}}", err) } - return keyID, password, nil + return to.String(resp.KeyID), to.String(resp.SecretText), nil } // deleteAppPassword removes a password, if present, from an App's credentials list. func (c *client) deleteAppPassword(ctx context.Context, appObjID, keyID string) error { - // Load current credentials - resp, err := c.provider.ListApplicationPasswordCredentials(ctx, appObjID) - if err != nil { - return errwrap.Wrapf("error fetching credentials: {{err}}", err) - } - curCreds := *resp.Value - - // Remove credential - found := false - for i := range curCreds { - if to.String(curCreds[i].KeyID) == keyID { - curCreds[i] = curCreds[len(curCreds)-1] - curCreds = curCreds[:len(curCreds)-1] - found = true - break + if _, err := c.provider.RemoveApplicationPassword(ctx, appObjID, keyID); err != nil { + if strings.Contains(err.Error(), "No password credential found with keyId") { + return nil } - } - - // KeyID is not present, so nothing to do - if !found { - return nil - } - - // Save new credentials list - if _, err := c.provider.UpdateApplicationPasswordCredentials(ctx, appObjID, - graphrbac.PasswordCredentialsUpdateParameters{ - Value: &curCreds, - }, - ); err != nil { - return errwrap.Wrapf("error updating credentials: {{err}}", err) + return errwrap.Wrapf("error removing credentials: {{err}}", err) } return nil @@ -211,7 +133,7 @@ func (c *client) deleteApp(ctx context.Context, appObjectID string) error { } // assignRoles assigns Azure roles to a service principal. -func (c *client) assignRoles(ctx context.Context, sp *graphrbac.ServicePrincipal, roles []*AzureRole) ([]string, error) { +func (c *client) assignRoles(ctx context.Context, spID string, roles []*AzureRole) ([]string, error) { var ids []string for _, role := range roles { @@ -223,9 +145,9 @@ func (c *client) assignRoles(ctx context.Context, sp *graphrbac.ServicePrincipal resultRaw, err := retry(ctx, func() (interface{}, bool, error) { ra, err := c.provider.CreateRoleAssignment(ctx, role.Scope, assignmentID, authorization.RoleAssignmentCreateParameters{ - RoleAssignmentProperties: &authorization.RoleAssignmentProperties{ - RoleDefinitionID: to.StringPtr(role.RoleID), - PrincipalID: sp.ObjectID, + Properties: &authorization.RoleAssignmentProperties{ + RoleDefinitionID: &role.RoleID, + PrincipalID: &spID, }, }) @@ -264,19 +186,10 @@ func (c *client) unassignRoles(ctx context.Context, roleIDs []string) error { } // addGroupMemberships adds the service principal to the Azure groups. -func (c *client) addGroupMemberships(ctx context.Context, sp *graphrbac.ServicePrincipal, groups []*AzureGroup) error { +func (c *client) addGroupMemberships(ctx context.Context, spID string, groups []*AzureGroup) error { for _, group := range groups { _, err := retry(ctx, func() (interface{}, bool, error) { - _, err := c.provider.AddGroupMember(ctx, group.ObjectID, - graphrbac.GroupAddMemberParameters{ - URL: to.StringPtr( - fmt.Sprintf("%s%s/directoryObjects/%s", - c.settings.Environment.GraphEndpoint, - c.settings.TenantID, - *sp.ObjectID, - ), - ), - }) + err := c.provider.AddGroupMember(ctx, group.ObjectID, spID) // Propagation delays within Azure can cause this error occasionally, so don't quit on it. if err != nil && strings.Contains(err.Error(), "Request_ResourceNotFound") { @@ -302,7 +215,7 @@ func (c *client) removeGroupMemberships(ctx context.Context, servicePrincipalObj var merr *multierror.Error for _, id := range groupIDs { - if _, err := c.provider.RemoveGroupMember(ctx, servicePrincipalObjectID, id); err != nil { + if err := c.provider.RemoveGroupMember(ctx, servicePrincipalObjectID, id); err != nil { merr = multierror.Append(merr, errwrap.Wrapf("error removing group membership: {{err}}", err)) } } @@ -323,12 +236,12 @@ func groupObjectIDs(groups []*AzureGroup) []string { // search for roles by name func (c *client) findRoles(ctx context.Context, roleName string) ([]authorization.RoleDefinition, error) { - return c.provider.ListRoles(ctx, fmt.Sprintf("subscriptions/%s", c.settings.SubscriptionID), fmt.Sprintf("roleName eq '%s'", roleName)) + return c.provider.ListRoleDefinitions(ctx, fmt.Sprintf("subscriptions/%s", c.settings.SubscriptionID), fmt.Sprintf("roleName eq '%s'", roleName)) } // findGroups is used to find a group by name. It returns all groups matching -// the passsed name. -func (c *client) findGroups(ctx context.Context, groupName string) ([]graphrbac.ADGroup, error) { +// the provided name. +func (c *client) findGroups(ctx context.Context, groupName string) ([]api.Group, error) { return c.provider.ListGroups(ctx, fmt.Sprintf("displayName eq '%s'", groupName)) } @@ -402,19 +315,24 @@ func retry(ctx context.Context, f func() (interface{}, bool, error)) (interface{ } rng := rand.New(rand.NewSource(time.Now().UnixNano())) + var lastErr error for { - if result, done, err := f(); done { - return result, err - } - - delay := time.Duration(2000+rng.Intn(6000)) * time.Millisecond - delayTimer.Reset(delay) - select { case <-delayTimer.C: - // Retry loop + result, done, err := f() + if done { + return result, err + } + lastErr = err + + delay := time.Duration(2+rng.Intn(6)) * time.Second + delayTimer.Reset(delay) case <-ctx.Done(): - return nil, fmt.Errorf("retry failed: %w", ctx.Err()) + err := lastErr + if err == nil { + err = ctx.Err() + } + return nil, fmt.Errorf("retry failed: %w", err) } } } diff --git a/client_test.go b/client_test.go index 0382fc0..19d5b29 100644 --- a/client_test.go +++ b/client_test.go @@ -22,7 +22,6 @@ func TestRetry(t *testing.T) { t.Run("Three retries", func(t *testing.T) { t.Parallel() - start := time.Now() count := 0 _, err := retry(context.Background(), func() (interface{}, bool, error) { @@ -34,11 +33,6 @@ func TestRetry(t *testing.T) { }) equal(t, count, 3) - // each sleep can last from 2 to 8 seconds - elapsed := time.Now().Sub(start).Seconds() - if elapsed < 4 || elapsed > 16 { - t.Fatalf("expected time of 4-16 seconds, got: %f", elapsed) - } assertErrorIsNil(t, err) }) @@ -75,7 +69,7 @@ func TestRetry(t *testing.T) { if called == 0 { t.Fatalf("retryable function was never called") } - assertDuration(t, elapsed, timeout, 100*time.Millisecond) + assertDuration(t, elapsed, timeout, 250*time.Millisecond) }) t.Run("Cancellation", func(t *testing.T) { @@ -83,7 +77,7 @@ func TestRetry(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) go func() { - time.Sleep(7 * time.Second) + time.Sleep(1 * time.Second) cancel() }() @@ -91,10 +85,8 @@ func TestRetry(t *testing.T) { _, err := retry(ctx, func() (interface{}, bool, error) { return nil, false, nil }) - elapsed := time.Now().Sub(start).Seconds() - if elapsed < 6 || elapsed > 8 { - t.Fatalf("expected time of ~7 seconds, got: %f", elapsed) - } + elapsed := time.Now().Sub(start) + assertDuration(t, elapsed, 1*time.Second, 250*time.Millisecond) if err == nil { t.Fatalf("expected err: got nil") diff --git a/passwords.go b/passwords.go deleted file mode 100644 index 25b58af..0000000 --- a/passwords.go +++ /dev/null @@ -1,31 +0,0 @@ -package azuresecrets - -import ( - "context" - "fmt" - - "github.com/hashicorp/go-secure-stdlib/base62" -) - -const ( - passwordLength = 36 -) - -type passwordGenerator interface { - GeneratePasswordFromPolicy(ctx context.Context, policyName string) (password string, err error) -} - -type passwords struct { - policyGenerator passwordGenerator - policyName string -} - -func (p passwords) generate(ctx context.Context) (password string, err error) { - if p.policyName == "" { - return base62.Random(passwordLength) - } - if p.policyGenerator == nil { - return "", fmt.Errorf("policy set, but no policy generator specified") - } - return p.policyGenerator.GeneratePasswordFromPolicy(ctx, p.policyName) -} diff --git a/path_config.go b/path_config.go index 431a068..e405aa0 100644 --- a/path_config.go +++ b/path_config.go @@ -24,6 +24,7 @@ type azureConfig struct { ClientSecret string `json:"client_secret"` Environment string `json:"environment"` PasswordPolicy string `json:"password_policy"` + UseMsGraphAPI bool `json:"use_microsoft_graph_api"` } func pathConfig(b *azureSecretBackend) *framework.Path { @@ -59,6 +60,10 @@ func pathConfig(b *azureSecretBackend) *framework.Path { Type: framework.TypeString, Description: "Name of the password policy to use to generate passwords for dynamic credentials.", }, + "use_microsoft_graph_api": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: "Enable usage of the Microsoft Graph API over the deprecated Azure AD Graph API.", + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ logical.ReadOperation: b.pathConfigRead, @@ -112,6 +117,10 @@ func (b *azureSecretBackend) pathConfigWrite(ctx context.Context, req *logical.R config.ClientSecret = clientSecret.(string) } + if useMsGraphApi, ok := data.GetOk("use_microsoft_graph_api"); ok { + config.UseMsGraphAPI = useMsGraphApi.(bool) + } + config.PasswordPolicy = data.Get("password_policy").(string) if merr.ErrorOrNil() != nil { @@ -136,10 +145,11 @@ func (b *azureSecretBackend) pathConfigRead(ctx context.Context, req *logical.Re resp := &logical.Response{ Data: map[string]interface{}{ - "subscription_id": config.SubscriptionID, - "tenant_id": config.TenantID, - "environment": config.Environment, - "client_id": config.ClientID, + "subscription_id": config.SubscriptionID, + "tenant_id": config.TenantID, + "environment": config.Environment, + "client_id": config.ClientID, + "use_microsoft_graph_api": config.UseMsGraphAPI, }, } return resp, nil diff --git a/path_config_test.go b/path_config_test.go index bed8e53..a43c471 100644 --- a/path_config_test.go +++ b/path_config_test.go @@ -12,11 +12,12 @@ func TestConfig(t *testing.T) { // Test valid config config := map[string]interface{}{ - "subscription_id": "a228ceec-bf1a-4411-9f95-39678d8cdb34", - "tenant_id": "7ac36e27-80fc-4209-a453-e8ad83dc18c2", - "client_id": "testClientId", - "client_secret": "testClientSecret", - "environment": "AZURECHINACLOUD", + "subscription_id": "a228ceec-bf1a-4411-9f95-39678d8cdb34", + "tenant_id": "7ac36e27-80fc-4209-a453-e8ad83dc18c2", + "client_id": "testClientId", + "client_secret": "testClientSecret", + "environment": "AZURECHINACLOUD", + "use_microsoft_graph_api": false, } testConfigCreate(t, b, s, config) @@ -54,11 +55,12 @@ func TestConfigDelete(t *testing.T) { // Test valid config config := map[string]interface{}{ - "subscription_id": "a228ceec-bf1a-4411-9f95-39678d8cdb34", - "tenant_id": "7ac36e27-80fc-4209-a453-e8ad83dc18c2", - "client_id": "testClientId", - "client_secret": "testClientSecret", - "environment": "AZURECHINACLOUD", + "subscription_id": "a228ceec-bf1a-4411-9f95-39678d8cdb34", + "tenant_id": "7ac36e27-80fc-4209-a453-e8ad83dc18c2", + "client_id": "testClientId", + "client_secret": "testClientSecret", + "environment": "AZURECHINACLOUD", + "use_microsoft_graph_api": false, } testConfigCreate(t, b, s, config) @@ -79,10 +81,11 @@ func TestConfigDelete(t *testing.T) { } config = map[string]interface{}{ - "subscription_id": "", - "tenant_id": "", - "client_id": "", - "environment": "", + "subscription_id": "", + "tenant_id": "", + "client_id": "", + "environment": "", + "use_microsoft_graph_api": false, } testConfigRead(t, b, s, config) } diff --git a/path_roles.go b/path_roles.go index be29298..244f545 100644 --- a/path_roles.go +++ b/path_roles.go @@ -7,10 +7,10 @@ import ( "strings" "time" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/go-autorest/autorest/to" "github.com/hashicorp/errwrap" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" @@ -203,7 +203,7 @@ func (b *azureSecretBackend) pathRoleUpdate(ctx context.Context, req *logical.Re for _, r := range role.AzureRoles { var roleDef authorization.RoleDefinition if r.RoleID != "" { - roleDef, err = client.provider.GetRoleByID(ctx, r.RoleID) + roleDef, err = client.provider.GetRoleDefinitionByID(ctx, r.RoleID) if err != nil { if strings.Contains(err.Error(), "RoleDefinitionDoesNotExist") { return logical.ErrorResponse("no role found for role_id: '%s'", r.RoleID), nil @@ -238,7 +238,7 @@ func (b *azureSecretBackend) pathRoleUpdate(ctx context.Context, req *logical.Re // update and verify Azure groups, including looking up each group by ID or name. groupSet := make(map[string]bool) for _, r := range role.AzureGroups { - var groupDef graphrbac.ADGroup + var groupDef api.Group if r.ObjectID != "" { groupDef, err = client.provider.GetGroup(ctx, r.ObjectID) if err != nil { @@ -260,9 +260,8 @@ func (b *azureSecretBackend) pathRoleUpdate(ctx context.Context, req *logical.Re groupDef = defs[0] } - groupDefID := to.String(groupDef.ObjectID) - groupDefName := to.String(groupDef.DisplayName) - r.GroupName, r.ObjectID = groupDefName, groupDefID + r.ObjectID = groupDef.ID + r.GroupName = groupDef.DisplayName if groupSet[r.ObjectID] { return logical.ErrorResponse("duplicate object_id '%s'", r.ObjectID), nil diff --git a/path_service_principal.go b/path_service_principal.go index 2505c8a..ab3c3eb 100644 --- a/path_service_principal.go +++ b/path_service_principal.go @@ -105,7 +105,7 @@ func (b *azureSecretBackend) createSPSecret(ctx context.Context, s logical.Stora return nil, err } appID := to.String(app.AppID) - appObjID := to.String(app.ObjectID) + appObjID := to.String(app.ID) // Write a WAL entry in case the SP create process doesn't complete walID, err := framework.PutWAL(ctx, s, walAppKey, &walApp{ @@ -118,19 +118,19 @@ func (b *azureSecretBackend) createSPSecret(ctx context.Context, s logical.Stora } // Create a service principal associated with the new App - sp, password, err := c.createSP(ctx, app, spExpiration) + spID, password, err := c.createSP(ctx, app, spExpiration) if err != nil { return nil, err } // Assign Azure roles to the new SP - raIDs, err := c.assignRoles(ctx, sp, role.AzureRoles) + raIDs, err := c.assignRoles(ctx, spID, role.AzureRoles) if err != nil { return nil, err } // Assign Azure group memberships to the new SP - if err := c.addGroupMemberships(ctx, sp, role.AzureGroups); err != nil { + if err := c.addGroupMemberships(ctx, spID, role.AzureGroups); err != nil { return nil, err } @@ -145,7 +145,7 @@ func (b *azureSecretBackend) createSPSecret(ctx context.Context, s logical.Stora } internalData := map[string]interface{}{ "app_object_id": appObjID, - "sp_object_id": sp.ObjectID, + "sp_object_id": spID, "role_assignment_ids": raIDs, "group_membership_ids": groupObjectIDs(role.AzureGroups), "role": roleName, diff --git a/path_service_principal_test.go b/path_service_principal_test.go index 9abad42..8e55329 100644 --- a/path_service_principal_test.go +++ b/path_service_principal_test.go @@ -3,14 +3,17 @@ package azuresecrets import ( "context" "fmt" + "net/http" + "net/url" "os" "strings" "testing" "time" - "github.com/Azure/go-autorest/autorest/to" + "github.com/Azure/go-autorest/autorest" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/hashicorp/vault/sdk/logical" @@ -59,7 +62,7 @@ func TestSP_WAL_Cleanup(t *testing.T) { // overwrite the normal test backend provider with the errMockProvider errMockProvider := newErrMockProvider() - b.getProvider = func(s *clientSettings) (AzureProvider, error) { + b.getProvider = func(s *clientSettings, useMsGraphApi bool, p api.Passwords) (api.AzureProvider, error) { return errMockProvider, nil } @@ -90,7 +93,7 @@ func TestSP_WAL_Cleanup(t *testing.T) { }) } -func assertEmptyWAL(t *testing.T, b *azureSecretBackend, emp AzureProvider, s logical.Storage) { +func assertEmptyWAL(t *testing.T, b *azureSecretBackend, emp api.AzureProvider, s logical.Storage) { t.Helper() wal, err := framework.ListWAL(context.Background(), s) @@ -271,8 +274,8 @@ func TestStaticSPRead(t *testing.T) { assertErrorIsNil(t, err) keyID := resp.Secret.InternalData["key_id"].(string) - if !strings.HasPrefix(keyID, "ffffff") { - t.Fatalf("expected prefix 'ffffff': %s", keyID) + if len(keyID) == 0 { + t.Fatalf("expected keyId to not be empty") } client, err := b.getClient(context.Background(), s) @@ -413,8 +416,8 @@ func TestStaticSPRevoke(t *testing.T) { assertErrorIsNil(t, err) keyID := resp.Secret.InternalData["key_id"].(string) - if !strings.HasPrefix(keyID, "ffffff") { - t.Fatalf("expected prefix 'ffffff': %s", keyID) + if len(keyID) == 0 { + t.Fatalf("expected keyId to not be empty") } client, err := b.getClient(context.Background(), s) @@ -482,7 +485,7 @@ func TestCredentialReadProviderError(t *testing.T) { // TestCredentialInteg is an integration test against the live Azure service. It requires // valid, sufficiently-privileged Azure credentials in env variables. -func TestCredentialInteg(t *testing.T) { +func TestCredentialInteg_aad(t *testing.T) { if os.Getenv("VAULT_ACC") != "1" { t.SkipNow() } @@ -491,7 +494,7 @@ func TestCredentialInteg(t *testing.T) { t.Skip("Azure Secrets: Azure environment variables not set. Skipping.") } - t.Run("SP", func(t *testing.T) { + t.Run("service principals", func(t *testing.T) { t.Parallel() b := backend() @@ -531,24 +534,15 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) // Request credentials resp, err = b.HandleRequest(context.Background(), &logical.Request{ Operation: logical.ReadOperation, Path: fmt.Sprintf("creds/%s", rolename), - Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) appID := resp.Data["client_id"].(string) @@ -556,30 +550,9 @@ func TestCredentialInteg(t *testing.T) { client, err := b.getClient(context.Background(), s) assertErrorIsNil(t, err) provider := client.provider.(*provider) + spObjID := findServicePrincipalID(t, provider.spClient, appID) - // recover the SP Object ID, which is not used by the application but - // is helpful for verification testing - spList, err := provider.spClient.List(context.Background(), "") - assertErrorIsNil(t, err) - - var spObjID string - for spList.NotDone() { - for _, v := range spList.Values() { - if to.String(v.AppID) == appID { - spObjID = to.String(v.ObjectID) - goto FOUND - } - } - spList.Next() - } - t.Fatal("Couldn't find SP Object ID") - - FOUND: - // verify the new SP can be accessed - _, err = provider.spClient.Get(context.Background(), spObjID) - if err != nil { - t.Fatalf("Expected nil error on GET of new SP, got: %#v", err) - } + assertServicePrincipalExists(t, provider.spClient, spObjID) // Verify that the role assignments were created. Get the assignment // info from Azure and verify it matches the Reader role. @@ -589,10 +562,10 @@ func TestCredentialInteg(t *testing.T) { ra, err := provider.raClient.GetByID(context.Background(), raIDs[0]) assertErrorIsNil(t, err) - roleDefs, err := client.provider.ListRoles(context.Background(), fmt.Sprintf("subscriptions/%s", subscriptionID), "") + roleDefs, err := provider.ListRoleDefinitions(context.Background(), fmt.Sprintf("subscriptions/%s", subscriptionID), "") assertErrorIsNil(t, err) - defID := *ra.RoleAssignmentPropertiesWithScope.RoleDefinitionID + defID := *ra.Properties.RoleDefinitionID found := false for _, def := range roleDefs { if *def.ID == defID && *def.RoleName == "Reader" { @@ -619,14 +592,10 @@ func TestCredentialInteg(t *testing.T) { // Verify that SP get is an error after delete. Expected there // to be a delay and that this step would take some time/retries, // but that seems not to be the case. - _, err = provider.spClient.Get(context.Background(), spObjID) - - if err == nil { - t.Fatal("Expected error reading deleted SP") - } + assertServicePrincipalDoesNotExist(t, provider.spClient, spObjID) }) - t.Run("Static SP", func(t *testing.T) { + t.Run("static service principals", func(t *testing.T) { t.Parallel() b := backend() @@ -646,7 +615,7 @@ func TestCredentialInteg(t *testing.T) { // Add a Vault role that will provide creds with Azure "Reader" permissions subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") - rolename := "test_role" + rolename := "static_test_role" role := map[string]interface{}{ "azure_roles": fmt.Sprintf(`[{ "role_name": "Reader", @@ -659,11 +628,7 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) // Request credentials resp, err = b.HandleRequest(context.Background(), &logical.Request{ @@ -672,11 +637,7 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) origResp := resp @@ -696,11 +657,7 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) // Request credentials resp, err = b.HandleRequest(context.Background(), &logical.Request{ @@ -709,11 +666,7 @@ func TestCredentialInteg(t *testing.T) { Data: role, Storage: s, }) - assertErrorIsNil(t, err) - - if resp != nil && resp.IsError() { - t.Fatal(resp.Error()) - } + assertRespNoError(t, resp, err) // Test the added password by creating a new Azure provider with these // creds and attempting an operation with it. @@ -733,7 +686,7 @@ func TestCredentialInteg(t *testing.T) { for i := 0; i < 8; i++ { // New credentials are only tested during an actual operation, not provider creation. // This step should never fail. - p, err := newAzureProvider(settings) + p, err := newAzureProvider(settings, true, api.Passwords{}) if err != nil { t.Fatal(err) } @@ -766,13 +719,299 @@ func TestCredentialInteg(t *testing.T) { }) } +// Similar to TestCredentialInteg, this is an integration test against the live Azure service. It requires +// valid, sufficiently-privileged Azure credentials in env variables. +// The credentials provided to this must include permissions to use MS Graph and not AAD +// Unfortunately this means that this test cannot be run within the same test execution as TestCredentialInteg +func TestCredentialInteg_msgraph(t *testing.T) { + if os.Getenv("VAULT_ACC") != "1" { + t.SkipNow() + } + + if os.Getenv("AZURE_CLIENT_SECRET") == "" { + t.Skip("Azure Secrets: Azure environment variables not set. Skipping.") + } + + t.Run("service principals", func(t *testing.T) { + t.Parallel() + + skipIfMissingEnvVars(t, + "AZURE_SUBSCRIPTION_ID", + "AZURE_CLIENT_ID", + "AZURE_CLIENT_SECRET", + "AZURE_TENANT_ID", + ) + + b := backend() + s := new(logical.InmemStorage) + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + clientID := os.Getenv("AZURE_CLIENT_ID") + clientSecret := os.Getenv("AZURE_CLIENT_SECRET") + tenantID := os.Getenv("AZURE_TENANT_ID") + + config := &logical.BackendConfig{ + Logger: logging.NewVaultLogger(log.Trace), + System: &logical.StaticSystemView{ + DefaultLeaseTTLVal: defaultLeaseTTLHr, + MaxLeaseTTLVal: maxLeaseTTLHr, + }, + StorageView: s, + } + err := b.Setup(context.Background(), config) + assertErrorIsNil(t, err) + + configData := map[string]interface{}{ + "subscription_id": subscriptionID, + "client_id": clientID, + "client_secret": clientSecret, + "tenant_id": tenantID, + "use_microsoft_graph_api": true, + } + + configResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.CreateOperation, + Path: "config", + Data: configData, + Storage: s, + }) + assertRespNoError(t, configResp, err) + + roleName := "test_role_msgraph" + + roleData := map[string]interface{}{ + "azure_roles": fmt.Sprintf(`[ + { + "role_name": "Reader", + "scope": "/subscriptions/%s/resourceGroups/vault-azure-secrets-test1" + }, + { + "role_name": "Reader", + "scope": "/subscriptions/%s/resourceGroups/vault-azure-secrets-test2" + }]`, subscriptionID, subscriptionID), + } + + roleResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.CreateOperation, + Path: fmt.Sprintf("roles/%s", roleName), + Data: roleData, + Storage: s, + }) + assertRespNoError(t, roleResp, err) + + credsResp, err := b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.ReadOperation, + Path: fmt.Sprintf("creds/%s", roleName), + Storage: s, + }) + assertRespNoError(t, credsResp, err) + + appID := credsResp.Data["client_id"].(string) + + // Use the underlying provider to access clients directly for testing + client, err := b.getClient(context.Background(), s) + assertErrorIsNil(t, err) + provider := client.provider.(*provider) + spObjID := findServicePrincipalID(t, provider.spClient, appID) + + assertServicePrincipalExists(t, provider.spClient, spObjID) + + // Verify that the role assignments were created. Get the assignment + // info from Azure and verify it matches the Reader role. + raIDs := credsResp.Secret.InternalData["role_assignment_ids"].([]string) + equal(t, 2, len(raIDs)) + + ra, err := provider.raClient.GetByID(context.Background(), raIDs[0]) + assertErrorIsNil(t, err) + + roleDefs, err := provider.ListRoleDefinitions(context.Background(), fmt.Sprintf("subscriptions/%s", subscriptionID), "") + assertErrorIsNil(t, err) + + defID := *ra.Properties.RoleDefinitionID + found := false + for _, def := range roleDefs { + if *def.ID == defID && *def.RoleName == "Reader" { + found = true + break + } + } + + if !found { + t.Fatal("'Reader' role assignment not found") + } + + // Serialize and deserialize the secret to remove typing, as will really happen. + fakeSaveLoad(credsResp.Secret) + + // Revoke the Service Principal by sending back the secret we just received + req := &logical.Request{ + Secret: credsResp.Secret, + Storage: s, + } + + b.spRevoke(context.Background(), req, nil) + + // Verify that SP get is an error after delete. Expected there + // to be a delay and that this step would take some time/retries, + // but that seems not to be the case. + assertServicePrincipalDoesNotExist(t, provider.spClient, spObjID) + }) +} + +func skipIfMissingEnvVars(t *testing.T, envVars ...string) { + t.Helper() + for _, envVar := range envVars { + if os.Getenv(envVar) == "" { + t.Skipf("Missing env variable: [%s] - skipping test", envVar) + } + } +} + func assertClientSecret(tb testing.TB, data map[string]interface{}) { assertKeyExists(tb, data, "client_secret") actualPassword, ok := data["client_secret"].(string) if !ok { tb.Fatalf("client_secret is not a string") } - if len(actualPassword) != passwordLength { - tb.Fatalf("client_secret is not the correct length: expected %d but was %d", passwordLength, len(actualPassword)) + if len(actualPassword) != api.PasswordLength { + tb.Fatalf("client_secret is not the correct length: expected %d but was %d", api.PasswordLength, len(actualPassword)) + } +} + +type servicePrincipalResp struct { + AppID string `json:"appId"` + ID string `json:"id"` +} + +func findServicePrincipalID(t *testing.T, client api.ServicePrincipalClient, appID string) (spID string) { + t.Helper() + + switch spClient := client.(type) { + case api.AADServicePrincipalsClient: + spList, err := spClient.Client.List(context.Background(), "") + assertErrorIsNil(t, err) + for spList.NotDone() { + for _, sp := range spList.Values() { + if *sp.AppID == appID { + return *sp.ObjectID + } + } + err = spList.NextWithContext(context.Background()) + assertErrorIsNil(t, err) + } + case *api.AppClient: + pathVals := &url.Values{} + pathVals.Set("$filter", fmt.Sprintf("appId eq '%s'", appID)) + + prep := spClient.GetPreparer( + autorest.AsGet(), + autorest.WithPath(fmt.Sprintf("/v1.0/servicePrincipals?%s", pathVals.Encode())), + ) + + type listSPsResponse struct { + ServicePrincipals []servicePrincipalResp `json:"value"` + } + + respBody := listSPsResponse{} + + err := spClient.SendRequest(context.Background(), prep, + autorest.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&respBody), + ) + assertErrorIsNil(t, err) + + if len(respBody.ServicePrincipals) == 0 { + t.Fatalf("Failed to find service principals from application ID") + } + + for _, sp := range respBody.ServicePrincipals { + if sp.AppID == appID { + return sp.ID + } + } + default: + t.Fatalf("Unrecognized service principal client type: %T", spClient) + } + + t.Fatalf("Failed to find service principal with application ID: %s", appID) + return "" // Because compilers +} + +func assertServicePrincipalExists(t *testing.T, client api.ServicePrincipalClient, spID string) { + t.Helper() + + switch spClient := client.(type) { + case api.AADServicePrincipalsClient: + _, err := spClient.Client.Get(context.Background(), spID) + if err != nil { + t.Fatalf("Expected nil error on GET of new SP, got: %#v", err) + } + case *api.AppClient: + pathParams := map[string]interface{}{ + "id": spID, + } + + prep := spClient.GetPreparer( + autorest.AsGet(), + autorest.WithPathParameters("/v1.0/servicePrincipals/{id}", pathParams), + ) + + respBody := servicePrincipalResp{} + + err := spClient.SendRequest(context.Background(), prep, + autorest.WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByUnmarshallingJSON(&respBody), + ) + assertErrorIsNil(t, err) + + if respBody.ID == "" { + t.Fatalf("Failed to find service principal") + } + default: + t.Fatalf("Unrecognized service principal client type: %T", spClient) + } +} + +func assertServicePrincipalDoesNotExist(t *testing.T, client api.ServicePrincipalClient, spID string) { + t.Helper() + + switch spClient := client.(type) { + case api.AADServicePrincipalsClient: + _, err := spClient.Client.Get(context.Background(), spID) + if err == nil { + t.Fatalf("Expected error on GET of new SP") + } + case *api.AppClient: + pathParams := map[string]interface{}{ + "id": spID, + } + + prep := spClient.GetPreparer( + autorest.AsGet(), + autorest.WithPathParameters("/v1.0/servicePrincipals/{id}", pathParams), + ) + + respBody := servicePrincipalResp{} + + err := spClient.SendRequest(context.Background(), prep, + autorest.WithErrorUnlessStatusCode(http.StatusNotFound), + autorest.ByUnmarshallingJSON(&respBody), + ) + assertErrorIsNil(t, err) + + if respBody.ID != "" { + t.Fatalf("Found service principal when it shouldn't exist") + } + default: + t.Fatalf("Unrecognized service principal client type: %T", spClient) + } +} + +func assertRespNoError(t *testing.T, resp *logical.Response, err error) { + t.Helper() + + assertErrorIsNil(t, err) + + if resp != nil && resp.IsError() { + t.Fatal(resp.Error()) } } diff --git a/provider.go b/provider.go index 9949881..b6f8ab0 100644 --- a/provider.go +++ b/provider.go @@ -3,61 +3,19 @@ package azuresecrets import ( "context" "strings" + "time" + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure/auth" + "github.com/Azure/go-autorest/autorest/date" + "github.com/hashicorp/vault-plugin-secrets-azure/api" "github.com/hashicorp/vault/sdk/helper/useragent" "github.com/hashicorp/vault/sdk/version" ) -// AzureProvider is an interface to access underlying Azure client objects and supporting services. -// Where practical the original function signature is preserved. client provides higher -// level operations atop AzureProvider. -type AzureProvider interface { - ApplicationsClient - ServicePrincipalsClient - ADGroupsClient - RoleAssignmentsClient - RoleDefinitionsClient -} - -type ApplicationsClient interface { - CreateApplication(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) - DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) - GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) - UpdateApplicationPasswordCredentials( - ctx context.Context, - applicationObjectID string, - parameters graphrbac.PasswordCredentialsUpdateParameters) (result autorest.Response, err error) - ListApplicationPasswordCredentials(ctx context.Context, applicationObjectID string) (result graphrbac.PasswordCredentialListResult, err error) -} - -type ServicePrincipalsClient interface { - CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) -} - -type ADGroupsClient interface { - AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) - RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) - GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) - ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) -} - -type RoleAssignmentsClient interface { - CreateRoleAssignment( - ctx context.Context, - scope string, - roleAssignmentName string, - parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) - DeleteRoleAssignmentByID(ctx context.Context, roleID string) (authorization.RoleAssignment, error) -} - -type RoleDefinitionsClient interface { - ListRoles(ctx context.Context, scope string, filter string) ([]authorization.RoleDefinition, error) - GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) -} +var _ api.AzureProvider = (*provider)(nil) // provider is a concrete implementation of AzureProvider. In most cases it is a simple passthrough // to the appropriate client object. But if the response requires processing that is more practical @@ -65,78 +23,87 @@ type RoleDefinitionsClient interface { type provider struct { settings *clientSettings - appClient *graphrbac.ApplicationsClient - spClient *graphrbac.ServicePrincipalsClient - groupsClient *graphrbac.GroupsClient + appClient api.ApplicationsClient + spClient api.ServicePrincipalClient + groupsClient api.GroupsClient raClient *authorization.RoleAssignmentsClient rdClient *authorization.RoleDefinitionsClient } // newAzureProvider creates an azureProvider, backed by Azure client objects for underlying services. -func newAzureProvider(settings *clientSettings) (AzureProvider, error) { +func newAzureProvider(settings *clientSettings, useMsGraphApi bool, passwords api.Passwords) (api.AzureProvider, error) { // build clients that use the GraphRBAC endpoint - authorizer, err := getAuthorizer(settings, settings.Environment.GraphEndpoint) + graphAuthorizer, err := getAuthorizer(settings, settings.Environment.GraphEndpoint) if err != nil { return nil, err } - var userAgent string - if settings.PluginEnv != nil { - userAgent = useragent.PluginString(settings.PluginEnv, "azure-secrets") + userAgent := getUserAgent(settings) + + var appClient api.ApplicationsClient + var groupsClient api.GroupsClient + var spClient api.ServicePrincipalClient + if useMsGraphApi { + graphApiAuthorizer, err := getAuthorizer(settings, api.DefaultGraphMicrosoftComURI) + if err != nil { + return nil, err + } + + msGraphAppClient, err := api.NewMSGraphApplicationClient(settings.SubscriptionID, userAgent, graphApiAuthorizer) + if err != nil { + return nil, err + } + + appClient = msGraphAppClient + groupsClient = msGraphAppClient + spClient = msGraphAppClient } else { - userAgent = useragent.String() - } + aadGraphClient := graphrbac.NewApplicationsClient(settings.TenantID) + aadGraphClient.Authorizer = graphAuthorizer + aadGraphClient.AddToUserAgent(userAgent) - // Sets a unique ID in the user-agent - // Normal user-agent looks like this: - // - // Vault/1.6.0 (+https://www.vaultproject.io/; azure-secrets; go1.15.7) - // - // Here we append a unique code if it's an enterprise version, where - // VersionMetadata will contain a non-empty string like "ent" or "prem". - // Otherwise use the default identifier for OSS Vault. The end result looks - // like so: - // - // Vault/1.6.0 (+https://www.vaultproject.io/; azure-secrets; go1.15.7; b2c13ec1-60e8-4733-9a76-88dbb2ce2471) - vaultIDString := "; 15cd22ce-24af-43a4-aa83-4c1a36a4b177)" - ver := version.GetVersion() - if ver.VersionMetadata != "" { - vaultIDString = "; b2c13ec1-60e8-4733-9a76-88dbb2ce2471)" - } - userAgent = strings.Replace(userAgent, ")", vaultIDString, 1) + appClient = &api.ActiveDirectoryApplicationClient{Client: &aadGraphClient, Passwords: passwords} - appClient := graphrbac.NewApplicationsClient(settings.TenantID) - appClient.Authorizer = authorizer - appClient.AddToUserAgent(userAgent) + aadGroupsClient := graphrbac.NewGroupsClient(settings.TenantID) + aadGroupsClient.Authorizer = graphAuthorizer + aadGroupsClient.AddToUserAgent(userAgent) - spClient := graphrbac.NewServicePrincipalsClient(settings.TenantID) - spClient.Authorizer = authorizer - spClient.AddToUserAgent(userAgent) + groupsClient = api.ActiveDirectoryApplicationGroupsClient{ + BaseURI: aadGroupsClient.BaseURI, + TenantID: aadGroupsClient.TenantID, + Client: aadGroupsClient, + } - groupsClient := graphrbac.NewGroupsClient(settings.TenantID) - groupsClient.Authorizer = authorizer - groupsClient.AddToUserAgent(userAgent) + servicePrincipalClient := graphrbac.NewServicePrincipalsClient(settings.TenantID) + servicePrincipalClient.Authorizer = graphAuthorizer + servicePrincipalClient.AddToUserAgent(userAgent) + + spClient = api.AADServicePrincipalsClient{ + Client: servicePrincipalClient, + Passwords: passwords, + } + } // build clients that use the Resource Manager endpoint - authorizer, err = getAuthorizer(settings, settings.Environment.ResourceManagerEndpoint) + resourceManagerAuthorizer, err := getAuthorizer(settings, settings.Environment.ResourceManagerEndpoint) if err != nil { return nil, err } raClient := authorization.NewRoleAssignmentsClientWithBaseURI(settings.Environment.ResourceManagerEndpoint, settings.SubscriptionID) - raClient.Authorizer = authorizer + raClient.Authorizer = resourceManagerAuthorizer raClient.AddToUserAgent(userAgent) rdClient := authorization.NewRoleDefinitionsClientWithBaseURI(settings.Environment.ResourceManagerEndpoint, settings.SubscriptionID) - rdClient.Authorizer = authorizer + rdClient.Authorizer = resourceManagerAuthorizer rdClient.AddToUserAgent(userAgent) p := &provider{ settings: settings, - appClient: &appClient, - spClient: &spClient, - groupsClient: &groupsClient, + appClient: appClient, + spClient: spClient, + groupsClient: groupsClient, raClient: &raClient, rdClient: &rdClient, } @@ -144,61 +111,81 @@ func newAzureProvider(settings *clientSettings) (AzureProvider, error) { return p, nil } +func getUserAgent(settings *clientSettings) string { + var userAgent string + if settings.PluginEnv != nil { + userAgent = useragent.PluginString(settings.PluginEnv, "azure-secrets") + } else { + userAgent = useragent.String() + } + + // Sets a unique ID in the user-agent + // Normal user-agent looks like this: + // + // Vault/1.6.0 (+https://www.vaultproject.io/; azure-secrets; go1.15.7) + // + // Here we append a unique code if it's an enterprise version, where + // VersionMetadata will contain a non-empty string like "ent" or "prem". + // Otherwise use the default identifier for OSS Vault. The end result looks + // like so: + // + // Vault/1.6.0 (+https://www.vaultproject.io/; azure-secrets; go1.15.7; b2c13ec1-60e8-4733-9a76-88dbb2ce2471) + vaultIDString := "; 15cd22ce-24af-43a4-aa83-4c1a36a4b177)" + ver := version.GetVersion() + if ver.VersionMetadata != "" { + vaultIDString = "; b2c13ec1-60e8-4733-9a76-88dbb2ce2471)" + } + userAgent = strings.Replace(userAgent, ")", vaultIDString, 1) + + return userAgent +} + // getAuthorizer attempts to create an authorizer, preferring ClientID/Secret if present, // and falling back to MSI if not. -func getAuthorizer(settings *clientSettings, resource string) (authorizer autorest.Authorizer, err error) { - +func getAuthorizer(settings *clientSettings, resource string) (autorest.Authorizer, error) { if settings.ClientID != "" && settings.ClientSecret != "" && settings.TenantID != "" { config := auth.NewClientCredentialsConfig(settings.ClientID, settings.ClientSecret, settings.TenantID) config.AADEndpoint = settings.Environment.ActiveDirectoryEndpoint config.Resource = resource - authorizer, err = config.Authorizer() - if err != nil { - return nil, err - } - } else { - config := auth.NewMSIConfig() - config.Resource = resource - authorizer, err = config.Authorizer() - if err != nil { - return nil, err - } + return config.Authorizer() } - return authorizer, nil + config := auth.NewMSIConfig() + config.Resource = resource + return config.Authorizer() } // CreateApplication create a new Azure application object. -func (p *provider) CreateApplication(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) { - return p.appClient.Create(ctx, parameters) +func (p *provider) CreateApplication(ctx context.Context, displayName string) (result api.ApplicationResult, err error) { + return p.appClient.CreateApplication(ctx, displayName) } -func (p *provider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) { - return p.appClient.Get(ctx, applicationObjectID) +func (p *provider) GetApplication(ctx context.Context, applicationObjectID string) (result api.ApplicationResult, err error) { + return p.appClient.GetApplication(ctx, applicationObjectID) } // DeleteApplication deletes an Azure application object. // This will in turn remove the service principal (but not the role assignments). func (p *provider) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) { - return p.appClient.Delete(ctx, applicationObjectID) + return p.appClient.DeleteApplication(ctx, applicationObjectID) } -func (p *provider) UpdateApplicationPasswordCredentials(ctx context.Context, applicationObjectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (result autorest.Response, err error) { - return p.appClient.UpdatePasswordCredentials(ctx, applicationObjectID, parameters) +func (p *provider) AddApplicationPassword(ctx context.Context, applicationObjectID string, displayName string, endDateTime date.Time) (result api.PasswordCredentialResult, err error) { + return p.appClient.AddApplicationPassword(ctx, applicationObjectID, displayName, endDateTime) } -func (p *provider) ListApplicationPasswordCredentials(ctx context.Context, applicationObjectID string) (result graphrbac.PasswordCredentialListResult, err error) { - return p.appClient.ListPasswordCredentials(ctx, applicationObjectID) +func (p *provider) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyID string) (result autorest.Response, err error) { + return p.appClient.RemoveApplicationPassword(ctx, applicationObjectID, keyID) } // CreateServicePrincipal creates a new Azure service principal. // An Application must be created prior to calling this and pass in parameters. -func (p *provider) CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) { - return p.spClient.Create(ctx, parameters) +func (p *provider) CreateServicePrincipal(ctx context.Context, appID string, startDate time.Time, endDate time.Time) (id string, password string, err error) { + return p.spClient.CreateServicePrincipal(ctx, appID, startDate, endDate) } // ListRoles like all Azure roles with a scope (often subscription). -func (p *provider) ListRoles(ctx context.Context, scope string, filter string) (result []authorization.RoleDefinition, err error) { +func (p *provider) ListRoleDefinitions(ctx context.Context, scope string, filter string) (result []authorization.RoleDefinition, err error) { page, err := p.rdClient.List(ctx, scope, filter) if err != nil { @@ -209,7 +196,7 @@ func (p *provider) ListRoles(ctx context.Context, scope string, filter string) ( } // GetRoleByID fetches the full role definition given a roleID. -func (p *provider) GetRoleByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) { +func (p *provider) GetRoleDefinitionByID(ctx context.Context, roleID string) (result authorization.RoleDefinition, err error) { return p.rdClient.GetByID(ctx, roleID) } @@ -242,26 +229,21 @@ func (p *provider) ListRoleAssignments(ctx context.Context, filter string) ([]au } // AddGroupMember adds a member to a AAD Group. -func (p *provider) AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) { - return p.groupsClient.AddMember(ctx, groupObjectID, parameters) +func (p *provider) AddGroupMember(ctx context.Context, groupObjectID string, memberObjectID string) (err error) { + return p.groupsClient.AddGroupMember(ctx, groupObjectID, memberObjectID) } // RemoveGroupMember removes a member from a AAD Group. -func (p *provider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) { - return p.groupsClient.RemoveMember(ctx, groupObjectID, memberObjectID) +func (p *provider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (err error) { + return p.groupsClient.RemoveGroupMember(ctx, groupObjectID, memberObjectID) } // GetGroup gets group information from the directory. -func (p *provider) GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) { - return p.groupsClient.Get(ctx, objectID) +func (p *provider) GetGroup(ctx context.Context, objectID string) (result api.Group, err error) { + return p.groupsClient.GetGroup(ctx, objectID) } // ListGroups gets list of groups for the current tenant. -func (p *provider) ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) { - page, err := p.groupsClient.List(ctx, filter) - if err != nil { - return nil, err - } - - return page.Values(), nil +func (p *provider) ListGroups(ctx context.Context, filter string) (result []api.Group, err error) { + return p.groupsClient.ListGroups(ctx, filter) } diff --git a/provider_mock_test.go b/provider_mock_test.go new file mode 100644 index 0000000..471e8ab --- /dev/null +++ b/provider_mock_test.go @@ -0,0 +1,266 @@ +package azuresecrets + +import ( + "context" + "errors" + "fmt" + "regexp" + "strings" + "sync" + "time" + + "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" + "github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/date" + "github.com/Azure/go-autorest/autorest/to" + "github.com/hashicorp/vault-plugin-secrets-azure/api" +) + +// mockProvider is a Provider that provides stubs and simple, deterministic responses. +type mockProvider struct { + subscriptionID string + applications map[string]bool + passwords map[string]api.PasswordCredential + failNextCreateApplication bool + lock sync.Mutex +} + +func newMockProvider() api.AzureProvider { + return &mockProvider{ + subscriptionID: generateUUID(), + applications: make(map[string]bool), + passwords: make(map[string]api.PasswordCredential), + } +} + +// ListRoles returns a single fake role based on the inbound filter +func (m *mockProvider) ListRoleDefinitions(_ context.Context, _ string, filter string) (result []authorization.RoleDefinition, err error) { + reRoleName := regexp.MustCompile("roleName eq '(.*)'") + + match := reRoleName.FindAllStringSubmatch(filter, -1) + if len(match) > 0 { + name := match[0][1] + if name == "multiple" { + return []authorization.RoleDefinition{ + { + ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s-1", name)), + RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ + RoleName: to.StringPtr(name), + }, + }, + { + ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s-2", name)), + RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ + RoleName: to.StringPtr(name), + }, + }, + }, nil + } + return []authorization.RoleDefinition{ + { + ID: to.StringPtr(fmt.Sprintf("/subscriptions/FAKE_SUB_ID/providers/Microsoft.Authorization/roleDefinitions/FAKE_ROLE-%s", name)), + RoleDefinitionProperties: &authorization.RoleDefinitionProperties{ + RoleName: to.StringPtr(name), + }, + }, + }, nil + } + + return []authorization.RoleDefinition{}, nil +} + +// GetRoleByID will returns a fake role definition from the povided ID +// Assumes an ID format of: .*FAKE_ROLE-{rolename} +func (m *mockProvider) GetRoleDefinitionByID(_ context.Context, roleID string) (result authorization.RoleDefinition, err error) { + d := authorization.RoleDefinition{} + s := strings.Split(roleID, "FAKE_ROLE-") + if len(s) > 1 { + d.ID = to.StringPtr(roleID) + d.RoleDefinitionProperties = &authorization.RoleDefinitionProperties{ + RoleName: to.StringPtr(s[1]), + } + } + + return d, nil +} + +func (m *mockProvider) CreateServicePrincipal(_ context.Context, _ string, _ time.Time, _ time.Time) (string, string, error) { + id := generateUUID() + pass := generateUUID() + return id, pass, nil +} + +func (m *mockProvider) CreateApplication(_ context.Context, _ string) (api.ApplicationResult, error) { + if m.failNextCreateApplication { + m.failNextCreateApplication = false + return api.ApplicationResult{}, errors.New("Mock: fail to create application") + } + appObjID := generateUUID() + + m.lock.Lock() + defer m.lock.Unlock() + + m.applications[appObjID] = true + + return api.ApplicationResult{ + AppID: to.StringPtr(generateUUID()), + ID: &appObjID, + }, nil +} + +func (m *mockProvider) GetApplication(_ context.Context, _ string) (api.ApplicationResult, error) { + return api.ApplicationResult{ + AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"), + }, nil +} + +func (m *mockProvider) DeleteApplication(_ context.Context, applicationObjectID string) (autorest.Response, error) { + delete(m.applications, applicationObjectID) + return autorest.Response{}, nil +} + +func (m *mockProvider) AddApplicationPassword(_ context.Context, _ string, displayName string, endDateTime date.Time) (result api.PasswordCredentialResult, err error) { + keyID := generateUUID() + cred := api.PasswordCredential{ + DisplayName: to.StringPtr(displayName), + StartDate: &date.Time{Time: time.Now()}, + EndDate: &endDateTime, + KeyID: to.StringPtr(keyID), + SecretText: to.StringPtr(generateUUID()), + } + + m.lock.Lock() + defer m.lock.Unlock() + m.passwords[keyID] = cred + + return api.PasswordCredentialResult{ + PasswordCredential: cred, + }, nil +} + +func (m *mockProvider) RemoveApplicationPassword(_ context.Context, _ string, keyID string) (result autorest.Response, err error) { + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.passwords, keyID) + + return autorest.Response{}, nil +} + +func (m *mockProvider) appExists(s string) bool { + return m.applications[s] +} + +func (m *mockProvider) passwordExists(s string) bool { + _, ok := m.passwords[s] + return ok +} + +func (m *mockProvider) VMGet(_ context.Context, _ string, _ string, _ compute.InstanceViewTypes) (result compute.VirtualMachine, err error) { + return compute.VirtualMachine{}, nil +} + +func (m *mockProvider) VMUpdate(_ context.Context, _ string, _ string, _ compute.VirtualMachineUpdate) (result compute.VirtualMachinesUpdateFuture, err error) { + return compute.VirtualMachinesUpdateFuture{}, nil +} + +func (m *mockProvider) CreateRoleAssignment(_ context.Context, _ string, _ string, _ authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { + return authorization.RoleAssignment{ + ID: to.StringPtr(generateUUID()), + }, nil +} + +func (m *mockProvider) DeleteRoleAssignmentByID(_ context.Context, _ string) (result authorization.RoleAssignment, err error) { + return authorization.RoleAssignment{}, nil +} + +// AddGroupMember adds a member to a AAD Group. +func (m *mockProvider) AddGroupMember(_ context.Context, _ string, _ string) (err error) { + return nil +} + +// RemoveGroupMember removes a member from a AAD Group. +func (m *mockProvider) RemoveGroupMember(_ context.Context, _ string, _ string) (err error) { + return nil +} + +// GetGroup gets group information from the directory. +func (m *mockProvider) GetGroup(_ context.Context, objectID string) (api.Group, error) { + g := api.Group{ + ID: objectID, + } + s := strings.Split(objectID, "FAKE_GROUP-") + if len(s) > 1 { + g.DisplayName = s[1] + } + + return g, nil +} + +// ListGroups gets list of groups for the current tenant. +func (m *mockProvider) ListGroups(_ context.Context, filter string) (result []api.Group, err error) { + reGroupName := regexp.MustCompile("displayName eq '(.*)'") + + match := reGroupName.FindAllStringSubmatch(filter, -1) + if len(match) > 0 { + name := match[0][1] + if name == "multiple" { + return []api.Group{ + { + ID: fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name), + DisplayName: name, + }, + { + ID: fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-2", name), + DisplayName: name, + }, + }, nil + } + + return []api.Group{ + { + ID: fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name), + DisplayName: name, + }, + }, nil + } + + return []api.Group{}, nil +} + +// errMockProvider simulates a normal provider which fails to associate a role, +// returning an error +type errMockProvider struct { + *mockProvider +} + +func newErrMockProvider() api.AzureProvider { + return &errMockProvider{ + mockProvider: &mockProvider{ + subscriptionID: generateUUID(), + applications: make(map[string]bool), + passwords: make(map[string]api.PasswordCredential), + }, + } +} + +// CreateRoleAssignment for the errMockProvider intentionally fails +func (e *errMockProvider) CreateRoleAssignment(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { + return authorization.RoleAssignment{}, errors.New("PrincipalNotFound") +} + +// GetApplication for the errMockProvider only returns an application if that +// key is found, unlike mockProvider which returns the same application object +// id each time. Existing tests depend on the mockProvider behavior, which is +// why errMockProvider has it's own version. +func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (api.ApplicationResult, error) { + for s := range e.applications { + if s == applicationObjectID { + return api.ApplicationResult{ + AppID: to.StringPtr(s), + }, nil + } + } + return api.ApplicationResult{}, errors.New("not found") +}