Skip to content

Commit

Permalink
Merge pull request #2640 from SujanSanjula96/role-v2-patch-token-revo…
Browse files Browse the repository at this point in the history
…cation

Fix unnecessary DB queries when revoking tokens
  • Loading branch information
SujanSanjula96 authored Jan 7, 2025
2 parents f511999 + c14d5ef commit 8649c99
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
import org.wso2.carbon.identity.application.authentication.framework.model.AuthenticatedUser;
import org.wso2.carbon.identity.application.common.IdentityApplicationManagementException;
import org.wso2.carbon.identity.application.common.model.IdentityProvider;
import org.wso2.carbon.identity.application.common.model.InboundAuthenticationRequestConfig;
import org.wso2.carbon.identity.application.common.model.ServiceProvider;
import org.wso2.carbon.identity.application.common.model.User;
import org.wso2.carbon.identity.application.mgt.ApplicationConstants;
import org.wso2.carbon.identity.application.mgt.ApplicationConstants.StandardInboundProtocols;
import org.wso2.carbon.identity.application.mgt.ApplicationManagementService;
import org.wso2.carbon.identity.base.IdentityConstants;
import org.wso2.carbon.identity.central.log.mgt.utils.LoggerUtils;
Expand Down Expand Up @@ -66,8 +67,7 @@
import org.wso2.carbon.identity.role.v2.mgt.core.RoleConstants;
import org.wso2.carbon.identity.role.v2.mgt.core.RoleManagementService;
import org.wso2.carbon.identity.role.v2.mgt.core.exception.IdentityRoleManagementException;
import org.wso2.carbon.identity.role.v2.mgt.core.model.AssociatedApplication;
import org.wso2.carbon.identity.role.v2.mgt.core.model.Role;
import org.wso2.carbon.identity.role.v2.mgt.core.model.RoleBasicInfo;
import org.wso2.carbon.idp.mgt.IdentityProviderManagementException;
import org.wso2.carbon.user.api.Tenant;
import org.wso2.carbon.user.core.UserStoreException;
Expand All @@ -86,7 +86,6 @@
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
Expand Down Expand Up @@ -775,51 +774,43 @@ private static AuthenticatedUser buildAuthenticatedUser(UserStoreManager userSto

/**
* Get clientIds of associated application of an application role.
* @param role Role object.
*
* @param role Role basic info object.
* @param authenticatedUser Authenticated user.
* @return Set of clientIds of associated applications.
*/
private static Set<String> getClientIdsOfAssociatedApplications(Role role, AuthenticatedUser authenticatedUser)
private static Optional<String> getClientIdOfAssociatedApplication(RoleBasicInfo role,
AuthenticatedUser authenticatedUser)
throws UserStoreException {

ApplicationManagementService applicationManagementService =
OAuthComponentServiceHolder.getInstance().getApplicationManagementService();
List<String> associatedApplications = role.getAssociatedApplications().stream()
.map(AssociatedApplication::getId).collect(Collectors.toList());
String associatedApplication = role.getAudienceId();
try {
if (authenticatedUser.getUserResidentOrganization() != null) {
List<String> newAssociatedApplications = new ArrayList<>();
for (String app : associatedApplications) {
newAssociatedApplications.add(
SharedAppResolveDAO.getMainApplication(app, authenticatedUser.getAccessingOrganization()));
}
associatedApplications = newAssociatedApplications;
associatedApplication = SharedAppResolveDAO.getMainApplication(
associatedApplication, authenticatedUser.getAccessingOrganization());
}
} catch (IdentityOAuth2Exception e) {
throw new UserStoreException("Error occurred while getting the main applications of the shared apps.", e);
}
Set<String> clientIds = new HashSet<>();
associatedApplications.forEach(associatedApplication -> {
try {
ServiceProvider application = applicationManagementService
.getApplicationByResourceId(associatedApplication, authenticatedUser.getTenantDomain());
if (application == null || application.getInboundAuthenticationConfig() == null) {
return;
}
Arrays.stream(application.getInboundAuthenticationConfig().getInboundAuthenticationRequestConfigs())
.forEach(inboundAuthenticationRequestConfig -> {
if (ApplicationConstants.StandardInboundProtocols.OAUTH2.equals(
inboundAuthenticationRequestConfig.getInboundAuthType())) {
clientIds.add(inboundAuthenticationRequestConfig.getInboundAuthKey());
}
});
} catch (IdentityApplicationManagementException e) {
String errorMessage = "Error occurred while retrieving application of id : " +
associatedApplication;
LOG.error(errorMessage);
try {
ServiceProvider application = applicationManagementService
.getApplicationByResourceId(associatedApplication, authenticatedUser.getTenantDomain());
if (application != null && application.getInboundAuthenticationConfig() != null) {
InboundAuthenticationRequestConfig[] inboundAuthenticationRequestConfigs =
application.getInboundAuthenticationConfig().getInboundAuthenticationRequestConfigs();
return Arrays.stream(inboundAuthenticationRequestConfigs)
.filter(config -> StandardInboundProtocols.OAUTH2.equals(config.getInboundAuthType()))
.map(InboundAuthenticationRequestConfig::getInboundAuthKey)
.findFirst();
}
});
return clientIds;
} catch (IdentityApplicationManagementException e) {
String errorMessage = "Error occurred while retrieving application of id : " +
associatedApplication;
LOG.error(errorMessage);
}
return Optional.empty();
}

private static Set<String> filterClientIdsWithOrganizationAudience(List<String> clientIds, String tenantDomain) {
Expand Down Expand Up @@ -849,14 +840,14 @@ private static Set<String> filterClientIdsWithOrganizationAudience(List<String>
* @param tenantDomain Tenant domain.
* @return Role.
*/
private static Role getRole(String roleId, String tenantDomain) throws UserStoreException {
private static RoleBasicInfo getRoleBasicInfo(String roleId, String tenantDomain) throws UserStoreException {

try {
RoleManagementService roleV2ManagementService =
OAuthComponentServiceHolder.getInstance().getRoleV2ManagementService();
return roleV2ManagementService.getRole(roleId, tenantDomain);
return roleV2ManagementService.getRoleBasicInfoById(roleId, tenantDomain);
} catch (IdentityRoleManagementException e) {
String errorMessage = "Error occurred while retrieving role of id : " + roleId;
String errorMessage = "Error occurred while retrieving basic role info of id : " + roleId;
throw new UserStoreException(errorMessage, e);
}
}
Expand Down Expand Up @@ -1011,18 +1002,19 @@ public static boolean revokeTokens(String username, UserStoreManager userStoreMa
}

// Get details about the role to identify the audience and associated applications.
Set<String> clientIds = null;
Role role = null;
Set<String> clientIds = new HashSet<>();;
RoleBasicInfo role = null;
boolean getClientIdsFromUser = false;
if (roleId != null) {
role = getRole(roleId, IdentityTenantUtil.getTenantDomain(userStoreManager.getTenantId()));
role = getRoleBasicInfo(roleId, tenantDomain);
if (role != null && RoleConstants.APPLICATION.equals(role.getAudience())) {
// Get clientIds of associated applications for the specific application role.
if (LOG.isDebugEnabled()) {
LOG.debug("Get clientIds of associated applications for the application role: "
+ role.getName());
}
clientIds = getClientIdsOfAssociatedApplications(role, authenticatedUser);
getClientIdOfAssociatedApplication(role, authenticatedUser)
.ifPresent(clientIds::add);
} else {
// Get all the distinct client Ids authorized by this user since this is an organization role.
if (LOG.isDebugEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* Copyright (c) 2017, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
* Copyright (c) 2017-2025, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 Inc. licenses this file to you under the Apache License,
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -11,23 +11,58 @@
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.wso2.carbon.identity.oauth;

import org.apache.commons.lang.StringUtils;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import org.wso2.carbon.identity.application.authentication.framework.model.AuthenticatedUser;
import org.wso2.carbon.identity.application.common.model.InboundAuthenticationConfig;
import org.wso2.carbon.identity.application.common.model.InboundAuthenticationRequestConfig;
import org.wso2.carbon.identity.application.common.model.ServiceProvider;
import org.wso2.carbon.identity.application.common.model.User;
import org.wso2.carbon.identity.application.mgt.ApplicationConstants;
import org.wso2.carbon.identity.application.mgt.ApplicationManagementService;
import org.wso2.carbon.identity.common.testng.WithCarbonHome;
import org.wso2.carbon.identity.common.testng.WithRealmService;
import org.wso2.carbon.identity.oauth.cache.CacheEntry;
import org.wso2.carbon.identity.oauth.cache.OAuthCache;
import org.wso2.carbon.identity.oauth.cache.OAuthCacheKey;

import org.wso2.carbon.identity.oauth.internal.OAuthComponentServiceHolder;
import org.wso2.carbon.identity.oauth2.dao.AccessTokenDAO;
import org.wso2.carbon.identity.oauth2.dao.OAuthTokenPersistenceFactory;
import org.wso2.carbon.identity.oauth2.model.AccessTokenDO;
import org.wso2.carbon.identity.oauth2.util.OAuth2Util;
import org.wso2.carbon.identity.organization.management.service.util.OrganizationManagementUtil;
import org.wso2.carbon.identity.role.v2.mgt.core.RoleConstants;
import org.wso2.carbon.identity.role.v2.mgt.core.RoleManagementService;
import org.wso2.carbon.identity.role.v2.mgt.core.model.RoleBasicInfo;
import org.wso2.carbon.user.api.RealmConfiguration;
import org.wso2.carbon.user.core.UserStoreManager;
import org.wso2.carbon.utils.multitenancy.MultitenantConstants;

import java.util.HashSet;
import java.util.Set;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
Expand All @@ -40,6 +75,37 @@
@WithCarbonHome
@WithRealmService
public class OAuthUtilTest {

@Mock
RoleManagementService roleManagementService;
@Mock
ApplicationManagementService applicationManagementService;

private AutoCloseable closeable;
private MockedStatic<OrganizationManagementUtil> organizationManagementUtil;
private MockedStatic<OAuthComponentServiceHolder> oAuthComponentServiceHolder;
private MockedStatic<OAuth2Util> oAuth2Util;
private MockedStatic<OAuthTokenPersistenceFactory> oAuthTokenPersistenceFactory;

@BeforeMethod
public void setUp() throws Exception {

organizationManagementUtil = mockStatic(OrganizationManagementUtil.class);
oAuthComponentServiceHolder = mockStatic(OAuthComponentServiceHolder.class);
oAuth2Util = mockStatic(OAuth2Util.class);
oAuthTokenPersistenceFactory = mockStatic(OAuthTokenPersistenceFactory.class);
closeable = MockitoAnnotations.openMocks(this);
}

@AfterMethod
public void tearDown() throws Exception {

organizationManagementUtil.close();
oAuthComponentServiceHolder.close();
oAuth2Util.close();
oAuthTokenPersistenceFactory.close();
closeable.close();
}

@DataProvider(name = "testGetAuthenticatedUser")
public Object[][] fullQualifiedUserName() {
Expand Down Expand Up @@ -160,6 +226,71 @@ public void testGetAuthenticatedUserException() throws Exception {
OAuthUtil.getAuthenticatedUser("");
}

@Test
public void testRevokeTokensForApplicationAudienceRoles() throws Exception {

String username = "testUser";
String roleId = "testRoleId";
String roleName = "testRole";
String appId = "testAppId";
String clientId = "testClientId";
String accessToken = "testAccessToken";

UserStoreManager userStoreManager = mock(UserStoreManager.class);
when(userStoreManager.getTenantId()).thenReturn(-1234);
when(userStoreManager.getRealmConfiguration()).thenReturn(mock(RealmConfiguration.class));
when(userStoreManager.getRealmConfiguration().getUserStoreProperty(anyString())).thenReturn("PRIMARY");

when(OrganizationManagementUtil.isOrganization(anyString())).thenReturn(false);
when(OAuth2Util.getTenantId(anyString())).thenReturn(-1234);

OAuthComponentServiceHolder mockOAuthComponentServiceHolder = mock(OAuthComponentServiceHolder.class);
when(OAuthComponentServiceHolder.getInstance()).thenReturn(mockOAuthComponentServiceHolder);

when(mockOAuthComponentServiceHolder.getRoleV2ManagementService()).thenReturn(roleManagementService);
RoleBasicInfo roleBasicInfo = new RoleBasicInfo();
roleBasicInfo.setId(roleId);
roleBasicInfo.setAudience(RoleConstants.APPLICATION);
roleBasicInfo.setAudienceId(appId);
roleBasicInfo.setName(roleName);
when(roleManagementService.getRoleBasicInfoById(roleId, MultitenantConstants.SUPER_TENANT_DOMAIN_NAME))
.thenReturn(roleBasicInfo);

when(mockOAuthComponentServiceHolder.getApplicationManagementService())
.thenReturn(applicationManagementService);
ServiceProvider serviceProvider = new ServiceProvider();
InboundAuthenticationConfig inboundAuthenticationConfig = new InboundAuthenticationConfig();
InboundAuthenticationRequestConfig[] inboundAuthenticationRequestConfigs =
new InboundAuthenticationRequestConfig[1];
InboundAuthenticationRequestConfig inboundAuthenticationRequestConfig =
new InboundAuthenticationRequestConfig();
inboundAuthenticationRequestConfig.setInboundAuthKey(clientId);
inboundAuthenticationRequestConfig.setInboundAuthType(ApplicationConstants.StandardInboundProtocols.OAUTH2);
inboundAuthenticationRequestConfigs[0] = inboundAuthenticationRequestConfig;
inboundAuthenticationConfig.setInboundAuthenticationRequestConfigs(inboundAuthenticationRequestConfigs);
serviceProvider.setInboundAuthenticationConfig(inboundAuthenticationConfig);
when(applicationManagementService.getApplicationByResourceId(
appId, MultitenantConstants.SUPER_TENANT_DOMAIN_NAME)).thenReturn(serviceProvider);

OAuthTokenPersistenceFactory mockOAuthTokenPersistenceFactory = mock(OAuthTokenPersistenceFactory.class);
when(OAuthTokenPersistenceFactory.getInstance()).thenReturn(mockOAuthTokenPersistenceFactory);
AccessTokenDAO mockAccessTokenDAO = mock(AccessTokenDAO.class);
when(mockOAuthTokenPersistenceFactory.getAccessTokenDAO()).thenReturn(mockAccessTokenDAO);
Set<AccessTokenDO> accessTokens = new HashSet<>();
AccessTokenDO accessTokenDO = new AccessTokenDO();
accessTokenDO.setAccessToken(accessToken);
accessTokenDO.setConsumerKey(clientId);
accessTokenDO.setScope(new String[]{"default"});
accessTokenDO.setAuthzUser(new AuthenticatedUser());
accessTokens.add(accessTokenDO);
when(mockAccessTokenDAO.getAccessTokens(anyString(),
any(AuthenticatedUser.class), nullable(String.class), anyBoolean())).thenReturn(accessTokens);

boolean result = OAuthUtil.revokeTokens(username, userStoreManager, roleId);
verify(mockAccessTokenDAO, times(1)).revokeAccessTokens(any(), anyBoolean());
assertTrue(result, "Token revocation failed.");
}

private OAuthCache getOAuthCache(OAuthCacheKey oAuthCacheKey) {


Expand Down

0 comments on commit 8649c99

Please sign in to comment.