Skip to content

Commit

Permalink
Merge pull request #2641 from indeewari/user_attributes_for_access_to…
Browse files Browse the repository at this point in the history
…ken_in_implici_hybrid

Implementing the user attribute handling for implicit flow
  • Loading branch information
sadilchamishka authored Dec 5, 2024
2 parents 9c32672 + 0a9904a commit f9ec15d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,20 @@ public JWTClaimsSet handleCustomClaims(JWTClaimsSet.Builder builder, OAuthTokenR
public JWTClaimsSet handleCustomClaims(JWTClaimsSet.Builder builder, OAuthAuthzReqMessageContext request)
throws IdentityOAuth2Exception {

// TODO : Implement this method for implicit flow and hybrid flow.
return builder.build();
/*
Handling the user attributes for the access token. There is no requirement of the consent
to manage user attributes for the access token.
*/
String clientId = request.getAuthorizationReqDTO().getConsumerKey();
String spTenantDomain = getServiceProviderTenantDomain(request);
AuthenticatedUser authenticatedUser = request.getAuthorizationReqDTO().getUser();

Map<String, Object> claims = getAccessTokenUserClaims(authenticatedUser, clientId, spTenantDomain);
if (claims == null || claims.isEmpty()) {
return builder.build();
}
Map<String, Object> filteredClaims = handleClaimsFormat(claims, clientId, spTenantDomain);
return setClaimsToJwtClaimSet(builder, filteredClaims);
}

private Map<String, Object> getAccessTokenUserClaims(AuthenticatedUser authenticatedUser, String clientId,
Expand Down Expand Up @@ -298,4 +310,21 @@ private String getServiceProviderTenantDomain(OAuthTokenReqMessageContext reques
}
return spTenantDomain;
}

/**
* Retrieves the service provider tenant domain from the OAuthAuthzReqMessageContext.
*
* @param requestMsgCtx OAuthAuthzReqMessageContext containing the tenant domain.
* @return The tenant domain.
*/
private String getServiceProviderTenantDomain(OAuthAuthzReqMessageContext requestMsgCtx) {

String spTenantDomain = (String) requestMsgCtx.getProperty(MultitenantConstants.TENANT_DOMAIN);
// There are certain flows where tenant domain is not added as a message context property.
if (spTenantDomain == null) {
spTenantDomain = requestMsgCtx.getAuthorizationReqDTO().getTenantDomain();
}
return spTenantDomain;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@
import org.wso2.carbon.identity.oauth.config.OAuthServerConfiguration;
import org.wso2.carbon.identity.oauth.dao.OAuthAppDO;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception;
import org.wso2.carbon.identity.oauth2.authz.OAuthAuthzReqMessageContext;
import org.wso2.carbon.identity.oauth2.dao.OAuthTokenPersistenceFactory;
import org.wso2.carbon.identity.oauth2.dto.OAuth2AccessTokenReqDTO;
import org.wso2.carbon.identity.oauth2.dto.OAuth2AuthorizeReqDTO;
import org.wso2.carbon.identity.oauth2.internal.OAuth2ServiceComponentHolder;
import org.wso2.carbon.identity.oauth2.token.OAuthTokenReqMessageContext;
import org.wso2.carbon.identity.oauth2.util.AuthzUtil;
Expand Down Expand Up @@ -364,6 +366,42 @@ public void testHandleCustomClaimsWithAddressClaimForOAuthTokenReqMsgContext() t
}
}

@Test
public void testHandleCustomClaimsForOAuthAuthzReqMsgContext() throws Exception {

try (MockedStatic<JDBCPersistenceManager> jdbcPersistenceManager = mockStatic(JDBCPersistenceManager.class);
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration = mockStatic(
OAuthServerConfiguration.class);
MockedStatic<ClaimMetadataHandler> claimMetadataHandler = mockStatic(ClaimMetadataHandler.class)) {
OAuthServerConfiguration oauthServerConfigurationMock = mock(OAuthServerConfiguration.class);
oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance)
.thenReturn(oauthServerConfigurationMock);
try (MockedStatic<OAuth2Util> oAuth2Util = mockStatic(OAuth2Util.class);
MockedStatic<AuthzUtil> authzUtil = mockStatic(AuthzUtil.class);
MockedStatic<IdentityTenantUtil> identityTenantUtil = mockStatic(IdentityTenantUtil.class);
MockedStatic<IdentityUtil> identityUtil = mockStatic(IdentityUtil.class, Mockito.CALLS_REAL_METHODS)) {
identityUtil.when(IdentityUtil::isGroupsVsRolesSeparationImprovementsEnabled).thenReturn(true);
oAuth2Util.when(() -> OAuth2Util.getAppInformationByClientId(any(), any())).thenReturn(
getoAuthAppDO(jwtAccessTokenClaims));
Map<String, String> mappings = getOIDCtoLocalClaimsMapping();
claimMetadataHandler.when(ClaimMetadataHandler::getInstance).thenReturn(mockClaimMetadataHandler);
lenient().when(mockClaimMetadataHandler.getMappingsMapFromOtherDialectToCarbon(
anyString(), isNull(), anyString(), anyBoolean())).thenReturn(mappings);
OAuthAuthzReqMessageContext requestMsgCtx = getOAuthAuthzReqMessageContextForLocalUser();
mockApplicationManagementService();
authzUtil.when(() -> AuthzUtil.getUserRoles(any(), anyString())).thenReturn(new ArrayList<>());
UserRealm userRealm = getUserRealmWithUserClaims(USER_CLAIMS_MAP);
mockUserRealm(requestMsgCtx.getAuthorizationReqDTO().getUser().toString(), userRealm,
identityTenantUtil);
JWTClaimsSet.Builder jwtClaimsSetBuilder = new JWTClaimsSet.Builder();
JWTClaimsSet jwtClaimsSet = getJwtClaimSet(jwtClaimsSetBuilder, requestMsgCtx, jdbcPersistenceManager,
oAuthServerConfiguration);
assertNotNull(jwtClaimsSet);
assertFalse(jwtClaimsSet.getClaims().isEmpty());
}
}
}

private static Map<String, String> getOIDCtoLocalClaimsMapping() {

Map<String, String> mappings = new HashMap<>();
Expand Down Expand Up @@ -405,6 +443,26 @@ private JWTClaimsSet getJwtClaimSet(JWTClaimsSet.Builder jwtClaimsSetBuilder,
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration)
throws IdentityOAuth2Exception {

JWTAccessTokenOIDCClaimsHandler jWTAccessTokenOIDCClaimsHandler =
getJwtAccessTokenOIDCClaimsHandler(jdbcPersistenceManager, oAuthServerConfiguration);
return jWTAccessTokenOIDCClaimsHandler.handleCustomClaims(jwtClaimsSetBuilder, requestMsgCtx);
}

private JWTClaimsSet getJwtClaimSet(JWTClaimsSet.Builder jwtClaimsSetBuilder,
OAuthAuthzReqMessageContext requestMsgCtx,
MockedStatic<JDBCPersistenceManager> jdbcPersistenceManager,
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration)
throws IdentityOAuth2Exception {

JWTAccessTokenOIDCClaimsHandler jWTAccessTokenOIDCClaimsHandler =
getJwtAccessTokenOIDCClaimsHandler(jdbcPersistenceManager, oAuthServerConfiguration);
return jWTAccessTokenOIDCClaimsHandler.handleCustomClaims(jwtClaimsSetBuilder, requestMsgCtx);
}

private JWTAccessTokenOIDCClaimsHandler getJwtAccessTokenOIDCClaimsHandler(
MockedStatic<JDBCPersistenceManager> jdbcPersistenceManager,
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration) {

OAuthServerConfiguration mockOAuthServerConfiguration = mock(OAuthServerConfiguration.class);
oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance).thenReturn(mockOAuthServerConfiguration);
DataSource dataSource = mock(DataSource.class);
Expand Down Expand Up @@ -432,9 +490,7 @@ private JWTClaimsSet getJwtClaimSet(JWTClaimsSet.Builder jwtClaimsSetBuilder,
jdbcPersistenceManager.when(JDBCPersistenceManager::getInstance).thenReturn(mockJdbcPersistenceManager);
lenient().when(mockJdbcPersistenceManager.getDataSource()).thenReturn(dataSource);

JWTAccessTokenOIDCClaimsHandler jWTAccessTokenOIDCClaimsHandler =
new JWTAccessTokenOIDCClaimsHandler();
return jWTAccessTokenOIDCClaimsHandler.handleCustomClaims(jwtClaimsSetBuilder, requestMsgCtx);
return new JWTAccessTokenOIDCClaimsHandler();
}

private OAuthTokenReqMessageContext getTokenReqMessageContextForLocalUser() {
Expand Down Expand Up @@ -515,4 +571,13 @@ private void setPrivateField(Object object, String fieldName, Object value) thro
field.set(object, value);
}

private OAuthAuthzReqMessageContext getOAuthAuthzReqMessageContextForLocalUser() {

OAuth2AuthorizeReqDTO oAuth2AuthorizeReqDTO = new OAuth2AuthorizeReqDTO();
oAuth2AuthorizeReqDTO.setTenantDomain(TENANT_DOMAIN);
oAuth2AuthorizeReqDTO.setConsumerKey(DUMMY_CLIENT_ID);
oAuth2AuthorizeReqDTO.setUser(getDefaultAuthenticatedLocalUser());

return new OAuthAuthzReqMessageContext(oAuth2AuthorizeReqDTO);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@
<class name="org.wso2.carbon.identity.openidconnect.OpenIDConnectSystemClaimImplTest"/>
<class name="org.wso2.carbon.identity.openidconnect.OpenIDConnectClaimFilterImplTest"/>
<class name="org.wso2.carbon.identity.openidconnect.util.ClaimHandlerUtilTest"/>
<class name="org.wso2.carbon.identity.openidconnect.JWTAccessTokenOIDCClaimsHandlerTest"/>
</classes>
</test>
</suite>

0 comments on commit f9ec15d

Please sign in to comment.