Skip to content

Commit

Permalink
[release-1.11] AUTH: Detect Azure MSI compatible environments and rem…
Browse files Browse the repository at this point in the history
…ove timeout (#2916)

Signed-off-by: Bernd Verst <[email protected]>
  • Loading branch information
berndverst authored Jun 17, 2023
1 parent 85be43a commit ca0d8e8
Showing 1 changed file with 62 additions and 10 deletions.
72 changes: 62 additions & 10 deletions internal/authentication/azure/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/pem"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"
Expand All @@ -39,6 +40,13 @@ type EnvironmentSettings struct {
Cloud *cloud.Configuration
}

const (
arcIMDSEndpoint = "IMDS_ENDPOINT"
identityEndpoint = "IDENTITY_ENDPOINT"
msiEndpoint = "MSI_ENDPOINT"
imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
)

// timeoutWrapper prevents a potentially very long timeout when managed identity or CLI credential aren't available
type timeoutWrapper struct {
cred azcore.TokenCredential
Expand Down Expand Up @@ -101,9 +109,8 @@ func (s EnvironmentSettings) GetAzureEnvironment() (*cloud.Configuration, error)
// 1. Client credentials
// 2. Client certificate
// 3. Workload identity
// 4. MSI with a timeout of 1 second
// 4. MSI (we use a timeout of 1 second when no compatible managed identity implementation is available)
// 5. Azure CLI
// 6. Retry MSI without timeout
//
// This order and timeout (with the exception of the additional step 5) matches the DefaultAzureCredential.
func (s EnvironmentSettings) GetTokenCredential() (azcore.TokenCredential, error) {
Expand Down Expand Up @@ -144,12 +151,36 @@ func (s EnvironmentSettings) GetTokenCredential() (azcore.TokenCredential, error
}

// 4. MSI with timeout of 1 second (same as DefaultAzureCredential)
var msiCred *azcore.TokenCredential
{
c := s.GetMSI()
msiCred, err := c.GetTokenCredential()

useTimeout := true
if _, ok := os.LookupEnv(identityEndpoint); ok {
// App Service & Service Fabric
useTimeout = false
} else {
if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
// Azure Arc
useTimeout = false
} else {
if _, ok := os.LookupEnv(msiEndpoint); ok {
// Cloud Shell
useTimeout = false
} else if isVirtualMachineWithManagedIdentity() {
// Azure VM with MSI enabled
useTimeout = false
}
}
}

// We need to use a timeout for MSI on environments where it is not available because the request for the default IMDS endpoint can hang for several minutes.
if useTimeout {
msiCred = &timeoutWrapper{cred: msiCred, authmethod: "managed identity", timeout: 1 * time.Second}
}

if err == nil {
creds = append(creds, &timeoutWrapper{cred: msiCred, authmethod: "managed identity", timeout: 1 * time.Second})
creds = append(creds, msiCred)
} else {
errs = append(errs, err)
}
Expand All @@ -159,17 +190,12 @@ func (s EnvironmentSettings) GetTokenCredential() (azcore.TokenCredential, error
{
cred, credErr := azidentity.NewAzureCLICredential(nil)
if credErr == nil {
creds = append(creds, &timeoutWrapper{cred: cred, authmethod: "Azure CLI", timeout: 5 * time.Second})
creds = append(creds, &timeoutWrapper{cred: cred, authmethod: "Azure CLI", timeout: 30 * time.Second})
} else {
errs = append(errs, credErr)
}
}

// 6. Retry MSI without timeout
if msiCred != nil {
creds = append(creds, *msiCred)
}

if len(creds) == 0 {
return nil, fmt.Errorf("no suitable token provider for Azure AD; errors: %w", errors.Join(errs...))
}
Expand Down Expand Up @@ -408,3 +434,29 @@ func (c MSIConfig) GetTokenCredential() (token azcore.TokenCredential, err error
func (s EnvironmentSettings) GetEnvironment(key string) (val string, ok bool) {
return metadata.GetMetadataProperty(s.Metadata, MetadataKeys[key]...)
}

// isVirtualMachineWithManagedIdentity returns true if the code is running on a virtual machine with managed identity enabled.
// This is indicated by the standard IMDS endpoint being reachable.
func isVirtualMachineWithManagedIdentity() bool {
client := http.Client{
Timeout: time.Second * 3,
}

req, err := http.NewRequest(http.MethodGet, imdsEndpoint, nil)
if err != nil {
return false
}

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

req = req.WithContext(ctx)

resp, err := client.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()

return true
}

0 comments on commit ca0d8e8

Please sign in to comment.