diff --git a/pkg/azure/defaultazurecredential/authorizer.go b/pkg/azure/defaultazurecredential/authorizer.go index 20aff08b0..c0a289f80 100644 --- a/pkg/azure/defaultazurecredential/authorizer.go +++ b/pkg/azure/defaultazurecredential/authorizer.go @@ -2,11 +2,14 @@ package defaultazurecredential import ( "context" + "fmt" + "os" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" "k8s.io/klog/v2" ) @@ -26,20 +29,34 @@ func NewAuthorizer() (autorest.Authorizer, error) { return nil, err } + scope := tokenScopeFromEnvironment() + klog.V(7).Infof("Fetching token with scope %s", scope) return autorest.NewBearerAuthorizer(&tokenCredentialWrapper{ - cred: cred, + cred: cred, + scope: scope, }), nil } +func tokenScopeFromEnvironment() string { + cloud := os.Getenv("AZURE_ENVIRONMENT") + env, err := azure.EnvironmentFromName(cloud) + if err != nil { + env = azure.PublicCloud + } + + return fmt.Sprintf("%s.default", env.TokenAudience) +} + type tokenCredentialWrapper struct { - cred azcore.TokenCredential + cred azcore.TokenCredential + scope string } func (w *tokenCredentialWrapper) OAuthToken() string { klog.V(7).Info("Getting Azure token using DefaultAzureCredential") token, err := w.cred.GetToken(context.Background(), policy.TokenRequestOptions{ - Scopes: []string{"https://management.azure.com/.default"}, + Scopes: []string{w.scope}, }) if err != nil { diff --git a/pkg/azure/defaultazurecredential/authorizer_test.go b/pkg/azure/defaultazurecredential/authorizer_test.go new file mode 100644 index 000000000..0646f43a4 --- /dev/null +++ b/pkg/azure/defaultazurecredential/authorizer_test.go @@ -0,0 +1,22 @@ +package defaultazurecredential + +import ( + "os" + "testing" +) + +func TestTokenScopeFromEnvironment(t *testing.T) { + scope := map[string]string{ + "AZUREPUBLICCLOUD": "https://management.azure.com/.default", + "AZURECHINACLOUD": "https://management.chinacloudapi.cn/.default", + "AZUREUSGOVERNMENTCLOUD": "https://management.usgovcloudapi.net/.default", + } + + for env, expectedScope := range scope { + os.Setenv("AZURE_ENVIRONMENT", env) + scope := tokenScopeFromEnvironment() + if scope != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, scope) + } + } +}