Skip to content

Commit

Permalink
Allow Azure Auth order to be specified via azureAuthMethods compone…
Browse files Browse the repository at this point in the history
…nt metadata (#3217)

Signed-off-by: Bernd Verst <github@bernd.dev>
Signed-off-by: luigirende <luigirende@gmail.com>
Signed-off-by: luiren <luigirende@gmail.com>
Signed-off-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: luigirende <luigirende@gmail.com>
Co-authored-by: luiren <luigi.rende@assistdigital.com>
  • Loading branch information
4 people committed Nov 10, 2023
1 parent 8fe74b1 commit 3bcd0c7
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 56 deletions.
176 changes: 120 additions & 56 deletions internal/authentication/azure/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,94 +105,158 @@ func (s EnvironmentSettings) GetAzureEnvironment() (*cloud.Configuration, error)
}
}

// GetTokenCredential returns an azcore.TokenCredential retrieved from, in order:
// 1. Client credentials
// 2. Client certificate
// 3. Workload identity
// 4. MSI (we use a timeout of 1 second when no compatible managed identity implementation is available)
// 5. Azure CLI
//
// This order and timeout (with the exception of the additional step 5) matches the DefaultAzureCredential.
func (s EnvironmentSettings) GetTokenCredential() (azcore.TokenCredential, error) {
// Create a chain
var creds []azcore.TokenCredential
errs := make([]error, 0, 3)

// 1. Client credentials
func (s EnvironmentSettings) addClientCredentialsProvider(creds *[]azcore.TokenCredential, errs *[]error) {
if c, e := s.GetClientCredentials(); e == nil {
cred, err := c.GetTokenCredential()
if err == nil {
creds = append(creds, cred)
*creds = append(*creds, cred)
} else {
errs = append(errs, err)
*errs = append(*errs, err)
}
}
}

// 2. Client certificate
func (s EnvironmentSettings) addClientCertificateProvider(creds *[]azcore.TokenCredential, errs *[]error) {
if c, e := s.GetClientCert(); e == nil {
cred, err := c.GetTokenCredential()
if err == nil {
creds = append(creds, cred)
*creds = append(*creds, cred)
} else {
errs = append(errs, err)
*errs = append(*errs, err)
}
}
}

// 3. Workload identity
func (s EnvironmentSettings) addWorkloadIdentityProvider(creds *[]azcore.TokenCredential, errs *[]error) {
// workload identity requires values for AZURE_AUTHORITY_HOST, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, AZURE_TENANT_ID
// The workload identity mutating admissions webhook in Kubernetes injects these values into the pod.
// These environment variables are read using the default WorkloadIdentityCredentialOptions

workloadCred, err := azidentity.NewWorkloadIdentityCredential(nil)
if err == nil {
creds = append(creds, workloadCred)
*creds = append(*creds, workloadCred)
} else {
errs = append(errs, err)
*errs = append(*errs, err)
}
}

// 4. MSI with timeout of 1 second (same as DefaultAzureCredential)
{
c := s.GetMSI()
msiCred, err := c.GetTokenCredential()
func (s EnvironmentSettings) addManagedIdentityProvider(timeout time.Duration, creds *[]azcore.TokenCredential, errs *[]error) {
c := s.GetMSI()
msiCred, err := c.GetTokenCredential()

useTimeout := true
if _, ok := os.LookupEnv(identityEndpoint); ok {
// App Service & Service Fabric
useTimeout := true
if _, ok := os.LookupEnv(identityEndpoint); ok {
// App Service, Functions, Service Fabric and Container Apps
useTimeout = false
} else {
if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
// Azure Arc
useTimeout = false
} else {
if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
// Azure Arc
if _, ok := os.LookupEnv(msiEndpoint); ok {
// Cloud Shell
useTimeout = false
} else if isVirtualMachineWithManagedIdentity() {
// Azure VM with MSI enabled
useTimeout = false
} else {
if _, ok := os.LookupEnv(msiEndpoint); ok {
// Cloud Shell
useTimeout = false
} else if isVirtualMachineWithManagedIdentity() {
// Azure VM with MSI enabled
useTimeout = false
}
}
}
}

// We need to use a timeout for MSI on environments where it is not available because the request for the default IMDS endpoint can hang for several minutes.
if useTimeout {
msiCred = &timeoutWrapper{cred: msiCred, authmethod: "managed identity", timeout: 1 * time.Second}
}
// We need to use a timeout for MSI on environments where it is not available because the request for the default IMDS endpoint can hang for several minutes.
if useTimeout {
msiCred = &timeoutWrapper{cred: msiCred, authmethod: "managed identity", timeout: timeout}
}

if err == nil {
creds = append(creds, msiCred)
} else {
errs = append(errs, err)
}
if err == nil {
*creds = append(*creds, msiCred)
} else {
*errs = append(*errs, err)
}
}

// 5. AzureCLICredential
{
cred, credErr := azidentity.NewAzureCLICredential(nil)
if credErr == nil {
creds = append(creds, &timeoutWrapper{cred: cred, authmethod: "Azure CLI", timeout: 30 * time.Second})
} else {
errs = append(errs, credErr)
func (s EnvironmentSettings) addCLIProvider(timeout time.Duration, creds *[]azcore.TokenCredential, errs *[]error) {
cred, credErr := azidentity.NewAzureCLICredential(nil)
if credErr == nil {
*creds = append(*creds, &timeoutWrapper{cred: cred, authmethod: "Azure CLI", timeout: 30 * time.Second})
} else {
*errs = append(*errs, credErr)
}
}

func (s EnvironmentSettings) addProviderByAuthMethodName(authMethod string, creds *[]azcore.TokenCredential, errs *[]error) {
switch authMethod {
case "clientcredentials", "creds":
s.addClientCredentialsProvider(creds, errs)
case "clientcertificate", "cert":
s.addClientCertificateProvider(creds, errs)
case "workloadidentity", "wi":
s.addWorkloadIdentityProvider(creds, errs)
case "managedidentity", "mi":
s.addManagedIdentityProvider(1*time.Second, creds, errs)
case "commandlineinterface", "cli":
s.addCLIProvider(30*time.Second, creds, errs)
}
}

func getAzureAuthMethods() []string {
return []string{"clientcredentials", "creds", "clientcertificate", "cert", "workloadidentity", "wi", "managedidentity", "mi", "commandlineinterface", "cli", "none"}
}

// GetTokenCredential returns an azcore.TokenCredential retrieved from the order specified via
// the azureAuthMethods component metadata property which denotes a comma-separated list of auth methods to try in order.
// The possible values contained are (case-insensitive):
// ServicePrincipal, Certificate, WorkloadIdentity, ManagedIdentity, CLI
// The string "None" can be used to disable Azure authentication.
//
// If the azureAuthMethods property is not present, the following order is used (which with the exception of step 5
// matches the DefaultAzureCredential order):
// 1. Client credentials
// 2. Client certificate
// 3. Workload identity
// 4. MSI (we use a timeout of 1 second when no compatible managed identity implementation is available)
// 5. Azure CLI
func (s EnvironmentSettings) GetTokenCredential() (azcore.TokenCredential, error) {
// Create a chain
var creds []azcore.TokenCredential
errs := make([]error, 0, 3)

authMethods, ok := s.GetEnvironment("AzureAuthMethods")
if !ok || strings.TrimSpace(authMethods) == "" {
// 1. Client credentials
s.addClientCredentialsProvider(&creds, &errs)

// 2. Client certificate
s.addClientCertificateProvider(&creds, &errs)

// 3. Workload identity
s.addWorkloadIdentityProvider(&creds, &errs)

// 4. MSI with timeout of 1 second (same as DefaultAzureCredential)
s.addManagedIdentityProvider(1*time.Second, &creds, &errs)

// 5. AzureCLICredential
s.addCLIProvider(30*time.Second, &creds, &errs)
} else {
authMethodIdentifiers := getAzureAuthMethods()
authMethods := strings.Split(strings.ToLower(strings.TrimSpace(authMethods)), ",")
for _, authMethod := range authMethods {
authMethod = strings.TrimSpace(authMethod)
found := false
for _, authMethodIdentifier := range authMethodIdentifiers {
if authMethod == authMethodIdentifier {
found = true
if authMethod != "none" {
s.addProviderByAuthMethodName(authMethod, &creds, &errs)
break
} else {
// If authMethod is "none", we don't add any provider and return an error
return nil, fmt.Errorf("all Azure auth methods have been disabled with auth method 'None'")
}
}
}
if !found {
return nil, fmt.Errorf("invalid Azure auth method: %v", authMethod)
}
}
}

Expand Down
74 changes: 74 additions & 0 deletions internal/authentication/azure/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,80 @@ func TestAuthorizorWithMSI(t *testing.T) {
assert.NotNil(t, spt)
}

func TestFallbackToMSIbutAzureAuthDisallowed(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
map[string]string{
"azureClientId": fakeClientID,
"vaultName": "vaultName",
"azureAuthMethods": "None",
},
)
assert.NoError(t, err)

_, err = settings.GetTokenCredential()
assert.Error(t, err)
assert.ErrorContains(t, err, "all Azure auth methods have been disabled")
}

func TestFallbackToMSIandInAllowedList(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
map[string]string{
"azureClientId": fakeClientID,
"vaultName": "vaultName",
"azureAuthMethods": "clientcredentials,clientcertificate,workloadidentity,managedIdentity",
},
)
assert.NoError(t, err)

testCertConfig := settings.GetMSI()
assert.NotNil(t, testCertConfig)

spt, err := settings.GetTokenCredential()
assert.NoError(t, err)
assert.NotNil(t, spt)
}

func TestFallbackToMSIandNotInAllowedList(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
map[string]string{
"azureClientId": fakeClientID,
"vaultName": "vaultName",
"azureAuthMethods": "clientcredentials,clientcertificate,workloadidentity",
},
)
assert.NoError(t, err)

_, err = settings.GetTokenCredential()
assert.Error(t, err)
assert.ErrorContains(t, err, "no suitable token provider for Azure AD")
}

func TestFallbackToMSIandInvalidAuthMethod(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
map[string]string{
"azureClientId": fakeClientID,
"vaultName": "vaultName",
"azureAuthMethods": "clientcredentials,clientcertificate,workloadidentity,managedIdentity,cli,SUPERAUTH",
},
)
require.NoError(t, err)

testCertConfig := settings.GetMSI()
require.NotNil(t, testCertConfig)

_, err = settings.GetTokenCredential()
assert.Error(t, err)
assert.ErrorContains(t, err, "invalid Azure auth method: superauth")
}

func TestAuthorizorWithMSIAndUserAssignedID(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
Expand Down
3 changes: 3 additions & 0 deletions internal/authentication/azure/metadata-properties.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ var MetadataKeys = map[string][]string{ //nolint:gochecknoglobals
// Identifier for the Azure environment
// Allowed values (case-insensitive): AzurePublicCloud/AzurePublic, AzureChinaCloud/AzureChina, AzureUSGovernmentCloud/AzureUSGovernment
"AzureEnvironment": {"azureEnvironment", "azureCloud"},
// Identifier for the Azure authentication methods to try (in order), comma-separated
// Allowed values (case-insensitive): ClientCredentials, creds, ClientCertificate, cert, WorkloadIdentity, wi, ManagedIdentity, mi, CommandLineInterface, cli, None
"AzureAuthMethods": {"azureAuthMethods", "azureADAuthMethods", "entraIDAuthMethods", "microsoftEntraIDAuthMethods"},

// Metadata keys for storage components

Expand Down

0 comments on commit 3bcd0c7

Please sign in to comment.