Skip to content

Commit

Permalink
feat(oidc): optionally query OIDC UserInfo to gather group claims
Browse files Browse the repository at this point in the history
Signed-off-by: Nathanael Liechti <[email protected]>
  • Loading branch information
the-technat authored Sep 6, 2023
1 parent eba40d4 commit b82ed12
Show file tree
Hide file tree
Showing 8 changed files with 454 additions and 48 deletions.
14 changes: 14 additions & 0 deletions docs/operator-manual/user-management/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,20 @@ For a simple case this can be:
oidc.config: |
requestedIDTokenClaims: {"groups": {"essential": true}}
```

### Retrieving group claims when not in the token

Some OIDC providers don't return the group information for a user in the token, even if explicitly requested using the `requestedIDTokenClaims` setting (Okta for example). They instead provide the groups on the user info endpoint. With the following config, Argo CD queries the user info endpoint during login for groups information of a user:

```yaml
oidc.config: |
enableUserInfoGroups: true
userInfoPath: /userinfo
userInfoCacheExpiration: "5m"
```

**Note: If you omit the `userInfoCacheExpiration` setting, the argocd-server will cache group information as long as the OIDC token is valid!**

### Configuring a custom logout URL for your OIDC provider

Optionally, if your OIDC provider exposes a logout API and you wish to configure a custom logout URL for the purposes of invalidating
Expand Down
30 changes: 28 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ func (a *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
// Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex)
var err error
mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(a.DexServerAddr, a.BaseHRef, a.DexTLSConfig))
a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.DexTLSConfig, a.BaseHRef)
a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.DexTLSConfig, a.BaseHRef, cacheutil.NewRedisCache(a.RedisClient, a.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
errorsutil.CheckError(err)
mux.HandleFunc(common.LoginEndpoint, a.ssoClientApp.HandleLogin)
mux.HandleFunc(common.CallbackEndpoint, a.ssoClientApp.HandleCallback)
Expand Down Expand Up @@ -1303,7 +1303,33 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error
if err != nil {
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
}
return claims, newToken, nil

// Some SSO implementations (Okta) require a call to
// the OIDC user info path to get attributes like groups
// we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims
// otherwise this would cause a panic
var groupClaims jwt.MapClaims
if groupClaims, ok = claims.(jwt.MapClaims); !ok {
groupClaims = *claims.(*jwt.MapClaims)
}
iss := jwtutil.StringField(groupClaims, "iss")
if iss != "argocd" && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" {
sub := jwtutil.StringField(groupClaims, "sub")
userInfo, unauthorized, err := a.ssoClientApp.GetUserInfo(sub, a.settings.IssuerURL(), a.settings.UserInfoPath())
if unauthorized {
log.Errorf("error while quering userinfo endpoint: %v", err)
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session")
}
if err != nil {
return claims, "", status.Errorf(codes.Internal, "error fetching user info endpoint: %v", err)
}
if groupClaims["sub"] != userInfo["sub"] {
return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
}
groupClaims["groups"] = userInfo["groups"]
}

return groupClaims, newToken, nil
}

// getToken extracts the token from gRPC metadata or cookie headers
Expand Down
125 changes: 113 additions & 12 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ import (
"github.com/argoproj/argo-cd/v2/server/rbacpolicy"
"github.com/argoproj/argo-cd/v2/test"
"github.com/argoproj/argo-cd/v2/util/assets"
"github.com/argoproj/argo-cd/v2/util/cache"
cacheutil "github.com/argoproj/argo-cd/v2/util/cache"
appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate"
"github.com/argoproj/argo-cd/v2/util/oidc"
"github.com/argoproj/argo-cd/v2/util/rbac"
settings_util "github.com/argoproj/argo-cd/v2/util/settings"
testutil "github.com/argoproj/argo-cd/v2/util/test"
Expand Down Expand Up @@ -533,7 +535,7 @@ func dexMockHandler(t *testing.T, url string) func(http.ResponseWriter, *http.Re
}
}

func getTestServer(t *testing.T, anonymousEnabled bool, withFakeSSO bool, useDexForSSO bool) (argocd *ArgoCDServer, oidcURL string) {
func getTestServer(t *testing.T, anonymousEnabled bool, withFakeSSO bool, useDexForSSO bool, additionalOIDCConfig settings_util.OIDCConfig) (argocd *ArgoCDServer, oidcURL string) {
cm := test.NewFakeConfigMap()
if anonymousEnabled {
cm.Data["users.anonymous.enabled"] = "true"
Expand Down Expand Up @@ -562,13 +564,12 @@ connectors:
clientID: test-client
clientSecret: $dex.oidc.clientSecret`
} else {
oidcConfig := settings_util.OIDCConfig{
Name: "Okta",
Issuer: oidcServer.URL,
ClientID: "argo-cd",
ClientSecret: "$oidc.okta.clientSecret",
}
oidcConfigString, err := yaml.Marshal(oidcConfig)
// override required oidc config fields but keep other configs as passed in
additionalOIDCConfig.Name = "Okta"
additionalOIDCConfig.Issuer = oidcServer.URL
additionalOIDCConfig.ClientID = "argo-cd"
additionalOIDCConfig.ClientSecret = "$oidc.okta.clientSecret"
oidcConfigString, err := yaml.Marshal(additionalOIDCConfig)
require.NoError(t, err)
cm.Data["oidc.config"] = string(oidcConfigString)
// Avoid bothering with certs for local tests.
Expand All @@ -589,9 +590,109 @@ connectors:
argoCDOpts.DexServerAddr = ts.URL
}
argocd = NewServer(context.Background(), argoCDOpts)
var err error
argocd.ssoClientApp, err = oidc.NewClientApp(argocd.settings, argocd.DexServerAddr, argocd.DexTLSConfig, argocd.BaseHRef, cache.NewInMemoryCache(24*time.Hour))
require.NoError(t, err)
return argocd, oidcServer.URL
}

func TestGetClaims(t *testing.T) {

defaultExpiry := jwt.NewNumericDate(time.Now().Add(time.Hour * 24))
defaultExpiryUnix := float64(defaultExpiry.Unix())

type testData struct {
test string
claims jwt.MapClaims
expectedErrorContains string
expectedClaims jwt.MapClaims
expectNewToken bool
additionalOIDCConfig settings_util.OIDCConfig
}
var tests = []testData{
{
test: "GetClaims",
claims: jwt.MapClaims{
"aud": "argo-cd",
"exp": defaultExpiry,
"sub": "randomUser",
},
expectedErrorContains: "",
expectedClaims: jwt.MapClaims{
"aud": "argo-cd",
"exp": defaultExpiryUnix,
"sub": "randomUser",
},
expectNewToken: false,
additionalOIDCConfig: settings_util.OIDCConfig{},
},
{
// note: a passing test with user info groups can never be achieved since the user never logged in properly
// therefore the oidcClient's cache contains no accessToken for the user info endpoint
// and since the oidcClient cache is unexported (for good reasons) we can't mock this behaviour
test: "GetClaimsWithUserInfoGroupsEnabled",
claims: jwt.MapClaims{
"aud": common.ArgoCDClientAppID,
"exp": defaultExpiry,
"sub": "randomUser",
},
expectedErrorContains: "invalid session",
expectedClaims: jwt.MapClaims{
"aud": common.ArgoCDClientAppID,
"exp": defaultExpiryUnix,
"sub": "randomUser",
},
expectNewToken: false,
additionalOIDCConfig: settings_util.OIDCConfig{
EnableUserInfoGroups: true,
UserInfoPath: "/userinfo",
UserInfoCacheExpiration: "5m",
},
},
}

for _, testData := range tests {
testDataCopy := testData

t.Run(testDataCopy.test, func(t *testing.T) {
t.Parallel()

// Must be declared here to avoid race.
ctx := context.Background() //nolint:ineffassign,staticcheck

argocd, oidcURL := getTestServer(t, false, true, false, testDataCopy.additionalOIDCConfig)

// create new JWT and store it on the context to simulate an incoming request
testDataCopy.claims["iss"] = oidcURL
testDataCopy.expectedClaims["iss"] = oidcURL
token := jwt.NewWithClaims(jwt.SigningMethodRS512, testDataCopy.claims)
key, err := jwt.ParseRSAPrivateKeyFromPEM(testutil.PrivateKey)
require.NoError(t, err)
tokenString, err := token.SignedString(key)
require.NoError(t, err)
ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(apiclient.MetaDataTokenKey, tokenString))

gotClaims, newToken, err := argocd.getClaims(ctx)

// Note: testutil.oidcMockHandler currently doesn't implement reissuing expired tokens
// so newToken will always be empty
if testDataCopy.expectNewToken {
assert.NotEmpty(t, newToken)
}
if testDataCopy.expectedClaims == nil {
assert.Nil(t, gotClaims)
} else {
assert.Equal(t, testDataCopy.expectedClaims, gotClaims)
}
if testDataCopy.expectedErrorContains != "" {
assert.ErrorContains(t, err, testDataCopy.expectedErrorContains, "getClaims should have thrown an error and return an error")
} else {
assert.NoError(t, err)
}
})
}
}

func TestAuthenticate_3rd_party_JWTs(t *testing.T) {
// Marshaling single strings to strings is typical, so we test for this relatively common behavior.
jwt.MarshalSingleStringAsArray = false
Expand Down Expand Up @@ -723,7 +824,7 @@ func TestAuthenticate_3rd_party_JWTs(t *testing.T) {
// Must be declared here to avoid race.
ctx := context.Background() //nolint:ineffassign,staticcheck

argocd, oidcURL := getTestServer(t, testDataCopy.anonymousEnabled, true, testDataCopy.useDex)
argocd, oidcURL := getTestServer(t, testDataCopy.anonymousEnabled, true, testDataCopy.useDex, settings_util.OIDCConfig{})

if testDataCopy.useDex {
testDataCopy.claims.Issuer = fmt.Sprintf("%s/api/dex", oidcURL)
Expand Down Expand Up @@ -779,7 +880,7 @@ func TestAuthenticate_no_request_metadata(t *testing.T) {
t.Run(testDataCopy.test, func(t *testing.T) {
t.Parallel()

argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true)
argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true, settings_util.OIDCConfig{})
ctx := context.Background()

ctx, err := argocd.Authenticate(ctx)
Expand Down Expand Up @@ -825,7 +926,7 @@ func TestAuthenticate_no_SSO(t *testing.T) {
// Must be declared here to avoid race.
ctx := context.Background() //nolint:ineffassign,staticcheck

argocd, dexURL := getTestServer(t, testDataCopy.anonymousEnabled, false, true)
argocd, dexURL := getTestServer(t, testDataCopy.anonymousEnabled, false, true, settings_util.OIDCConfig{})
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{Issuer: fmt.Sprintf("%s/api/dex", dexURL)})
tokenString, err := token.SignedString([]byte("key"))
require.NoError(t, err)
Expand Down Expand Up @@ -933,7 +1034,7 @@ func TestAuthenticate_bad_request_metadata(t *testing.T) {
// Must be declared here to avoid race.
ctx := context.Background() //nolint:ineffassign,staticcheck

argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true)
argocd, _ := getTestServer(t, testDataCopy.anonymousEnabled, true, true, settings_util.OIDCConfig{})
ctx = metadata.NewIncomingContext(context.Background(), testDataCopy.metadata)

ctx, err := argocd.Authenticate(ctx)
Expand Down
4 changes: 4 additions & 0 deletions util/cache/inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ func NewInMemoryCache(expiration time.Duration) *InMemoryCache {
}
}

func init() {
gob.Register([]interface{}{})
}

// compile-time validation of adherance of the CacheClient contract
var _ CacheClient = &InMemoryCache{}

Expand Down
Loading

0 comments on commit b82ed12

Please sign in to comment.