From 6867a697107c11c3262d83f6f91f1ab8585b9ab3 Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath Date: Tue, 17 Dec 2024 23:49:53 +0530 Subject: [PATCH] add tests --- .../endpoint/authz/OAuth2AuthzEndpoint.java | 2 + .../JWTAccessTokenOIDCClaimsHandler.java | 168 +++++++++++------- .../JWTAccessTokenOIDCClaimsHandlerTest.java | 116 +++++++++++- 3 files changed, 222 insertions(+), 64 deletions(-) diff --git a/components/org.wso2.carbon.identity.oauth.endpoint/src/main/java/org/wso2/carbon/identity/oauth/endpoint/authz/OAuth2AuthzEndpoint.java b/components/org.wso2.carbon.identity.oauth.endpoint/src/main/java/org/wso2/carbon/identity/oauth/endpoint/authz/OAuth2AuthzEndpoint.java index 07cdc75a72..d30ebf4cdc 100644 --- a/components/org.wso2.carbon.identity.oauth.endpoint/src/main/java/org/wso2/carbon/identity/oauth/endpoint/authz/OAuth2AuthzEndpoint.java +++ b/components/org.wso2.carbon.identity.oauth.endpoint/src/main/java/org/wso2/carbon/identity/oauth/endpoint/authz/OAuth2AuthzEndpoint.java @@ -1420,6 +1420,8 @@ private void addToAuthenticationResultDetailsToOAuthMessage(OAuthMessage oAuthMe authnResult.getProperty(FrameworkConstants.AnalyticsAttributes.SESSION_ID)); // Adding federated tokens come with the authentication result of the authorization call. addFederatedTokensToSessionCache(oAuthMessage, authnResult); + // Adding federated user claims come with the authentication result to resolve access token claims in + // federated flow. addUnfilteredFederatedUserClaimsToSessionCache(oAuthMessage, authnResult); } diff --git a/components/org.wso2.carbon.identity.oauth/src/main/java/org/wso2/carbon/identity/openidconnect/JWTAccessTokenOIDCClaimsHandler.java b/components/org.wso2.carbon.identity.oauth/src/main/java/org/wso2/carbon/identity/openidconnect/JWTAccessTokenOIDCClaimsHandler.java index c430157634..08854a2353 100644 --- a/components/org.wso2.carbon.identity.oauth/src/main/java/org/wso2/carbon/identity/openidconnect/JWTAccessTokenOIDCClaimsHandler.java +++ b/components/org.wso2.carbon.identity.oauth/src/main/java/org/wso2/carbon/identity/openidconnect/JWTAccessTokenOIDCClaimsHandler.java @@ -114,11 +114,11 @@ public JWTClaimsSet handleCustomClaims(JWTClaimsSet.Builder builder, OAuthAuthzR } /** - * Get response map. + * Get user claims in OIDC dialect. * - * @param requestMsgCtx Token request message context - * @return Mapped claimed - * @throws OAuthSystemException + * @param requestMsgCtx OAuthTokenReqMessageContext + * @return User claims in OIDC dialect + * @throws IdentityOAuth2Exception IdentityOAuth2Exception */ private Map getUserClaimsInOIDCDialect(OAuthTokenReqMessageContext requestMsgCtx) throws IdentityOAuth2Exception { @@ -147,6 +147,7 @@ private Map getUserClaimsInOIDCDialect(OAuthTokenReqMessageConte } else { // Get claim map from the cached attributes userClaimsInOIDCDialect = getOIDCClaimsFromUserAttributes(userAttributes, requestMsgCtx); + // Since this is a federated flow, we need to get the federated user attributes as well. Map federatedUserAttributes = getCachedUserAttributes(requestMsgCtx, true); Map federatedUserClaimsInOIDCDialect = getOIDCClaimsFromFederatedUserAttributes(federatedUserAttributes, requestMsgCtx); @@ -162,6 +163,14 @@ private Map getUserClaimsInOIDCDialect(OAuthTokenReqMessageConte } } + /** + * Filter claims with allowed access token claims + * + * @param userClaimsInOIDCDialect User claims in OIDC dialect + * @param requestMsgCtx OAuthTokenReqMessageContext + * @return Filtered claims + * @throws IdentityOAuth2Exception IdentityOAuth2Exception + */ private Map filterClaims(Map userClaimsInOIDCDialect, OAuthTokenReqMessageContext requestMsgCtx) throws IdentityOAuth2Exception { @@ -178,6 +187,14 @@ private Map filterClaims(Map userClaimsInOIDCDia return handleClaimsFormat(claims, clientId, spTenantDomain); } + /** + * Filter claims with allowed access token claims + * + * @param userClaimsInOIDCDialect User claims in OIDC dialect + * @param authzReqMessageContext OAuthAuthzReqMessageContext + * @return Filtered claims + * @throws IdentityOAuth2Exception IdentityOAuth2Exception + */ private Map filterClaims(Map userClaimsInOIDCDialect, OAuthAuthzReqMessageContext authzReqMessageContext) throws IdentityOAuth2Exception { @@ -194,6 +211,13 @@ private Map filterClaims(Map userClaimsInOIDCDia return handleClaimsFormat(claims, clientId, spTenantDomain); } + /** + * Get claims for local user form userstore. + * + * @param requestMsgCtx OAuthTokenReqMessageContext + * @return Local user claims + * @throws IdentityOAuth2Exception IdentityOAuth2Exception + */ private Map retrieveClaimsForLocalUser(OAuthTokenReqMessageContext requestMsgCtx) throws IdentityOAuth2Exception { @@ -202,7 +226,7 @@ private Map retrieveClaimsForLocalUser(OAuthTokenReqMessageConte String clientId = requestMsgCtx.getOauth2AccessTokenReqDTO().getClientId(); AuthenticatedUser authenticatedUser = requestMsgCtx.getAuthorizedUser(); - return getUserClaimsInOIDCDialect(spTenantDomain, clientId, authenticatedUser); + return getLocalUserClaimsInOIDCDialect(spTenantDomain, clientId, authenticatedUser); } catch (UserStoreException | IdentityApplicationManagementException | IdentityException | OrganizationManagementException e) { if (FrameworkUtils.isContinueOnClaimHandlingErrorAllowed()) { @@ -284,6 +308,13 @@ private Map getOIDCClaimsFromFederatedUserAttributes(Map getUserClaimsInOIDCDialect(OAuthAuthzReqMessageContext authzReqMessageContext) throws IdentityOAuth2Exception { @@ -308,11 +339,12 @@ private Map getUserClaimsInOIDCDialect(OAuthAuthzReqMessageConte } } else { userClaimsInOIDCDialect = getOIDCClaimMapFromUserAttributes(userAttributes); - Map unfilteredUserAttributes = + // Since this is a federated flow we are retrieving the federated user attributes as well. + Map federatedUserAttributes = getUserAttributesCachedAgainstToken(getAccessToken(authzReqMessageContext), true); Map federatedUserClaimsInOIDCDialect = - getUserClaimsInOIDCDialectFromFederatedUserAttributed(authzReqMessageContext - .getAuthorizationReqDTO().getTenantDomain(), unfilteredUserAttributes); + getUserClaimsInOIDCDialectFromFederatedUserAttributes(authzReqMessageContext + .getAuthorizationReqDTO().getTenantDomain(), federatedUserAttributes); userClaimsInOIDCDialect.putAll(federatedUserClaimsInOIDCDialect); } return filterClaims(userClaimsInOIDCDialect, authzReqMessageContext); @@ -412,6 +444,13 @@ entries in the persistence layer(SessionStore). return userAttributes; } + /** + * Get user attributes cached against the authorization code. + * + * @param authorizationCode Authorization Code + * @param fetchFederatedUserAttr Flag to indicate whether to fetch federated user attributes. + * @return User attributes cached against the authorization code + */ private Map getUserAttributesCachedAgainstAuthorizationCode(String authorizationCode, boolean fetchFederatedUserAttr) { @@ -423,6 +462,13 @@ private Map getUserAttributesCachedAgainstAuthorizationCod return userAttributes; } + /** + * GEt user attributes cached against the device code. + * + * @param deviceCode Device Code + * @param fetchFederatedUserAttributes Flag to indicate whether to fetch federated user attributes. + * @return User attributes cached against the device code + */ private Map getUserAttributesCachedAgainstDeviceCode(String deviceCode, boolean fetchFederatedUserAttributes) { @@ -463,6 +509,13 @@ private Map getUserAttributesFromCacheUsingCode(String aut return cacheEntry == null ? new HashMap<>() : cacheEntry.getUserAttributes(); } + /** + * Get user attributes cached against the access token. + * + * @param accessToken Access Token + * @param fetchFederatedUserAttributes Flag to indicate whether to fetch federated user attributes. + * @return User attributes cached against the access token + */ private Map getUserAttributesCachedAgainstToken(String accessToken, boolean fetchFederatedUserAttributes) { Map userAttributes = Collections.emptyMap(); @@ -473,6 +526,13 @@ private Map getUserAttributesCachedAgainstToken(String acc return userAttributes; } + /** + * Get claims for local user form userstore. + * + * @param authzReqMessageContext OAuthAuthzReqMessageContext + * @return Local user claims + * @throws IdentityOAuth2Exception IdentityOAuth2Exception + */ private Map retrieveClaimsForLocalUser(OAuthAuthzReqMessageContext authzReqMessageContext) throws IdentityOAuth2Exception { @@ -481,7 +541,7 @@ private Map retrieveClaimsForLocalUser(OAuthAuthzReqMessageConte String clientId = authzReqMessageContext.getAuthorizationReqDTO().getConsumerKey(); AuthenticatedUser authenticatedUser = authzReqMessageContext.getAuthorizationReqDTO().getUser(); - return getUserClaimsInOIDCDialect(spTenantDomain, clientId, authenticatedUser); + return getLocalUserClaimsInOIDCDialect(spTenantDomain, clientId, authenticatedUser); } catch (UserStoreException | IdentityApplicationManagementException | IdentityException | OrganizationManagementException e) { if (FrameworkUtils.isContinueOnClaimHandlingErrorAllowed()) { @@ -522,12 +582,13 @@ private Map retrieveClaimsForFederatedUser(OAuthAuthzReqMessageC return userClaimsMappedToOIDCDialect; } Map userAttributes = authenticatedUser.getUserAttributes(); - Map unfilteredFederatedUserAttributes = + // Since this is a federated flow we are retrieving the federated user attributes as well. + Map federatedUserAttributes = oAuth2AuthorizeReqDTO.getUnfilteredFederatedUserAttributes(); userClaimsMappedToOIDCDialect = getOIDCClaimMapFromUserAttributes(userAttributes); Map federatedUserClaimsMappedToOIDCDialect = - getUserClaimsInOIDCDialectFromFederatedUserAttributed(authzReqMessageContext.getAuthorizationReqDTO() - .getTenantDomain(), unfilteredFederatedUserAttributes); + getUserClaimsInOIDCDialectFromFederatedUserAttributes(authzReqMessageContext.getAuthorizationReqDTO() + .getTenantDomain(), federatedUserAttributes); userClaimsMappedToOIDCDialect.putAll(federatedUserClaimsMappedToOIDCDialect); return userClaimsMappedToOIDCDialect; } @@ -549,9 +610,17 @@ private Map getOIDCClaimMapFromUserAttributes(Map getUserClaimsInOIDCDialectFromFederatedUserAttributed(String spTenantDomain, - Map - unfilteredUserAttributed) + /** + * Get user claims in OIDC claim dialect from federated user attributes. + * + * @param spTenantDomain Service Provider Tenant Domain + * @param federatedUserAttr Federated User Attributes + * @return User claims in OIDC dialect + * @throws IdentityOAuth2Exception Identity OAuth2 Exception + */ + private static Map getUserClaimsInOIDCDialectFromFederatedUserAttributes(String spTenantDomain, + Map + federatedUserAttr) throws IdentityOAuth2Exception { // Retrieve OIDC to Local Claim Mappings. @@ -564,8 +633,8 @@ private static Map getUserClaimsInOIDCDialectFromFederatedUserAt } // Get user claims in OIDC dialect. Map userClaimsInOidcDialect = new HashMap<>(); - if (MapUtils.isNotEmpty(unfilteredUserAttributed)) { - for (Map.Entry userAttribute : unfilteredUserAttributed.entrySet()) { + if (MapUtils.isNotEmpty(federatedUserAttr)) { + for (Map.Entry userAttribute : federatedUserAttr.entrySet()) { ClaimMapping claimMapping = userAttribute.getKey(); String claimValue = userAttribute.getValue(); if (oidcToLocalClaimMappings.containsValue(claimMapping.getLocalClaim().getClaimUri())) { @@ -590,36 +659,19 @@ private static Map getUserClaimsInOIDCDialectFromFederatedUserAt } /** - * Get user claims in OIDC claim dialect. + * Get user claims in OIDC claim dialect from userstore. * - * @param oidcToLocalClaimMappings OIDC dialect to Local dialect claim mappings - * @param userClaims User claims in local dialect - * @return Map of user claim values in OIDC dialect. + * @param spTenantDomain Service Provider Tenant Domain + * @param clientId Client Id + * @param authenticatedUser Authenticated User + * @return User claims in OIDC dialect + * @throws IdentityApplicationManagementException Identity Application Management Exception + * @throws IdentityException Identity Exception + * @throws UserStoreException User Store Exception + * @throws OrganizationManagementException Organization Management Exception */ - private static Map getUserClaimsInOidcDialect(Map oidcToLocalClaimMappings, - Map userClaims) { - - Map userClaimsInOidcDialect = new HashMap<>(); - if (MapUtils.isNotEmpty(userClaims)) { - // Map<"email", "http://wso2.org/claims/emailaddress"> - for (Map.Entry claimMapping : oidcToLocalClaimMappings.entrySet()) { - String claimValue = userClaims.get(claimMapping.getValue()); - if (claimValue != null) { - String oidcClaimUri = claimMapping.getKey(); - userClaimsInOidcDialect.put(oidcClaimUri, claimValue); - if (log.isDebugEnabled() && - IdentityUtil.isTokenLoggable(IdentityConstants.IdentityTokens.USER_CLAIMS)) { - log.debug("Mapped claim: key - " + oidcClaimUri + " value - " + claimValue); - } - } - } - } - - return userClaimsInOidcDialect; - } - - private Map getUserClaimsInOIDCDialect(String spTenantDomain, String clientId, - AuthenticatedUser authenticatedUser) + private Map getLocalUserClaimsInOIDCDialect(String spTenantDomain, String clientId, + AuthenticatedUser authenticatedUser) throws IdentityApplicationManagementException, IdentityException, UserStoreException, OrganizationManagementException { @@ -644,23 +696,6 @@ private Map getUserClaimsInOIDCDialect(String spTenantDomain, St return OIDCClaimUtil.getUserClaimsInOIDCDialect(serviceProvider, authenticatedUser, localClaimURIs); } - /** - * Get claims map. - * - * @param userAttributes User Attributes - * @return User attribute map - */ - private Map getClaimMapFromUserAttributes(Map userAttributes) { - - Map claims = new HashMap<>(); - if (isNotEmpty(userAttributes)) { - for (Map.Entry entry : userAttributes.entrySet()) { - claims.put(entry.getKey().getRemoteClaim().getClaimUri(), entry.getValue()); - } - } - return claims; - } - /** * Get user attribute cached against the access token. * @@ -688,14 +723,17 @@ private Map getUserAttributesFromCacheUsingToken(String ac } private String getAuthorizationCode(OAuthTokenReqMessageContext requestMsgCtx) { + return (String) requestMsgCtx.getProperty(AUTHZ_CODE); } private String getAccessToken(OAuthAuthzReqMessageContext authzReqMessageContext) { + return (String) authzReqMessageContext.getProperty(ACCESS_TOKEN); } private String getAccessToken(OAuthTokenReqMessageContext requestMsgCtx) { + return (String) requestMsgCtx.getProperty(ACCESS_TOKEN); } @@ -966,6 +1004,12 @@ private boolean isOrganizationSsoUserSwitchingOrganization(AuthenticatedUser aut (accessingOrganization); } + /** + * Check whether grant type is organization switch grant. + * + * @param requestMsgCtx OAuthTokenReqMessageContext + * @return true if grant type is organization switch grant. + */ private boolean isOrganizationSwitchGrantType(OAuthTokenReqMessageContext requestMsgCtx) { return StringUtils.equals(requestMsgCtx.getOauth2AccessTokenReqDTO().getGrantType(), diff --git a/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/openidconnect/JWTAccessTokenOIDCClaimsHandlerTest.java b/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/openidconnect/JWTAccessTokenOIDCClaimsHandlerTest.java index 408459fdc0..51d2c80fa7 100644 --- a/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/openidconnect/JWTAccessTokenOIDCClaimsHandlerTest.java +++ b/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/openidconnect/JWTAccessTokenOIDCClaimsHandlerTest.java @@ -32,6 +32,7 @@ import org.mockito.Mockito; import org.mockito.stubbing.Answer; import org.mockito.testng.MockitoTestNGListener; +import org.testng.Assert; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.BeforeTest; @@ -43,6 +44,7 @@ import org.wso2.carbon.context.PrivilegedCarbonContext; import org.wso2.carbon.identity.application.authentication.framework.model.AuthenticatedUser; import org.wso2.carbon.identity.application.authentication.framework.util.FrameworkUtils; +import org.wso2.carbon.identity.application.common.model.ClaimMapping; import org.wso2.carbon.identity.application.common.model.PermissionsAndRoleConfig; import org.wso2.carbon.identity.application.common.model.ServiceProvider; import org.wso2.carbon.identity.application.mgt.ApplicationManagementService; @@ -51,15 +53,20 @@ import org.wso2.carbon.identity.core.persistence.JDBCPersistenceManager; import org.wso2.carbon.identity.core.util.IdentityTenantUtil; import org.wso2.carbon.identity.core.util.IdentityUtil; +import org.wso2.carbon.identity.oauth.cache.AuthorizationGrantCache; +import org.wso2.carbon.identity.oauth.cache.AuthorizationGrantCacheEntry; +import org.wso2.carbon.identity.oauth.cache.AuthorizationGrantCacheKey; import org.wso2.carbon.identity.oauth.config.OAuthServerConfiguration; import org.wso2.carbon.identity.oauth.dao.OAuthAppDO; import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception; +import org.wso2.carbon.identity.oauth2.TestConstants; import org.wso2.carbon.identity.oauth2.authz.OAuthAuthzReqMessageContext; import org.wso2.carbon.identity.oauth2.dao.OAuthTokenPersistenceFactory; import org.wso2.carbon.identity.oauth2.dto.OAuth2AccessTokenReqDTO; import org.wso2.carbon.identity.oauth2.dto.OAuth2AuthorizeReqDTO; import org.wso2.carbon.identity.oauth2.internal.OAuth2ServiceComponentHolder; import org.wso2.carbon.identity.oauth2.token.OAuthTokenReqMessageContext; +import org.wso2.carbon.identity.oauth2.token.handlers.grant.saml.SAML2BearerGrantHandlerTest; import org.wso2.carbon.identity.oauth2.util.AuthzUtil; import org.wso2.carbon.identity.oauth2.util.OAuth2Util; import org.wso2.carbon.identity.openidconnect.dao.CacheBackedScopeClaimMappingDAOImpl; @@ -246,11 +253,19 @@ public void testHandleCustomClaimsWithoutRegisteredOIDCClaimsForOAuthTokenReqMsg oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance) .thenReturn(oauthServerConfigurationMock); try (MockedStatic oAuth2Util = mockStatic(OAuth2Util.class); - MockedStatic claimMetadataHandler = mockStatic(ClaimMetadataHandler.class)) { - claimMetadataHandler.when(ClaimMetadataHandler::getInstance).thenReturn(this.mockClaimMetadataHandler); + MockedStatic claimMetadataHandler = mockStatic(ClaimMetadataHandler.class); + MockedStatic identityTenantUtil = mockStatic(IdentityTenantUtil.class); + MockedStatic authzUtil = mockStatic(AuthzUtil.class)) { + claimMetadataHandler.when(ClaimMetadataHandler::getInstance).thenReturn(mockClaimMetadataHandler); + lenient().when(mockClaimMetadataHandler.getMappingsMapFromOtherDialectToCarbon( + anyString(), isNull(), anyString(), anyBoolean())).thenReturn(new HashMap<>()); oAuth2Util.when(() -> OAuth2Util.getAppInformationByClientId(any(), any())).thenReturn( getoAuthAppDO(jwtAccessTokenClaims)); + mockApplicationManagementService(); OAuthTokenReqMessageContext requestMsgCtx = getTokenReqMessageContextForLocalUser(); + authzUtil.when(() -> AuthzUtil.getUserRoles(any(), anyString())).thenReturn(new ArrayList<>()); + UserRealm userRealm = getUserRealmWithUserClaims(USER_CLAIMS_MAP); + mockUserRealm(requestMsgCtx.getAuthorizedUser().toString(), userRealm, identityTenantUtil); JWTClaimsSet.Builder jwtClaimsSetBuilder = new JWTClaimsSet.Builder(); JWTClaimsSet jwtClaimsSet = getJwtClaimSet(jwtClaimsSetBuilder, requestMsgCtx, jdbcPersistenceManager, oAuthServerConfiguration); @@ -402,6 +417,71 @@ public void testHandleCustomClaimsForOAuthAuthzReqMsgContext() throws Exception } } + @Test + public void testHandleClaimsForOAuthTokenReqMessageContextWithAuthorizationCode() throws Exception { + + try (MockedStatic jdbcPersistenceManager = mockStatic(JDBCPersistenceManager.class); + MockedStatic oAuthServerConfiguration = mockStatic( + OAuthServerConfiguration.class); + MockedStatic claimMetadataHandler = mockStatic(ClaimMetadataHandler.class)) { + OAuthServerConfiguration oauthServerConfigurationMock = mock(OAuthServerConfiguration.class); + oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance) + .thenReturn(oauthServerConfigurationMock); + try (MockedStatic oAuth2Util = mockStatic(OAuth2Util.class); + MockedStatic authzUtil = mockStatic(AuthzUtil.class); + MockedStatic identityTenantUtil = mockStatic(IdentityTenantUtil.class); + MockedStatic identityUtil = mockStatic(IdentityUtil.class, Mockito.CALLS_REAL_METHODS)) { + MockedStatic authorizationGrantCache = + mockStatic(AuthorizationGrantCache.class); + identityUtil.when(IdentityUtil::isGroupsVsRolesSeparationImprovementsEnabled).thenReturn(true); + authzUtil.when(() -> AuthzUtil.getUserRoles(any(), anyString())).thenReturn(new ArrayList<>()); + oAuth2Util.when(() -> OAuth2Util.getAppInformationByClientId(any(), any())).thenReturn( + getoAuthAppDO(jwtAccessTokenClaims)); + Map mappings = getOIDCtoLocalClaimsMapping(); + claimMetadataHandler.when(ClaimMetadataHandler::getInstance).thenReturn(mockClaimMetadataHandler); + lenient().when(mockClaimMetadataHandler.getMappingsMapFromOtherDialectToCarbon( + anyString(), isNull(), anyString(), anyBoolean())).thenReturn(mappings); + Map userAttributes = new HashMap<>(); + OAuthTokenReqMessageContext requestMsgCtx = getTokenReqMessageContextForFederatedUser(userAttributes); + requestMsgCtx.addProperty("AuthorizationCode", "dummyAuthorizationCode"); + Map federatedUserAttributes = new HashMap<>(); + federatedUserAttributes.put(SAML2BearerGrantHandlerTest.buildClaimMapping(LOCAL_COUNTRY_CLAIM_URI), + TestConstants.CLAIM_VALUE1); + federatedUserAttributes.put(SAML2BearerGrantHandlerTest.buildClaimMapping(LOCAL_EMAIL_CLAIM_URI), + TestConstants.CLAIM_VALUE2); + AuthorizationGrantCacheEntry authorizationGrantCacheEntry = new + AuthorizationGrantCacheEntry(); + authorizationGrantCacheEntry.setUnfilteredFederatedUserAttributes(federatedUserAttributes); + mockAuthorizationGrantCache(authorizationGrantCacheEntry, authorizationGrantCache); + + UserRealm userRealm = getUserRealmWithUserClaims(USER_CLAIMS_MAP); + mockUserRealm(requestMsgCtx.getAuthorizedUser().toString(), userRealm, identityTenantUtil); + JWTClaimsSet.Builder jwtClaimsSetBuilder = new JWTClaimsSet.Builder(); + JWTClaimsSet jwtClaimsSet = getJwtClaimSet(jwtClaimsSetBuilder, requestMsgCtx, jdbcPersistenceManager, + oAuthServerConfiguration); + assertNotNull(jwtClaimsSet, "JWT Custom claim handling failed."); + assertFalse(jwtClaimsSet.getClaims().isEmpty(), "JWT custom claim handling failed"); + Assert.assertEquals(jwtClaimsSet.getClaims().size(), 2, + "Expected custom claims are not set."); + Assert.assertEquals(jwtClaimsSet.getClaim("email"), TestConstants.CLAIM_VALUE2, + "OIDC claim email is not added with the JWT token"); + } + } + } + + private void mockAuthorizationGrantCache(AuthorizationGrantCacheEntry authorizationGrantCacheEntry, + MockedStatic authorizationGrantCache) { + + AuthorizationGrantCache mockAuthorizationGrantCache = mock(AuthorizationGrantCache.class); + + if (authorizationGrantCacheEntry == null) { + authorizationGrantCacheEntry = mock(AuthorizationGrantCacheEntry.class); + } + authorizationGrantCache.when(AuthorizationGrantCache::getInstance).thenReturn(mockAuthorizationGrantCache); + lenient().when(mockAuthorizationGrantCache.getValueFromCacheByCode(any(AuthorizationGrantCacheKey.class))). + thenReturn(authorizationGrantCacheEntry); + } + private static Map getOIDCtoLocalClaimsMapping() { Map mappings = new HashMap<>(); @@ -504,6 +584,38 @@ private OAuthTokenReqMessageContext getTokenReqMessageContextForLocalUser() { return requestMsgCtx; } + /** + * To get token request message context for federates user. + * + * @param userAttributes Relevant user attributes need to be added to authenticates user. + * @return relevant token request context for federated authenticated user. + */ + private OAuthTokenReqMessageContext getTokenReqMessageContextForFederatedUser(Map userAttributes) { + + OAuth2AccessTokenReqDTO accessTokenReqDTO = new OAuth2AccessTokenReqDTO(); + accessTokenReqDTO.setTenantDomain(TENANT_DOMAIN); + accessTokenReqDTO.setClientId(DUMMY_CLIENT_ID); + OAuthTokenReqMessageContext requestMsgCtx = new OAuthTokenReqMessageContext(accessTokenReqDTO); + requestMsgCtx.addProperty(MultitenantConstants.TENANT_DOMAIN, TENANT_DOMAIN); + AuthenticatedUser authenticatedUser = getDefaultAuthenticatedUserFederatedUser(); + + if (userAttributes != null) { + authenticatedUser.setUserAttributes(userAttributes); + } + requestMsgCtx.setAuthorizedUser(authenticatedUser); + return requestMsgCtx; + } + + private AuthenticatedUser getDefaultAuthenticatedUserFederatedUser() { + + AuthenticatedUser authenticatedUser = new AuthenticatedUser(); + authenticatedUser.setUserName(USER_NAME); + authenticatedUser.setUserId(StringUtils.EMPTY); + authenticatedUser.setFederatedUser(true); + return authenticatedUser; + } + private AuthenticatedUser getDefaultAuthenticatedLocalUser() { AuthenticatedUser authenticatedUser = new AuthenticatedUser();