Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Azure Auth order to be specified via azureAuthMethods component metadata #3217

Merged
merged 10 commits into from
Nov 10, 2023
176 changes: 120 additions & 56 deletions internal/authentication/azure/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,94 +105,158 @@ func (s EnvironmentSettings) GetAzureEnvironment() (*cloud.Configuration, error)
}
}

// GetTokenCredential returns an azcore.TokenCredential retrieved from, in order:
// 1. Client credentials
// 2. Client certificate
// 3. Workload identity
// 4. MSI (we use a timeout of 1 second when no compatible managed identity implementation is available)
// 5. Azure CLI
//
// This order and timeout (with the exception of the additional step 5) matches the DefaultAzureCredential.
func (s EnvironmentSettings) GetTokenCredential() (azcore.TokenCredential, error) {
// Create a chain
var creds []azcore.TokenCredential
errs := make([]error, 0, 3)

// 1. Client credentials
func (s EnvironmentSettings) addClientCredentialsProvider(creds *[]azcore.TokenCredential, errs *[]error) {
if c, e := s.GetClientCredentials(); e == nil {
cred, err := c.GetTokenCredential()
if err == nil {
creds = append(creds, cred)
*creds = append(*creds, cred)
} else {
errs = append(errs, err)
*errs = append(*errs, err)
}
}
}

// 2. Client certificate
func (s EnvironmentSettings) addClientCertificateProvider(creds *[]azcore.TokenCredential, errs *[]error) {
if c, e := s.GetClientCert(); e == nil {
cred, err := c.GetTokenCredential()
if err == nil {
creds = append(creds, cred)
*creds = append(*creds, cred)
} else {
errs = append(errs, err)
*errs = append(*errs, err)
}
}
}

// 3. Workload identity
func (s EnvironmentSettings) addWorkloadIdentityProvider(creds *[]azcore.TokenCredential, errs *[]error) {
// workload identity requires values for AZURE_AUTHORITY_HOST, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, AZURE_TENANT_ID
// The workload identity mutating admissions webhook in Kubernetes injects these values into the pod.
// These environment variables are read using the default WorkloadIdentityCredentialOptions

workloadCred, err := azidentity.NewWorkloadIdentityCredential(nil)
if err == nil {
creds = append(creds, workloadCred)
*creds = append(*creds, workloadCred)
} else {
errs = append(errs, err)
*errs = append(*errs, err)
}
}

// 4. MSI with timeout of 1 second (same as DefaultAzureCredential)
{
c := s.GetMSI()
msiCred, err := c.GetTokenCredential()
func (s EnvironmentSettings) addManagedIdentityProvider(timeout time.Duration, creds *[]azcore.TokenCredential, errs *[]error) {
c := s.GetMSI()
msiCred, err := c.GetTokenCredential()

useTimeout := true
if _, ok := os.LookupEnv(identityEndpoint); ok {
// App Service & Service Fabric
useTimeout := true
if _, ok := os.LookupEnv(identityEndpoint); ok {
// App Service, Functions, Service Fabric and Container Apps
useTimeout = false
} else {
if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
// Azure Arc
useTimeout = false
} else {
if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
// Azure Arc
if _, ok := os.LookupEnv(msiEndpoint); ok {
// Cloud Shell
useTimeout = false
} else if isVirtualMachineWithManagedIdentity() {
// Azure VM with MSI enabled
useTimeout = false
} else {
if _, ok := os.LookupEnv(msiEndpoint); ok {
// Cloud Shell
useTimeout = false
} else if isVirtualMachineWithManagedIdentity() {
// Azure VM with MSI enabled
useTimeout = false
}
}
}
}

// We need to use a timeout for MSI on environments where it is not available because the request for the default IMDS endpoint can hang for several minutes.
if useTimeout {
msiCred = &timeoutWrapper{cred: msiCred, authmethod: "managed identity", timeout: 1 * time.Second}
}
// We need to use a timeout for MSI on environments where it is not available because the request for the default IMDS endpoint can hang for several minutes.
if useTimeout {
msiCred = &timeoutWrapper{cred: msiCred, authmethod: "managed identity", timeout: timeout}
}

if err == nil {
creds = append(creds, msiCred)
} else {
errs = append(errs, err)
}
if err == nil {
*creds = append(*creds, msiCred)
} else {
*errs = append(*errs, err)
}
}

// 5. AzureCLICredential
{
cred, credErr := azidentity.NewAzureCLICredential(nil)
if credErr == nil {
creds = append(creds, &timeoutWrapper{cred: cred, authmethod: "Azure CLI", timeout: 30 * time.Second})
} else {
errs = append(errs, credErr)
func (s EnvironmentSettings) addCLIProvider(timeout time.Duration, creds *[]azcore.TokenCredential, errs *[]error) {
cred, credErr := azidentity.NewAzureCLICredential(nil)
if credErr == nil {
*creds = append(*creds, &timeoutWrapper{cred: cred, authmethod: "Azure CLI", timeout: 30 * time.Second})
} else {
*errs = append(*errs, credErr)
}
}

func (s EnvironmentSettings) addProviderByAuthMethodName(authMethod string, creds *[]azcore.TokenCredential, errs *[]error) {
switch authMethod {
case "clientcredentials", "creds":
s.addClientCredentialsProvider(creds, errs)
case "clientcertificate", "cert":
s.addClientCertificateProvider(creds, errs)
case "workloadidentity", "wi":
s.addWorkloadIdentityProvider(creds, errs)
case "managedidentity", "mi":
s.addManagedIdentityProvider(1*time.Second, creds, errs)
case "commandlineinterface", "cli":
s.addCLIProvider(30*time.Second, creds, errs)
}
}

func getAzureAuthMethods() []string {
return []string{"clientcredentials", "creds", "clientcertificate", "cert", "workloadidentity", "wi", "managedidentity", "mi", "commandlineinterface", "cli", "none"}
}

// GetTokenCredential returns an azcore.TokenCredential retrieved from the order specified via
// the azureAuthMethods component metadata property which denotes a comma-separated list of auth methods to try in order.
// The possible values contained are (case-insensitive):
// ServicePrincipal, Certificate, WorkloadIdentity, ManagedIdentity, CLI
// The string "None" can be used to disable Azure authentication.
//
// If the azureAuthMethods property is not present, the following order is used (which with the exception of step 5
// matches the DefaultAzureCredential order):
// 1. Client credentials
// 2. Client certificate
// 3. Workload identity
// 4. MSI (we use a timeout of 1 second when no compatible managed identity implementation is available)
// 5. Azure CLI
func (s EnvironmentSettings) GetTokenCredential() (azcore.TokenCredential, error) {
// Create a chain
var creds []azcore.TokenCredential
errs := make([]error, 0, 3)

authMethods, ok := s.GetEnvironment("AzureAuthMethods")
if !ok || strings.TrimSpace(authMethods) == "" {
// 1. Client credentials
s.addClientCredentialsProvider(&creds, &errs)

// 2. Client certificate
s.addClientCertificateProvider(&creds, &errs)

// 3. Workload identity
s.addWorkloadIdentityProvider(&creds, &errs)

// 4. MSI with timeout of 1 second (same as DefaultAzureCredential)
s.addManagedIdentityProvider(1*time.Second, &creds, &errs)

// 5. AzureCLICredential
s.addCLIProvider(30*time.Second, &creds, &errs)
} else {
authMethodIdentifiers := getAzureAuthMethods()
authMethods := strings.Split(strings.ToLower(strings.TrimSpace(authMethods)), ",")
for _, authMethod := range authMethods {
berndverst marked this conversation as resolved.
Show resolved Hide resolved
authMethod = strings.TrimSpace(authMethod)
found := false
for _, authMethodIdentifier := range authMethodIdentifiers {
if authMethod == authMethodIdentifier {
found = true
if authMethod != "none" {
s.addProviderByAuthMethodName(authMethod, &creds, &errs)
break
} else {
// If authMethod is "none", we don't add any provider and return an error
return nil, fmt.Errorf("all Azure auth methods have been disabled with auth method 'None'")
}
}
}
if !found {
return nil, fmt.Errorf("invalid Azure auth method: %v", authMethod)
}
}
}

Expand Down
74 changes: 74 additions & 0 deletions internal/authentication/azure/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,80 @@ func TestAuthorizorWithMSI(t *testing.T) {
assert.NotNil(t, spt)
}

func TestFallbackToMSIbutAzureAuthDisallowed(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
map[string]string{
"azureClientId": fakeClientID,
"vaultName": "vaultName",
"azureAuthMethods": "None",
},
)
assert.NoError(t, err)

_, err = settings.GetTokenCredential()
assert.Error(t, err)
assert.ErrorContains(t, err, "all Azure auth methods have been disabled")
}

func TestFallbackToMSIandInAllowedList(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
map[string]string{
"azureClientId": fakeClientID,
"vaultName": "vaultName",
"azureAuthMethods": "clientcredentials,clientcertificate,workloadidentity,managedIdentity",
},
)
assert.NoError(t, err)

testCertConfig := settings.GetMSI()
assert.NotNil(t, testCertConfig)

spt, err := settings.GetTokenCredential()
assert.NoError(t, err)
assert.NotNil(t, spt)
}

func TestFallbackToMSIandNotInAllowedList(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
map[string]string{
"azureClientId": fakeClientID,
"vaultName": "vaultName",
"azureAuthMethods": "clientcredentials,clientcertificate,workloadidentity",
},
)
assert.NoError(t, err)

_, err = settings.GetTokenCredential()
assert.Error(t, err)
assert.ErrorContains(t, err, "no suitable token provider for Azure AD")
}

func TestFallbackToMSIandInvalidAuthMethod(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
map[string]string{
"azureClientId": fakeClientID,
"vaultName": "vaultName",
"azureAuthMethods": "clientcredentials,clientcertificate,workloadidentity,managedIdentity,cli,SUPERAUTH",
},
)
require.NoError(t, err)

testCertConfig := settings.GetMSI()
require.NotNil(t, testCertConfig)

_, err = settings.GetTokenCredential()
assert.Error(t, err)
assert.ErrorContains(t, err, "invalid Azure auth method: superauth")
}

func TestAuthorizorWithMSIAndUserAssignedID(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
Expand Down
3 changes: 3 additions & 0 deletions internal/authentication/azure/metadata-properties.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ var MetadataKeys = map[string][]string{ //nolint:gochecknoglobals
// Identifier for the Azure environment
// Allowed values (case-insensitive): AzurePublicCloud/AzurePublic, AzureChinaCloud/AzureChina, AzureUSGovernmentCloud/AzureUSGovernment
"AzureEnvironment": {"azureEnvironment", "azureCloud"},
// Identifier for the Azure authentication methods to try (in order), comma-separated
// Allowed values (case-insensitive): ClientCredentials, creds, ClientCertificate, cert, WorkloadIdentity, wi, ManagedIdentity, mi, CommandLineInterface, cli, None
"AzureAuthMethods": {"azureAuthMethods", "azureADAuthMethods", "entraIDAuthMethods", "microsoftEntraIDAuthMethods"},

// Metadata keys for storage components

Expand Down