Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shashimalcse committed Dec 17, 2024
1 parent 63c5648 commit 6867a69
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,8 @@ private void addToAuthenticationResultDetailsToOAuthMessage(OAuthMessage oAuthMe
authnResult.getProperty(FrameworkConstants.AnalyticsAttributes.SESSION_ID));
// Adding federated tokens come with the authentication result of the authorization call.
addFederatedTokensToSessionCache(oAuthMessage, authnResult);
// Adding federated user claims come with the authentication result to resolve access token claims in
// federated flow.
addUnfilteredFederatedUserClaimsToSessionCache(oAuthMessage, authnResult);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ public JWTClaimsSet handleCustomClaims(JWTClaimsSet.Builder builder, OAuthAuthzR
}

/**
* Get response map.
* Get user claims in OIDC dialect.
*
* @param requestMsgCtx Token request message context
* @return Mapped claimed
* @throws OAuthSystemException
* @param requestMsgCtx OAuthTokenReqMessageContext
* @return User claims in OIDC dialect
* @throws IdentityOAuth2Exception IdentityOAuth2Exception
*/
private Map<String, Object> getUserClaimsInOIDCDialect(OAuthTokenReqMessageContext requestMsgCtx)
throws IdentityOAuth2Exception {
Expand Down Expand Up @@ -147,6 +147,7 @@ private Map<String, Object> getUserClaimsInOIDCDialect(OAuthTokenReqMessageConte
} else {
// Get claim map from the cached attributes
userClaimsInOIDCDialect = getOIDCClaimsFromUserAttributes(userAttributes, requestMsgCtx);
// Since this is a federated flow, we need to get the federated user attributes as well.
Map<ClaimMapping, String> federatedUserAttributes = getCachedUserAttributes(requestMsgCtx, true);
Map<String, Object> federatedUserClaimsInOIDCDialect =
getOIDCClaimsFromFederatedUserAttributes(federatedUserAttributes, requestMsgCtx);
Expand All @@ -162,6 +163,14 @@ private Map<String, Object> getUserClaimsInOIDCDialect(OAuthTokenReqMessageConte
}
}

/**
* Filter claims with allowed access token claims
*
* @param userClaimsInOIDCDialect User claims in OIDC dialect
* @param requestMsgCtx OAuthTokenReqMessageContext
* @return Filtered claims
* @throws IdentityOAuth2Exception IdentityOAuth2Exception
*/
private Map<String, Object> filterClaims(Map<String, Object> userClaimsInOIDCDialect,
OAuthTokenReqMessageContext requestMsgCtx) throws IdentityOAuth2Exception {

Expand All @@ -178,6 +187,14 @@ private Map<String, Object> filterClaims(Map<String, Object> userClaimsInOIDCDia
return handleClaimsFormat(claims, clientId, spTenantDomain);
}

/**
* Filter claims with allowed access token claims
*
* @param userClaimsInOIDCDialect User claims in OIDC dialect
* @param authzReqMessageContext OAuthAuthzReqMessageContext
* @return Filtered claims
* @throws IdentityOAuth2Exception IdentityOAuth2Exception
*/
private Map<String, Object> filterClaims(Map<String, Object> userClaimsInOIDCDialect,
OAuthAuthzReqMessageContext authzReqMessageContext)
throws IdentityOAuth2Exception {
Expand All @@ -194,6 +211,13 @@ private Map<String, Object> filterClaims(Map<String, Object> userClaimsInOIDCDia
return handleClaimsFormat(claims, clientId, spTenantDomain);
}

/**
* Get claims for local user form userstore.
*
* @param requestMsgCtx OAuthTokenReqMessageContext
* @return Local user claims
* @throws IdentityOAuth2Exception IdentityOAuth2Exception
*/
private Map<String, Object> retrieveClaimsForLocalUser(OAuthTokenReqMessageContext requestMsgCtx)
throws IdentityOAuth2Exception {

Expand All @@ -202,7 +226,7 @@ private Map<String, Object> retrieveClaimsForLocalUser(OAuthTokenReqMessageConte
String clientId = requestMsgCtx.getOauth2AccessTokenReqDTO().getClientId();
AuthenticatedUser authenticatedUser = requestMsgCtx.getAuthorizedUser();

return getUserClaimsInOIDCDialect(spTenantDomain, clientId, authenticatedUser);
return getLocalUserClaimsInOIDCDialect(spTenantDomain, clientId, authenticatedUser);
} catch (UserStoreException | IdentityApplicationManagementException | IdentityException |
OrganizationManagementException e) {
if (FrameworkUtils.isContinueOnClaimHandlingErrorAllowed()) {
Expand Down Expand Up @@ -284,6 +308,13 @@ private Map<String, Object> getOIDCClaimsFromFederatedUserAttributes(Map<ClaimMa
return OIDCClaimUtil.getMergedUserClaimsInOIDCDialect(spTenantDomain, userClaimsInOidcDialect);
}

/**
* Get user claims in OIDC dialect.
*
* @param authzReqMessageContext OAuthAuthzReqMessageContext
* @return User claims in OIDC dialect
* @throws IdentityOAuth2Exception IdentityOAuth2Exception
*/
private Map<String, Object> getUserClaimsInOIDCDialect(OAuthAuthzReqMessageContext authzReqMessageContext)
throws IdentityOAuth2Exception {

Expand All @@ -308,11 +339,12 @@ private Map<String, Object> getUserClaimsInOIDCDialect(OAuthAuthzReqMessageConte
}
} else {
userClaimsInOIDCDialect = getOIDCClaimMapFromUserAttributes(userAttributes);
Map<ClaimMapping, String> unfilteredUserAttributes =
// Since this is a federated flow we are retrieving the federated user attributes as well.
Map<ClaimMapping, String> federatedUserAttributes =
getUserAttributesCachedAgainstToken(getAccessToken(authzReqMessageContext), true);
Map<String, Object> federatedUserClaimsInOIDCDialect =
getUserClaimsInOIDCDialectFromFederatedUserAttributed(authzReqMessageContext
.getAuthorizationReqDTO().getTenantDomain(), unfilteredUserAttributes);
getUserClaimsInOIDCDialectFromFederatedUserAttributes(authzReqMessageContext
.getAuthorizationReqDTO().getTenantDomain(), federatedUserAttributes);
userClaimsInOIDCDialect.putAll(federatedUserClaimsInOIDCDialect);
}
return filterClaims(userClaimsInOIDCDialect, authzReqMessageContext);
Expand Down Expand Up @@ -412,6 +444,13 @@ entries in the persistence layer(SessionStore).
return userAttributes;
}

/**
* Get user attributes cached against the authorization code.
*
* @param authorizationCode Authorization Code
* @param fetchFederatedUserAttr Flag to indicate whether to fetch federated user attributes.
* @return User attributes cached against the authorization code
*/
private Map<ClaimMapping, String> getUserAttributesCachedAgainstAuthorizationCode(String authorizationCode,
boolean fetchFederatedUserAttr) {

Expand All @@ -423,6 +462,13 @@ private Map<ClaimMapping, String> getUserAttributesCachedAgainstAuthorizationCod
return userAttributes;
}

/**
* GEt user attributes cached against the device code.
*
* @param deviceCode Device Code
* @param fetchFederatedUserAttributes Flag to indicate whether to fetch federated user attributes.
* @return User attributes cached against the device code
*/
private Map<ClaimMapping, String> getUserAttributesCachedAgainstDeviceCode(String deviceCode,
boolean fetchFederatedUserAttributes) {

Expand Down Expand Up @@ -463,6 +509,13 @@ private Map<ClaimMapping, String> getUserAttributesFromCacheUsingCode(String aut
return cacheEntry == null ? new HashMap<>() : cacheEntry.getUserAttributes();
}

/**
* Get user attributes cached against the access token.
*
* @param accessToken Access Token
* @param fetchFederatedUserAttributes Flag to indicate whether to fetch federated user attributes.
* @return User attributes cached against the access token
*/
private Map<ClaimMapping, String> getUserAttributesCachedAgainstToken(String accessToken,
boolean fetchFederatedUserAttributes) {
Map<ClaimMapping, String> userAttributes = Collections.emptyMap();
Expand All @@ -473,6 +526,13 @@ private Map<ClaimMapping, String> getUserAttributesCachedAgainstToken(String acc
return userAttributes;
}

/**
* Get claims for local user form userstore.
*
* @param authzReqMessageContext OAuthAuthzReqMessageContext
* @return Local user claims
* @throws IdentityOAuth2Exception IdentityOAuth2Exception
*/
private Map<String, Object> retrieveClaimsForLocalUser(OAuthAuthzReqMessageContext authzReqMessageContext)
throws IdentityOAuth2Exception {

Expand All @@ -481,7 +541,7 @@ private Map<String, Object> retrieveClaimsForLocalUser(OAuthAuthzReqMessageConte
String clientId = authzReqMessageContext.getAuthorizationReqDTO().getConsumerKey();
AuthenticatedUser authenticatedUser = authzReqMessageContext.getAuthorizationReqDTO().getUser();

return getUserClaimsInOIDCDialect(spTenantDomain, clientId, authenticatedUser);
return getLocalUserClaimsInOIDCDialect(spTenantDomain, clientId, authenticatedUser);
} catch (UserStoreException | IdentityApplicationManagementException | IdentityException |
OrganizationManagementException e) {
if (FrameworkUtils.isContinueOnClaimHandlingErrorAllowed()) {
Expand Down Expand Up @@ -522,12 +582,13 @@ private Map<String, Object> retrieveClaimsForFederatedUser(OAuthAuthzReqMessageC
return userClaimsMappedToOIDCDialect;
}
Map<ClaimMapping, String> userAttributes = authenticatedUser.getUserAttributes();
Map<ClaimMapping, String> unfilteredFederatedUserAttributes =
// Since this is a federated flow we are retrieving the federated user attributes as well.
Map<ClaimMapping, String> federatedUserAttributes =
oAuth2AuthorizeReqDTO.getUnfilteredFederatedUserAttributes();
userClaimsMappedToOIDCDialect = getOIDCClaimMapFromUserAttributes(userAttributes);
Map<String, Object> federatedUserClaimsMappedToOIDCDialect =
getUserClaimsInOIDCDialectFromFederatedUserAttributed(authzReqMessageContext.getAuthorizationReqDTO()
.getTenantDomain(), unfilteredFederatedUserAttributes);
getUserClaimsInOIDCDialectFromFederatedUserAttributes(authzReqMessageContext.getAuthorizationReqDTO()
.getTenantDomain(), federatedUserAttributes);
userClaimsMappedToOIDCDialect.putAll(federatedUserClaimsMappedToOIDCDialect);
return userClaimsMappedToOIDCDialect;
}
Expand All @@ -549,9 +610,17 @@ private Map<String, Object> getOIDCClaimMapFromUserAttributes(Map<ClaimMapping,
return claims;
}

private static Map<String, Object> getUserClaimsInOIDCDialectFromFederatedUserAttributed(String spTenantDomain,
Map<ClaimMapping, String>
unfilteredUserAttributed)
/**
* Get user claims in OIDC claim dialect from federated user attributes.
*
* @param spTenantDomain Service Provider Tenant Domain
* @param federatedUserAttr Federated User Attributes
* @return User claims in OIDC dialect
* @throws IdentityOAuth2Exception Identity OAuth2 Exception
*/
private static Map<String, Object> getUserClaimsInOIDCDialectFromFederatedUserAttributes(String spTenantDomain,
Map<ClaimMapping, String>
federatedUserAttr)
throws IdentityOAuth2Exception {

// Retrieve OIDC to Local Claim Mappings.
Expand All @@ -564,8 +633,8 @@ private static Map<String, Object> getUserClaimsInOIDCDialectFromFederatedUserAt
}
// Get user claims in OIDC dialect.
Map<String, Object> userClaimsInOidcDialect = new HashMap<>();
if (MapUtils.isNotEmpty(unfilteredUserAttributed)) {
for (Map.Entry<ClaimMapping, String> userAttribute : unfilteredUserAttributed.entrySet()) {
if (MapUtils.isNotEmpty(federatedUserAttr)) {
for (Map.Entry<ClaimMapping, String> userAttribute : federatedUserAttr.entrySet()) {
ClaimMapping claimMapping = userAttribute.getKey();
String claimValue = userAttribute.getValue();
if (oidcToLocalClaimMappings.containsValue(claimMapping.getLocalClaim().getClaimUri())) {
Expand All @@ -590,36 +659,19 @@ private static Map<String, Object> getUserClaimsInOIDCDialectFromFederatedUserAt
}

/**
* Get user claims in OIDC claim dialect.
* Get user claims in OIDC claim dialect from userstore.
*
* @param oidcToLocalClaimMappings OIDC dialect to Local dialect claim mappings
* @param userClaims User claims in local dialect
* @return Map of user claim values in OIDC dialect.
* @param spTenantDomain Service Provider Tenant Domain
* @param clientId Client Id
* @param authenticatedUser Authenticated User
* @return User claims in OIDC dialect
* @throws IdentityApplicationManagementException Identity Application Management Exception
* @throws IdentityException Identity Exception
* @throws UserStoreException User Store Exception
* @throws OrganizationManagementException Organization Management Exception
*/
private static Map<String, Object> getUserClaimsInOidcDialect(Map<String, String> oidcToLocalClaimMappings,
Map<String, String> userClaims) {

Map<String, Object> userClaimsInOidcDialect = new HashMap<>();
if (MapUtils.isNotEmpty(userClaims)) {
// Map<"email", "http://wso2.org/claims/emailaddress">
for (Map.Entry<String, String> claimMapping : oidcToLocalClaimMappings.entrySet()) {
String claimValue = userClaims.get(claimMapping.getValue());
if (claimValue != null) {
String oidcClaimUri = claimMapping.getKey();
userClaimsInOidcDialect.put(oidcClaimUri, claimValue);
if (log.isDebugEnabled() &&
IdentityUtil.isTokenLoggable(IdentityConstants.IdentityTokens.USER_CLAIMS)) {
log.debug("Mapped claim: key - " + oidcClaimUri + " value - " + claimValue);
}
}
}
}

return userClaimsInOidcDialect;
}

private Map<String, Object> getUserClaimsInOIDCDialect(String spTenantDomain, String clientId,
AuthenticatedUser authenticatedUser)
private Map<String, Object> getLocalUserClaimsInOIDCDialect(String spTenantDomain, String clientId,
AuthenticatedUser authenticatedUser)
throws IdentityApplicationManagementException, IdentityException, UserStoreException,
OrganizationManagementException {

Expand All @@ -644,23 +696,6 @@ private Map<String, Object> getUserClaimsInOIDCDialect(String spTenantDomain, St
return OIDCClaimUtil.getUserClaimsInOIDCDialect(serviceProvider, authenticatedUser, localClaimURIs);
}

/**
* Get claims map.
*
* @param userAttributes User Attributes
* @return User attribute map
*/
private Map<String, Object> getClaimMapFromUserAttributes(Map<ClaimMapping, String> userAttributes) {

Map<String, Object> claims = new HashMap<>();
if (isNotEmpty(userAttributes)) {
for (Map.Entry<ClaimMapping, String> entry : userAttributes.entrySet()) {
claims.put(entry.getKey().getRemoteClaim().getClaimUri(), entry.getValue());
}
}
return claims;
}

/**
* Get user attribute cached against the access token.
*
Expand Down Expand Up @@ -688,14 +723,17 @@ private Map<ClaimMapping, String> getUserAttributesFromCacheUsingToken(String ac
}

private String getAuthorizationCode(OAuthTokenReqMessageContext requestMsgCtx) {

return (String) requestMsgCtx.getProperty(AUTHZ_CODE);
}

private String getAccessToken(OAuthAuthzReqMessageContext authzReqMessageContext) {

return (String) authzReqMessageContext.getProperty(ACCESS_TOKEN);
}

private String getAccessToken(OAuthTokenReqMessageContext requestMsgCtx) {

return (String) requestMsgCtx.getProperty(ACCESS_TOKEN);
}

Expand Down Expand Up @@ -966,6 +1004,12 @@ private boolean isOrganizationSsoUserSwitchingOrganization(AuthenticatedUser aut
(accessingOrganization);
}

/**
* Check whether grant type is organization switch grant.
*
* @param requestMsgCtx OAuthTokenReqMessageContext
* @return true if grant type is organization switch grant.
*/
private boolean isOrganizationSwitchGrantType(OAuthTokenReqMessageContext requestMsgCtx) {

return StringUtils.equals(requestMsgCtx.getOauth2AccessTokenReqDTO().getGrantType(),
Expand Down
Loading

0 comments on commit 6867a69

Please sign in to comment.