Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unnecessary DB queries when revoking tokens #2640

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
import org.wso2.carbon.identity.application.authentication.framework.model.AuthenticatedUser;
import org.wso2.carbon.identity.application.common.IdentityApplicationManagementException;
import org.wso2.carbon.identity.application.common.model.IdentityProvider;
import org.wso2.carbon.identity.application.common.model.InboundAuthenticationRequestConfig;
import org.wso2.carbon.identity.application.common.model.ServiceProvider;
import org.wso2.carbon.identity.application.common.model.User;
import org.wso2.carbon.identity.application.mgt.ApplicationConstants;
import org.wso2.carbon.identity.application.mgt.ApplicationConstants.StandardInboundProtocols;
import org.wso2.carbon.identity.application.mgt.ApplicationManagementService;
import org.wso2.carbon.identity.base.IdentityConstants;
import org.wso2.carbon.identity.central.log.mgt.utils.LoggerUtils;
Expand Down Expand Up @@ -66,8 +67,7 @@
import org.wso2.carbon.identity.role.v2.mgt.core.RoleConstants;
import org.wso2.carbon.identity.role.v2.mgt.core.RoleManagementService;
import org.wso2.carbon.identity.role.v2.mgt.core.exception.IdentityRoleManagementException;
import org.wso2.carbon.identity.role.v2.mgt.core.model.AssociatedApplication;
import org.wso2.carbon.identity.role.v2.mgt.core.model.Role;
import org.wso2.carbon.identity.role.v2.mgt.core.model.RoleBasicInfo;
import org.wso2.carbon.idp.mgt.IdentityProviderManagementException;
import org.wso2.carbon.user.api.Tenant;
import org.wso2.carbon.user.core.UserStoreException;
Expand All @@ -86,7 +86,6 @@
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
Expand Down Expand Up @@ -775,51 +774,43 @@ private static AuthenticatedUser buildAuthenticatedUser(UserStoreManager userSto

/**
* Get clientIds of associated application of an application role.
* @param role Role object.
*
* @param role Role basic info object.
* @param authenticatedUser Authenticated user.
* @return Set of clientIds of associated applications.
*/
private static Set<String> getClientIdsOfAssociatedApplications(Role role, AuthenticatedUser authenticatedUser)
private static Optional<String> getClientIdOfAssociatedApplication(RoleBasicInfo role,
AuthenticatedUser authenticatedUser)
throws UserStoreException {

ApplicationManagementService applicationManagementService =
OAuthComponentServiceHolder.getInstance().getApplicationManagementService();
List<String> associatedApplications = role.getAssociatedApplications().stream()
.map(AssociatedApplication::getId).collect(Collectors.toList());
String associatedApplication = role.getAudienceId();
try {
if (authenticatedUser.getUserResidentOrganization() != null) {
List<String> newAssociatedApplications = new ArrayList<>();
for (String app : associatedApplications) {
newAssociatedApplications.add(
SharedAppResolveDAO.getMainApplication(app, authenticatedUser.getAccessingOrganization()));
}
associatedApplications = newAssociatedApplications;
associatedApplication = SharedAppResolveDAO.getMainApplication(
associatedApplication, authenticatedUser.getAccessingOrganization());
}
} catch (IdentityOAuth2Exception e) {
throw new UserStoreException("Error occurred while getting the main applications of the shared apps.", e);
}
Set<String> clientIds = new HashSet<>();
associatedApplications.forEach(associatedApplication -> {
try {
ServiceProvider application = applicationManagementService
.getApplicationByResourceId(associatedApplication, authenticatedUser.getTenantDomain());
if (application == null || application.getInboundAuthenticationConfig() == null) {
return;
}
Arrays.stream(application.getInboundAuthenticationConfig().getInboundAuthenticationRequestConfigs())
.forEach(inboundAuthenticationRequestConfig -> {
if (ApplicationConstants.StandardInboundProtocols.OAUTH2.equals(
inboundAuthenticationRequestConfig.getInboundAuthType())) {
clientIds.add(inboundAuthenticationRequestConfig.getInboundAuthKey());
}
});
} catch (IdentityApplicationManagementException e) {
String errorMessage = "Error occurred while retrieving application of id : " +
associatedApplication;
LOG.error(errorMessage);
try {
ServiceProvider application = applicationManagementService
.getApplicationByResourceId(associatedApplication, authenticatedUser.getTenantDomain());
if (application != null && application.getInboundAuthenticationConfig() != null) {
InboundAuthenticationRequestConfig[] inboundAuthenticationRequestConfigs =
application.getInboundAuthenticationConfig().getInboundAuthenticationRequestConfigs();
return Arrays.stream(inboundAuthenticationRequestConfigs)
.filter(config -> StandardInboundProtocols.OAUTH2.equals(config.getInboundAuthType()))
.map(InboundAuthenticationRequestConfig::getInboundAuthKey)
.findFirst();
}
});
return clientIds;
} catch (IdentityApplicationManagementException e) {
String errorMessage = "Error occurred while retrieving application of id : " +
associatedApplication;
LOG.error(errorMessage);
}
return Optional.empty();
}

private static Set<String> filterClientIdsWithOrganizationAudience(List<String> clientIds, String tenantDomain) {
Expand Down Expand Up @@ -849,14 +840,14 @@ private static Set<String> filterClientIdsWithOrganizationAudience(List<String>
* @param tenantDomain Tenant domain.
* @return Role.
*/
private static Role getRole(String roleId, String tenantDomain) throws UserStoreException {
private static RoleBasicInfo getRoleBasicInfo(String roleId, String tenantDomain) throws UserStoreException {

try {
RoleManagementService roleV2ManagementService =
OAuthComponentServiceHolder.getInstance().getRoleV2ManagementService();
return roleV2ManagementService.getRole(roleId, tenantDomain);
return roleV2ManagementService.getRoleBasicInfoById(roleId, tenantDomain);
} catch (IdentityRoleManagementException e) {
String errorMessage = "Error occurred while retrieving role of id : " + roleId;
String errorMessage = "Error occurred while retrieving basic role info of id : " + roleId;
throw new UserStoreException(errorMessage, e);
}
}
Expand Down Expand Up @@ -1011,18 +1002,19 @@ public static boolean revokeTokens(String username, UserStoreManager userStoreMa
}

// Get details about the role to identify the audience and associated applications.
Set<String> clientIds = null;
Role role = null;
Set<String> clientIds = new HashSet<>();;
RoleBasicInfo role = null;
boolean getClientIdsFromUser = false;
if (roleId != null) {
role = getRole(roleId, IdentityTenantUtil.getTenantDomain(userStoreManager.getTenantId()));
role = getRoleBasicInfo(roleId, tenantDomain);
if (role != null && RoleConstants.APPLICATION.equals(role.getAudience())) {
// Get clientIds of associated applications for the specific application role.
if (LOG.isDebugEnabled()) {
LOG.debug("Get clientIds of associated applications for the application role: "
+ role.getName());
}
clientIds = getClientIdsOfAssociatedApplications(role, authenticatedUser);
getClientIdOfAssociatedApplication(role, authenticatedUser)
.ifPresent(clientIds::add);
} else {
// Get all the distinct client Ids authorized by this user since this is an organization role.
if (LOG.isDebugEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* Copyright (c) 2017, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
* Copyright (c) 2017-2025, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 Inc. licenses this file to you under the Apache License,
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
Expand All @@ -11,23 +11,58 @@
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.wso2.carbon.identity.oauth;

import org.apache.commons.lang.StringUtils;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import org.wso2.carbon.identity.application.authentication.framework.model.AuthenticatedUser;
import org.wso2.carbon.identity.application.common.model.InboundAuthenticationConfig;
import org.wso2.carbon.identity.application.common.model.InboundAuthenticationRequestConfig;
import org.wso2.carbon.identity.application.common.model.ServiceProvider;
import org.wso2.carbon.identity.application.common.model.User;
import org.wso2.carbon.identity.application.mgt.ApplicationConstants;
import org.wso2.carbon.identity.application.mgt.ApplicationManagementService;
import org.wso2.carbon.identity.common.testng.WithCarbonHome;
import org.wso2.carbon.identity.common.testng.WithRealmService;
import org.wso2.carbon.identity.oauth.cache.CacheEntry;
import org.wso2.carbon.identity.oauth.cache.OAuthCache;
import org.wso2.carbon.identity.oauth.cache.OAuthCacheKey;

import org.wso2.carbon.identity.oauth.internal.OAuthComponentServiceHolder;
import org.wso2.carbon.identity.oauth2.dao.AccessTokenDAO;
import org.wso2.carbon.identity.oauth2.dao.OAuthTokenPersistenceFactory;
import org.wso2.carbon.identity.oauth2.model.AccessTokenDO;
import org.wso2.carbon.identity.oauth2.util.OAuth2Util;
import org.wso2.carbon.identity.organization.management.service.util.OrganizationManagementUtil;
import org.wso2.carbon.identity.role.v2.mgt.core.RoleConstants;
import org.wso2.carbon.identity.role.v2.mgt.core.RoleManagementService;
import org.wso2.carbon.identity.role.v2.mgt.core.model.RoleBasicInfo;
import org.wso2.carbon.user.api.RealmConfiguration;
import org.wso2.carbon.user.core.UserStoreManager;
import org.wso2.carbon.utils.multitenancy.MultitenantConstants;

import java.util.HashSet;
import java.util.Set;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
Expand All @@ -40,6 +75,37 @@
@WithCarbonHome
@WithRealmService
public class OAuthUtilTest {

@Mock
RoleManagementService roleManagementService;
@Mock
ApplicationManagementService applicationManagementService;

private AutoCloseable closeable;
private MockedStatic<OrganizationManagementUtil> organizationManagementUtil;
private MockedStatic<OAuthComponentServiceHolder> oAuthComponentServiceHolder;
private MockedStatic<OAuth2Util> oAuth2Util;
private MockedStatic<OAuthTokenPersistenceFactory> oAuthTokenPersistenceFactory;

@BeforeMethod
public void setUp() throws Exception {

organizationManagementUtil = mockStatic(OrganizationManagementUtil.class);
oAuthComponentServiceHolder = mockStatic(OAuthComponentServiceHolder.class);
oAuth2Util = mockStatic(OAuth2Util.class);
oAuthTokenPersistenceFactory = mockStatic(OAuthTokenPersistenceFactory.class);
closeable = MockitoAnnotations.openMocks(this);
}

@AfterMethod
public void tearDown() throws Exception {

organizationManagementUtil.close();
oAuthComponentServiceHolder.close();
oAuth2Util.close();
oAuthTokenPersistenceFactory.close();
closeable.close();
}

@DataProvider(name = "testGetAuthenticatedUser")
public Object[][] fullQualifiedUserName() {
Expand Down Expand Up @@ -160,6 +226,71 @@ public void testGetAuthenticatedUserException() throws Exception {
OAuthUtil.getAuthenticatedUser("");
}

@Test
public void testRevokeTokensForApplicationAudienceRoles() throws Exception {

String username = "testUser";
String roleId = "testRoleId";
String roleName = "testRole";
String appId = "testAppId";
String clientId = "testClientId";
String accessToken = "testAccessToken";

UserStoreManager userStoreManager = mock(UserStoreManager.class);
when(userStoreManager.getTenantId()).thenReturn(-1234);
when(userStoreManager.getRealmConfiguration()).thenReturn(mock(RealmConfiguration.class));
when(userStoreManager.getRealmConfiguration().getUserStoreProperty(anyString())).thenReturn("PRIMARY");

when(OrganizationManagementUtil.isOrganization(anyString())).thenReturn(false);
when(OAuth2Util.getTenantId(anyString())).thenReturn(-1234);

OAuthComponentServiceHolder mockOAuthComponentServiceHolder = mock(OAuthComponentServiceHolder.class);
when(OAuthComponentServiceHolder.getInstance()).thenReturn(mockOAuthComponentServiceHolder);

when(mockOAuthComponentServiceHolder.getRoleV2ManagementService()).thenReturn(roleManagementService);
RoleBasicInfo roleBasicInfo = new RoleBasicInfo();
roleBasicInfo.setId(roleId);
roleBasicInfo.setAudience(RoleConstants.APPLICATION);
roleBasicInfo.setAudienceId(appId);
roleBasicInfo.setName(roleName);
when(roleManagementService.getRoleBasicInfoById(roleId, MultitenantConstants.SUPER_TENANT_DOMAIN_NAME))
.thenReturn(roleBasicInfo);

when(mockOAuthComponentServiceHolder.getApplicationManagementService())
.thenReturn(applicationManagementService);
ServiceProvider serviceProvider = new ServiceProvider();
InboundAuthenticationConfig inboundAuthenticationConfig = new InboundAuthenticationConfig();
InboundAuthenticationRequestConfig[] inboundAuthenticationRequestConfigs =
new InboundAuthenticationRequestConfig[1];
InboundAuthenticationRequestConfig inboundAuthenticationRequestConfig =
new InboundAuthenticationRequestConfig();
inboundAuthenticationRequestConfig.setInboundAuthKey(clientId);
inboundAuthenticationRequestConfig.setInboundAuthType(ApplicationConstants.StandardInboundProtocols.OAUTH2);
inboundAuthenticationRequestConfigs[0] = inboundAuthenticationRequestConfig;
inboundAuthenticationConfig.setInboundAuthenticationRequestConfigs(inboundAuthenticationRequestConfigs);
serviceProvider.setInboundAuthenticationConfig(inboundAuthenticationConfig);
when(applicationManagementService.getApplicationByResourceId(
appId, MultitenantConstants.SUPER_TENANT_DOMAIN_NAME)).thenReturn(serviceProvider);

OAuthTokenPersistenceFactory mockOAuthTokenPersistenceFactory = mock(OAuthTokenPersistenceFactory.class);
when(OAuthTokenPersistenceFactory.getInstance()).thenReturn(mockOAuthTokenPersistenceFactory);
AccessTokenDAO mockAccessTokenDAO = mock(AccessTokenDAO.class);
when(mockOAuthTokenPersistenceFactory.getAccessTokenDAO()).thenReturn(mockAccessTokenDAO);
Set<AccessTokenDO> accessTokens = new HashSet<>();
AccessTokenDO accessTokenDO = new AccessTokenDO();
accessTokenDO.setAccessToken(accessToken);
accessTokenDO.setConsumerKey(clientId);
accessTokenDO.setScope(new String[]{"default"});
accessTokenDO.setAuthzUser(new AuthenticatedUser());
accessTokens.add(accessTokenDO);
when(mockAccessTokenDAO.getAccessTokens(anyString(),
any(AuthenticatedUser.class), nullable(String.class), anyBoolean())).thenReturn(accessTokens);

boolean result = OAuthUtil.revokeTokens(username, userStoreManager, roleId);
verify(mockAccessTokenDAO, times(1)).revokeAccessTokens(any(), anyBoolean());
assertTrue(result, "Token revocation failed.");
}

private OAuthCache getOAuthCache(OAuthCacheKey oAuthCacheKey) {


Expand Down
Loading