Skip to content

Commit 0fd6bae

Browse files
committed
Merge remote-tracking branch 'upstream/main' into index-pipeline-tracking
* upstream/main: Mark watcher NotMultiProjectCapable and replace deprecated multi-project methods (elastic#131313) Enable force inference endpoint deleting for invalid models and after stopping model deployment fails (elastic#129090) [ML] Remove SageMaker Elastic updates (elastic#131301) Refactor AsyncSearchErrorTraceIT to use assertBusy (elastic#131328)
2 parents f8b7217 + 38b7bfc commit 0fd6bae

File tree

25 files changed

+589
-233
lines changed

25 files changed

+589
-233
lines changed

docs/changelog/129090.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 129090
2+
summary: Enable force inference endpoint deleting for invalid models and after stopping
3+
model deployment fails
4+
area: Machine Learning
5+
type: enhancement
6+
issues: []

x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchErrorTraceIT.java

Lines changed: 84 additions & 86 deletions
Large diffs are not rendered by default.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.common.Strings;
2424
import org.elasticsearch.common.util.concurrent.EsExecutors;
2525
import org.elasticsearch.inference.InferenceServiceRegistry;
26+
import org.elasticsearch.inference.Model;
2627
import org.elasticsearch.inference.UnparsedModel;
2728
import org.elasticsearch.injection.guice.Inject;
2829
import org.elasticsearch.rest.RestStatus;
@@ -128,10 +129,38 @@ private void doExecuteForked(
128129
}
129130

130131
var service = serviceRegistry.getService(unparsedModel.service());
132+
Model model;
131133
if (service.isPresent()) {
132-
var model = service.get()
133-
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
134-
service.get().stop(model, listener);
134+
try {
135+
model = service.get()
136+
.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
137+
} catch (Exception e) {
138+
if (request.isForceDelete()) {
139+
listener.onResponse(true);
140+
return;
141+
} else {
142+
listener.onFailure(
143+
new ElasticsearchStatusException(
144+
Strings.format(
145+
"Failed to parse model configuration for inference endpoint [%s]",
146+
request.getInferenceEndpointId()
147+
),
148+
RestStatus.INTERNAL_SERVER_ERROR,
149+
e
150+
)
151+
);
152+
return;
153+
}
154+
}
155+
service.get().stop(model, listener.delegateResponse((l, e) -> {
156+
if (request.isForceDelete()) {
157+
l.onResponse(true);
158+
} else {
159+
l.onFailure(e);
160+
}
161+
}));
162+
} else if (request.isForceDelete()) {
163+
listener.onResponse(true);
135164
} else {
136165
listener.onFailure(
137166
new ElasticsearchStatusException(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public SageMakerModel override(Map<String, Object> taskSettingsOverride) {
116116
getConfigurations(),
117117
getSecrets(),
118118
serviceSettings,
119-
taskSettings.updatedTaskSettings(taskSettingsOverride),
119+
taskSettings.override(taskSettingsOverride),
120120
awsSecretSettings
121121
);
122122
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,21 @@ public boolean isEmpty() {
7171
@Override
7272
public SageMakerTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
7373
var validationException = new ValidationException();
74-
7574
var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException);
75+
validationException.throwIfValidationErrorsExist();
76+
77+
return override(updateTaskSettings);
78+
}
7679

80+
public SageMakerTaskSettings override(Map<String, Object> newSettings) {
81+
var validationException = new ValidationException();
82+
var updateTaskSettings = fromMap(newSettings, apiTaskSettings.override(newSettings), validationException);
7783
validationException.throwIfValidationErrorsExist();
7884

85+
return override(updateTaskSettings);
86+
}
87+
88+
private SageMakerTaskSettings override(SageMakerTaskSettings updateTaskSettings) {
7989
var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP)
8090
? apiTaskSettings
8191
: updateTaskSettings.apiTaskSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,8 @@ default boolean isFragment() {
6868

6969
@Override
7070
SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings);
71+
72+
default SageMakerStoredTaskSchema override(Map<String, Object> newSettings) {
73+
return updatedTaskSettings(newSettings);
74+
}
7175
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ default SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest re
8888

8989
@Override
9090
default SageMakerElasticTaskSettings apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
91-
if (taskSettings != null && (taskSettings.isEmpty() == false)) {
92-
validationException.addValidationError(
93-
InferenceAction.Request.TASK_SETTINGS.getPreferredName()
94-
+ " is only supported during the inference request and cannot be stored in the inference endpoint."
95-
);
96-
}
9791
return SageMakerElasticTaskSettings.empty();
9892
}
9993

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import org.elasticsearch.TransportVersion;
1111
import org.elasticsearch.TransportVersions;
12+
import org.elasticsearch.common.ValidationException;
1213
import org.elasticsearch.common.io.stream.StreamInput;
1314
import org.elasticsearch.common.io.stream.StreamOutput;
1415
import org.elasticsearch.core.Nullable;
1516
import org.elasticsearch.xcontent.XContentBuilder;
17+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1618
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
1719

1820
import java.io.IOException;
@@ -40,6 +42,16 @@ public boolean isEmpty() {
4042

4143
@Override
4244
public SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings) {
45+
var validationException = new ValidationException();
46+
validationException.addValidationError(
47+
InferenceAction.Request.TASK_SETTINGS.getPreferredName()
48+
+ " is only supported during the inference request and cannot be stored in the inference endpoint."
49+
);
50+
throw validationException;
51+
}
52+
53+
@Override
54+
public SageMakerStoredTaskSchema override(Map<String, Object> newSettings) {
4355
return new SageMakerElasticTaskSettings(newSettings);
4456
}
4557

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import org.elasticsearch.core.TimeValue;
1818
import org.elasticsearch.inference.InferenceService;
1919
import org.elasticsearch.inference.InferenceServiceRegistry;
20+
import org.elasticsearch.inference.Model;
2021
import org.elasticsearch.inference.TaskType;
2122
import org.elasticsearch.inference.UnparsedModel;
23+
import org.elasticsearch.rest.RestStatus;
2224
import org.elasticsearch.tasks.Task;
2325
import org.elasticsearch.test.ESTestCase;
2426
import org.elasticsearch.threadpool.ThreadPool;
@@ -32,11 +34,17 @@
3234
import java.util.Optional;
3335

3436
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
37+
import static org.hamcrest.Matchers.containsString;
3538
import static org.hamcrest.Matchers.is;
3639
import static org.mockito.ArgumentMatchers.any;
3740
import static org.mockito.ArgumentMatchers.anyString;
41+
import static org.mockito.ArgumentMatchers.eq;
3842
import static org.mockito.Mockito.doAnswer;
43+
import static org.mockito.Mockito.doReturn;
44+
import static org.mockito.Mockito.doThrow;
3945
import static org.mockito.Mockito.mock;
46+
import static org.mockito.Mockito.verify;
47+
import static org.mockito.Mockito.verifyNoMoreInteractions;
4048
import static org.mockito.Mockito.when;
4149

4250
public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
@@ -130,4 +138,213 @@ public void testDeletesDefaultEndpoint_WhenForceIsTrue() {
130138

131139
assertTrue(response.isAcknowledged());
132140
}
141+
142+
public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() {
143+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
144+
var serviceName = randomAlphanumericOfLength(10);
145+
var taskType = randomFrom(TaskType.values());
146+
var mockService = mock(InferenceService.class);
147+
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
148+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
149+
150+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
151+
action.masterOperation(
152+
mock(Task.class),
153+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
154+
ClusterState.EMPTY_STATE,
155+
listener
156+
);
157+
158+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
159+
assertThat(exception.getMessage(), containsString("Failed to parse model configuration for inference endpoint"));
160+
161+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
162+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
163+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
164+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
165+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
166+
}
167+
168+
public void testDeletesUnparsableEndpoint_WhenForceIsTrue() {
169+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
170+
var serviceName = randomAlphanumericOfLength(10);
171+
var taskType = randomFrom(TaskType.values());
172+
var mockService = mock(InferenceService.class);
173+
mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
174+
doAnswer(invocationOnMock -> {
175+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
176+
listener.onResponse(true);
177+
return Void.TYPE;
178+
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
179+
180+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
181+
182+
action.masterOperation(
183+
mock(Task.class),
184+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
185+
ClusterState.EMPTY_STATE,
186+
listener
187+
);
188+
189+
var response = listener.actionGet(TIMEOUT);
190+
assertTrue(response.isAcknowledged());
191+
192+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
193+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
194+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
195+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
196+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
197+
}
198+
199+
private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) {
200+
doAnswer(invocationOnMock -> {
201+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
202+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
203+
return Void.TYPE;
204+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
205+
doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService)
206+
.parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
207+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
208+
}
209+
210+
public void testDeletesEndpointWithNoService_WhenForceIsTrue() {
211+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
212+
var serviceName = randomAlphanumericOfLength(10);
213+
var taskType = randomFrom(TaskType.values());
214+
mockNoService(inferenceEndpointId, serviceName, taskType);
215+
doAnswer(invocationOnMock -> {
216+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
217+
listener.onResponse(true);
218+
return Void.TYPE;
219+
}).when(mockModelRegistry).deleteModel(anyString(), any());
220+
221+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
222+
223+
action.masterOperation(
224+
mock(Task.class),
225+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
226+
ClusterState.EMPTY_STATE,
227+
listener
228+
);
229+
230+
var response = listener.actionGet(TIMEOUT);
231+
assertTrue(response.isAcknowledged());
232+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
233+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
234+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
235+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
236+
}
237+
238+
public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() {
239+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
240+
var serviceName = randomAlphanumericOfLength(10);
241+
var taskType = randomFrom(TaskType.values());
242+
mockNoService(inferenceEndpointId, serviceName, taskType);
243+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
244+
245+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
246+
247+
action.masterOperation(
248+
mock(Task.class),
249+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
250+
ClusterState.EMPTY_STATE,
251+
listener
252+
);
253+
254+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
255+
assertThat(exception.getMessage(), containsString("No service found for this inference endpoint"));
256+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
257+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
258+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
259+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
260+
}
261+
262+
private void mockNoService(String inferenceEndpointId, String serviceName, TaskType taskType) {
263+
doAnswer(invocationOnMock -> {
264+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
265+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
266+
return Void.TYPE;
267+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
268+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.empty());
269+
}
270+
271+
public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse() {
272+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
273+
var serviceName = randomAlphanumericOfLength(10);
274+
var taskType = randomFrom(TaskType.values());
275+
var mockService = mock(InferenceService.class);
276+
var mockModel = mock(Model.class);
277+
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
278+
when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
279+
280+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
281+
action.masterOperation(
282+
mock(Task.class),
283+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
284+
ClusterState.EMPTY_STATE,
285+
listener
286+
);
287+
288+
var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
289+
assertThat(exception.getMessage(), containsString("Failed to stop model deployment"));
290+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
291+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
292+
verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
293+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
294+
verify(mockService).stop(eq(mockModel), any());
295+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
296+
}
297+
298+
public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() {
299+
var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
300+
var serviceName = randomAlphanumericOfLength(10);
301+
var taskType = randomFrom(TaskType.values());
302+
var mockService = mock(InferenceService.class);
303+
var mockModel = mock(Model.class);
304+
mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
305+
doAnswer(invocationOnMock -> {
306+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
307+
listener.onResponse(true);
308+
return Void.TYPE;
309+
}).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
310+
311+
var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
312+
action.masterOperation(
313+
mock(Task.class),
314+
new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
315+
ClusterState.EMPTY_STATE,
316+
listener
317+
);
318+
319+
var response = listener.actionGet(TIMEOUT);
320+
assertTrue(response.isAcknowledged());
321+
verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
322+
verify(mockInferenceServiceRegistry).getService(eq(serviceName));
323+
verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
324+
verify(mockService).stop(eq(mockModel), any());
325+
verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
326+
verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
327+
}
328+
329+
private void mockStopDeploymentFails(
330+
String inferenceEndpointId,
331+
String serviceName,
332+
TaskType taskType,
333+
InferenceService mockService,
334+
Model mockModel
335+
) {
336+
doAnswer(invocationOnMock -> {
337+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
338+
listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
339+
return Void.TYPE;
340+
}).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
341+
when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
342+
doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
343+
doAnswer(invocationOnMock -> {
344+
ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
345+
listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR));
346+
return Void.TYPE;
347+
}).when(mockService).stop(eq(mockModel), any());
348+
}
349+
133350
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public final void testWithUnknownApiTaskSettings() {
119119
}
120120
}
121121

122-
public final void testUpdate() throws IOException {
122+
public void testUpdate() throws IOException {
123123
var taskSettings = randomApiTaskSettings();
124124
if (taskSettings != SageMakerStoredTaskSchema.NO_OP) {
125125
var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);

0 commit comments

Comments
 (0)