Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing the user attribute handling for implicit flow #2641

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,20 @@ public JWTClaimsSet handleCustomClaims(JWTClaimsSet.Builder builder, OAuthTokenR
public JWTClaimsSet handleCustomClaims(JWTClaimsSet.Builder builder, OAuthAuthzReqMessageContext request)
throws IdentityOAuth2Exception {

// TODO : Implement this method for implicit flow and hybrid flow.
return builder.build();
/*
Handling the user attributes for the access token. There is no requirement of the consent
to manage user attributes for the access token.
*/
String clientId = request.getAuthorizationReqDTO().getConsumerKey();
String spTenantDomain = getServiceProviderTenantDomain(request);
AuthenticatedUser authenticatedUser = request.getAuthorizationReqDTO().getUser();

Map<String, Object> claims = getAccessTokenUserClaims(authenticatedUser, clientId, spTenantDomain);
if (claims == null || claims.isEmpty()) {
return builder.build();
}
Map<String, Object> filteredClaims = handleClaimsFormat(claims, clientId, spTenantDomain);
return setClaimsToJwtClaimSet(builder, filteredClaims);
}

private Map<String, Object> getAccessTokenUserClaims(AuthenticatedUser authenticatedUser, String clientId,
Expand Down Expand Up @@ -298,4 +310,21 @@ private String getServiceProviderTenantDomain(OAuthTokenReqMessageContext reques
}
return spTenantDomain;
}

/**
* Retrieves the service provider tenant domain from the OAuthAuthzReqMessageContext.
*
* @param requestMsgCtx OAuthAuthzReqMessageContext containing the tenant domain.
* @return The tenant domain.
*/
private String getServiceProviderTenantDomain(OAuthAuthzReqMessageContext requestMsgCtx) {

String spTenantDomain = (String) requestMsgCtx.getProperty(MultitenantConstants.TENANT_DOMAIN);
// There are certain flows where tenant domain is not added as a message context property.
if (spTenantDomain == null) {
spTenantDomain = requestMsgCtx.getAuthorizationReqDTO().getTenantDomain();
}
return spTenantDomain;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@
import org.wso2.carbon.identity.oauth.config.OAuthServerConfiguration;
import org.wso2.carbon.identity.oauth.dao.OAuthAppDO;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception;
import org.wso2.carbon.identity.oauth2.authz.OAuthAuthzReqMessageContext;
import org.wso2.carbon.identity.oauth2.dao.OAuthTokenPersistenceFactory;
import org.wso2.carbon.identity.oauth2.dto.OAuth2AccessTokenReqDTO;
import org.wso2.carbon.identity.oauth2.dto.OAuth2AuthorizeReqDTO;
import org.wso2.carbon.identity.oauth2.internal.OAuth2ServiceComponentHolder;
import org.wso2.carbon.identity.oauth2.token.OAuthTokenReqMessageContext;
import org.wso2.carbon.identity.oauth2.util.AuthzUtil;
Expand Down Expand Up @@ -364,6 +366,42 @@ public void testHandleCustomClaimsWithAddressClaimForOAuthTokenReqMsgContext() t
}
}

@Test
public void testHandleCustomClaimsForOAuthAuthzReqMsgContext() throws Exception {

try (MockedStatic<JDBCPersistenceManager> jdbcPersistenceManager = mockStatic(JDBCPersistenceManager.class);
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration = mockStatic(
OAuthServerConfiguration.class);
MockedStatic<ClaimMetadataHandler> claimMetadataHandler = mockStatic(ClaimMetadataHandler.class)) {
OAuthServerConfiguration oauthServerConfigurationMock = mock(OAuthServerConfiguration.class);
oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance)
.thenReturn(oauthServerConfigurationMock);
try (MockedStatic<OAuth2Util> oAuth2Util = mockStatic(OAuth2Util.class);
MockedStatic<AuthzUtil> authzUtil = mockStatic(AuthzUtil.class);
MockedStatic<IdentityTenantUtil> identityTenantUtil = mockStatic(IdentityTenantUtil.class);
MockedStatic<IdentityUtil> identityUtil = mockStatic(IdentityUtil.class, Mockito.CALLS_REAL_METHODS)) {
identityUtil.when(IdentityUtil::isGroupsVsRolesSeparationImprovementsEnabled).thenReturn(true);
oAuth2Util.when(() -> OAuth2Util.getAppInformationByClientId(any(), any())).thenReturn(
getoAuthAppDO(jwtAccessTokenClaims));
Map<String, String> mappings = getOIDCtoLocalClaimsMapping();
claimMetadataHandler.when(ClaimMetadataHandler::getInstance).thenReturn(mockClaimMetadataHandler);
lenient().when(mockClaimMetadataHandler.getMappingsMapFromOtherDialectToCarbon(
anyString(), isNull(), anyString(), anyBoolean())).thenReturn(mappings);
OAuthAuthzReqMessageContext requestMsgCtx = getOAuthAuthzReqMessageContextForLocalUser();
mockApplicationManagementService();
authzUtil.when(() -> AuthzUtil.getUserRoles(any(), anyString())).thenReturn(new ArrayList<>());
UserRealm userRealm = getUserRealmWithUserClaims(USER_CLAIMS_MAP);
mockUserRealm(requestMsgCtx.getAuthorizationReqDTO().getUser().toString(), userRealm,
identityTenantUtil);
JWTClaimsSet.Builder jwtClaimsSetBuilder = new JWTClaimsSet.Builder();
JWTClaimsSet jwtClaimsSet = getJwtClaimSet(jwtClaimsSetBuilder, requestMsgCtx, jdbcPersistenceManager,
oAuthServerConfiguration);
assertNotNull(jwtClaimsSet);
assertFalse(jwtClaimsSet.getClaims().isEmpty());
}
}
}

private static Map<String, String> getOIDCtoLocalClaimsMapping() {

Map<String, String> mappings = new HashMap<>();
Expand Down Expand Up @@ -405,6 +443,26 @@ private JWTClaimsSet getJwtClaimSet(JWTClaimsSet.Builder jwtClaimsSetBuilder,
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration)
throws IdentityOAuth2Exception {

JWTAccessTokenOIDCClaimsHandler jWTAccessTokenOIDCClaimsHandler =
getJwtAccessTokenOIDCClaimsHandler(jdbcPersistenceManager, oAuthServerConfiguration);
return jWTAccessTokenOIDCClaimsHandler.handleCustomClaims(jwtClaimsSetBuilder, requestMsgCtx);
}

private JWTClaimsSet getJwtClaimSet(JWTClaimsSet.Builder jwtClaimsSetBuilder,
OAuthAuthzReqMessageContext requestMsgCtx,
MockedStatic<JDBCPersistenceManager> jdbcPersistenceManager,
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration)
throws IdentityOAuth2Exception {

JWTAccessTokenOIDCClaimsHandler jWTAccessTokenOIDCClaimsHandler =
getJwtAccessTokenOIDCClaimsHandler(jdbcPersistenceManager, oAuthServerConfiguration);
return jWTAccessTokenOIDCClaimsHandler.handleCustomClaims(jwtClaimsSetBuilder, requestMsgCtx);
}

private JWTAccessTokenOIDCClaimsHandler getJwtAccessTokenOIDCClaimsHandler(
MockedStatic<JDBCPersistenceManager> jdbcPersistenceManager,
MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration) {

OAuthServerConfiguration mockOAuthServerConfiguration = mock(OAuthServerConfiguration.class);
oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance).thenReturn(mockOAuthServerConfiguration);
DataSource dataSource = mock(DataSource.class);
Expand Down Expand Up @@ -432,9 +490,7 @@ private JWTClaimsSet getJwtClaimSet(JWTClaimsSet.Builder jwtClaimsSetBuilder,
jdbcPersistenceManager.when(JDBCPersistenceManager::getInstance).thenReturn(mockJdbcPersistenceManager);
lenient().when(mockJdbcPersistenceManager.getDataSource()).thenReturn(dataSource);

JWTAccessTokenOIDCClaimsHandler jWTAccessTokenOIDCClaimsHandler =
new JWTAccessTokenOIDCClaimsHandler();
return jWTAccessTokenOIDCClaimsHandler.handleCustomClaims(jwtClaimsSetBuilder, requestMsgCtx);
return new JWTAccessTokenOIDCClaimsHandler();
}

private OAuthTokenReqMessageContext getTokenReqMessageContextForLocalUser() {
Expand Down Expand Up @@ -515,4 +571,13 @@ private void setPrivateField(Object object, String fieldName, Object value) thro
field.set(object, value);
}

private OAuthAuthzReqMessageContext getOAuthAuthzReqMessageContextForLocalUser() {

OAuth2AuthorizeReqDTO oAuth2AuthorizeReqDTO = new OAuth2AuthorizeReqDTO();
oAuth2AuthorizeReqDTO.setTenantDomain(TENANT_DOMAIN);
oAuth2AuthorizeReqDTO.setConsumerKey(DUMMY_CLIENT_ID);
oAuth2AuthorizeReqDTO.setUser(getDefaultAuthenticatedLocalUser());

return new OAuthAuthzReqMessageContext(oAuth2AuthorizeReqDTO);
}
}
Loading