Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/129090.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 129090
summary: Enable force inference endpoint deleting for invalid models and after stopping
model deployment fails
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
Expand Down Expand Up @@ -128,10 +129,38 @@ private void doExecuteForked(
}

var service = serviceRegistry.getService(unparsedModel.service());
Model model;
if (service.isPresent()) {
var model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
service.get().stop(model, listener);
try {
model = service.get()
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
} catch (Exception e) {
if (request.isForceDelete()) {
listener.onResponse(true);
return;
} else {
listener.onFailure(
new ElasticsearchStatusException(
Strings.format(
"Failed to parse model configuration for inference endpoint [%s]",
request.getInferenceEndpointId()
),
RestStatus.INTERNAL_SERVER_ERROR,
e
)
);
return;
}
}
service.get().stop(model, listener.delegateResponse((l, e) -> {
if (request.isForceDelete()) {
l.onResponse(true);
} else {
l.onFailure(e);
}
}));
} else if (request.isForceDelete()) {
listener.onResponse(true);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -32,11 +34,17 @@
import java.util.Optional;

import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
Expand Down Expand Up @@ -130,4 +138,213 @@ public void testDeletesDefaultEndpoint_WhenForceIsTrue() {

assertTrue(response.isAcknowledged());
}

public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
var mockService = mock(InferenceService.class);
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);

var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
ClusterState.EMPTY_STATE,
listener
);

var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Failed to parse model configuration for inference endpoint"));

verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
}

public void testDeletesUnparsableEndpoint_WhenForceIsTrue() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
var mockService = mock(InferenceService.class);
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
listener.onResponse(true);
return Void.TYPE;
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());

var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();

action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
ClusterState.EMPTY_STATE,
listener
);

var response = listener.actionGet(TIMEOUT);
assertTrue(response.isAcknowledged());

verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
}

private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) {
doAnswer(invocationOnMock -> {
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
return Void.TYPE;
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService)
.parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
}

public void testDeletesEndpointWithNoService_WhenForceIsTrue() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
mockNoService(inferenceEndpointId, serviceName, taskType);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
listener.onResponse(true);
return Void.TYPE;
}).when(mockModelRegistry).deleteModel(anyString(), any());

var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();

action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
ClusterState.EMPTY_STATE,
listener
);

var response = listener.actionGet(TIMEOUT);
assertTrue(response.isAcknowledged());
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
}

public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
mockNoService(inferenceEndpointId, serviceName, taskType);
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);

var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();

action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
ClusterState.EMPTY_STATE,
listener
);

var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("No service found for this inference endpoint"));
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
}

private void mockNoService(String inferenceEndpointId, String serviceName, TaskType taskType) {
doAnswer(invocationOnMock -> {
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
return Void.TYPE;
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.empty());
}

public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
var mockService = mock(InferenceService.class);
var mockModel = mock(Model.class);
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);

var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
ClusterState.EMPTY_STATE,
listener
);

var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), containsString("Failed to stop model deployment"));
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
verify(mockService).stop(eq(mockModel), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
}

public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() {
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
var serviceName = randomAlphanumericOfLength(10);
var taskType = randomFrom(TaskType.values());
var mockService = mock(InferenceService.class);
var mockModel = mock(Model.class);
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
doAnswer(invocationOnMock -> {
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
listener.onResponse(true);
return Void.TYPE;
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());

var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
action.masterOperation(
mock(Task.class),
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
ClusterState.EMPTY_STATE,
listener
);

var response = listener.actionGet(TIMEOUT);
assertTrue(response.isAcknowledged());
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
verify(mockService).stop(eq(mockModel), any());
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
}

private void mockStopDeploymentFails(
String inferenceEndpointId,
String serviceName,
TaskType taskType,
InferenceService mockService,
Model mockModel
) {
doAnswer(invocationOnMock -> {
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
return Void.TYPE;
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
doAnswer(invocationOnMock -> {
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR));
return Void.TYPE;
}).when(mockService).stop(eq(mockModel), any());
}

}
Loading