diff --git a/common/build.gradle b/common/build.gradle index 3d831aff2b..a572fec814 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -23,6 +23,8 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.15.2' testImplementation "org.opensearch.test:framework:${opensearch_version}" + compileOnly group: 'org.opensearch', name:'opensearch-security-spi', version:"${opensearch_build}" + compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.11.0' compileOnly group: 'org.json', name: 'json', version: '20231013' diff --git a/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java b/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java new file mode 100644 index 0000000000..5d0978dbc2 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/ResourceSharingClientAccessor.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import org.opensearch.security.spi.resources.client.ResourceSharingClient; + +/** + * Accessor for resource sharing client + */ +public class ResourceSharingClientAccessor { + private ResourceSharingClient CLIENT; + + private static ResourceSharingClientAccessor resourceSharingClientAccessor; + + private ResourceSharingClientAccessor() {} + + public static ResourceSharingClientAccessor getInstance() { + if (resourceSharingClientAccessor == null) { + resourceSharingClientAccessor = new ResourceSharingClientAccessor(); + } + + return resourceSharingClientAccessor; + } + + /** + * Set the resource sharing client + */ + public void setResourceSharingClient(ResourceSharingClient client) { + resourceSharingClientAccessor.CLIENT = client; + } + + /** + * Get the resource sharing client + */ + public ResourceSharingClient getResourceSharingClient() { + return resourceSharingClientAccessor.CLIENT; + } + +} diff --git a/plugin/build.gradle b/plugin/build.gradle index 70ace354ea..d86f752a81 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -45,7 +45,7 @@ opensearchplugin { name 'opensearch-ml' description 'machine learning plugin for opensearch' classname 'org.opensearch.ml.plugin.MachineLearningPlugin' - extendedPlugins = ['opensearch-job-scheduler'] + extendedPlugins = ['opensearch-job-scheduler', 'opensearch-security;optional=true'] } configurations { @@ -71,6 +71,8 @@ dependencies { zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}" + compileOnly group: 'org.opensearch', name:'opensearch-security-spi', version:"${opensearch_build}" + implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" // Multi-tenant SDK Client diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index 92fd0228f1..b04721b954 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -30,6 +30,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; @@ -66,6 +67,7 @@ public class CreateControllerTransportAction extends HandledTransportAction<ActionRequest, MLCreateControllerResponse> { MLIndicesHandler mlIndicesHandler; Client client; + Settings settings; MLModelManager mlModelManager; ClusterService clusterService; MLModelCacheHelper mlModelCacheHelper; @@ -78,6 +80,7 @@ public CreateControllerTransportAction( ActionFilters actionFilters, MLIndicesHandler mlIndicesHandler, Client client, + Settings settings, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLModelCacheHelper mlModelCacheHelper, @@ -87,6 +90,7 @@ public CreateControllerTransportAction( super(MLCreateControllerAction.NAME, transportService, actionFilters, MLCreateControllerRequest::new); this.mlIndicesHandler = mlIndicesHandler; this.client = client; + this.settings = settings; this.mlModelManager = mlModelManager; this.clusterService = clusterService; this.mlModelCacheHelper = mlModelCacheHelper; @@ -112,7 +116,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCrea Boolean isHidden = mlModel.getIsHidden(); if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> { if (hasPermission) { if (mlModel.getModelState() != MLModelState.DEPLOYING) { indexAndCreateController(mlModel, controller, wrappedListener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java index 3be5e07a0b..5e8464bad2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java @@ -27,6 +27,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -55,6 +56,7 @@ @FieldDefaults(level = AccessLevel.PRIVATE) public class DeleteControllerTransportAction extends HandledTransportAction<ActionRequest, DeleteResponse> { Client client; + Settings settings; NamedXContentRegistry xContentRegistry; ClusterService clusterService; MLModelManager mlModelManager; @@ -67,6 +69,7 @@ public DeleteControllerTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, NamedXContentRegistry xContentRegistry, ClusterService clusterService, MLModelManager mlModelManager, @@ -98,7 +101,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> { if (hasPermission) { mlModelManager .getController( diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java index d70488948f..0e040672fa 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java @@ -19,6 +19,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -48,6 +49,7 @@ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class GetControllerTransportAction extends HandledTransportAction<ActionRequest, MLControllerGetResponse> { Client client; + Settings settings; NamedXContentRegistry xContentRegistry; ClusterService clusterService; MLModelManager mlModelManager; @@ -59,6 +61,7 @@ public GetControllerTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, NamedXContentRegistry xContentRegistry, ClusterService clusterService, MLModelManager mlModelManager, @@ -67,6 +70,7 @@ public GetControllerTransportAction( ) { super(MLControllerGetAction.NAME, transportService, actionFilters, MLControllerGetRequest::new); this.client = client; + this.settings = settings; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; this.mlModelManager = mlModelManager; @@ -96,34 +100,40 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCont mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { Boolean isHidden = mlModel.getIsHidden(); modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { - if (hasPermission) { - wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build()); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - getErrorMessage( - "User doesn't have privilege to perform this operation on this model controller.", - modelId, - isHidden - ), - RestStatus.FORBIDDEN - ) + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + client, + settings, + ActionListener.wrap(hasPermission -> { + if (hasPermission) { + wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build()); + } else { + wrappedListener + .onFailure( + new OpenSearchStatusException( + getErrorMessage( + "User doesn't have privilege to perform this operation on this model controller.", + modelId, + isHidden + ), + RestStatus.FORBIDDEN + ) + ); + } + }, exception -> { + log + .error( + getErrorMessage( + "Permission denied: Unable to create the model controller for the given model.", + modelId, + isHidden + ), + exception ); - } - }, exception -> { - log - .error( - getErrorMessage( - "Permission denied: Unable to create the model controller for the given model.", - modelId, - isHidden - ), - exception - ); - wrappedListener.onFailure(exception); - })); + wrappedListener.onFailure(exception); + }) + ); }, e -> wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index ac44069930..ca57388192 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -29,6 +29,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; @@ -60,6 +61,7 @@ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) public class UpdateControllerTransportAction extends HandledTransportAction<ActionRequest, UpdateResponse> { Client client; + Settings settings; MLModelManager mlModelManager; MLModelCacheHelper mlModelCacheHelper; ClusterService clusterService; @@ -71,6 +73,7 @@ public UpdateControllerTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLModelCacheHelper mlModelCacheHelper, @@ -79,6 +82,7 @@ public UpdateControllerTransportAction( ) { super(MLUpdateControllerAction.NAME, transportService, actionFilters, MLUpdateControllerRequest::new); this.client = client; + this.settings = settings; this.mlModelManager = mlModelManager; this.clusterService = clusterService; this.mlModelCacheHelper = mlModelCacheHelper; @@ -104,7 +108,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update Boolean isHidden = mlModel.getIsHidden(); if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(hasPermission -> { if (hasPermission) { mlModelManager.getController(modelId, ActionListener.wrap(controller -> { boolean isDeployRequiredAfterUpdate = controller.isDeployRequiredAfterUpdate(updateControllerInput); diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 9ae795438d..463e3e8ad5 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -177,7 +177,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl } } else { modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, settings, ActionListener.wrap(access -> { if (!access) { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 90584451e0..7ab561cc43 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -20,6 +20,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -65,17 +66,20 @@ public class MLSearchHandler { private ModelAccessControlHelper modelAccessControlHelper; private ClusterService clusterService; + private Settings settings; public MLSearchHandler( Client client, NamedXContentRegistry xContentRegistry, ModelAccessControlHelper modelAccessControlHelper, - ClusterService clusterService + ClusterService clusterService, + Settings settings ) { this.modelAccessControlHelper = modelAccessControlHelper; this.client = client; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; + this.settings = settings; } /** @@ -144,7 +148,7 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId, .searchDataObjectAsync(searchDataObjectRequest) .whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrapperListener)); } else { - SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user); + SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user, settings); SearchRequest modelGroupSearchRequest = new SearchRequest(); sourceBuilder.fetchSource(new String[] { MLModelGroup.MODEL_GROUP_ID_FIELD, }, null); sourceBuilder.size(10000); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index 7a9b3925b4..30abb76b45 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -19,6 +19,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -57,6 +58,7 @@ public class DeleteModelGroupTransportAction extends HandledTransportAction<Acti final SdkClient sdkClient; final NamedXContentRegistry xContentRegistry; final ClusterService clusterService; + final Settings settings; final ModelAccessControlHelper modelAccessControlHelper; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @@ -66,6 +68,7 @@ public DeleteModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, ClusterService clusterService, @@ -74,6 +77,7 @@ public DeleteModelGroupTransportAction( ) { super(MLModelGroupDeleteAction.NAME, transportService, actionFilters, MLModelGroupDeleteRequest::new); this.client = client; + this.settings = settings; this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; @@ -107,6 +111,7 @@ private void validateAndDeleteModelGroup(String modelGroupId, String tenantId, A modelGroupId, client, sdkClient, + settings, ActionListener .wrap( hasAccess -> handleAccessValidation(hasAccess, modelGroupId, tenantId, listener), diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java index f1dbe8be48..5bc07565fe 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/GetModelGroupTransportAction.java @@ -17,6 +17,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; @@ -52,6 +53,7 @@ public class GetModelGroupTransportAction extends HandledTransportAction<ActionRequest, MLModelGroupGetResponse> { final Client client; + final Settings settings; final SdkClient sdkClient; final NamedXContentRegistry xContentRegistry; final ClusterService clusterService; @@ -63,6 +65,7 @@ public GetModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, ClusterService clusterService, @@ -71,6 +74,7 @@ public GetModelGroupTransportAction( ) { super(MLModelGroupGetAction.NAME, transportService, actionFilters, MLModelGroupGetRequest::new); this.client = client; + this.settings = settings; this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; @@ -183,7 +187,7 @@ private void validateModelGroupAccess( MLModelGroup mlModelGroup, ActionListener<MLModelGroupGetResponse> wrappedListener ) { - modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { + modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, settings, ActionListener.wrap(access -> { if (!access) { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index 96af8cd317..e17af1e3ae 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -6,7 +6,10 @@ package org.opensearch.ml.action.model_group; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -14,6 +17,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -35,6 +39,7 @@ @Log4j2 public class SearchModelGroupTransportAction extends HandledTransportAction<MLSearchActionRequest, SearchResponse> { Client client; + Settings settings; SdkClient sdkClient; ClusterService clusterService; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @@ -46,6 +51,7 @@ public SearchModelGroupTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, SdkClient sdkClient, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, @@ -53,6 +59,7 @@ public SearchModelGroupTransportAction( ) { super(MLModelGroupSearchAction.NAME, transportService, actionFilters, MLSearchActionRequest::new); this.client = client; + this.settings = settings; this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; @@ -76,13 +83,19 @@ private void preProcessRoleAndPerformSearch( User user, ActionListener<SearchResponse> listener ) { + boolean isResourceSharingFeatureEnabled = ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED.get(settings) + && this.settings.getAsBoolean(OPENSEARCH_RESOURCE_SHARING_ENABLED, OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, context::restore); final ActionListener<SearchResponse> doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); - if (!modelAccessControlHelper.skipModelAccessControl(user)) { + // TODO: Remove this feature flag check once feature is GA, as it will be enabled by default + if (isResourceSharingFeatureEnabled) { + // User will be fetched from thread context using persistent header, so stash context will not stash user info + modelAccessControlHelper.addAccessibleModelGroupsFilter(request.source()); + } else if (!modelAccessControlHelper.skipModelAccessControl(user)) { // Security is enabled, filter is enabled and user isn't admin modelAccessControlHelper.addUserBackendRolesFilter(user, request.source()); log.debug("Filtering result by {}", user.getBackendRoles()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index d3ab730f0d..0eae2483f9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -8,13 +8,21 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.utils.MLExceptionUtils.logException; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT; +import static org.opensearch.security.spi.resources.ResourceAccessLevels.PLACE_HOLDER; import java.time.Instant; +import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; @@ -23,6 +31,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; @@ -35,6 +44,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction; @@ -51,6 +61,10 @@ import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; +import org.opensearch.security.spi.resources.sharing.Recipient; +import org.opensearch.security.spi.resources.sharing.Recipients; +import org.opensearch.security.spi.resources.sharing.ShareWith; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -66,6 +80,7 @@ public class TransportUpdateModelGroupAction extends HandledTransportAction<Acti private final ActionFilters actionFilters; private Client client; final SdkClient sdkClient; + final Settings settings; private NamedXContentRegistry xContentRegistry; ClusterService clusterService; @@ -78,6 +93,7 @@ public TransportUpdateModelGroupAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, ClusterService clusterService, @@ -89,6 +105,7 @@ public TransportUpdateModelGroupAction( this.actionFilters = actionFilters; this.transportService = transportService; this.client = client; + this.settings = settings; this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; @@ -107,6 +124,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda return; } User user = RestActionUtils.getUserContext(client); + boolean isResourceSharingFeatureEnabled = ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED.get(settings) + && this.settings.getAsBoolean(OPENSEARCH_RESOURCE_SHARING_ENABLED, OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT); FetchSourceContext fetchSourceContext = new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY); GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest .builder() @@ -146,12 +165,74 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda mlModelGroup.getTenantId(), wrappedListener )) { - if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { - validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); + // TODO: Remove this feature flag check once feature is GA, as it will be enabled by default + if (isResourceSharingFeatureEnabled) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor + .getInstance() + .getResourceSharingClient(); + resourceSharingClient + .verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, ActionListener.wrap(isAuthorized -> { + if (!isAuthorized) { + listener + .onFailure( + new OpenSearchStatusException( + "User " + + user.getName() + + " is not authorized to update ml-model-group: " + + mlModelGroup.getName(), + RestStatus.FORBIDDEN + ) + ); + return; + } + + Pair<ShareWith, ShareWith> shareAndRevoke = buildShareAndRevokeEntities( + r.source(), + updateModelGroupInput, + user + ); + + // Share and revoke based on updates + resourceSharingClient + .share( + modelGroupId, + ML_MODEL_GROUP_INDEX, + shareAndRevoke.getLeft(), + ActionListener.wrap(res -> { + resourceSharingClient + .revoke( + modelGroupId, + ML_MODEL_GROUP_INDEX, + shareAndRevoke.getRight(), + ActionListener.wrap(res2 -> { + // For backwards compatibility we still allow storing backend_roles + // data in ml_model_group + // index + updateModelGroup( + modelGroupId, + r.source(), + updateModelGroupInput, + wrappedListener, + user + ); + }, listener::onFailure) + ); + }, listener::onFailure) + ); + + }, listener::onFailure)); } else { - validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); + // TODO: At some point, this call must be replaced by the one above, (i.e. no user info to + // be stored in model-group index) + if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { + validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); + } else { + validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); + } + + updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user); } - updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user); + } } catch (Exception e) { log.error("Failed to parse ml connector {}", r.id(), e); @@ -177,6 +258,54 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda } } + @SuppressWarnings("unchecked") + private Pair<ShareWith, ShareWith> buildShareAndRevokeEntities( + Map<String, Object> source, + MLUpdateModelGroupInput updateModelGroupInput, + User user + ) { + Map<Recipient, Set<String>> shareMap = new HashMap<>(); + Map<Recipient, Set<String>> revokeMap = new HashMap<>(); + if (updateModelGroupInput.getModelAccessMode() != null) { + Set<String> sourceBRs = new HashSet<>((List<String>) (source.getOrDefault(MLModelGroup.BACKEND_ROLES_FIELD, List.of()))); + switch (updateModelGroupInput.getModelAccessMode()) { + case PRIVATE -> { + // revoke all accesses + revokeMap.put(Recipient.BACKEND_ROLES, sourceBRs); + } + case RESTRICTED -> { + // share with new entries + Set<String> updateBRs = new HashSet<>(updateModelGroupInput.getBackendRoles()); + Set<String> toShare = new HashSet<>(updateBRs); + toShare.removeAll(sourceBRs); + + // Revoke those that are not present in the update but were already shared with + Set<String> toRevoke = new HashSet<>(sourceBRs); + toRevoke.removeAll(updateBRs); + + shareMap = Map.of(Recipient.BACKEND_ROLES, toShare); + revokeMap = Map.of(Recipient.BACKEND_ROLES, toRevoke); + } + case PUBLIC -> // share with * + shareMap = Map.of(Recipient.USERS, Set.of("*"), Recipient.ROLES, Set.of("*"), Recipient.BACKEND_ROLES, Set.of("*")); + default -> { + + } + } + } + if (updateModelGroupInput.getBackendRoles() != null) { + source.put(MLModelGroup.BACKEND_ROLES_FIELD, updateModelGroupInput.getBackendRoles()); + } + if (Boolean.TRUE.equals(updateModelGroupInput.getIsAddAllBackendRoles())) { + source.put(MLModelGroup.BACKEND_ROLES_FIELD, user.getBackendRoles()); + } + + ShareWith share = new ShareWith(Map.of(PLACE_HOLDER, new Recipients(shareMap))); + ShareWith revoke = new ShareWith(Map.of(PLACE_HOLDER, new Recipients(revokeMap))); + + return Pair.of(share, revoke); + } + private void updateModelGroup( String modelGroupId, Map<String, Object> source, diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index c7f9d4ae18..2a121998eb 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -106,12 +106,11 @@ public class DeleteModelTransportAction extends HandledTransportAction<ActionReq Boolean isSafeDelete; final Client client; + final Settings settings; final SdkClient sdkClient; final NamedXContentRegistry xContentRegistry; final ClusterService clusterService; - Settings settings; - final ModelAccessControlHelper modelAccessControlHelper; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; @@ -215,42 +214,55 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete } } else { modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else if (isModelNotDeployed(mlModelState)) { - if (isSafeDelete) { - // We only check downstream task when it's not hidden and cluster setting is true. - checkDownstreamTaskBeforeDeleteModel( - modelId, - tenantId, - mlModel.getAlgorithm().name(), - isHidden, - actionListener - ); + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + client, + settings, + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else if (isModelNotDeployed(mlModelState)) { + if (isSafeDelete) { + // We only check downstream task when it's not hidden and cluster setting is true. + checkDownstreamTaskBeforeDeleteModel( + modelId, + tenantId, + mlModel.getAlgorithm().name(), + isHidden, + actionListener + ); + } else { + deleteModel( + modelId, + tenantId, + mlModel.getAlgorithm().name(), + isHidden, + actionListener + ); + } + // deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, + // actionListener); } else { - deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener); + wrappedListener + .onFailure( + new OpenSearchStatusException( + "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete", + RestStatus.BAD_REQUEST + ) + ); } - // deleteModel(modelId, tenantId, mlModel.getAlgorithm().name(), isHidden, actionListener); - } else { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete", - RestStatus.BAD_REQUEST - ) - ); - } - }, e -> { - log.error(getErrorMessage("Failed to validate Access", modelId, isHidden), e); - wrappedListener.onFailure(e); - })); + }, e -> { + log.error(getErrorMessage("Failed to validate Access", modelId, isHidden), e); + wrappedListener.onFailure(e); + }) + ); } } catch (Exception e) { log.error("Failed to parse ml model {}", r.id(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 64c9eb6676..24246b82b1 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -141,27 +141,33 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode } } else { modelAccessControlHelper - .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure( - new OpenSearchStatusException( - "User doesn't have privilege to perform this operation on this model", - RestStatus.FORBIDDEN - ) - ); - } else { - log.debug("Completed Get Model Request, id:{}", modelId); - Connector connector = mlModel.getConnector(); - if (connector != null) { - connector.removeCredential(); + .validateModelGroupAccess( + user, + mlModel.getModelGroupId(), + client, + settings, + ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have privilege to perform this operation on this model", + RestStatus.FORBIDDEN + ) + ); + } else { + log.debug("Completed Get Model Request, id:{}", modelId); + Connector connector = mlModel.getConnector(); + if (connector != null) { + connector.removeCredential(); + } + wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); } - wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build()); - } - }, e -> { - log.error("Failed to validate Access for Model Id {}", modelId, e); - wrappedListener.onFailure(e); - })); + }, e -> { + log.error("Failed to validate Access for Model Id {}", modelId, e); + wrappedListener.onFailure(e); + }) + ); } } catch (Exception e) { log.error("Failed to parse ml model {}", r.id(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index 4acda2bb21..db5c298888 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -164,6 +164,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update mlModel.getModelGroupId(), client, sdkClient, + settings, ActionListener.wrap(hasPermission -> { if (hasPermission) { updateRemoteOrTextEmbeddingModel( @@ -390,7 +391,7 @@ private void updateModelWithRegisteringToAnotherModelGroup( UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); if (newModelGroupId != null) { modelAccessControlHelper - .validateModelGroupAccess(user, newModelGroupId, client, ActionListener.wrap(hasNewModelGroupPermission -> { + .validateModelGroupAccess(user, newModelGroupId, client, settings, ActionListener.wrap(hasNewModelGroupPermission -> { if (hasNewModelGroupPermission) { mlModelGroupManager.getModelGroupResponse(sdkClient, newModelGroupId, ActionListener.wrap(newModelGroupResponse -> { buildUpdateRequest( diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index ab5944db94..20357ecced 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -57,6 +57,7 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action Client client; SdkClient sdkClient; + Settings settings; ClusterService clusterService; @@ -92,6 +93,7 @@ public TransportPredictionTaskAction( this.clusterService = clusterService; this.client = client; this.sdkClient = sdkClient; + this.settings = settings; this.xContentRegistry = xContentRegistry; this.mlModelManager = mlModelManager; this.modelAccessControlHelper = modelAccessControlHelper; @@ -138,6 +140,7 @@ public void onResponse(MLModel mlModel) { mlModel.getModelGroupId(), client, sdkClient, + settings, ActionListener.wrap(access -> { if (!access) { wrappedListener diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 00c577c8d6..26f3b3f768 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -213,6 +213,7 @@ private void checkUserAccess( registerModelInput.getModelGroupId(), client, sdkClient, + settings, ActionListener.wrap(access -> { if (access) { doRegister(registerModelInput, listener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index 4c1bf76529..e0fafeaa54 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -26,6 +26,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -72,6 +73,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction<Action Client client; NamedXContentRegistry xContentRegistry; + Settings settings; ClusterService clusterService; ScriptService scriptService; @@ -89,6 +91,7 @@ public CancelBatchJobTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, + Settings settings, NamedXContentRegistry xContentRegistry, ClusterService clusterService, ScriptService scriptService, @@ -110,6 +113,7 @@ public CancelBatchJobTransportAction( this.mlTaskManager = mlTaskManager; this.mlModelManager = mlModelManager; this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + this.settings = settings; } @Override @@ -192,35 +196,36 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> { - modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job")); - } else { - if (model.getConnector() != null) { - Connector connector = model.getConnector(); - executeConnector(connector, mlInput, actionListener); - } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { - ActionListener<Connector> listener = ActionListener - .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { - log.error("Failed to get connector {}", model.getConnectorId(), e); - actionListener.onFailure(e); - }); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper - .getConnector( - client, - model.getConnectorId(), - ActionListener.runBefore(listener, threadContext::restore) - ); - } + modelAccessControlHelper + .validateModelGroupAccess(user, model.getModelGroupId(), client, settings, ActionListener.wrap(access -> { + if (!access) { + actionListener.onFailure(new MLValidationException("You don't have permission to cancel this batch job")); } else { - actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + if (model.getConnector() != null) { + Connector connector = model.getConnector(); + executeConnector(connector, mlInput, actionListener); + } else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) { + ActionListener<Connector> listener = ActionListener + .wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> { + log.error("Failed to get connector {}", model.getConnectorId(), e); + actionListener.onFailure(e); + }); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper + .getConnector( + client, + model.getConnectorId(), + ActionListener.runBefore(listener, threadContext::restore) + ); + } + } else { + actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); + } } - } - }, e -> { - log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e); - actionListener.onFailure(e); - })); + }, e -> { + log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), e); + actionListener.onFailure(e); + })); }, e -> { log.error("Failed to retrieve the ML model with the given ID", e); actionListener diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 2422a439d2..ef824350ac 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -111,6 +111,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest Client client; SdkClient sdkClient; NamedXContentRegistry xContentRegistry; + Settings settings; ClusterService clusterService; ScriptService scriptService; @@ -164,6 +165,7 @@ public GetTaskTransportAction( this.mlModelManager = mlModelManager; this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; this.mlEngine = mlEngine; + this.settings = settings; remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it); @@ -374,6 +376,7 @@ private void processRemoteBatchPrediction( model.getModelGroupId(), client, sdkClient, + settings, ActionListener.wrap(access -> { if (!access) { actionListener diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index ada9f1a604..027b1bae7a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -307,6 +307,7 @@ private void validateAccess(String modelId, String tenantId, ActionListener<Bool mlModel.getModelGroupId(), client, sdkClient, + settings, listener ); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java index aaa201b436..a0d042d827 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java @@ -18,6 +18,7 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.support.WriteRequest; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; @@ -45,6 +46,7 @@ public class MLModelChunkUploader { private final MLIndicesHandler mlIndicesHandler; private final Client client; + private final Settings settings; private final NamedXContentRegistry xContentRegistry; ModelAccessControlHelper modelAccessControlHelper; @@ -52,11 +54,13 @@ public class MLModelChunkUploader { public MLModelChunkUploader( MLIndicesHandler mlIndicesHandler, Client client, + Settings settings, final NamedXContentRegistry xContentRegistry, ModelAccessControlHelper modelAccessControlHelper ) { this.mlIndicesHandler = mlIndicesHandler; this.client = client; + this.settings = settings; this.xContentRegistry = xContentRegistry; this.modelAccessControlHelper = modelAccessControlHelper; } @@ -80,115 +84,122 @@ public void uploadModelChunk(MLUploadModelChunkInput uploadModelChunkInput, Acti MLModel existingModel = MLModel.parse(parser, algorithmName); modelAccessControlHelper - .validateModelGroupAccess(user, existingModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - log.error("You don't have permissions to perform this operation on this model."); - wrappedListener - .onFailure( - new IllegalArgumentException( - "You don't have permissions to perform this operation on this model." - ) - ); - } else { - existingModel.setModelId(r.getId()); - if (existingModel.getTotalChunks() <= uploadModelChunkInput.getChunkNumber()) { - throw new Exception("Chunk number exceeds total chunks"); - } - byte[] bytes = uploadModelChunkInput.getContent(); - // Check the size of the content not to exceed 10 mb - if (bytes == null || bytes.length == 0) { - throw new Exception("Chunk size either 0 or null"); - } - if (validateChunkSize(bytes.length)) { - throw new Exception("Chunk size exceeds 10MB"); - } - mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { - if (!res) { - wrappedListener.onFailure(new RuntimeException("No response to create ML Model index")); - return; - } - int chunkNum = uploadModelChunkInput.getChunkNumber(); - MLModel mlModel = MLModel - .builder() - .algorithm(existingModel.getAlgorithm()) - .modelGroupId(existingModel.getModelGroupId()) - .version(existingModel.getVersion()) - .modelId(existingModel.getModelId()) - .modelFormat(existingModel.getModelFormat()) - .totalChunks(existingModel.getTotalChunks()) - .algorithm(existingModel.getAlgorithm()) - .chunkNumber(chunkNum) - .content(Base64.getEncoder().encodeToString(bytes)) - .build(); - IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); - indexRequest.id(uploadModelChunkInput.getModelId() + "_" + uploadModelChunkInput.getChunkNumber()); - indexRequest - .source( - mlModel - .toXContent( - XContentBuilder.builder(XContentType.JSON.xContent()), - ToXContent.EMPTY_PARAMS - ) + .validateModelGroupAccess( + user, + existingModel.getModelGroupId(), + client, + settings, + ActionListener.wrap(access -> { + if (!access) { + log.error("You don't have permissions to perform this operation on this model."); + wrappedListener + .onFailure( + new IllegalArgumentException( + "You don't have permissions to perform this operation on this model." + ) ); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexRequest, ActionListener.wrap(response -> { - log - .info( - "Index model successful for {} for chunk number {}", - uploadModelChunkInput.getModelId(), - chunkNum + 1 + } else { + existingModel.setModelId(r.getId()); + if (existingModel.getTotalChunks() <= uploadModelChunkInput.getChunkNumber()) { + throw new Exception("Chunk number exceeds total chunks"); + } + byte[] bytes = uploadModelChunkInput.getContent(); + // Check the size of the content not to exceed 10 mb + if (bytes == null || bytes.length == 0) { + throw new Exception("Chunk size either 0 or null"); + } + if (validateChunkSize(bytes.length)) { + throw new Exception("Chunk size exceeds 10MB"); + } + mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { + if (!res) { + wrappedListener.onFailure(new RuntimeException("No response to create ML Model index")); + return; + } + int chunkNum = uploadModelChunkInput.getChunkNumber(); + MLModel mlModel = MLModel + .builder() + .algorithm(existingModel.getAlgorithm()) + .modelGroupId(existingModel.getModelGroupId()) + .version(existingModel.getVersion()) + .modelId(existingModel.getModelId()) + .modelFormat(existingModel.getModelFormat()) + .totalChunks(existingModel.getTotalChunks()) + .algorithm(existingModel.getAlgorithm()) + .chunkNumber(chunkNum) + .content(Base64.getEncoder().encodeToString(bytes)) + .build(); + IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); + indexRequest + .id(uploadModelChunkInput.getModelId() + "_" + uploadModelChunkInput.getChunkNumber()); + indexRequest + .source( + mlModel + .toXContent( + XContentBuilder.builder(XContentType.JSON.xContent()), + ToXContent.EMPTY_PARAMS + ) ); - if (existingModel.getTotalChunks() == (uploadModelChunkInput.getChunkNumber() + 1)) { - Semaphore semaphore = new Semaphore(1); - semaphore.acquire(); - MLModel mlModelMeta = MLModel - .builder() - .name(existingModel.getName()) - .algorithm(existingModel.getAlgorithm()) - .version(existingModel.getVersion()) - .modelGroupId((existingModel.getModelGroupId())) - .modelFormat(existingModel.getModelFormat()) - .modelState(MLModelState.REGISTERED) - .modelConfig(existingModel.getModelConfig()) - .totalChunks(existingModel.getTotalChunks()) - .modelContentHash(existingModel.getModelContentHash()) - .modelContentSizeInBytes(existingModel.getModelContentSizeInBytes()) - .createdTime(existingModel.getCreatedTime()) - .build(); - IndexRequest indexReq = new IndexRequest(ML_MODEL_INDEX); - indexReq.id(modelId); - indexReq - .source( - mlModelMeta - .toXContent( - XContentBuilder.builder(XContentType.JSON.xContent()), - ToXContent.EMPTY_PARAMS - ) + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.wrap(response -> { + log + .info( + "Index model successful for {} for chunk number {}", + uploadModelChunkInput.getModelId(), + chunkNum + 1 ); - indexReq.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(indexReq, ActionListener.wrap(re -> { - log.debug("Index model successful", existingModel.getName()); - semaphore.release(); - }, e -> { - log.error("Failed to update model state", e); - semaphore.release(); - wrappedListener.onFailure(e); - })); - } - wrappedListener.onResponse(new MLUploadModelChunkResponse("Uploaded")); - }, e -> { - log.error("Failed to upload chunk model", e); - wrappedListener.onFailure(e); + if (existingModel.getTotalChunks() == (uploadModelChunkInput.getChunkNumber() + 1)) { + Semaphore semaphore = new Semaphore(1); + semaphore.acquire(); + MLModel mlModelMeta = MLModel + .builder() + .name(existingModel.getName()) + .algorithm(existingModel.getAlgorithm()) + .version(existingModel.getVersion()) + .modelGroupId((existingModel.getModelGroupId())) + .modelFormat(existingModel.getModelFormat()) + .modelState(MLModelState.REGISTERED) + .modelConfig(existingModel.getModelConfig()) + .totalChunks(existingModel.getTotalChunks()) + .modelContentHash(existingModel.getModelContentHash()) + .modelContentSizeInBytes(existingModel.getModelContentSizeInBytes()) + .createdTime(existingModel.getCreatedTime()) + .build(); + IndexRequest indexReq = new IndexRequest(ML_MODEL_INDEX); + indexReq.id(modelId); + indexReq + .source( + mlModelMeta + .toXContent( + XContentBuilder.builder(XContentType.JSON.xContent()), + ToXContent.EMPTY_PARAMS + ) + ); + indexReq.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexReq, ActionListener.wrap(re -> { + log.debug("Index model successful", existingModel.getName()); + semaphore.release(); + }, e -> { + log.error("Failed to update model state", e); + semaphore.release(); + wrappedListener.onFailure(e); + })); + } + wrappedListener.onResponse(new MLUploadModelChunkResponse("Uploaded")); + }, e -> { + log.error("Failed to upload chunk model", e); + wrappedListener.onFailure(e); + })); + }, ex -> { + log.error("Failed to init model index", ex); + wrappedListener.onFailure(ex); })); - }, ex -> { - log.error("Failed to init model index", ex); - wrappedListener.onFailure(ex); - })); - } - }, e -> { - logException("Failed to validate model access", e, log); - wrappedListener.onFailure(e); - })); + } + }, e -> { + logException("Failed to validate model access", e, log); + wrappedListener.onFailure(e); + }) + ); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); wrappedListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java index ec3b67949c..7eafc5b78b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaAction.java @@ -12,6 +12,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLTaskState; @@ -37,6 +38,7 @@ public class TransportRegisterModelMetaAction extends HandledTransportAction<Act ActionFilters actionFilters; MLModelManager mlModelManager; Client client; + Settings settings; ModelAccessControlHelper modelAccessControlHelper; MLModelGroupManager mlModelGroupManager; @@ -46,6 +48,7 @@ public TransportRegisterModelMetaAction( ActionFilters actionFilters, MLModelManager mlModelManager, Client client, + Settings settings, ModelAccessControlHelper modelAccessControlHelper, MLModelGroupManager mlModelGroupManager ) { @@ -54,6 +57,7 @@ public TransportRegisterModelMetaAction( this.actionFilters = actionFilters; this.mlModelManager = mlModelManager; this.client = client; + this.settings = settings; this.modelAccessControlHelper = modelAccessControlHelper; this.mlModelGroupManager = mlModelGroupManager; } @@ -92,30 +96,31 @@ private void checkUserAccess( ) { User user = RestActionUtils.getUserContext(client); - modelAccessControlHelper.validateModelGroupAccess(user, mlUploadInput.getModelGroupId(), client, ActionListener.wrap(access -> { - if (access) { - createModelGroup(mlUploadInput, listener); - return; - } - if (isModelNameAlreadyExisting) { - listener - .onFailure( - new IllegalArgumentException( - "The name {" - + mlUploadInput.getName() - + "} you provided is unavailable because it is used by another model group with id {" - + mlUploadInput.getModelGroupId() - + "} to which you do not have access. Please provide a different name." - ) - ); - } else { - log.error("You don't have permissions to perform this operation on this model."); - listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); - } - }, e -> { - logException("Failed to validate model access", e, log); - listener.onFailure(e); - })); + modelAccessControlHelper + .validateModelGroupAccess(user, mlUploadInput.getModelGroupId(), client, settings, ActionListener.wrap(access -> { + if (access) { + createModelGroup(mlUploadInput, listener); + return; + } + if (isModelNameAlreadyExisting) { + listener + .onFailure( + new IllegalArgumentException( + "The name {" + + mlUploadInput.getName() + + "} you provided is unavailable because it is used by another model group with id {" + + mlUploadInput.getModelGroupId() + + "} to which you do not have access. Please provide a different name." + ) + ); + } else { + log.error("You don't have permissions to perform this operation on this model."); + listener.onFailure(new IllegalArgumentException("You don't have permissions to perform this operation on this model.")); + } + }, e -> { + logException("Failed to validate model access", e, log); + listener.onFailure(e); + })); } private void createModelGroup(MLRegisterModelMetaInput mlUploadInput, ActionListener<MLRegisterModelMetaResponse> listener) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index ac2cfded6c..4cb4229f83 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -11,6 +11,8 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT; import java.util.Collections; import java.util.HashSet; @@ -19,6 +21,7 @@ import org.apache.lucene.search.join.ScoreMode; import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.cluster.service.ClusterService; @@ -28,6 +31,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; @@ -45,6 +49,7 @@ import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; @@ -54,6 +59,7 @@ import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; import org.opensearch.transport.client.Client; import com.google.common.collect.ImmutableList; @@ -84,9 +90,42 @@ public ModelAccessControlHelper(ClusterService clusterService, Settings settings RangeQueryBuilder.class ); + private boolean isResourceSharingFeatureEnabled(Settings settings) { + return isModelAccessControlEnabled() + && settings.getAsBoolean(OPENSEARCH_RESOURCE_SHARING_ENABLED, OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT); + } + // TODO Eventually remove this when all usages of it have been migrated to the SdkClient version - public void validateModelGroupAccess(User user, String modelGroupId, Client client, ActionListener<Boolean> listener) { - if (modelGroupId == null || isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { + public void validateModelGroupAccess( + User user, + String modelGroupId, + Client client, + Settings settings, + ActionListener<Boolean> listener + ) { + if (modelGroupId == null) { + listener.onResponse(true); + return; + } + boolean isResourceSharingFeatureEnabled = isResourceSharingFeatureEnabled(settings); + if (isResourceSharingFeatureEnabled) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, ActionListener.wrap(isAuthorized -> { + if (!isAuthorized) { + listener + .onFailure( + new OpenSearchStatusException( + "User " + user.getName() + " is not authorized to delete ml-model-group id: " + modelGroupId, + RestStatus.FORBIDDEN + ) + ); + return; + } + listener.onResponse(true); + }, listener::onFailure)); + return; + } + if (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)) { listener.onResponse(true); return; } @@ -132,11 +171,32 @@ public void validateModelGroupAccess( String modelGroupId, Client client, SdkClient sdkClient, + Settings settings, ActionListener<Boolean> listener ) { - if (modelGroupId == null - || (!mlFeatureEnabledSetting.isMultiTenancyEnabled() - && (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user)))) { + if (modelGroupId == null) { + listener.onResponse(true); + return; + } + boolean isResourceSharingFeatureEnabled = isResourceSharingFeatureEnabled(settings); + if (isResourceSharingFeatureEnabled) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + resourceSharingClient.verifyAccess(modelGroupId, ML_MODEL_GROUP_INDEX, ActionListener.wrap(isAuthorized -> { + if (!isAuthorized) { + listener + .onFailure( + new OpenSearchStatusException( + "User " + user.getName() + " is not authorized to delete ml-model-group id: " + modelGroupId, + RestStatus.FORBIDDEN + ) + ); + return; + } + listener.onResponse(true); + }, listener::onFailure)); + return; + } + if (!mlFeatureEnabledSetting.isMultiTenancyEnabled() && (isAdmin(user) || !isSecurityEnabledAndModelAccessControlEnabled(user))) { listener.onResponse(true); return; } @@ -313,7 +373,30 @@ public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuil return searchSourceBuilder; } - public SearchSourceBuilder createSearchSourceBuilder(User user) { + public SearchSourceBuilder createSearchSourceBuilder(User user, Settings settings) { + boolean isResourceSharingFeatureEnabled = isResourceSharingFeatureEnabled(settings); + // TODO: Remove this feature flag check once feature is GA, as it will be enabled by default + if (isResourceSharingFeatureEnabled) { + return addAccessibleModelGroupsFilter(new SearchSourceBuilder()); + } return addUserBackendRolesFilter(user, new SearchSourceBuilder()); } + + public SearchSourceBuilder addAccessibleModelGroupsFilter(SearchSourceBuilder searchSourceBuilder) { + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + + resourceSharingClient.getAccessibleResourceIds(ML_MODEL_GROUP_INDEX, ActionListener.wrap(modelGroupIds -> { + if (modelGroupIds.isEmpty()) { + // User has no access → return nothing + searchSourceBuilder.query(QueryBuilders.boolQuery().mustNot(QueryBuilders.matchAllQuery())); + } else { + // Restrict search strictly to these ids + searchSourceBuilder.query(QueryBuilders.termsQuery(MLModelGroup.MODEL_GROUP_ID_FIELD + ".keyword", modelGroupIds)); + } + }, failure -> { + // do nothing to the source or return empty set? + searchSourceBuilder.query(QueryBuilders.boolQuery().mustNot(QueryBuilders.matchAllQuery())); + })); + return searchSourceBuilder; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index 7e264d3347..1ba0fbe568 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -7,9 +7,16 @@ import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED; +import static org.opensearch.security.spi.resources.FeatureConfigConstants.OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT; +import static org.opensearch.security.spi.resources.ResourceAccessLevels.PLACE_HOLDER; import java.time.Instant; import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; @@ -19,6 +26,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.commons.authuser.User; @@ -34,6 +42,7 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; @@ -48,6 +57,10 @@ import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; +import org.opensearch.security.spi.resources.sharing.Recipient; +import org.opensearch.security.spi.resources.sharing.Recipients; +import org.opensearch.security.spi.resources.sharing.ShareWith; import org.opensearch.transport.client.Client; import lombok.extern.log4j.Log4j2; @@ -56,6 +69,7 @@ public class MLModelGroupManager { private final MLIndicesHandler mlIndicesHandler; private final Client client; + private final Settings settings; private final SdkClient sdkClient; ClusterService clusterService; @@ -66,6 +80,7 @@ public class MLModelGroupManager { public MLModelGroupManager( MLIndicesHandler mlIndicesHandler, Client client, + Settings settings, SdkClient sdkClient, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, @@ -73,6 +88,7 @@ public MLModelGroupManager( ) { this.mlIndicesHandler = mlIndicesHandler; this.client = client; + this.settings = settings; this.sdkClient = sdkClient; this.clusterService = clusterService; this.modelAccessControlHelper = modelAccessControlHelper; @@ -83,6 +99,11 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str try { String modelName = input.getName(); User user = RestActionUtils.getUserContext(client); + // Create a recipient sharing list + AtomicReference<Map<Recipient, Set<String>>> recipientMap = new AtomicReference<>(); + boolean isResourceSharingFeatureEnabled = ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED.get(settings) + && this.settings.getAsBoolean(OPENSEARCH_RESOURCE_SHARING_ENABLED, OPENSEARCH_RESOURCE_SHARING_ENABLED_DEFAULT); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener<String> wrappedListener = ActionListener.runBefore(listener, context::restore); validateUniqueModelGroupName(input.getName(), input.getTenantId(), ActionListener.wrap(modelGroups -> { @@ -101,12 +122,22 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str } else { MLModelGroup.MLModelGroupBuilder builder = MLModelGroup.builder(); MLModelGroup mlModelGroup; + + // TODO: Remove security-related entries from MLModelGroup builder if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { validateRequestForAccessControl(input, user); builder = builder.access(input.getModelAccessMode().getValue()); + if (Boolean.TRUE.equals(input.getIsAddAllBackendRoles())) { input.setBackendRoles(user.getBackendRoles()); + // share with current user's backend-roles + // TODO: check if resource should be shared with user's backend roles by default + recipientMap.set(Map.of(Recipient.BACKEND_ROLES, Set.copyOf(user.getBackendRoles()))); + } else { + // set to specified backend_roles + recipientMap.set(Map.of(Recipient.BACKEND_ROLES, Set.copyOf(input.getBackendRoles()))); } + mlModelGroup = builder .name(modelName) .description(input.getDescription()) @@ -118,6 +149,14 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str .build(); } else { validateSecurityDisabledOrModelAccessControlDisabled(input); + + // TODO: Check if following line is actually required since by default the model will be pass-through when sec + // is disabled + recipientMap + .set( + Map.of(Recipient.USERS, Set.of("*"), Recipient.ROLES, Set.of("*"), Recipient.BACKEND_ROLES, Set.of("*")) + ); + mlModelGroup = builder .name(modelName) .description(input.getDescription()) @@ -152,7 +191,40 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str indexResponse.getResult(), indexResponse.getId() ); - wrappedListener.onResponse(r.id()); + + // TODO: Remove this feature flag check once feature is GA, as it will be enabled by default + if (isResourceSharingFeatureEnabled) { + // Create an entry in resource-sharing index + String modelGroupId = indexResponse.getId(); + String modelGroupIndex = indexResponse.getIndex(); + ShareWith shareWith = new ShareWith( + Map.of(PLACE_HOLDER, new Recipients(recipientMap.get())) + ); + + ResourceSharingClient resourceSharingClient = ResourceSharingClientAccessor + .getInstance() + .getResourceSharingClient(); + + resourceSharingClient + .share( + modelGroupId, + modelGroupIndex, + shareWith, + ActionListener.wrap(resourceSharing -> { + log + .debug( + "Successfully shared ml-model-group: {} with entities: {}", + modelName, + recipientMap + ); + + wrappedListener.onResponse(r.id()); + }, listener::onFailure) + ); + } else { + wrappedListener.onResponse(r.id()); + } + } catch (Exception e) { wrappedListener.onFailure(e); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index dac65fb845..2ee5902a24 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -650,7 +650,7 @@ public Collection<Object> createComponents( mlFeatureEnabledSetting ); - mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper); + mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, settings, xContentRegistry, modelAccessControlHelper); MLTaskDispatcher mlTaskDispatcher = new MLTaskDispatcher(clusterService, client, settings, nodeHelper); mlTrainingTaskRunner = new MLTrainingTaskRunner( @@ -762,7 +762,7 @@ public Collection<Object> createComponents( MetricsCorrelation metricsCorrelation = new MetricsCorrelation(client, settings, clusterService); MLEngineClassLoader.register(FunctionName.METRICS_CORRELATION, metricsCorrelation); - MLSearchHandler mlSearchHandler = new MLSearchHandler(client, xContentRegistry, modelAccessControlHelper, clusterService); + MLSearchHandler mlSearchHandler = new MLSearchHandler(client, xContentRegistry, modelAccessControlHelper, clusterService, settings); MLModelAutoReDeployer mlModelAutoRedeployer = new MLModelAutoReDeployer( clusterService, client, diff --git a/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java b/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java new file mode 100644 index 0000000000..13405ead58 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.resources; + +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; + +import java.util.Set; + +import org.opensearch.ml.common.MLModelGroup; +import org.opensearch.ml.common.ResourceSharingClientAccessor; +import org.opensearch.security.spi.resources.ResourceProvider; +import org.opensearch.security.spi.resources.ResourceSharingExtension; +import org.opensearch.security.spi.resources.client.ResourceSharingClient; + +public class MLResourceSharingExtension implements ResourceSharingExtension { + @Override + public Set<ResourceProvider> getResourceProviders() { + return Set.of(new ResourceProvider(MLModelGroup.class.getCanonicalName(), ML_MODEL_GROUP_INDEX)); + } + + @Override + public void assignResourceSharingClient(ResourceSharingClient resourceSharingClient) { + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(resourceSharingClient); + } +} diff --git a/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension b/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension new file mode 100644 index 0000000000..00dfe98b5a --- /dev/null +++ b/plugin/src/main/resources/META-INF/services/org.opensearch.security.spi.resources.ResourceSharingExtension @@ -0,0 +1 @@ +org.opensearch.ml.resources.MLResourceSharingExtension diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java index 61c0282ac2..68238b30bc 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java @@ -149,6 +149,7 @@ public void setup() throws IOException { actionFilters, mlIndicesHandler, client, + Settings.EMPTY, clusterService, modelAccessControlHelper, mlModelCacheHelper, @@ -171,7 +172,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<MLModel> listener = invocation.getArgument(3); @@ -236,7 +237,7 @@ public void testCreateControllerWithModelAccessControlNoPermission() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -253,7 +254,7 @@ public void testCreateControllerWithModelAccessControlOtherException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java index 52bbfdad3d..403a7e806d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java @@ -147,6 +147,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + Settings.EMPTY, xContentRegistry, clusterService, mlModelManager, @@ -162,7 +163,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<MLModel> listener = invocation.getArgument(3); @@ -216,7 +217,7 @@ public void testDeleteControllerWithModelAccessControlNoPermission() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -241,7 +242,7 @@ public void testDeleteControllerWithModelAccessControlNoPermissionHiddenModel() ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -258,7 +259,7 @@ public void testDeleteControllerWithModelAccessControlOtherException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -283,7 +284,7 @@ public void testDeleteControllerWithModelAccessControlOtherExceptionHiddenModel( new RuntimeException("Permission denied: Unable to delete the model controller with the provided model. Details: ") ); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java index f414572b02..462a6fe739 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java @@ -103,6 +103,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, xContentRegistry, clusterService, mlModelManager, @@ -122,7 +123,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareControllerGetResponse(); doAnswer(invocation -> { @@ -160,7 +161,7 @@ public void testGetControllerWithModelAccessControlNoPermission() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -177,7 +178,7 @@ public void testGetControllerWithModelAccessControlOtherException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java index 2bdef9c022..13c958d82d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java @@ -151,6 +151,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + Settings.EMPTY, clusterService, modelAccessControlHelper, mlModelCacheHelper, @@ -181,7 +182,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<MLModel> listener = invocation.getArgument(3); @@ -246,7 +247,7 @@ public void testUpdateControllerWithModelAccessControlNoPermission() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -271,7 +272,7 @@ public void testUpdateControllerWithModelAccessControlNoPermissionHiddenModel() ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -288,7 +289,7 @@ public void testUpdateControllerWithModelAccessControlOtherException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -310,7 +311,7 @@ public void testUpdateControllerWithModelAccessControlOtherExceptionHiddenModel( ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Permission denied: Unable to create the model controller for the model. Details: ")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index deb8c054af..f618b85b0a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -177,7 +177,7 @@ public void setup() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true); @@ -358,7 +358,7 @@ public void testDoExecute_userHasNoAccessException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); @@ -414,7 +414,7 @@ public void test_ValidationFailedException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class); transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java index 0bf67454b9..83cf3a12b5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportActionTests.java @@ -106,6 +106,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, sdkClient, xContentRegistry, clusterService, @@ -118,7 +119,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -229,7 +230,7 @@ public void test_UserHasNoAccessException() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -243,7 +244,7 @@ public void test_ValidationFailedException() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); deleteModelGroupTransportAction.doExecute(null, mlModelGroupDeleteRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java index aa2ceb20ce..1a54ee2faf 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/GetModelGroupTransportActionTests.java @@ -97,6 +97,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, sdkClient, xContentRegistry, clusterService, @@ -109,7 +110,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -135,7 +136,7 @@ public void testGetModel_UserHasNoAccess() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModelGroup(); doAnswer(invocation -> { @@ -155,7 +156,7 @@ public void testGetModel_ValidateAccessFailed() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModelGroup(); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java index f6aac6ec92..f97a4b622e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportActionTests.java @@ -97,6 +97,7 @@ public void setup() { transportService, actionFilters, client, + Settings.EMPTY, sdkClient, clusterService, modelAccessControlHelper, diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index c62716d793..839dfa9d55 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -120,6 +120,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, sdkClient, xContentRegistry, clusterService, diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index 2f91a1837c..fb4c420183 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -190,7 +190,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(clusterService.getSettings()).thenReturn(settings); @@ -420,7 +420,7 @@ public void test_UserHasNoAccessException() throws IOException, InterruptedExcep ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); @@ -501,7 +501,7 @@ public void test_ValidationFailedException() throws IOException, InterruptedExce ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java index e534e26505..9b0cc019a4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/GetModelTransportActionTests.java @@ -117,7 +117,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); @@ -137,7 +137,7 @@ public void testGetModel_UserHasNodeAccess() throws IOException, InterruptedExce ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); getModelTransportAction.doExecute(null, mlModelGetRequest, actionListener); @@ -199,7 +199,7 @@ public void testGetModel_ValidateAccessFailed() throws IOException, InterruptedE ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLModel(false); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index d1a0279bb6..68999b0a50 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -114,7 +114,7 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { public void setup() { MockitoAnnotations.openMocks(this); sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService)); + mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService, Settings.EMPTY)); searchModelTransportAction = new SearchModelTransportAction( transportService, actionFilters, @@ -184,7 +184,7 @@ public void test_DoExecute_addBackendRoles() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); @@ -196,7 +196,7 @@ public void test_DoExecute_addBackendRoles_without_groupIds() { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); @@ -208,7 +208,7 @@ public void test_DoExecute_addBackendRoles_exception() { listener.onFailure(new RuntimeException("runtime exception")); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(1)).search(any(), any()); @@ -281,7 +281,7 @@ public void test_DoExecute_addBackendRoles_boolQuery() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.matchQuery("name", "model_IT"))); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); @@ -295,7 +295,7 @@ public void test_DoExecute_addBackendRoles_termQuery() throws IOException { listener.onResponse(searchResponse); return null; }).when(client).search(any(), isA(ActionListener.class)); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); @@ -330,7 +330,7 @@ public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws In return null; }).when(client).search(any(), any()); - when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); + when(modelAccessControlHelper.createSearchSourceBuilder(any(), Settings.EMPTY)).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); mlSearchActionRequest = new MLSearchActionRequest(searchRequest, "123456"); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 6d7e3ea9a9..cc4d88deb5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -296,7 +296,9 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), eq("test_model_group_id"), any(), isA(ActionListener.class)); + }) + .when(modelAccessControlHelper) + .validateModelGroupAccess(any(), eq("test_model_group_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener<Boolean> listener = invocation.getArgument(6); @@ -311,6 +313,7 @@ public void setup() throws IOException { eq("test_model_group_id"), any(), any(SdkClient.class), + any(), isA(ActionListener.class) ); @@ -321,7 +324,7 @@ public void setup() throws IOException { return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener<Boolean> listener = invocation.getArgument(6); @@ -336,6 +339,7 @@ public void setup() throws IOException { eq("updated_test_model_group_id"), any(), any(SdkClient.class), + any(), isA(ActionListener.class) ); @@ -602,7 +606,7 @@ public void testUpdateModelWithModelAccessControlNoPermission() throws Interrupt return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), any(), any(), any(), any(), any(SdkClient.class), isA(ActionListener.class)); + .validateModelGroupAccess(any(), any(), any(), any(), any(), any(SdkClient.class), any(), isA(ActionListener.class)); CountDownLatch latch = new CountDownLatch(1); LatchedActionListener<UpdateResponse> latchedActionListener = new LatchedActionListener<>(actionListener, latch); @@ -628,7 +632,7 @@ public void testUpdateModelWithModelAccessControlOtherException() { ) ); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), isA(ActionListener.class)); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -647,7 +651,7 @@ public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlNoPermis return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -671,7 +675,7 @@ public void testUpdateModelWithRegisterToNewModelGroupModelAccessControlOtherExc return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), isA(ActionListener.class)); + .validateModelGroupAccess(any(), eq("updated_test_model_group_id"), any(), any(), isA(ActionListener.class)); transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -831,7 +835,16 @@ public void testUpdateRequestDocInRegisterToNewModelGroupIOException() throws IO return null; }) .when(modelAccessControlHelper) - .validateModelGroupAccess(any(), any(), any(), eq("mockUpdateModelGroupId"), any(), eq(sdkClient), isA(ActionListener.class)); + .validateModelGroupAccess( + any(), + any(), + any(), + eq("mockUpdateModelGroupId"), + any(), + eq(sdkClient), + any(), + isA(ActionListener.class) + ); MLModelGroup modelGroup = MLModelGroup .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index 9b1036f731..2539afa897 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -170,7 +170,7 @@ public void testPrediction_default_exception() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null); @@ -209,7 +209,7 @@ public void testPrediction_OpenSearchStatusException() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onFailure(new OpenSearchStatusException("Testing OpenSearchStatusException", RestStatus.BAD_REQUEST)); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null); @@ -232,7 +232,7 @@ public void testPrediction_MLResourceNotFoundException() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onFailure(new MLResourceNotFoundException("Testing MLResourceNotFoundException")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null); @@ -255,7 +255,7 @@ public void testPrediction_MLLimitExceededException() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT)); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index b0e290a693..dcff39c59c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -212,7 +212,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLStat mlStat = mock(MLStat.class); when(mlStats.getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT))).thenReturn(mlStat); @@ -292,7 +292,7 @@ public void testDoExecute_userHasNoAccessException() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("test url", "testModelGroupsID"), actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -456,7 +456,7 @@ public void test_ValidationFailedException() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("http://test_url", "modelGroupID"), actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -706,7 +706,7 @@ public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLRegisterModelInput registerModelInput = MLRegisterModelInput .builder() @@ -754,7 +754,7 @@ public void test_NoAccessWhenModelNameAlreadyExists() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); transportRegisterModelAction.doExecute(task, prepareRequest("Test URL", null), actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java index 88ae44c70b..1d78899439 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java @@ -140,6 +140,7 @@ public void setup() throws IOException { transportService, actionFilters, client, + settings, xContentRegistry, clusterService, scriptService, @@ -188,7 +189,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<Connector> listener = invocation.getArgument(2); @@ -289,7 +290,7 @@ public void test_BatchPredictCancel_NoModelGroupAccess() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index ca511306a4..616eedb0e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -247,7 +247,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<Connector> listener = invocation.getArgument(2); @@ -337,7 +337,7 @@ public void test_BatchPredictStatus_NoModelGroupAccess() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -362,7 +362,7 @@ public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); @@ -391,7 +391,7 @@ public void test_BatchPredictStatus_NoConnectorFound() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<Connector> listener = invocation.getArgument(2); @@ -422,7 +422,7 @@ public void test_BatchPredictStatus_NoModel() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<Connector> listener = invocation.getArgument(2); diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 72af11264d..55dfebd100 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -448,7 +448,7 @@ public void testDoExecute() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); List<MLUndeployModelNodeResponse> responseList = new ArrayList<>(); List<FailedNodeException> failuresList = new ArrayList<>(); @@ -479,7 +479,7 @@ public void testDoExecute_modelAccessControl_notEnabled() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); MLUndeployModelsResponse mlUndeployModelsResponse = new MLUndeployModelsResponse(mock(MLUndeployModelNodesResponse.class)); doAnswer(invocation -> { @@ -497,7 +497,7 @@ public void testDoExecute_validate_false() { ActionListener<Boolean> listener = invocation.getArgument(6); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<MLUndeployModelsResponse> listener = invocation.getArgument(2); diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java index 8375ae5fca..f6730459ad 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploaderTests.java @@ -95,7 +95,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<IndexResponse> listener = invocation.getArgument(1); @@ -117,7 +117,7 @@ public void setup() throws IOException { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); - mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, xContentRegistry, modelAccessControlHelper); + mlModelChunkUploader = new MLModelChunkUploader(mlIndicesHandler, client, settings, xContentRegistry, modelAccessControlHelper); MLModel mlModel = MLModel .builder() @@ -184,7 +184,7 @@ public void testDoExecute_userHasNoAccessException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); MLUploadModelChunkInput uploadModelChunkInput = prepareRequest(); uploadModelChunkInput.setChunkNumber(1); diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java index 7c04103a0e..2597bdbf1f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java @@ -86,6 +86,7 @@ public void setup() throws IOException { actionFilters, mlModelManager, client, + settings, modelAccessControlHelper, mlModelGroupManager ); @@ -94,7 +95,7 @@ public void setup() throws IOException { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(true); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); doAnswer(invocation -> { ActionListener<String> listener = invocation.getArgument(1); @@ -163,7 +164,7 @@ public void testDoExecute_userHasNoAccessException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); @@ -180,7 +181,7 @@ public void test_ValidationFailedException() { ActionListener<Boolean> listener = invocation.getArgument(3); listener.onFailure(new Exception("Failed to validate access")); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); @@ -213,7 +214,7 @@ public void testDoExecute_NoAccessWhenModelNameAlreadyExists() throws IOExceptio ActionListener<Boolean> listener = invocation.getArgument(3); listener.onResponse(false); return null; - }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any(), any()); SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java index 5083211d91..e31d556e87 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java @@ -114,14 +114,15 @@ public void setupModelGroup(String owner, String access, List<String> backendRol // TODO Remove when all calls are migrated to SdkClient version public void test_UndefinedModelGroupID_NoSdkClient() { - modelAccessControlHelper.validateModelGroupAccess(null, null, client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(null, null, client, Settings.EMPTY, actionListener); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); } public void test_UndefinedModelGroupID() { - modelAccessControlHelper.validateModelGroupAccess(null, mlFeatureEnabledSetting, null, null, client, sdkClient, actionListener); + modelAccessControlHelper + .validateModelGroupAccess(null, mlFeatureEnabledSetting, null, null, client, sdkClient, Settings.EMPTY, actionListener); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -130,7 +131,7 @@ public void test_UndefinedModelGroupID() { // TODO Remove when all calls are migrated to SdkClient version public void test_UndefinedOwner_NoSdkClient() throws IOException { getResponse = modelGroupBuilder(null, null, null); - modelAccessControlHelper.validateModelGroupAccess(null, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(null, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -139,7 +140,16 @@ public void test_UndefinedOwner_NoSdkClient() throws IOException { public void test_UndefinedOwner() throws IOException { getResponse = modelGroupBuilder(null, null, null); modelAccessControlHelper - .validateModelGroupAccess(null, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, actionListener); + .validateModelGroupAccess( + null, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + actionListener + ); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -150,7 +160,7 @@ public void test_ExceptionEmptyBackendRoles_NoSdkClient() throws IOException { String owner = "owner|IT,HR|myTenant"; User user = User.parse("owner|IT,HR|myTenant"); getResponse = modelGroupBuilder(null, AccessMode.RESTRICTED.getValue(), owner); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Backend roles shouldn't be null", argumentCaptor.getValue().getMessage()); @@ -168,7 +178,16 @@ public void test_ExceptionEmptyBackendRoles() throws IOException, InterruptedExc CountDownLatch latch = new CountDownLatch(1); LatchedActionListener<Boolean> latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -182,7 +201,7 @@ public void test_MatchingBackendRoles_NoSdkClient() throws IOException { List<String> backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.RESTRICTED.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -201,7 +220,16 @@ public void test_MatchingBackendRoles() throws IOException, InterruptedException CountDownLatch latch = new CountDownLatch(1); LatchedActionListener<Boolean> latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -215,7 +243,7 @@ public void test_PublicModelGroup_NoSdkClient() throws IOException { List<String> backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PUBLIC.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -234,7 +262,16 @@ public void test_PublicModelGroup() throws IOException, InterruptedException { CountDownLatch latch = new CountDownLatch(1); LatchedActionListener<Boolean> latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -248,7 +285,7 @@ public void test_PrivateModelGroupWithSameOwner_NoSdkClient() throws IOException List<String> backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); User user = User.parse("owner|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertTrue(argumentCaptor.getValue()); @@ -267,7 +304,16 @@ public void test_PrivateModelGroupWithSameOwner() throws IOException, Interrupte CountDownLatch latch = new CountDownLatch(1); LatchedActionListener<Boolean> latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -281,7 +327,7 @@ public void test_PrivateModelGroupWithDifferentOwner_NoSdkClient() throws IOExce List<String> backendRoles = Arrays.asList("IT", "HR"); setupModelGroup(owner, AccessMode.PRIVATE.getValue(), backendRoles); User user = User.parse("user|IT,HR|myTenant"); - modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, actionListener); + modelAccessControlHelper.validateModelGroupAccess(user, "testGroupID", client, Settings.EMPTY, actionListener); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); verify(actionListener).onResponse(argumentCaptor.capture()); assertFalse(argumentCaptor.getValue()); @@ -300,7 +346,16 @@ public void test_PrivateModelGroupWithDifferentOwner() throws IOException, Inter CountDownLatch latch = new CountDownLatch(1); LatchedActionListener<Boolean> latchedActionListener = new LatchedActionListener<>(actionListener, latch); modelAccessControlHelper - .validateModelGroupAccess(user, mlFeatureEnabledSetting, null, "testGroupID", client, sdkClient, latchedActionListener); + .validateModelGroupAccess( + user, + mlFeatureEnabledSetting, + null, + "testGroupID", + client, + sdkClient, + Settings.EMPTY, + latchedActionListener + ); latch.await(500, TimeUnit.MILLISECONDS); ArgumentCaptor<Boolean> argumentCaptor = ArgumentCaptor.forClass(Boolean.class); @@ -415,7 +470,7 @@ public void test_AddUserBackendRolesFilter() { public void test_CreateSearchSourceBuilder() { User user = User.parse("owner|IT,HR|myTenant"); - assertNotNull(modelAccessControlHelper.createSearchSourceBuilder(user)); + assertNotNull(modelAccessControlHelper.createSearchSourceBuilder(user, Settings.EMPTY)); } private GetResponse modelGroupBuilder(List<String> backendRoles, String access, String owner) throws IOException { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 36ecd569b0..03dd82524f 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -112,6 +112,7 @@ public void setup() throws IOException { mlModelGroupManager = new MLModelGroupManager( mlIndicesHandler, client, + settings, sdkClient, clusterService, modelAccessControlHelper,