Skip to content

fix: remove azure claim overage code. #2005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 0 additions & 210 deletions internal/api/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,10 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"unicode/utf8"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
"github.com/supabase/auth/internal/conf"
"golang.org/x/oauth2"
)
Expand Down Expand Up @@ -167,208 +162,3 @@ func (g azureProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*Use

return nil, fmt.Errorf("azure: no OIDC ID token present in response")
}

type AzureIDTokenClaimSource struct {
Endpoint string `json:"endpoint"`
}

type AzureIDTokenClaims struct {
jwt.RegisteredClaims

Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
XMicrosoftEmailDomainOwnerVerified any `json:"xms_edov"`

ClaimNames map[string]string `json:"_claim_names"`
ClaimSources map[string]AzureIDTokenClaimSource `json:"_claim_sources"`
}

// ResolveIndirectClaims resolves claims in the Azure Token that require a call to the Microsoft Graph API. This is typically to an API like this: https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects?view=graph-rest-1.0&tabs=http
func (c *AzureIDTokenClaims) ResolveIndirectClaims(ctx context.Context, httpClient *http.Client, accessToken string) (map[string]any, error) {
if len(c.ClaimNames) == 0 || len(c.ClaimSources) == 0 {
return nil, nil
}

result := make(map[string]any)

for claimName, claimSource := range c.ClaimNames {
claimEndpointObject, ok := c.ClaimSources[claimSource]

if !ok || !strings.HasPrefix(claimEndpointObject.Endpoint, "https://") {
continue
}

u, err := url.ParseRequestURI(claimEndpointObject.Endpoint)
if err != nil {
return nil, fmt.Errorf("azure: failed to parse endpoint URL %q (resolving overage claim %q): %w", claimEndpointObject.Endpoint, claimName, err)
}

queryParams := u.Query()
if !queryParams.Has("api-version") {
// https://stackoverflow.com/questions/51085863/retrieve-group-claims-using-claim-sources-returns-the-specified-api-version-is
queryParams.Add("api-version", "1.6")
u.RawQuery = queryParams.Encode()
}

claimEndpoint := u.String()

req, err := http.NewRequestWithContext(ctx, http.MethodPost, claimEndpoint, strings.NewReader(`{"securityEnabledOnly":true}`))
if err != nil {
return nil, fmt.Errorf("azure: failed to create POST request to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err)
}

req.Header.Add("Authorization", "Bearer "+accessToken)
req.Header.Add("Content-Type", "application/json")

resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("azure: failed to send POST request to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err)
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
resBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2*1024))

body := "<empty>"
if len(resBody) > 0 {
if utf8.Valid(resBody) {
body = string(resBody)
} else {
body = "<invalid-utf8>"
}
}

readErrString := ""
if readErr != nil {
readErrString = fmt.Sprintf(" with read error %q", readErr.Error())
}

return nil, fmt.Errorf("azure: received %d but expected 200 HTTP status code when sending POST to %q (resolving overage claim %q) with response body %q%s", resp.StatusCode, claimEndpoint, claimName, body, readErrString)
}

var responseResult struct {
Value any `json:"value"`
}

if err := json.NewDecoder(resp.Body).Decode(&responseResult); err != nil {
return nil, fmt.Errorf("azure: failed to parse JSON response from POST to %q (resolving overage claim %q): %w", claimEndpoint, claimName, err)
}

result[claimName] = responseResult.Value
}

return result, nil
}

func (c *AzureIDTokenClaims) IsEmailVerified() bool {
emailVerified := false

edov := c.XMicrosoftEmailDomainOwnerVerified

// If xms_edov is not set, and an email is present or xms_edov is true,
// only then is the email regarded as verified.
// https://learn.microsoft.com/en-us/azure/active-directory/develop/migrate-off-email-claim-authorization#using-the-xms_edov-optional-claim-to-determine-email-verification-status-and-migrate-users
if edov == nil {
// An email is provided, but xms_edov is not -- probably not
// configured, so we must assume the email is verified as Azure
// will only send out a potentially unverified email address in
// single-tenanat apps.
emailVerified = c.Email != ""
} else {
edovBool := false

// Azure can't be trusted with how they encode the xms_edov
// claim. Sometimes it's "xms_edov": "1", sometimes "xms_edov": true.
switch v := edov.(type) {
case bool:
edovBool = v

case string:
edovBool = v == "1" || v == "true"

default:
edovBool = false
}

emailVerified = c.Email != "" && edovBool
}

return emailVerified
}

// removeAzureClaimsFromCustomClaims contains the list of claims to be removed
// from the CustomClaims map. See:
// https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference
var removeAzureClaimsFromCustomClaims = []string{
"aud",
"iss",
"iat",
"nbf",
"exp",
"c_hash",
"at_hash",
"aio",
"nonce",
"rh",
"uti",
"jti",
"ver",
"sub",
"name",
"preferred_username",
}

func parseAzureIDToken(ctx context.Context, token *oidc.IDToken, accessToken string) (*oidc.IDToken, *UserProvidedData, error) {
var data UserProvidedData

var azureClaims AzureIDTokenClaims
if err := token.Claims(&azureClaims); err != nil {
return nil, nil, err
}

data.Metadata = &Claims{
Issuer: token.Issuer,
Subject: token.Subject,
ProviderId: token.Subject,
PreferredUsername: azureClaims.PreferredUsername,
FullName: azureClaims.Name,
CustomClaims: make(map[string]any),
}

if azureClaims.Email != "" {
data.Emails = []Email{{
Email: azureClaims.Email,
Verified: azureClaims.IsEmailVerified(),
Primary: true,
}}
}

if err := token.Claims(&data.Metadata.CustomClaims); err != nil {
return nil, nil, err
}

resolvedClaims, err := azureClaims.ResolveIndirectClaims(ctx, http.DefaultClient, accessToken)
if err != nil {
return nil, nil, err
}

if data.Metadata.CustomClaims == nil {
if resolvedClaims != nil {
data.Metadata.CustomClaims = make(map[string]any, len(resolvedClaims))
}
}

if data.Metadata.CustomClaims != nil {
for _, claim := range removeAzureClaimsFromCustomClaims {
delete(data.Metadata.CustomClaims, claim)
}
}

for k, v := range resolvedClaims {
data.Metadata.CustomClaims[k] = v
}

return token, &data, nil
}
146 changes: 1 addition & 145 deletions internal/api/provider/azure_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
package provider

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/require"
)
import "testing"

func TestIsAzureIssuer(t *testing.T) {
positiveExamples := []string{
Expand All @@ -36,138 +27,3 @@ func TestIsAzureIssuer(t *testing.T) {
}
}
}

func TestAzureResolveIndirectClaims(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)

w.Write([]byte(`{
"@odata.context": "https://graph.microsoft.com/v1.0/$metadata#Collection(Edm.String)",
"value": [
"fee2c45b-915a-4a64-b130-f4eb9e75525e",
"4fe90ae7-065a-478b-9400-e0a0e1cbd540",
"c9ee2d50-9e8a-4352-b97c-4c2c99557c22",
"e0c3beaf-eeb4-43d8-abc5-94f037a65697"
]
}`))
}))

defer server.Close()

var claims AzureIDTokenClaims

resolvedClaims, err := claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token")
require.Nil(t, resolvedClaims)
require.Nil(t, err)

claims.ClaimNames = make(map[string]string)

resolvedClaims, err = claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token")
require.Nil(t, resolvedClaims)
require.Nil(t, err)

claims.ClaimNames = map[string]string{
"groups": "src1",
"missing-source": "src2",
"not-https": "src3",
}
claims.ClaimSources = map[string]AzureIDTokenClaimSource{
"src1": {
Endpoint: server.URL,
},
"src3": {
Endpoint: "http://example.com",
},
}

resolvedClaims, err = claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token")
require.NoError(t, err)
require.NotNil(t, resolvedClaims)
require.Equal(t, 1, len(resolvedClaims))
require.Equal(t, 4, len(resolvedClaims["groups"].([]interface{})))
}

func TestAzureResolveIndirectClaimsFailures(t *testing.T) {
examples := []struct {
name string
urlSuffix string
statusCode int
body []byte
expectedError string
}{
{
name: "invalid url",
urlSuffix: "\000",
expectedError: "azure: failed to parse endpoint URL \"SERVER-URL\\x00\" (resolving overage claim \"groups\"): parse \"SERVER-URL\\x00\": net/url: invalid control character in URL",
},
{
name: "no such server",
urlSuffix: "000",
expectedError: "azure: failed to send POST request to \"SERVER-URL000\" (resolving overage claim \"groups\"): Post \"SERVER-URL000\": dial tcp: address PORT000: invalid port",
},
{
name: "non 200 status code",
statusCode: 500,
body: []byte(`something is wrong`),
expectedError: "azure: received 500 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"something is wrong\"",
},
{
name: "non 200 status code, non utf8 valid body",
statusCode: 201,
body: []byte{255, 255, 255, 255},
expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"<invalid-utf8>\"",
},
{
name: "non 200 status code, empty body",
statusCode: 201,
body: []byte{},
expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"<empty>\"",
},
{
name: "non 200 status code, body over 2KB",
statusCode: 201,
body: []byte(strings.Repeat("x", 2*1024+1)),
expectedError: "azure: received 201 but expected 200 HTTP status code when sending POST to \"SERVER-URL\" (resolving overage claim \"groups\") with response body \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\"",
},
{
name: "ok response, not json",
statusCode: 200,
body: []byte("not json"),
expectedError: "azure: failed to parse JSON response from POST to \"SERVER-URL\" (resolving overage claim \"groups\"): invalid character 'o' in literal null (expecting 'u')",
},
}

for _, example := range examples {
t.Run(example.name, func(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "1.6", r.URL.Query().Get("api-version"))

w.WriteHeader(example.statusCode)

w.Write(example.body)
}))

defer server.Close()

u, _ := url.Parse(server.URL)

var claims AzureIDTokenClaims

claims.ClaimNames = map[string]string{
"groups": "src1",
}
claims.ClaimSources = map[string]AzureIDTokenClaimSource{
"src1": {
Endpoint: server.URL + example.urlSuffix,
},
}

resolvedClaims, err := claims.ResolveIndirectClaims(context.Background(), server.Client(), "access-token")
require.Nil(t, resolvedClaims)
require.Error(t, err)
require.Equal(t, example.expectedError, strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(err.Error(), server.URL, "SERVER-URL"), u.Port(), "PORT"), "?api-version=1.6", ""))
})
}

}
Loading
Loading