From 936150cb5986fc07d492de7535bb9f0bf2a749aa Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 10 Apr 2025 14:14:59 +0100 Subject: [PATCH 01/12] Basic v2 classes --- .../CohereCompletionRequestManager.java | 2 +- .../request/CohereCompletionRequest.java | 89 +++++++++++++++++ .../CohereCompletionRequestEntity.java | 2 +- .../CohereCompletionRequest.java | 3 +- .../v2/CohereCompletionRequestEntity.java | 47 +++++++++ .../v2/CohereEmbeddingsRequestEntity.java | 86 ++++++++++++++++ .../cohere/request/v2/CohereRequest.java | 27 +++++ .../request/v2/CohereRerankRequest.java | 99 +++++++++++++++++++ .../request/v2/CohereRerankRequestEntity.java | 79 +++++++++++++++ .../request/v2/CohereV2EmbeddingsRequest.java | 92 +++++++++++++++++ .../CohereCompletionRequestEntityTests.java | 1 - .../request/CohereCompletionRequestTests.java | 1 - 12 files changed, 522 insertions(+), 6 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequest.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/{completion => }/CohereCompletionRequestEntity.java (99%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/{completion => v2}/CohereCompletionRequest.java (96%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java index 2c6b4beb80c5b..3c63031f633a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java @@ -18,7 +18,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import org.elasticsearch.xpack.inference.services.cohere.request.completion.CohereCompletionRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereCompletionRequest; import org.elasticsearch.xpack.inference.services.cohere.response.CohereCompletionResponseEntity; import java.util.Objects; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequest.java new file mode 100644 index 0000000000000..b52601396570e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequest.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class CohereCompletionRequest extends CohereRequest { + private final CohereAccount account; + private final List input; + private final String modelId; + private final String inferenceEntityId; + private final boolean stream; + + public CohereCompletionRequest(List input, CohereCompletionModel model, boolean stream) { + Objects.requireNonNull(model); + + this.account = CohereAccount.of(model, CohereCompletionRequest::buildDefaultUri); + this.input = Objects.requireNonNull(input); + this.modelId = model.getServiceSettings().modelId(); + this.inferenceEntityId = model.getInferenceEntityId(); + this.stream = stream; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new CohereCompletionRequestEntity(input, modelId, isStreaming())).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public boolean isStreaming() { + return stream; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + // no truncation + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // no truncation + return null; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntity.java similarity index 99% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntity.java index 7ab8d6753a0c4..5d61f19cfe683 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request.completion; +package org.elasticsearch.xpack.inference.services.cohere.request; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequest.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequest.java index b477295afbc09..2b0d5dfa508a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request.completion; +package org.elasticsearch.xpack.inference.services.cohere.request.v2; import org.apache.http.client.methods.HttpPost; import org.apache.http.client.utils.URIBuilder; @@ -15,7 +15,6 @@ import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import java.net.URI; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequestEntity.java new file mode 100644 index 0000000000000..6d4149792df5e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequestEntity.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record CohereCompletionRequestEntity(List input, @Nullable String model, boolean stream) implements ToXContentObject { + + private static final String MESSAGE_FIELD = "message"; + private static final String MODEL = "model"; + private static final String STREAM = "stream"; + + public CohereCompletionRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(input.get(0)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + // we only allow one input for completion, so always get the first one + builder.field(MESSAGE_FIELD, input.get(0)); + if (model != null) { + builder.field(MODEL, model); + } + + if (stream) { + builder.field(STREAM, true); + } + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..cd0fd41188f04 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereEmbeddingsRequestEntity.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; + +public record CohereEmbeddingsRequestEntity( + List input, + InputType inputType, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable String model, + @Nullable CohereEmbeddingType embeddingType +) implements ToXContentObject { + + private static final String SEARCH_DOCUMENT = "search_document"; + private static final String SEARCH_QUERY = "search_query"; + private static final String CLUSTERING = "clustering"; + private static final String CLASSIFICATION = "classification"; + private static final String TEXTS_FIELD = "texts"; + public static final String INPUT_TYPE_FIELD = "input_type"; + static final String EMBEDDING_TYPES_FIELD = "embedding_types"; + + public CohereEmbeddingsRequestEntity { + Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEXTS_FIELD, input); + if (model != null) { + builder.field(CohereServiceSettings.OLD_MODEL_ID_FIELD, model); + } + + // prefer the root level inputType over task settings input type + if (InputType.isSpecified(inputType)) { + builder.field(INPUT_TYPE_FIELD, convertToString(inputType)); + } else if (InputType.isSpecified(taskSettings.getInputType())) { + builder.field(INPUT_TYPE_FIELD, convertToString(taskSettings.getInputType())); + } + + if (embeddingType != null) { + builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); + } + + if (taskSettings.getTruncation() != null) { + builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); + } + + builder.endObject(); + return builder; + } + + // default for testing + public static String convertToString(InputType inputType) { + return switch (inputType) { + case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT; + case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY; + case CLASSIFICATION -> CLASSIFICATION; + case CLUSTERING -> CLUSTERING; + default -> { + assert false : invalidInputTypeMessage(inputType); + yield null; + } + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRequest.java new file mode 100644 index 0000000000000..6620cbc03b700 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRequest.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public abstract class CohereRequest implements Request { + + public static void decorateWithAuthHeader(HttpPost request, CohereAccount account) { + request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + request.setHeader(createAuthBearerHeader(account.apiKey())); + request.setHeader(CohereUtils.createRequestSourceHeader()); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequest.java new file mode 100644 index 0000000000000..acb0adec4478c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequest.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class CohereRerankRequest extends CohereRequest { + + private final CohereAccount account; + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final CohereRerankTaskSettings taskSettings; + private final String model; + private final String inferenceEntityId; + + public CohereRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankModel model + ) { + Objects.requireNonNull(model); + + this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri); + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; + taskSettings = model.getTaskSettings(); + this.model = model.getServiceSettings().modelId(); + inferenceEntityId = model.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model)) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + return this; // TODO? + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequestEntity.java new file mode 100644 index 0000000000000..f75fc96b6a964 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequestEntity.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record CohereRerankRequestEntity( + String model, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankTaskSettings taskSettings +) implements ToXContentObject { + + private static final String DOCUMENTS_FIELD = "documents"; + private static final String QUERY_FIELD = "query"; + private static final String MODEL_FIELD = "model"; + + public CohereRerankRequestEntity { + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + Objects.requireNonNull(taskSettings); + } + + public CohereRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankTaskSettings taskSettings, + String model + ) { + this(model, query, input, returnDocuments, topN, taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, model); + builder.field(QUERY_FIELD, query); + builder.field(DOCUMENTS_FIELD, documents); + + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } else if (taskSettings.getDoesReturnDocuments() != null) { + builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); + } + + // prefer the root level top_n over task settings + if (topN != null) { + builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, topN); + } else if (taskSettings.getTopNDocumentsOnly() != null) { + builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); + } + + if (taskSettings.getMaxChunksPerDoc() != null) { + builder.field(CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, taskSettings.getMaxChunksPerDoc()); + } + + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java new file mode 100644 index 0000000000000..78d67b5937ec2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class CohereV2EmbeddingsRequest extends CohereRequest { + + private final CohereAccount account; + private final List input; + private final InputType inputType; + private final CohereEmbeddingsTaskSettings taskSettings; + private final String model; + private final CohereEmbeddingType embeddingType; + private final String inferenceEntityId; + + public CohereV2EmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { + Objects.requireNonNull(embeddingsModel); + + account = CohereAccount.of(embeddingsModel, CohereV2EmbeddingsRequest::buildDefaultUri); + this.input = Objects.requireNonNull(input); + this.inputType = inputType; + taskSettings = embeddingsModel.getTaskSettings(); + model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); + embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); + inferenceEntityId = embeddingsModel.getInferenceEntityId(); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new CohereEmbeddingsRequestEntity(input, inputType, taskSettings, model, embeddingType)) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java index 39247b6e93e77..fa775c5b809b6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.services.cohere.request.completion.CohereCompletionRequestEntity; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java index 930480af50fb4..67e2b92def246 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; -import org.elasticsearch.xpack.inference.services.cohere.request.completion.CohereCompletionRequest; import java.io.IOException; import java.util.List; From 14cc105d8fcce6d46bdfa5990828de01dd525b2e Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 23 Apr 2025 12:06:57 +0100 Subject: [PATCH 02/12] Add v2 classes --- .../CohereCompletionRequestManager.java | 4 +- .../CohereEmbeddingsRequestManager.java | 4 +- .../cohere/CohereRerankRequestManager.java | 4 +- .../services/cohere/CohereService.java | 14 +- .../completion/CohereCompletionModel.java | 10 +- .../embeddings/CohereEmbeddingsModel.java | 9 +- .../request/CohereCompletionRequest.java | 89 ----- .../CohereCompletionRequestEntity.java | 47 --- .../request/CohereEmbeddingsRequest.java | 91 ----- .../CohereEmbeddingsRequestEntity.java | 86 ----- .../cohere/request/CohereRequest.java | 65 +++- .../cohere/request/CohereRerankRequest.java | 98 ------ .../services/cohere/request/CohereUtils.java | 30 ++ .../request/v1/CohereV1CompletionRequest.java | 58 ++++ .../request/v1/CohereV1EmbeddingsRequest.java | 82 +++++ .../CohereV1RerankRequest.java} | 65 ++-- .../request/v2/CohereCompletionRequest.java | 90 ----- .../v2/CohereCompletionRequestEntity.java | 47 --- .../v2/CohereEmbeddingsRequestEntity.java | 86 ----- .../cohere/request/v2/CohereRequest.java | 27 -- .../request/v2/CohereRerankRequest.java | 99 ------ .../request/v2/CohereV2CompletionRequest.java | 54 +++ .../request/v2/CohereV2EmbeddingsRequest.java | 79 ++--- .../CohereV2RerankRequest.java} | 65 ++-- .../cohere/rerank/CohereRerankModel.java | 12 +- .../rerank/CohereRerankServiceSettings.java | 4 - .../action/CohereEmbeddingsActionTests.java | 3 +- .../CohereCompletionModelTests.java | 12 +- .../CohereEmbeddingsModelTests.java | 7 - .../CohereCompletionRequestEntityTests.java | 52 --- .../CohereEmbeddingsRequestEntityTests.java | 130 ------- .../cohere/request/CohereRequestTests.java | 1 - .../cohere/request/CohereUtilsTests.java | 16 + .../v1/CohereV1CompletionRequestTests.java | 124 +++++++ .../CohereV1EmbeddingsRequestTests.java} | 123 ++++++- .../CohereV1RerankRequestEntityTests.java} | 41 +-- .../CohereV2CompletionRequestTests.java} | 57 +-- .../v2/CohereV2EmbeddingsRequestTests.java | 324 ++++++++++++++++++ .../v2/CohereV2RerankRequestEntityTests.java | 124 +++++++ 39 files changed, 1168 insertions(+), 1165 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequest.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequest.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/{v2/CohereRerankRequestEntity.java => v1/CohereV1RerankRequest.java} (53%) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequest.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereEmbeddingsRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRequest.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/{CohereRerankRequestEntity.java => v2/CohereV2RerankRequest.java} (50%) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/{CohereEmbeddingsRequestTests.java => v1/CohereV1EmbeddingsRequestTests.java} (71%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/{CohereRerankRequestEntityTests.java => v1/CohereV1RerankRequestEntityTests.java} (66%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/{CohereCompletionRequestTests.java => v2/CohereV2CompletionRequestTests.java} (52%) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestEntityTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java index 3c63031f633a6..535f1fdd95a3d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java @@ -18,7 +18,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereCompletionRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest; import org.elasticsearch.xpack.inference.services.cohere.response.CohereCompletionResponseEntity; import java.util.Objects; @@ -55,7 +55,7 @@ public void execute( var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); var inputs = chatCompletionInput.getInputs(); var stream = chatCompletionInput.stream(); - CohereCompletionRequest request = new CohereCompletionRequest(inputs, model, stream); + CohereV1CompletionRequest request = new CohereV1CompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java index e721c3e46cecf..09843ca1ce283 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java @@ -19,7 +19,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1EmbeddingsRequest; import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity; import java.util.List; @@ -56,7 +56,7 @@ public void execute( List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); - CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(docsInput, inputType, model); + CohereV1EmbeddingsRequest request = new CohereV1EmbeddingsRequest(docsInput, inputType, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java index 134aab77530e1..cbf606c0db678 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java @@ -17,7 +17,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereRerankRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1RerankRequest; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity; @@ -51,7 +51,7 @@ public void execute( ActionListener listener ) { var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - CohereRerankRequest request = new CohereRerankRequest( + CohereV1RerankRequest request = new CohereV1RerankRequest( rerankInput.getQuery(), rerankInput.getChunks(), rerankInput.getReturnDocuments(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index bf6a0bd03122b..b4b565ef70560 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -166,24 +166,14 @@ private static CohereModel createModel( return switch (taskType) { case TEXT_EMBEDDING -> new CohereEmbeddingsModel( inferenceEntityId, - taskType, - NAME, serviceSettings, taskSettings, chunkingSettings, secretSettings, context ); - case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); - case COMPLETION -> new CohereCompletionModel( - inferenceEntityId, - taskType, - NAME, - serviceSettings, - taskSettings, - secretSettings, - context - ); + case RERANK -> new CohereRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); + case COMPLETION -> new CohereCompletionModel(inferenceEntityId, serviceSettings, secretSettings, context); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java index a31a6ae290fea..3dfcdae778a50 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -26,17 +27,12 @@ public class CohereCompletionModel extends CohereModel { public CohereCompletionModel( String modelId, - TaskType taskType, - String service, Map serviceSettings, - Map taskSettings, @Nullable Map secrets, ConfigurationParseContext context ) { this( modelId, - taskType, - service, CohereCompletionServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, DefaultSecretSettings.fromMap(secrets) @@ -46,14 +42,12 @@ public CohereCompletionModel( // should only be used for testing CohereCompletionModel( String modelId, - TaskType taskType, - String service, CohereCompletionServiceSettings serviceSettings, TaskSettings taskSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(modelId, TaskType.COMPLETION, CohereService.NAME, serviceSettings, taskSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index 2edd365e66311..e8c2400ea4f1c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -29,8 +30,6 @@ public static CohereEmbeddingsModel of(CohereEmbeddingsModel model, Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, @@ -39,8 +38,6 @@ public CohereEmbeddingsModel( ) { this( inferenceId, - taskType, - service, CohereEmbeddingsServiceSettings.fromMap(serviceSettings, context), CohereEmbeddingsTaskSettings.fromMap(taskSettings), chunkingSettings, @@ -51,15 +48,13 @@ public CohereEmbeddingsModel( // should only be used for testing CohereEmbeddingsModel( String modelId, - TaskType taskType, - String service, CohereEmbeddingsServiceSettings serviceSettings, CohereEmbeddingsTaskSettings taskSettings, ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings), + new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, CohereService.NAME, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings.getCommonSettings() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequest.java deleted file mode 100644 index b52601396570e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequest.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -public class CohereCompletionRequest extends CohereRequest { - private final CohereAccount account; - private final List input; - private final String modelId; - private final String inferenceEntityId; - private final boolean stream; - - public CohereCompletionRequest(List input, CohereCompletionModel model, boolean stream) { - Objects.requireNonNull(model); - - this.account = CohereAccount.of(model, CohereCompletionRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.modelId = model.getServiceSettings().modelId(); - this.inferenceEntityId = model.getInferenceEntityId(); - this.stream = stream; - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereCompletionRequestEntity(input, modelId, isStreaming())).getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public boolean isStreaming() { - return stream; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - // no truncation - return this; - } - - @Override - public boolean[] getTruncationInfo() { - // no truncation - return null; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntity.java deleted file mode 100644 index 5d61f19cfe683..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntity.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public record CohereCompletionRequestEntity(List input, @Nullable String model, boolean stream) implements ToXContentObject { - - private static final String MESSAGE_FIELD = "message"; - private static final String MODEL = "model"; - private static final String STREAM = "stream"; - - public CohereCompletionRequestEntity { - Objects.requireNonNull(input); - Objects.requireNonNull(input.get(0)); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - // we only allow one input for completion, so always get the first one - builder.field(MESSAGE_FIELD, input.get(0)); - if (model != null) { - builder.field(MODEL, model); - } - - if (stream) { - builder.field(STREAM, true); - } - - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequest.java deleted file mode 100644 index 7ce218c3a8fe8..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequest.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -public class CohereEmbeddingsRequest extends CohereRequest { - - private final CohereAccount account; - private final List input; - private final InputType inputType; - private final CohereEmbeddingsTaskSettings taskSettings; - private final String model; - private final CohereEmbeddingType embeddingType; - private final String inferenceEntityId; - - public CohereEmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { - Objects.requireNonNull(embeddingsModel); - - account = CohereAccount.of(embeddingsModel, CohereEmbeddingsRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.inputType = inputType; - taskSettings = embeddingsModel.getTaskSettings(); - model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); - embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); - inferenceEntityId = embeddingsModel.getInferenceEntityId(); - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereEmbeddingsRequestEntity(input, inputType, taskSettings, model, embeddingType)) - .getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - return this; - } - - @Override - public boolean[] getTruncationInfo() { - return null; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntity.java deleted file mode 100644 index e4de77cd56edd..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntity.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; - -public record CohereEmbeddingsRequestEntity( - List input, - InputType inputType, - CohereEmbeddingsTaskSettings taskSettings, - @Nullable String model, - @Nullable CohereEmbeddingType embeddingType -) implements ToXContentObject { - - private static final String SEARCH_DOCUMENT = "search_document"; - private static final String SEARCH_QUERY = "search_query"; - private static final String CLUSTERING = "clustering"; - private static final String CLASSIFICATION = "classification"; - private static final String TEXTS_FIELD = "texts"; - public static final String INPUT_TYPE_FIELD = "input_type"; - static final String EMBEDDING_TYPES_FIELD = "embedding_types"; - - public CohereEmbeddingsRequestEntity { - Objects.requireNonNull(input); - Objects.requireNonNull(taskSettings); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(TEXTS_FIELD, input); - if (model != null) { - builder.field(CohereServiceSettings.OLD_MODEL_ID_FIELD, model); - } - - // prefer the root level inputType over task settings input type - if (InputType.isSpecified(inputType)) { - builder.field(INPUT_TYPE_FIELD, convertToString(inputType)); - } else if (InputType.isSpecified(taskSettings.getInputType())) { - builder.field(INPUT_TYPE_FIELD, convertToString(taskSettings.getInputType())); - } - - if (embeddingType != null) { - builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); - } - - if (taskSettings.getTruncation() != null) { - builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); - } - - builder.endObject(); - return builder; - } - - // default for testing - public static String convertToString(InputType inputType) { - return switch (inputType) { - case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT; - case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY; - case CLASSIFICATION -> CLASSIFICATION; - case CLUSTERING -> CLUSTERING; - default -> { - assert false : invalidInputTypeMessage(inputType); - yield null; - } - }; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java index fda1661d02472..f0d90f10945b3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java @@ -9,13 +9,22 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; -public abstract class CohereRequest implements Request { +public abstract class CohereRequest implements Request, ToXContentObject { public static void decorateWithAuthHeader(HttpPost request, CohereAccount account) { request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); @@ -23,4 +32,58 @@ public static void decorateWithAuthHeader(HttpPost request, CohereAccount accoun request.setHeader(CohereUtils.createRequestSourceHeader()); } + protected final CohereAccount account; + private final String inferenceEntityId; + private final String modelId; + private final boolean stream; + + protected CohereRequest(CohereAccount account, String inferenceEntityId, @Nullable String modelId, boolean stream) { + this.account = account; + this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); + this.modelId = modelId; // model is optional in the v1 api + this.stream = stream; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(this).getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public boolean isStreaming() { + return stream; + } + + @Override + public URI getURI() { + return account.uri(); + } + + public String getModelId() { + return modelId; + } + + @Override + public Request truncate() { + // no truncation + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // no truncation + return null; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequest.java deleted file mode 100644 index ed2a7ea97925e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequest.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -public class CohereRerankRequest extends CohereRequest { - - private final CohereAccount account; - private final String query; - private final List input; - private final Boolean returnDocuments; - private final Integer topN; - private final CohereRerankTaskSettings taskSettings; - private final String model; - private final String inferenceEntityId; - - public CohereRerankRequest( - String query, - List input, - @Nullable Boolean returnDocuments, - @Nullable Integer topN, - CohereRerankModel model - ) { - Objects.requireNonNull(model); - - this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.query = Objects.requireNonNull(query); - this.returnDocuments = returnDocuments; - this.topN = topN; - taskSettings = model.getTaskSettings(); - this.model = model.getServiceSettings().modelId(); - inferenceEntityId = model.getInferenceEntityId(); - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model)) - .getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - return this; // TODO? - } - - @Override - public boolean[] getTruncationInfo() { - return null; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java index 6eef2c67f5af0..f512444c6d6a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java @@ -9,19 +9,49 @@ import org.apache.http.Header; import org.apache.http.message.BasicHeader; +import org.elasticsearch.inference.InputType; + +import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; public class CohereUtils { public static final String HOST = "api.cohere.ai"; public static final String VERSION_1 = "v1"; + public static final String VERSION_2 = "v2"; public static final String CHAT_PATH = "chat"; public static final String EMBEDDINGS_PATH = "embed"; public static final String RERANK_PATH = "rerank"; public static final String REQUEST_SOURCE_HEADER = "Request-Source"; public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch"; + public static final String CLUSTERING = "clustering"; + public static final String CLASSIFICATION = "classification"; + public static final String DOCUMENTS_FIELD = "documents"; + public static final String EMBEDDING_TYPES_FIELD = "embedding_types"; + public static final String INPUT_TYPE_FIELD = "input_type"; + public static final String MESSAGE_FIELD = "message"; + public static final String MODEL_FIELD = "model"; + public static final String QUERY_FIELD = "query"; + public static final String SEARCH_DOCUMENT = "search_document"; + public static final String SEARCH_QUERY = "search_query"; + public static final String TEXTS_FIELD = "texts"; + public static final String STREAM_FIELD = "stream"; + public static Header createRequestSourceHeader() { return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); } + public static String inputTypeToString(InputType inputType) { + return switch (inputType) { + case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT; + case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY; + case CLASSIFICATION -> CLASSIFICATION; + case CLUSTERING -> CLUSTERING; + default -> { + assert false : invalidInputTypeMessage(inputType); + yield null; + } + }; + } + private CohereUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java new file mode 100644 index 0000000000000..a657149ff5e6c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v1; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Objects; + +public class CohereV1CompletionRequest extends CohereRequest { + private final List input; + + public CohereV1CompletionRequest(List input, CohereCompletionModel model, boolean stream) { + super( + CohereAccount.of(model, CohereV1CompletionRequest::buildDefaultUri), + model.getInferenceEntityId(), + model.getServiceSettings().modelId(), + stream + ); + + this.input = Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + // we only allow one input for completion, so always get the first one + builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst()); + if (getModelId() != null) { + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + } + if (isStreaming()) { + builder.field(CohereUtils.STREAM_FIELD, true); + } + builder.endObject(); + return builder; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java new file mode 100644 index 0000000000000..7482c68c095db --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v1; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Objects; + +public class CohereV1EmbeddingsRequest extends CohereRequest { + + private final List input; + private final InputType inputType; + private final CohereEmbeddingsTaskSettings taskSettings; + private final CohereEmbeddingType embeddingType; + + public CohereV1EmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { + super( + CohereAccount.of(embeddingsModel, CohereV1EmbeddingsRequest::buildDefaultUri), + embeddingsModel.getInferenceEntityId(), + embeddingsModel.getServiceSettings().getCommonSettings().modelId(), + false + ); + + this.input = Objects.requireNonNull(input); + this.inputType = inputType; + taskSettings = embeddingsModel.getTaskSettings(); + embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH) + .build(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CohereUtils.TEXTS_FIELD, input); + if (getModelId() != null) { + builder.field(CohereServiceSettings.OLD_MODEL_ID_FIELD, getModelId()); + } + + // prefer the root level inputType over task settings input type + if (InputType.isSpecified(inputType)) { + builder.field(CohereUtils.INPUT_TYPE_FIELD, CohereUtils.inputTypeToString(inputType)); + } else if (InputType.isSpecified(taskSettings.getInputType())) { + builder.field(CohereUtils.INPUT_TYPE_FIELD, CohereUtils.inputTypeToString(taskSettings.getInputType())); + } + + if (embeddingType != null) { + builder.field(CohereUtils.EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); + } + + if (taskSettings.getTruncation() != null) { + builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); + } + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java similarity index 53% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java index f75fc96b6a964..389a810b73f7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java @@ -5,54 +5,66 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request.v2; +package org.elasticsearch.xpack.inference.services.cohere.request.v1; +import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.util.List; import java.util.Objects; -public record CohereRerankRequestEntity( - String model, - String query, - List documents, - @Nullable Boolean returnDocuments, - @Nullable Integer topN, - CohereRerankTaskSettings taskSettings -) implements ToXContentObject { +public class CohereV1RerankRequest extends CohereRequest { - private static final String DOCUMENTS_FIELD = "documents"; - private static final String QUERY_FIELD = "query"; - private static final String MODEL_FIELD = "model"; + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final CohereRerankTaskSettings taskSettings; - public CohereRerankRequestEntity { - Objects.requireNonNull(query); - Objects.requireNonNull(documents); - Objects.requireNonNull(taskSettings); - } - - public CohereRerankRequestEntity( + public CohereV1RerankRequest( String query, List input, @Nullable Boolean returnDocuments, @Nullable Integer topN, - CohereRerankTaskSettings taskSettings, - String model + CohereRerankModel model ) { - this(model, query, input, returnDocuments, topN, taskSettings); + super( + CohereAccount.of(model, CohereV1RerankRequest::buildDefaultUri), + model.getInferenceEntityId(), + model.getServiceSettings().modelId(), + false + ); + + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; + taskSettings = model.getTaskSettings(); + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH) + .build(); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(MODEL_FIELD, model); - builder.field(QUERY_FIELD, query); - builder.field(DOCUMENTS_FIELD, documents); + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + builder.field(CohereUtils.QUERY_FIELD, query); + builder.field(CohereUtils.DOCUMENTS_FIELD, input); // prefer the root level return_documents over task settings if (returnDocuments != null) { @@ -75,5 +87,4 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequest.java deleted file mode 100644 index 2b0d5dfa508a6..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequest.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request.v2; - -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -public class CohereCompletionRequest extends CohereRequest { - private final CohereAccount account; - private final List input; - private final String modelId; - private final String inferenceEntityId; - private final boolean stream; - - public CohereCompletionRequest(List input, CohereCompletionModel model, boolean stream) { - Objects.requireNonNull(model); - - this.account = CohereAccount.of(model, CohereCompletionRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.modelId = model.getServiceSettings().modelId(); - this.inferenceEntityId = model.getInferenceEntityId(); - this.stream = stream; - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereCompletionRequestEntity(input, modelId, isStreaming())).getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public boolean isStreaming() { - return stream; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - // no truncation - return this; - } - - @Override - public boolean[] getTruncationInfo() { - // no truncation - return null; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequestEntity.java deleted file mode 100644 index 6d4149792df5e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereCompletionRequestEntity.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request.v2; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public record CohereCompletionRequestEntity(List input, @Nullable String model, boolean stream) implements ToXContentObject { - - private static final String MESSAGE_FIELD = "message"; - private static final String MODEL = "model"; - private static final String STREAM = "stream"; - - public CohereCompletionRequestEntity { - Objects.requireNonNull(input); - Objects.requireNonNull(input.get(0)); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - // we only allow one input for completion, so always get the first one - builder.field(MESSAGE_FIELD, input.get(0)); - if (model != null) { - builder.field(MODEL, model); - } - - if (stream) { - builder.field(STREAM, true); - } - - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereEmbeddingsRequestEntity.java deleted file mode 100644 index cd0fd41188f04..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereEmbeddingsRequestEntity.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request.v2; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; - -public record CohereEmbeddingsRequestEntity( - List input, - InputType inputType, - CohereEmbeddingsTaskSettings taskSettings, - @Nullable String model, - @Nullable CohereEmbeddingType embeddingType -) implements ToXContentObject { - - private static final String SEARCH_DOCUMENT = "search_document"; - private static final String SEARCH_QUERY = "search_query"; - private static final String CLUSTERING = "clustering"; - private static final String CLASSIFICATION = "classification"; - private static final String TEXTS_FIELD = "texts"; - public static final String INPUT_TYPE_FIELD = "input_type"; - static final String EMBEDDING_TYPES_FIELD = "embedding_types"; - - public CohereEmbeddingsRequestEntity { - Objects.requireNonNull(input); - Objects.requireNonNull(taskSettings); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(TEXTS_FIELD, input); - if (model != null) { - builder.field(CohereServiceSettings.OLD_MODEL_ID_FIELD, model); - } - - // prefer the root level inputType over task settings input type - if (InputType.isSpecified(inputType)) { - builder.field(INPUT_TYPE_FIELD, convertToString(inputType)); - } else if (InputType.isSpecified(taskSettings.getInputType())) { - builder.field(INPUT_TYPE_FIELD, convertToString(taskSettings.getInputType())); - } - - if (embeddingType != null) { - builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); - } - - if (taskSettings.getTruncation() != null) { - builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); - } - - builder.endObject(); - return builder; - } - - // default for testing - public static String convertToString(InputType inputType) { - return switch (inputType) { - case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT; - case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY; - case CLASSIFICATION -> CLASSIFICATION; - case CLUSTERING -> CLUSTERING; - default -> { - assert false : invalidInputTypeMessage(inputType); - yield null; - } - }; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRequest.java deleted file mode 100644 index 6620cbc03b700..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRequest.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request.v2; - -import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.HttpPost; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; - -import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; - -public abstract class CohereRequest implements Request { - - public static void decorateWithAuthHeader(HttpPost request, CohereAccount account) { - request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); - request.setHeader(createAuthBearerHeader(account.apiKey())); - request.setHeader(CohereUtils.createRequestSourceHeader()); - } - -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequest.java deleted file mode 100644 index acb0adec4478c..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereRerankRequest.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request.v2; - -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -public class CohereRerankRequest extends CohereRequest { - - private final CohereAccount account; - private final String query; - private final List input; - private final Boolean returnDocuments; - private final Integer topN; - private final CohereRerankTaskSettings taskSettings; - private final String model; - private final String inferenceEntityId; - - public CohereRerankRequest( - String query, - List input, - @Nullable Boolean returnDocuments, - @Nullable Integer topN, - CohereRerankModel model - ) { - Objects.requireNonNull(model); - - this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.query = Objects.requireNonNull(query); - this.returnDocuments = returnDocuments; - this.topN = topN; - taskSettings = model.getTaskSettings(); - this.model = model.getServiceSettings().modelId(); - inferenceEntityId = model.getInferenceEntityId(); - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model)) - .getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - return this; // TODO? - } - - @Override - public boolean[] getTruncationInfo() { - return null; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java new file mode 100644 index 0000000000000..733222548ec1e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Objects; + +public class CohereV2CompletionRequest extends CohereRequest { + private final List input; + + public CohereV2CompletionRequest(List input, CohereCompletionModel model, boolean stream) { + super( + CohereAccount.of(model, CohereV2CompletionRequest::buildDefaultUri), + model.getInferenceEntityId(), + Objects.requireNonNull(model.getServiceSettings().modelId()), + stream + ); + + this.input = Objects.requireNonNull(input); + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_2, CohereUtils.CHAT_PATH) + .build(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + // we only allow one input for completion, so always get the first one + builder.field(CohereUtils.MESSAGE_FIELD, input.getFirst()); + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + builder.field(CohereUtils.STREAM_FIELD, isStreaming()); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java index 78d67b5937ec2..f196e5759a384 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java @@ -7,86 +7,65 @@ package org.elasticsearch.xpack.inference.services.cohere.request.v2; -import org.apache.http.client.methods.HttpPost; import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InputType; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.inputTypeToString; + public class CohereV2EmbeddingsRequest extends CohereRequest { - private final CohereAccount account; private final List input; private final InputType inputType; private final CohereEmbeddingsTaskSettings taskSettings; - private final String model; private final CohereEmbeddingType embeddingType; - private final String inferenceEntityId; public CohereV2EmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { - Objects.requireNonNull(embeddingsModel); + super( + CohereAccount.of(embeddingsModel, CohereV2EmbeddingsRequest::buildDefaultUri), + embeddingsModel.getInferenceEntityId(), + Objects.requireNonNull(embeddingsModel.getServiceSettings().getCommonSettings().modelId()), + false + ); - account = CohereAccount.of(embeddingsModel, CohereV2EmbeddingsRequest::buildDefaultUri); this.input = Objects.requireNonNull(input); - this.inputType = inputType; + this.inputType = Objects.requireNonNull(inputType); // inputType is required in v2 taskSettings = embeddingsModel.getTaskSettings(); - model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); - inferenceEntityId = embeddingsModel.getInferenceEntityId(); - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereEmbeddingsRequestEntity(input, inputType, taskSettings, model, embeddingType)) - .getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - return this; - } - - @Override - public boolean[] getTruncationInfo() { - return null; } public static URI buildDefaultUri() throws URISyntaxException { return new URIBuilder().setScheme("https") .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH) + .setPathSegments(CohereUtils.VERSION_2, CohereUtils.EMBEDDINGS_PATH) .build(); } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + builder.field(CohereUtils.TEXTS_FIELD, input); + // TODO merge input type from task settings InputType.isSpecified(inputType) + builder.field(CohereUtils.INPUT_TYPE_FIELD, inputTypeToString(inputType)); + builder.field(CohereUtils.EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); + if (taskSettings.getTruncation() != null) { + builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); + } + builder.endObject(); + return builder; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java similarity index 50% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java index ddddb9fa314e5..50b7339e575af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java @@ -5,54 +5,70 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request; +package org.elasticsearch.xpack.inference.services.cohere.request.v2; +import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.util.List; import java.util.Objects; -public record CohereRerankRequestEntity( - String model, - String query, - List documents, - @Nullable Boolean returnDocuments, - @Nullable Integer topN, - CohereRerankTaskSettings taskSettings -) implements ToXContentObject { +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.QUERY_FIELD; - private static final String DOCUMENTS_FIELD = "documents"; - private static final String QUERY_FIELD = "query"; - private static final String MODEL_FIELD = "model"; +public class CohereV2RerankRequest extends CohereRequest { - public CohereRerankRequestEntity { - Objects.requireNonNull(query); - Objects.requireNonNull(documents); - Objects.requireNonNull(taskSettings); - } + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final CohereRerankTaskSettings taskSettings; - public CohereRerankRequestEntity( + public CohereV2RerankRequest( String query, List input, @Nullable Boolean returnDocuments, @Nullable Integer topN, - CohereRerankTaskSettings taskSettings, - String model + CohereRerankModel model ) { - this(model, query, input, returnDocuments, topN, taskSettings); + super( + CohereAccount.of(model, CohereV2RerankRequest::buildDefaultUri), + model.getInferenceEntityId(), + Objects.requireNonNull(model.getServiceSettings().modelId()), + false + ); + + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; + taskSettings = model.getTaskSettings(); + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_2, CohereUtils.RERANK_PATH) + .build(); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(MODEL_FIELD, model); + builder.field(MODEL_FIELD, getModelId()); builder.field(QUERY_FIELD, query); - builder.field(DOCUMENTS_FIELD, documents); + builder.field(DOCUMENTS_FIELD, input); // prefer the root level return_documents over task settings if (returnDocuments != null) { @@ -75,5 +91,4 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java index ca853a2d28909..ca51c3f94f074 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -28,8 +29,6 @@ public static CohereRerankModel of(CohereRerankModel model, Map public CohereRerankModel( String modelId, - TaskType taskType, - String service, Map serviceSettings, Map taskSettings, @Nullable Map secrets, @@ -37,25 +36,20 @@ public CohereRerankModel( ) { this( modelId, - taskType, - service, CohereRerankServiceSettings.fromMap(serviceSettings, context), CohereRerankTaskSettings.fromMap(taskSettings), DefaultSecretSettings.fromMap(secrets) ); } - // should only be used for testing - CohereRerankModel( + public CohereRerankModel( String modelId, - TaskType taskType, - String service, CohereRerankServiceSettings serviceSettings, CohereRerankTaskSettings taskSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(modelId, TaskType.RERANK, CohereService.NAME, serviceSettings, taskSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java index 78178466f9f3a..aac5cd705c13c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.rerank; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ValidationException; @@ -44,8 +42,6 @@ public class CohereRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings { public static final String NAME = "cohere_rerank_service_settings"; - private static final Logger logger = LogManager.getLogger(CohereRerankServiceSettings.class); - public static CohereRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index 13d5191577d4c..f5a387f0f00b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -53,7 +53,6 @@ import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.services.cohere.request.CohereEmbeddingsRequestEntity.convertToString; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -146,7 +145,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var requestMap = entityAsMap(webServer.requests().get(0).getBody()); if (inputType != null && inputType != InputType.UNSPECIFIED) { - var cohereInputType = convertToString(inputType); + var cohereInputType = CohereUtils.inputTypeToString(inputType); MatcherAssert.assertThat( requestMap, is( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java index b9fc7ee7b9952..9261170f2dc27 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -24,23 +23,20 @@ public class CohereCompletionModelTests extends ESTestCase { public void testCreateModel_AlwaysWithEmptyTaskSettings() { var model = new CohereCompletionModel( - "model", - TaskType.COMPLETION, - "service", - new HashMap<>(Map.of()), - new HashMap<>(Map.of("model", "overridden model")), + "inference_id", + new HashMap<>(Map.of("model", "cohere completion model")), null, ConfigurationParseContext.PERSISTENT ); + assertThat(model.uri(), is("TODO")); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(model.getServiceSettings().modelId(), is("cohere completion model")); } public static CohereCompletionModel createModel(String url, String apiKey, @Nullable String model) { return new CohereCompletionModel( "id", - TaskType.COMPLETION, - "service", new CohereCompletionServiceSettings(url, model, null), EmptyTaskSettings.INSTANCE, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index b3e264fdf1ab7..294d33f7d9fac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -121,8 +120,6 @@ public static CohereEmbeddingsModel createModel( ) { return new CohereEmbeddingsModel( "id", - TaskType.TEXT_EMBEDDING, - "service", new CohereEmbeddingsServiceSettings( new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) @@ -144,8 +141,6 @@ public static CohereEmbeddingsModel createModel( ) { return new CohereEmbeddingsModel( "id", - TaskType.TEXT_EMBEDDING, - "service", new CohereEmbeddingsServiceSettings( new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) @@ -168,8 +163,6 @@ public static CohereEmbeddingsModel createModel( ) { return new CohereEmbeddingsModel( "id", - TaskType.TEXT_EMBEDDING, - "service", new CohereEmbeddingsServiceSettings( new CohereServiceSettings(url, similarityMeasure, dimensions, tokenLimit, model, null), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java deleted file mode 100644 index fa775c5b809b6..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; - -import java.io.IOException; -import java.util.List; - -import static org.hamcrest.CoreMatchers.is; - -public class CohereCompletionRequestEntityTests extends ESTestCase { - - public void testXContent_WritesAllFields() throws IOException { - var entity = new CohereCompletionRequestEntity(List.of("some input"), "model", false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"message":"some input","model":"model"}""")); - } - - public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException { - var entity = new CohereCompletionRequestEntity(List.of("some input"), null, false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"message":"some input"}""")); - } - - public void testXContent_ThrowsIfInputIsNull() { - expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(null, null, false)); - } - - public void testXContent_ThrowsIfMessageInInputIsNull() { - expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(List.of((String) null), null, false)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntityTests.java deleted file mode 100644 index 30a01422f6f30..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntityTests.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; -import org.hamcrest.MatcherAssert; - -import java.io.IOException; -import java.util.List; - -import static org.hamcrest.CoreMatchers.is; - -public class CohereEmbeddingsRequestEntityTests extends ESTestCase { - public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - InputType.INTERNAL_INGEST, - new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), - "model", - CohereEmbeddingType.FLOAT - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); - } - - public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - null, - new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), - "model", - CohereEmbeddingType.INT8 - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); - } - - public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - InputType.INTERNAL_SEARCH, - new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), - "model", - CohereEmbeddingType.BYTE - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); - } - - public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - InputType.SEARCH, - new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), - "model", - CohereEmbeddingType.BINARY - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); - } - - public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - null, - new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), - "model", - CohereEmbeddingType.BIT - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); - } - - public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { - var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), null, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"]}""")); - } - - public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { - var thrownException = expectThrows( - AssertionError.class, - () -> CohereEmbeddingsRequestEntity.convertToString(InputType.UNSPECIFIED) - ); - MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequestTests.java index 604509afdbd7d..81106764474cf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequestTests.java @@ -32,5 +32,4 @@ public void testDecorateWithAuthHeader() { assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer abc")); assertThat(request.getFirstHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtilsTests.java index ad5c9c4e80330..ef2b29bbe9a2a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtilsTests.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.inference.services.cohere.request; +import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; +import org.hamcrest.MatcherAssert; import static org.hamcrest.Matchers.is; @@ -20,4 +22,18 @@ public void testCreateRequestSourceHeader() { assertThat(requestSourceHeader.getValue(), is("unspecified:elasticsearch")); } + public void testInputTypeToString() { + assertThat(CohereUtils.inputTypeToString(InputType.INGEST), is("search_document")); + assertThat(CohereUtils.inputTypeToString(InputType.INTERNAL_INGEST), is("search_document")); + assertThat(CohereUtils.inputTypeToString(InputType.SEARCH), is("search_query")); + assertThat(CohereUtils.inputTypeToString(InputType.INTERNAL_SEARCH), is("search_query")); + assertThat(CohereUtils.inputTypeToString(InputType.CLASSIFICATION), is("classification")); + assertThat(CohereUtils.inputTypeToString(InputType.CLUSTERING), is("clustering")); + assertThat(InputType.values().length, is(7)); // includes unspecified. Fail if new values are added + } + + public void testInputTypeToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows(AssertionError.class, () -> CohereUtils.inputTypeToString(InputType.UNSPECIFIED)); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java new file mode 100644 index 0000000000000..76e7e57b0ebbb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v1; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class CohereV1CompletionRequestTests extends ESTestCase { + + public void testCreateRequest_UrlDefined() throws IOException { + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null), false); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(Map.of("message", "abc"))); + } + + public void testCreateRequest_ModelDefined() throws IOException { + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); + } + + public void testDefaultUrl() throws IOException { + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel(null, "secret", null), false); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("TODO")); + } + + public void testTruncate_ReturnsSameInstance() { + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); + var truncatedRequest = request.truncate(); + + assertThat(truncatedRequest, sameInstance(request)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); + + assertNull(request.getTruncationInfo()); + } + + public void testXContent_WritesAllFields() throws IOException { + var request = new CohereV1CompletionRequest( + List.of("some input"), + CohereCompletionModelTests.createModel("url", "secret", "model"), + false + ); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + request.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"message":"some input","model":"model"}""")); + } + + public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException { + var request = new CohereV1CompletionRequest( + List.of("some input"), + CohereCompletionModelTests.createModel("url", "secret", null), + false + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + request.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"message":"some input"}""")); + } + + public void testXContent_ThrowsIfInputIsNull() { + expectThrows( + NullPointerException.class, + () -> new CohereV1CompletionRequest(null, CohereCompletionModelTests.createModel("url", "secret", null), false) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java similarity index 71% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java index 508c81bb940cd..245fe69ddea6a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java @@ -5,12 +5,15 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request; +package org.elasticsearch.xpack.inference.services.cohere.request.v1; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; @@ -18,6 +21,8 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import java.io.IOException; @@ -25,17 +30,16 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.services.cohere.request.CohereEmbeddingsRequestEntity.convertToString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class CohereEmbeddingsRequestTests extends ESTestCase { +public class CohereV1EmbeddingsRequestTests extends ESTestCase { public void testCreateRequest_UrlDefined() throws IOException { var inputType = InputTypeTests.randomWithNull(); var request = createRequest( List.of("abc"), inputType, - CohereEmbeddingsModelTests.createModel("url", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, null, null) + CohereEmbeddingsModelTests.createModel(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, null, null) ); var httpRequest = request.createHttpRequest(); @@ -43,7 +47,7 @@ public void testCreateRequest_UrlDefined() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("TODO should be full url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -277,16 +281,117 @@ public void testCreateRequest_TruncateNone() throws IOException { validateInputType(requestMap, null, inputType); } - public static CohereEmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { - return new CohereEmbeddingsRequest(input, inputType, model); + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = createRequest( + List.of("abc"), + InputType.INTERNAL_INGEST, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), + CohereEmbeddingType.FLOAT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); + } + + public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { + var entity = createRequest( + List.of("abc"), + null, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.INT8 + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { + var entity = createRequest( + List.of("abc"), + InputType.INTERNAL_SEARCH, + new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), + CohereEmbeddingType.BYTE + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { + var entity = createRequest( + List.of("abc"), + InputType.SEARCH, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.BINARY + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { + var entity = createRequest( + List.of("abc"), + null, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.BIT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = createRequest(List.of("abc"), null, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"]}""")); + } + + public static CohereV1EmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { + return new CohereV1EmbeddingsRequest(input, inputType, model); + } + + public static CohereV1EmbeddingsRequest createRequest( + List input, + InputType inputType, + CohereEmbeddingsTaskSettings taskSettings, + CohereEmbeddingType embeddingType + ) { + var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, null, embeddingType); + return new CohereV1EmbeddingsRequest(input, inputType, model); } private void validateInputType(Map requestMap, InputType taskSettingsInputType, InputType requestInputType) { if (InputType.isSpecified(requestInputType)) { - var convertedInputType = convertToString(requestInputType); + var convertedInputType = CohereUtils.inputTypeToString(requestInputType); assertThat(requestMap.get("input_type"), is(convertedInputType)); } else if (InputType.isSpecified(taskSettingsInputType)) { - var convertedInputType = convertToString(taskSettingsInputType); + var convertedInputType = CohereUtils.inputTypeToString(taskSettingsInputType); assertThat(requestMap.get("input_type"), is(convertedInputType)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestEntityTests.java similarity index 66% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestEntityTests.java index 7c0fa143a56db..2deda6deb9e81 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestEntityTests.java @@ -5,91 +5,92 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request; +package org.elasticsearch.xpack.inference.services.cohere.request.v1; import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import java.io.IOException; import java.util.List; -import static org.hamcrest.CoreMatchers.is; - -public class CohereRerankRequestEntityTests extends ESTestCase { +public class CohereV1RerankRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new CohereRerankRequestEntity( + var entity = new CohereV1RerankRequest( "query", List.of("abc"), Boolean.TRUE, 22, - new CohereRerankTaskSettings(null, null, 3), - "model" + createModel("model", new CohereRerankTaskSettings(null, null, 3)) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, is(""" + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}""")); } public void testXContent_WritesMinimalFields() throws IOException { - var entity = new CohereRerankRequestEntity( + var entity = new CohereV1RerankRequest( "query", List.of("abc"), null, null, - new CohereRerankTaskSettings(null, null, null), - "model" + createModel("model", new CohereRerankTaskSettings(null, null, null)) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, is(""" + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" {"model":"model","query":"query","documents":["abc"]}""")); } public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException { - var entity = new CohereRerankRequestEntity( + var entity = new CohereV1RerankRequest( "query", List.of("abc"), Boolean.FALSE, 99, - new CohereRerankTaskSettings(33, Boolean.TRUE, null), - "model" + createModel("model", new CohereRerankTaskSettings(33, Boolean.TRUE, null)) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, is(""" + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" {"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}""")); } public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException { - var entity = new CohereRerankRequestEntity( + var entity = new CohereV1RerankRequest( "query", List.of("abc"), null, null, - new CohereRerankTaskSettings(33, Boolean.TRUE, null), - "model" + createModel("model", new CohereRerankTaskSettings(33, Boolean.TRUE, null)) ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); - MatcherAssert.assertThat(xContentResult, is(""" + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}""")); } + + private CohereRerankModel createModel(String modelId, CohereRerankTaskSettings taskSettings) { + return new CohereRerankModel("inference_id", new CohereRerankServiceSettings("uri", modelId, null), taskSettings, null); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java similarity index 52% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java index 67e2b92def246..2e2d687498366 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java @@ -5,13 +5,19 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request; +package org.elasticsearch.xpack.inference.services.cohere.request.v2; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest; +import org.hamcrest.CoreMatchers; import java.io.IOException; import java.util.List; @@ -20,12 +26,15 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.sameInstance; -public class CohereCompletionRequestTests extends ESTestCase { +public class CohereV2CompletionRequestTests extends ESTestCase { - public void testCreateRequest_UrlDefined() throws IOException { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null), false); + public void testCreateRequest() throws IOException { + var request = new CohereV2CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel(null, "secret", "required model id"), + false + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -38,36 +47,30 @@ public void testCreateRequest_UrlDefined() throws IOException { assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, is(Map.of("message", "abc"))); + assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id"))); } - public void testCreateRequest_ModelDefined() throws IOException { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); + public void testDefaultUrl() throws IOException { + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel(null, "secret", null), false); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); - var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is("url")); - assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); - assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); - assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); - - var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); + assertThat(httpPost.getURI().toString(), is("TODO")); } - public void testTruncate_ReturnsSameInstance() { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); - var truncatedRequest = request.truncate(); - - assertThat(truncatedRequest, sameInstance(request)); - } - - public void testTruncationInfo_ReturnsNull() { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); - - assertNull(request.getTruncationInfo()); + public void testXContents() throws IOException { + var request = new CohereV2CompletionRequest( + List.of("some input"), + CohereCompletionModelTests.createModel(null, "secret", "model"), + false + ); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + request.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"message":"some input","model":"model"}""")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java new file mode 100644 index 0000000000000..35ad0b78cbd70 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java @@ -0,0 +1,324 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1EmbeddingsRequest; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CohereV2EmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest() throws IOException { + var inputType = InputTypeTests.randomWithNull(); + var request = createRequest( + List.of("abc"), + inputType, + CohereEmbeddingsModelTests.createModel( + null, + "secret", + CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + null, + null, + "model id", + null + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("TODO should be full url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("float"))); + MatcherAssert.assertThat(requestMap.get("model"), is(List.of("model id"))); + MatcherAssert.assertThat(requestMap.get("truncate"), is("start")); + validateInputType(requestMap, null, inputType); + } + + public void testCreateRequest_WithTaskSettingsInputType() throws IOException { + var inputType = InputTypeTests.randomWithoutUnspecified(); + var request = createRequest( + List.of("abc"), + null, + CohereEmbeddingsModelTests.createModel( + "url", + "secret", + new CohereEmbeddingsTaskSettings(inputType, CohereTruncation.END), + null, + null, + null, + null + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + validateInputType(requestMap, inputType, null); + } + + public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() throws IOException { + var inputType = InputTypeTests.randomWithNull(); + var request = createRequest( + List.of("abc"), + inputType, + CohereEmbeddingsModelTests.createModel( + "url", + "secret", + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), + null, + null, + "model", + CohereEmbeddingType.INT8 + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("model"), is("model")); + MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("int8"))); + MatcherAssert.assertThat(requestMap.get("truncate"), is("end")); + validateInputType(requestMap, null, inputType); + } + + public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() throws IOException { + var inputType = InputTypeTests.randomWithNull(); + var request = createRequest( + List.of("abc"), + inputType, + CohereEmbeddingsModelTests.createModel( + "url", + "secret", + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), + null, + null, + "model", + CohereEmbeddingType.BIT + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("model"), is("model")); + MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("binary"))); + MatcherAssert.assertThat(requestMap.get("truncate"), is("end")); + validateInputType(requestMap, null, inputType); + } + + public void testCreateRequest_TruncateNone() throws IOException { + var inputType = InputTypeTests.randomWithNull(); + var request = createRequest( + List.of("abc"), + inputType, + CohereEmbeddingsModelTests.createModel( + "url", + "secret", + new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), + null, + null, + null, + null + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("float"))); + MatcherAssert.assertThat(requestMap.get("truncate"), is("none")); + validateInputType(requestMap, null, inputType); + } + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = createRequest( + List.of("abc"), + InputType.INTERNAL_INGEST, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), + CohereEmbeddingType.FLOAT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); + } + + public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { + var entity = createRequest( + List.of("abc"), + null, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.INT8 + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { + var entity = createRequest( + List.of("abc"), + InputType.INTERNAL_SEARCH, + new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), + CohereEmbeddingType.BYTE + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { + var entity = createRequest( + List.of("abc"), + InputType.SEARCH, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.BINARY + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { + var entity = createRequest( + List.of("abc"), + null, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.BIT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = createRequest(List.of("abc"), null, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"]}""")); + } + + public static CohereV2EmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { + return new CohereV2EmbeddingsRequest(input, inputType, model); + } + + public static CohereV1EmbeddingsRequest createRequest( + List input, + InputType inputType, + CohereEmbeddingsTaskSettings taskSettings, + CohereEmbeddingType embeddingType + ) { + var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, null, embeddingType); + return new CohereV1EmbeddingsRequest(input, inputType, model); + } + + private void validateInputType(Map requestMap, InputType taskSettingsInputType, InputType requestInputType) { + if (InputType.isSpecified(requestInputType)) { + var convertedInputType = CohereUtils.inputTypeToString(requestInputType); + assertThat(requestMap.get("input_type"), is(convertedInputType)); + } else if (InputType.isSpecified(taskSettingsInputType)) { + var convertedInputType = CohereUtils.inputTypeToString(taskSettingsInputType); + assertThat(requestMap.get("input_type"), is(convertedInputType)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestEntityTests.java new file mode 100644 index 0000000000000..c927640ce992d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestEntityTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CohereV2RerankRequestEntityTests extends ESTestCase { + public void testUrl() throws IOException { + var request = new CohereV2RerankRequest( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + createModel("model", null, new CohereRerankTaskSettings(null, null, 3)) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("TOOD")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + } + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new CohereV2RerankRequest( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + createModel("model", "uri", new CohereRerankTaskSettings(null, null, 3)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}""")); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new CohereV2RerankRequest( + "query", + List.of("abc"), + null, + null, + createModel("model", "uri", new CohereRerankTaskSettings(null, null, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"]}""")); + } + + public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException { + var entity = new CohereV2RerankRequest( + "query", + List.of("abc"), + Boolean.FALSE, + 99, + createModel("model", "uri", new CohereRerankTaskSettings(33, Boolean.TRUE, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}""")); + } + + public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException { + var entity = new CohereV2RerankRequest( + "query", + List.of("abc"), + null, + null, + createModel("model", "uri", new CohereRerankTaskSettings(33, Boolean.TRUE, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}""")); + } + + private CohereRerankModel createModel(String modelId, String uri, CohereRerankTaskSettings taskSettings) { + return new CohereRerankModel("inference_id", new CohereRerankServiceSettings(uri, modelId, null), taskSettings, null); + } +} From d5c80bb2d81c72909c967bfa295260a9a3868f0a Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 1 May 2025 15:27:53 +0100 Subject: [PATCH 03/12] fix the tests --- .../http/sender/QueryAndDocsInputs.java | 4 - .../CohereCompletionRequestManager.java | 62 ---------------- .../CohereEmbeddingsRequestManager.java | 63 ---------------- .../services/cohere/CohereModel.java | 17 ++++- .../CohereRateLimitServiceSettings.java | 3 + .../cohere/CohereRerankRequestManager.java | 64 ---------------- .../cohere/action/CohereActionCreator.java | 74 ++++++++++++++++--- .../completion/CohereCompletionModel.java | 6 -- .../request/v2/CohereV2EmbeddingsRequest.java | 12 +-- .../cohere/rerank/CohereRerankModel.java | 6 -- .../rerank/CohereRerankServiceSettings.java | 1 + .../action/CohereCompletionActionTests.java | 9 +-- .../action/CohereEmbeddingsActionTests.java | 9 +-- .../CohereCompletionModelTests.java | 3 +- .../v1/CohereV1CompletionRequestTests.java | 22 ++++-- .../v1/CohereV1EmbeddingsRequestTests.java | 16 ++-- ...s.java => CohereV1RerankRequestTests.java} | 41 +++++++++- .../v2/CohereV2CompletionRequestTests.java | 17 +++-- .../v2/CohereV2EmbeddingsRequestTests.java | 60 +++++++-------- ...s.java => CohereV2RerankRequestTests.java} | 15 +++- 20 files changed, 212 insertions(+), 292 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/{CohereV1RerankRequestEntityTests.java => CohereV1RerankRequestTests.java} (68%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/{CohereV2RerankRequestEntityTests.java => CohereV2RerankRequestTests.java} (90%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index a2526a2a293eb..45e178a93f27d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -27,10 +27,6 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { private final Boolean returnDocuments; private final Integer topN; - public QueryAndDocsInputs(String query, List chunks) { - this(query, chunks, null, null, false); - } - public QueryAndDocsInputs( String query, List chunks, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java deleted file mode 100644 index 535f1fdd95a3d..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest; -import org.elasticsearch.xpack.inference.services.cohere.response.CohereCompletionResponseEntity; - -import java.util.Objects; -import java.util.function.Supplier; - -public class CohereCompletionRequestManager extends CohereRequestManager { - - private static final Logger logger = LogManager.getLogger(CohereCompletionRequestManager.class); - - private static final ResponseHandler HANDLER = createCompletionHandler(); - - private static ResponseHandler createCompletionHandler() { - return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse, true); - } - - public static CohereCompletionRequestManager of(CohereCompletionModel model, ThreadPool threadPool) { - return new CohereCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final CohereCompletionModel model; - - private CohereCompletionRequestManager(CohereCompletionModel model, ThreadPool threadPool) { - super(threadPool, model); - this.model = Objects.requireNonNull(model); - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); - var inputs = chatCompletionInput.getInputs(); - var stream = chatCompletionInput.stream(); - CohereV1CompletionRequest request = new CohereV1CompletionRequest(inputs, model, stream); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java deleted file mode 100644 index 09843ca1ce283..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1EmbeddingsRequest; -import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity; - -import java.util.List; -import java.util.Objects; -import java.util.function.Supplier; - -public class CohereEmbeddingsRequestManager extends CohereRequestManager { - private static final Logger logger = LogManager.getLogger(CohereEmbeddingsRequestManager.class); - private static final ResponseHandler HANDLER = createEmbeddingsHandler(); - - private static ResponseHandler createEmbeddingsHandler() { - return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse, false); - } - - public static CohereEmbeddingsRequestManager of(CohereEmbeddingsModel model, ThreadPool threadPool) { - return new CohereEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final CohereEmbeddingsModel model; - - private CohereEmbeddingsRequestManager(CohereEmbeddingsModel model, ThreadPool threadPool) { - super(threadPool, model); - this.model = Objects.requireNonNull(model); - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); - InputType inputType = input.getInputType(); - - CohereV1EmbeddingsRequest request = new CohereV1EmbeddingsRequest(docsInput, inputType, model); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java index 5cb52bdb7f405..ca24d3e405c90 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java @@ -9,21 +9,22 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.net.URI; import java.util.Map; import java.util.Objects; -public abstract class CohereModel extends Model { +public abstract class CohereModel extends RateLimitGroupingModel { private final SecureString apiKey; private final CohereRateLimitServiceSettings rateLimitServiceSettings; @@ -63,5 +64,15 @@ public CohereRateLimitServiceSettings rateLimitServiceSettings() { public abstract ExecutableAction accept(CohereActionVisitor creator, Map taskSettings); - public abstract URI uri(); + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); + } + + public int rateLimitGroupingHash() { + return apiKey().hashCode(); + } + + public URI uri() { + return rateLimitServiceSettings.uri(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java index 2607359c54c32..365f7e8ae4acb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java @@ -9,7 +9,10 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.net.URI; + public interface CohereRateLimitServiceSettings { RateLimitSettings rateLimitSettings(); + URI uri(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java deleted file mode 100644 index cbf606c0db678..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; -import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1RerankRequest; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; -import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity; - -import java.util.Objects; -import java.util.function.Supplier; - -public class CohereRerankRequestManager extends CohereRequestManager { - private static final Logger logger = LogManager.getLogger(CohereRerankRequestManager.class); - private static final ResponseHandler HANDLER = createCohereResponseHandler(); - - private static ResponseHandler createCohereResponseHandler() { - return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response), false); - } - - public static CohereRerankRequestManager of(CohereRerankModel model, ThreadPool threadPool) { - return new CohereRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final CohereRerankModel model; - - private CohereRerankRequestManager(CohereRerankModel model, ThreadPool threadPool) { - super(threadPool, model); - this.model = model; - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - CohereV1RerankRequest request = new CohereV1RerankRequest( - rerankInput.getQuery(), - rerankInput.getChunks(), - rerankInput.getReturnDocuments(), - rerankInput.getTopN(), - model - ); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index 83fbc5a8ad6e9..f16fd70e64d60 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -10,14 +10,23 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.cohere.CohereCompletionRequestManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereEmbeddingsRequestManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereRerankRequestManager; +import org.elasticsearch.xpack.inference.services.cohere.CohereResponseHandler; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1EmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1RerankRequest; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.response.CohereCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity; import java.util.Map; import java.util.Objects; @@ -28,12 +37,30 @@ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the cohere model type. */ public class CohereActionCreator implements CohereActionVisitor { + + private static final ResponseHandler EMBEDDINGS_HANDLER = new CohereResponseHandler( + "cohere text embedding", + CohereEmbeddingsResponseEntity::fromResponse, + false + ); + + private static final ResponseHandler RERANK_HANDLER = new CohereResponseHandler( + "cohere rerank", + (request, response) -> CohereRankedResponseEntity.fromResponse(response), + false + ); + + private static final ResponseHandler COMPLETION_HANDLER = new CohereResponseHandler( + "cohere completion", + CohereCompletionResponseEntity::fromResponse, + true + ); + private static final String COMPLETION_ERROR_PREFIX = "Cohere completion"; private final Sender sender; private final ServiceComponents serviceComponents; public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { - // TODO Batching - accept a class that can handle batching this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); } @@ -42,15 +69,37 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { public ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings) { var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings"); - // TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager - var requestCreator = CohereEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool()); + var requestCreator = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + (inferenceInputs -> new CohereV1EmbeddingsRequest( + inferenceInputs.getStringInputs(), + inferenceInputs.getInputType(), + overriddenModel + )), + EmbeddingsInput.class + ); return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); } @Override public ExecutableAction create(CohereRerankModel model, Map taskSettings) { var overriddenModel = CohereRerankModel.of(model, taskSettings); - var requestCreator = CohereRerankRequestManager.of(overriddenModel, serviceComponents.threadPool()); + var requestCreator = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + (inferenceInputs -> new CohereV1RerankRequest( + inferenceInputs.getQuery(), + inferenceInputs.getChunks(), + inferenceInputs.getReturnDocuments(), + inferenceInputs.getTopN(), + overriddenModel + )), + QueryAndDocsInputs.class + ); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere rerank"); return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); } @@ -58,8 +107,15 @@ public ExecutableAction create(CohereRerankModel model, Map task @Override public ExecutableAction create(CohereCompletionModel model, Map taskSettings) { // no overridden model as task settings are always empty for cohere completion model - var requestManager = CohereCompletionRequestManager.of(model, serviceComponents.threadPool()); + var requestCreator = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + (completionInput) -> new CohereV1CompletionRequest(completionInput.getInputs(), model, completionInput.stream()), + ChatCompletionInput.class + ); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); - return new SingleInputSenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); + return new SingleInputSenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java index 3dfcdae778a50..120964393fd6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import java.net.URI; import java.util.Map; public class CohereCompletionModel extends CohereModel { @@ -73,9 +72,4 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); } - - @Override - public URI uri() { - return getServiceSettings().uri(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java index f196e5759a384..1b743eaa7224a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java @@ -24,8 +24,6 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.inputTypeToString; - public class CohereV2EmbeddingsRequest extends CohereRequest { private final List input; @@ -57,10 +55,14 @@ public static URI buildDefaultUri() throws URISyntaxException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CohereUtils.MODEL_FIELD, getModelId()); builder.field(CohereUtils.TEXTS_FIELD, input); - // TODO merge input type from task settings InputType.isSpecified(inputType) - builder.field(CohereUtils.INPUT_TYPE_FIELD, inputTypeToString(inputType)); + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + // prefer the root level inputType over task settings input type + if (InputType.isSpecified(inputType)) { + builder.field(CohereUtils.INPUT_TYPE_FIELD, CohereUtils.inputTypeToString(inputType)); + } else if (InputType.isSpecified(taskSettings.getInputType())) { + builder.field(CohereUtils.INPUT_TYPE_FIELD, CohereUtils.inputTypeToString(taskSettings.getInputType())); + } builder.field(CohereUtils.EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); if (taskSettings.getTruncation() != null) { builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java index ca51c3f94f074..2244afc135582 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import java.net.URI; import java.util.Map; public class CohereRerankModel extends CohereModel { @@ -89,9 +88,4 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); } - - @Override - public URI uri() { - return getServiceSettings().uri(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java index aac5cd705c13c..874fc5dc906ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java @@ -104,6 +104,7 @@ public CohereRerankServiceSettings(StreamInput in) throws IOException { } } + @Override public URI uri() { return uri; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java index 016207a4835dd..6fbb16c607e85 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java @@ -24,13 +24,11 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereCompletionRequestManager; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import org.junit.After; @@ -44,9 +42,9 @@ import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -341,9 +339,8 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc } private ExecutableAction createAction(String url, String apiKey, @Nullable String modelName, Sender sender) { + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); var model = CohereCompletionModelTests.createModel(url, apiKey, modelName); - var requestManager = CohereCompletionRequestManager.of(model, threadPool); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere completion"); - return new SingleInputSenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage, "Cohere completion"); + return actionCreator.create(model, Map.of()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index f5a387f0f00b3..7adba4fb4b787 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -24,14 +24,12 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; @@ -50,9 +48,9 @@ import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -353,10 +351,9 @@ private ExecutableAction createAction( @Nullable CohereEmbeddingType embeddingType, Sender sender ) { + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings"); - var requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool); - return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + return actionCreator.create(model, null); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java index 9261170f2dc27..fefd4323d0f38 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java @@ -24,12 +24,11 @@ public class CohereCompletionModelTests extends ESTestCase { public void testCreateModel_AlwaysWithEmptyTaskSettings() { var model = new CohereCompletionModel( "inference_id", - new HashMap<>(Map.of("model", "cohere completion model")), + new HashMap<>(Map.of("model_id", "cohere completion model")), null, ConfigurationParseContext.PERSISTENT ); - assertThat(model.uri(), is("TODO")); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); assertThat(model.getServiceSettings().modelId(), is("cohere completion model")); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java index 76e7e57b0ebbb..b47ed2f9aabf4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java @@ -47,7 +47,11 @@ public void testCreateRequest_UrlDefined() throws IOException { } public void testCreateRequest_ModelDefined() throws IOException { - var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); + var request = new CohereV1CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("url", "secret", "model"), + false + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -63,25 +67,33 @@ public void testCreateRequest_ModelDefined() throws IOException { assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); } - public void testDefaultUrl() throws IOException { + public void testDefaultUrl() { var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel(null, "secret", null), false); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is("TODO")); + assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/chat")); } public void testTruncate_ReturnsSameInstance() { - var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); + var request = new CohereV1CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("url", "secret", "model"), + false + ); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } public void testTruncationInfo_ReturnsNull() { - var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); + var request = new CohereV1CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("url", "secret", "model"), + false + ); assertNull(request.getTruncationInfo()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java index 245fe69ddea6a..ecc126d44baea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java @@ -34,7 +34,7 @@ import static org.hamcrest.Matchers.is; public class CohereV1EmbeddingsRequestTests extends ESTestCase { - public void testCreateRequest_UrlDefined() throws IOException { + public void testCreateRequest() throws IOException { var inputType = InputTypeTests.randomWithNull(); var request = createRequest( List.of("abc"), @@ -47,7 +47,7 @@ public void testCreateRequest_UrlDefined() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("TODO should be full url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -283,6 +283,7 @@ public void testCreateRequest_TruncateNone() throws IOException { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { var entity = createRequest( + "model", List.of("abc"), InputType.INTERNAL_INGEST, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), @@ -299,6 +300,7 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { var entity = createRequest( + "model", List.of("abc"), null, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), @@ -315,6 +317,7 @@ public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { var entity = createRequest( + "model", List.of("abc"), InputType.INTERNAL_SEARCH, new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), @@ -331,6 +334,7 @@ public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() thr public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { var entity = createRequest( + "model", List.of("abc"), InputType.SEARCH, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), @@ -347,6 +351,7 @@ public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() thr public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { var entity = createRequest( + "model", List.of("abc"), null, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), @@ -362,14 +367,14 @@ public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws } public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { - var entity = createRequest(List.of("abc"), null, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null); + var entity = createRequest(null, List.of("abc"), null, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"]}""")); + {"texts":["abc"],"embedding_types":["float"]}""")); } public static CohereV1EmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { @@ -377,12 +382,13 @@ public static CohereV1EmbeddingsRequest createRequest(List input, InputT } public static CohereV1EmbeddingsRequest createRequest( + String modelId, List input, InputType inputType, CohereEmbeddingsTaskSettings taskSettings, CohereEmbeddingType embeddingType ) { - var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, null, embeddingType); + var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, modelId, embeddingType); return new CohereV1EmbeddingsRequest(input, inputType, model); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java similarity index 68% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java index 2deda6deb9e81..a9d73c78af9e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java @@ -7,21 +7,53 @@ package org.elasticsearch.xpack.inference.services.cohere.request.v1; +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import java.io.IOException; +import java.net.URI; import java.util.List; -public class CohereV1RerankRequestEntityTests extends ESTestCase { +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CohereV1RerankRequestTests extends ESTestCase { + public void testRequest() { + var request = new CohereV1RerankRequest( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + createModel("model", new CohereRerankTaskSettings(null, null, 3)) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/rerank")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + } + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { var entity = new CohereV1RerankRequest( "query", @@ -91,6 +123,11 @@ public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOExcep } private CohereRerankModel createModel(String modelId, CohereRerankTaskSettings taskSettings) { - return new CohereRerankModel("inference_id", new CohereRerankServiceSettings("uri", modelId, null), taskSettings, null); + return new CohereRerankModel( + "inference_id", + new CohereRerankServiceSettings((URI) null, modelId, null), + taskSettings, + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java index 2e2d687498366..3c87c2fafd3ea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java @@ -16,7 +16,6 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; -import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest; import org.hamcrest.CoreMatchers; import java.io.IOException; @@ -41,23 +40,27 @@ public void testCreateRequest() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/chat")); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id"))); + assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id", "stream", false))); } - public void testDefaultUrl() throws IOException { - var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel(null, "secret", null), false); + public void testDefaultUrl() { + var request = new CohereV2CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel(null, "secret", "model id"), + false + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is("TODO")); + assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/chat")); } public void testXContents() throws IOException { @@ -71,6 +74,6 @@ public void testXContents() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, CoreMatchers.is(""" - {"message":"some input","model":"model"}""")); + {"message":"some input","model":"model","stream":false}""")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java index 35ad0b78cbd70..bcf9f8b4fc3dc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; -import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1EmbeddingsRequest; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; @@ -36,14 +35,14 @@ public class CohereV2EmbeddingsRequestTests extends ESTestCase { public void testCreateRequest() throws IOException { - var inputType = InputTypeTests.randomWithNull(); + var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( null, "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, + new CohereEmbeddingsTaskSettings(inputType, CohereTruncation.START), null, null, "model id", @@ -56,7 +55,7 @@ public void testCreateRequest() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("TODO should be full url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -67,23 +66,23 @@ public void testCreateRequest() throws IOException { var requestMap = entityAsMap(httpPost.getEntity().getContent()); MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("float"))); - MatcherAssert.assertThat(requestMap.get("model"), is(List.of("model id"))); + MatcherAssert.assertThat(requestMap.get("model"), is("model id")); MatcherAssert.assertThat(requestMap.get("truncate"), is("start")); - validateInputType(requestMap, null, inputType); + validateInputType(requestMap, inputType, inputType); } public void testCreateRequest_WithTaskSettingsInputType() throws IOException { var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), - null, + InputType.UNSPECIFIED, CohereEmbeddingsModelTests.createModel( "url", "secret", new CohereEmbeddingsTaskSettings(inputType, CohereTruncation.END), null, null, - null, + "cohere model", null ) ); @@ -98,7 +97,7 @@ public void testCreateRequest_WithTaskSettingsInputType() throws IOException { } public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() throws IOException { - var inputType = InputTypeTests.randomWithNull(); + var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), inputType, @@ -135,7 +134,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th } public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() throws IOException { - var inputType = InputTypeTests.randomWithNull(); + var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), inputType, @@ -172,7 +171,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr } public void testCreateRequest_TruncateNone() throws IOException { - var inputType = InputTypeTests.randomWithNull(); + var inputType = InputTypeTests.randomWithoutUnspecified(); var request = createRequest( List.of("abc"), inputType, @@ -182,7 +181,7 @@ public void testCreateRequest_TruncateNone() throws IOException { new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), null, null, - null, + "cohere model", null ) ); @@ -209,6 +208,7 @@ public void testCreateRequest_TruncateNone() throws IOException { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { var entity = createRequest( + "cohere model", List.of("abc"), InputType.INTERNAL_INGEST, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), @@ -220,13 +220,14 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException String xContentResult = Strings.toString(builder); MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); + {"texts":["abc"],"model":"cohere model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); } public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { var entity = createRequest( + "cohere model", List.of("abc"), - null, + InputType.INGEST, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), CohereEmbeddingType.INT8 ); @@ -236,11 +237,12 @@ public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() String xContentResult = Strings.toString(builder); MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + {"texts":["abc"],"model":"cohere model","input_type":"search_document","embedding_types":["int8"],"truncate":"none"}""")); } public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { var entity = createRequest( + "cohere model", List.of("abc"), InputType.INTERNAL_SEARCH, new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), @@ -252,11 +254,12 @@ public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() thr String xContentResult = Strings.toString(builder); MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); } public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { var entity = createRequest( + "cohere model", List.of("abc"), InputType.SEARCH, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), @@ -268,13 +271,14 @@ public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() thr String xContentResult = Strings.toString(builder); MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); } public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { var entity = createRequest( + "cohere model", List.of("abc"), - null, + InputType.SEARCH, new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), CohereEmbeddingType.BIT ); @@ -284,32 +288,22 @@ public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws String xContentResult = Strings.toString(builder); MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); - } - - public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { - var entity = createRequest(List.of("abc"), null, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" - {"texts":["abc"]}""")); + {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); } public static CohereV2EmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { return new CohereV2EmbeddingsRequest(input, inputType, model); } - public static CohereV1EmbeddingsRequest createRequest( + public static CohereV2EmbeddingsRequest createRequest( + String modelId, List input, InputType inputType, CohereEmbeddingsTaskSettings taskSettings, CohereEmbeddingType embeddingType ) { - var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, null, embeddingType); - return new CohereV1EmbeddingsRequest(input, inputType, model); + var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, modelId, embeddingType); + return new CohereV2EmbeddingsRequest(input, inputType, model); } private void validateInputType(Map requestMap, InputType taskSettingsInputType, InputType requestInputType) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java similarity index 90% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java index c927640ce992d..722a32f172d82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java @@ -10,6 +10,7 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; @@ -18,6 +19,7 @@ import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; @@ -27,8 +29,8 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class CohereV2RerankRequestEntityTests extends ESTestCase { - public void testUrl() throws IOException { +public class CohereV2RerankRequestTests extends ESTestCase { + public void testUrl() { var request = new CohereV2RerankRequest( "query", List.of("abc"), @@ -41,7 +43,7 @@ public void testUrl() throws IOException { MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("TOOD")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/rerank")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -119,6 +121,11 @@ public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOExcep } private CohereRerankModel createModel(String modelId, String uri, CohereRerankTaskSettings taskSettings) { - return new CohereRerankModel("inference_id", new CohereRerankServiceSettings(uri, modelId, null), taskSettings, null); + return new CohereRerankModel( + "inference_id", + new CohereRerankServiceSettings(uri, modelId, null), + taskSettings, + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); } } From fcbbfa095aa9d5fa31ed16ed6e9b4fdc618736f6 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 27 May 2025 15:31:23 +0100 Subject: [PATCH 04/12] start upgrade test --- muted-tests.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/muted-tests.yml b/muted-tests.yml index b619f99409ef6..75c407ad4a095 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -182,8 +182,6 @@ tests: - class: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT method: test {yaml=analysis-common/40_token_filters/stemmer_override file access} issue: https://github.com/elastic/elasticsearch/issues/121625 -- class: org.elasticsearch.xpack.application.CohereServiceUpgradeIT - issue: https://github.com/elastic/elasticsearch/issues/121537 - class: org.elasticsearch.test.rest.ClientYamlTestSuiteIT method: test {yaml=snapshot.delete/10_basic/Delete a snapshot asynchronously} issue: https://github.com/elastic/elasticsearch/issues/122102 From d37dd1ffd0fcf7f79df1779e089c7715e2a1e775 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 19 Jun 2025 14:44:25 +0100 Subject: [PATCH 05/12] Upgrade test --- .../application/CohereServiceUpgradeIT.java | 38 +++++++++++++++++-- .../services/cohere/CohereAccount.java | 30 ++++++++++----- .../services/cohere/CohereModel.java | 2 +- .../embeddings/CohereEmbeddingsModel.java | 2 +- .../cohere/request/CohereRequest.java | 28 +++++++++++++- .../request/v1/CohereV1CompletionRequest.java | 18 ++------- .../request/v1/CohereV1EmbeddingsRequest.java | 13 ++----- .../request/v1/CohereV1RerankRequest.java | 18 ++------- .../request/v2/CohereV2CompletionRequest.java | 22 +++-------- .../request/v2/CohereV2EmbeddingsRequest.java | 13 ++----- .../request/v2/CohereV2RerankRequest.java | 18 ++------- .../v1/CohereV1CompletionRequestTests.java | 16 ++++---- .../v1/CohereV1EmbeddingsRequestTests.java | 21 +++++----- .../v2/CohereV2CompletionRequestTests.java | 14 +++++++ .../v2/CohereV2EmbeddingsRequestTests.java | 11 +++--- .../action/HuggingFaceActionCreatorTests.java | 6 ++- 16 files changed, 150 insertions(+), 120 deletions(-) diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java index 7851a093a95cb..f1becc026a932 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockRequest; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; @@ -36,10 +37,15 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase { // TODO: replace with proper test features private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0"; private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0"; + private static final String V2_API = "gte_v8.19.0"; private static MockWebServer cohereEmbeddingsServer; private static MockWebServer cohereRerankServer; + private enum ApiVersion { + V1, V2 + } + public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) { super(upgradedNodes); } @@ -62,15 +68,18 @@ public static void shutdown() { @SuppressWarnings("unchecked") public void testCohereEmbeddings() throws IOException { var embeddingsSupported = oldClusterHasFeature(COHERE_EMBEDDINGS_ADDED_TEST_FEATURE); - String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models"; assumeTrue("Cohere embedding service supported", embeddingsSupported); + String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models"; + ApiVersion oldClusterApiVersion = oldClusterHasFeature(V2_API) ? ApiVersion.V2 : ApiVersion.V1; + final String oldClusterIdInt8 = "old-cluster-embeddings-int8"; final String oldClusterIdFloat = "old-cluster-embeddings-float"; var testTaskType = TaskType.TEXT_EMBEDDING; if (isOldCluster()) { + // queue a response as PUT will call the service cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType); @@ -128,13 +137,17 @@ public void testCohereEmbeddings() throws IOException { // Inference on old cluster models assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE); + assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion); assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT); + assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", oldClusterApiVersion); { final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte"; + // new endpoints use the V2 API cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType); + assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2); configs = (List>) get(testTaskType, upgradedClusterIdByte).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); @@ -146,26 +159,31 @@ public void testCohereEmbeddings() throws IOException { { final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8"; + // new endpoints use the V2 API cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType); + assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2); configs = (List>) get(testTaskType, upgradedClusterIdInt8).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8); + assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2); delete(upgradedClusterIdInt8); } { final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float"; cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat())); put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType); + assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2); configs = (List>) get(testTaskType, upgradedClusterIdFloat).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("embedding_type", "float")); assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT); + assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2); delete(upgradedClusterIdFloat); } @@ -174,6 +192,17 @@ public void testCohereEmbeddings() throws IOException { } } + private void assertVersionInPath(MockRequest request, String endpoint, ApiVersion apiVersion) { + switch (apiVersion) { + case V2: + assertEquals("/v2/" + endpoint, request.getUri().getPath()); + break; + case V1: + assertEquals("/v1/" + endpoint, request.getUri().getPath()); + break; + } + } + void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException { switch (type) { case INT8: @@ -191,9 +220,11 @@ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) thro @SuppressWarnings("unchecked") public void testRerank() throws IOException { var rerankSupported = oldClusterHasFeature(COHERE_RERANK_ADDED_TEST_FEATURE); - String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models"; assumeTrue("Cohere rerank service supported", rerankSupported); + String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models"; + ApiVersion oldClusterApiVersion = oldClusterHasFeature(V2_API) ? ApiVersion.V2 : ApiVersion.V1; + final String oldClusterId = "old-cluster-rerank"; final String upgradedClusterId = "upgraded-cluster-rerank"; @@ -216,7 +247,6 @@ public void testRerank() throws IOException { assertThat(taskSettings, hasEntry("top_n", 3)); assertRerank(oldClusterId); - } else if (isUpgradedCluster()) { // check old cluster model var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); @@ -227,6 +257,7 @@ public void testRerank() throws IOException { assertThat(taskSettings, hasEntry("top_n", 3)); assertRerank(oldClusterId); + assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", oldClusterApiVersion); // New endpoint cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse())); @@ -235,6 +266,7 @@ public void testRerank() throws IOException { assertThat(configs, hasSize(1)); assertRerank(upgradedClusterId); + assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", ApiVersion.V2); delete(oldClusterId); delete(upgradedClusterId); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java index 297e918cac307..869357ef8fb17 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java @@ -7,25 +7,35 @@ package org.elasticsearch.xpack.inference.services.cohere; -import org.elasticsearch.common.CheckedSupplier; +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import java.net.URI; import java.net.URISyntaxException; import java.util.Objects; -import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; - -public record CohereAccount(URI uri, SecureString apiKey) { - - public static CohereAccount of(CohereModel model, CheckedSupplier uriBuilder) { - var uri = buildUri(model.uri(), "Cohere", uriBuilder); - - return new CohereAccount(uri, model.apiKey()); +public record CohereAccount(URI baseUri, SecureString apiKey) { + + public static CohereAccount of(CohereModel model) { + try { + var uri = model.baseUri() != null ? model.baseUri() : new URIBuilder().setScheme("https").setHost(CohereUtils.HOST).build(); + return new CohereAccount(uri, model.apiKey()); + } catch (URISyntaxException e) { + // using bad request here so that potentially sensitive URL information does not get logged + throw new ElasticsearchStatusException( + Strings.format("Failed to construct %s URL", CohereService.NAME), + RestStatus.BAD_REQUEST, + e + ); + } } public CohereAccount { - Objects.requireNonNull(uri); + Objects.requireNonNull(baseUri); Objects.requireNonNull(apiKey); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java index ca24d3e405c90..116655a598809 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java @@ -72,7 +72,7 @@ public int rateLimitGroupingHash() { return apiKey().hashCode(); } - public URI uri() { + public URI baseUri() { return rateLimitServiceSettings.uri(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index e8c2400ea4f1c..525674cc9b2ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -90,7 +90,7 @@ public ExecutableAction accept(CohereActionVisitor visitor, Map } @Override - public URI uri() { + public URI baseUri() { return getServiceSettings().getCommonSettings().uri(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java index f0d90f10945b3..ae351976545a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java @@ -9,17 +9,23 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import java.net.URI; +import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; @@ -46,7 +52,7 @@ protected CohereRequest(CohereAccount account, String inferenceEntityId, @Nullab @Override public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); + HttpPost httpPost = new HttpPost(getURI()); ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(this).getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); @@ -68,7 +74,25 @@ public boolean isStreaming() { @Override public URI getURI() { - return account.uri(); + return buildUri(account.baseUri()); + } + + /** + * Returns the URL path segments. + * @return List of segments that make up the path of the request. + */ + protected abstract List pathSegments(); + + private URI buildUri(URI baseUri) { + try { + return new URIBuilder(baseUri).setPathSegments(pathSegments()).build(); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + Strings.format("Failed to construct %s URL", CohereService.NAME), + RestStatus.BAD_REQUEST, + e + ); + } } public String getModelId() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java index a657149ff5e6c..4fa4552dcd94d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.request.v1; -import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; @@ -15,8 +14,6 @@ import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; import java.util.List; import java.util.Objects; @@ -24,12 +21,7 @@ public class CohereV1CompletionRequest extends CohereRequest { private final List input; public CohereV1CompletionRequest(List input, CohereCompletionModel model, boolean stream) { - super( - CohereAccount.of(model, CohereV1CompletionRequest::buildDefaultUri), - model.getInferenceEntityId(), - model.getServiceSettings().modelId(), - stream - ); + super(CohereAccount.of(model), model.getInferenceEntityId(), model.getServiceSettings().modelId(), stream); this.input = Objects.requireNonNull(input); } @@ -49,10 +41,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH) - .build(); + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java index 7482c68c095db..7c2d0b1fbf3f5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.request.v1; -import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; @@ -20,8 +19,6 @@ import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; import java.util.List; import java.util.Objects; @@ -34,7 +31,7 @@ public class CohereV1EmbeddingsRequest extends CohereRequest { public CohereV1EmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { super( - CohereAccount.of(embeddingsModel, CohereV1EmbeddingsRequest::buildDefaultUri), + CohereAccount.of(embeddingsModel), embeddingsModel.getInferenceEntityId(), embeddingsModel.getServiceSettings().getCommonSettings().modelId(), false @@ -46,11 +43,9 @@ public CohereV1EmbeddingsRequest(List input, InputType inputType, Cohere embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); } - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH) - .build(); + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java index 389a810b73f7f..70b34368eda61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.request.v1; -import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; @@ -17,8 +16,6 @@ import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; import java.util.List; import java.util.Objects; @@ -37,12 +34,7 @@ public CohereV1RerankRequest( @Nullable Integer topN, CohereRerankModel model ) { - super( - CohereAccount.of(model, CohereV1RerankRequest::buildDefaultUri), - model.getInferenceEntityId(), - model.getServiceSettings().modelId(), - false - ); + super(CohereAccount.of(model), model.getInferenceEntityId(), model.getServiceSettings().modelId(), false); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); @@ -51,11 +43,9 @@ public CohereV1RerankRequest( taskSettings = model.getTaskSettings(); } - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH) - .build(); + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java index 733222548ec1e..028c4a0d486c0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.request.v2; -import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; @@ -15,8 +14,6 @@ import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; import java.util.List; import java.util.Objects; @@ -24,23 +21,11 @@ public class CohereV2CompletionRequest extends CohereRequest { private final List input; public CohereV2CompletionRequest(List input, CohereCompletionModel model, boolean stream) { - super( - CohereAccount.of(model, CohereV2CompletionRequest::buildDefaultUri), - model.getInferenceEntityId(), - Objects.requireNonNull(model.getServiceSettings().modelId()), - stream - ); + super(CohereAccount.of(model), model.getInferenceEntityId(), Objects.requireNonNull(model.getServiceSettings().modelId()), stream); this.input = Objects.requireNonNull(input); } - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_2, CohereUtils.CHAT_PATH) - .build(); - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -51,4 +36,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_2, CohereUtils.CHAT_PATH); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java index 1b743eaa7224a..cb60469735dc3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.request.v2; -import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; @@ -19,8 +18,6 @@ import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; import java.util.List; import java.util.Objects; @@ -33,7 +30,7 @@ public class CohereV2EmbeddingsRequest extends CohereRequest { public CohereV2EmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { super( - CohereAccount.of(embeddingsModel, CohereV2EmbeddingsRequest::buildDefaultUri), + CohereAccount.of(embeddingsModel), embeddingsModel.getInferenceEntityId(), Objects.requireNonNull(embeddingsModel.getServiceSettings().getCommonSettings().modelId()), false @@ -45,11 +42,9 @@ public CohereV2EmbeddingsRequest(List input, InputType inputType, Cohere embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); } - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_2, CohereUtils.EMBEDDINGS_PATH) - .build(); + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_2, CohereUtils.EMBEDDINGS_PATH); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java index 50b7339e575af..941e191bc1447 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.request.v2; -import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; @@ -17,8 +16,6 @@ import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; import java.util.List; import java.util.Objects; @@ -41,12 +38,7 @@ public CohereV2RerankRequest( @Nullable Integer topN, CohereRerankModel model ) { - super( - CohereAccount.of(model, CohereV2RerankRequest::buildDefaultUri), - model.getInferenceEntityId(), - Objects.requireNonNull(model.getServiceSettings().modelId()), - false - ); + super(CohereAccount.of(model), model.getInferenceEntityId(), Objects.requireNonNull(model.getServiceSettings().modelId()), false); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); @@ -55,11 +47,9 @@ public CohereV2RerankRequest( taskSettings = model.getTaskSettings(); } - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_2, CohereUtils.RERANK_PATH) - .build(); + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_2, CohereUtils.RERANK_PATH); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java index b47ed2f9aabf4..1f444ed3e8ce2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java @@ -30,14 +30,18 @@ public class CohereV1CompletionRequestTests extends ESTestCase { public void testCreateRequest_UrlDefined() throws IOException { - var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null), false); + var request = new CohereV1CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("http://localhost", "secret", null), + false + ); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getURI().toString(), is("http://localhost/v1/chat")); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); @@ -47,18 +51,14 @@ public void testCreateRequest_UrlDefined() throws IOException { } public void testCreateRequest_ModelDefined() throws IOException { - var request = new CohereV1CompletionRequest( - List.of("abc"), - CohereCompletionModelTests.createModel("url", "secret", "model"), - false - ); + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel(null, "secret", "model"), false); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is("url")); + assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/chat")); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java index ecc126d44baea..18af39004b8eb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java @@ -67,7 +67,7 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + "http://localhost:8080", "secret", new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), null, @@ -82,7 +82,7 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("http://localhost:8080/v1/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -104,7 +104,7 @@ public void testCreateRequest_WithTaskSettingsInputType() throws IOException { List.of("abc"), null, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(inputType, CohereTruncation.END), null, @@ -119,7 +119,6 @@ public void testCreateRequest_WithTaskSettingsInputType() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -141,7 +140,7 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti List.of("abc"), requestInputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(taskSettingInputType, CohereTruncation.END), null, @@ -156,7 +155,7 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -177,7 +176,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), null, @@ -192,7 +191,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -214,7 +213,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), null, @@ -229,7 +228,6 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -251,7 +249,7 @@ public void testCreateRequest_TruncateNone() throws IOException { List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), null, @@ -266,7 +264,6 @@ public void testCreateRequest_TruncateNone() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java index 3c87c2fafd3ea..2fb51ca8ca457 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java @@ -63,6 +63,20 @@ public void testDefaultUrl() { assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/chat")); } + public void testOverriddenUrl() { + var request = new CohereV2CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("http://localhost", "secret", "model id"), + false + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("http://localhost/v2/chat")); + } + public void testXContents() throws IOException { var request = new CohereV2CompletionRequest( List.of("some input"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java index bcf9f8b4fc3dc..a7e009d63a903 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java @@ -102,7 +102,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + "http://localhost", "secret", new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), null, @@ -117,7 +117,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("http://localhost/v2/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -139,7 +139,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), null, @@ -154,7 +154,6 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -176,7 +175,7 @@ public void testCreateRequest_TruncateNone() throws IOException { List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), null, @@ -191,7 +190,7 @@ public void testCreateRequest_TruncateNone() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index f5d700016bf81..b14cfcd14ec43 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -345,7 +345,11 @@ public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOE var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new QueryAndDocsInputs("popular name", List.of("Luke")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new QueryAndDocsInputs("popular name", List.of("Luke"), null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat( From 92a373b2873d525c6f22ed9caeda932d64675b29 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Fri, 20 Jun 2025 16:19:39 +0100 Subject: [PATCH 06/12] Fix the tests --- .../org/elasticsearch/TransportVersions.java | 3 +- .../application/CohereServiceUpgradeIT.java | 51 +++++++++- .../xpack/inference/InferenceFeatures.java | 4 +- .../http/sender/QueryAndDocsInputs.java | 4 + .../services/cohere/CohereModel.java | 1 + .../CohereRateLimitServiceSettings.java | 2 + .../services/cohere/CohereService.java | 3 +- .../cohere/CohereServiceSettings.java | 98 ++++++++++++++++--- .../cohere/action/CohereActionCreator.java | 68 +++++++++---- .../CohereCompletionServiceSettings.java | 55 +++++++++-- .../CohereEmbeddingsServiceSettings.java | 6 ++ .../request/v2/CohereV2EmbeddingsRequest.java | 3 +- .../rerank/CohereRerankServiceSettings.java | 57 +++++++++-- .../TextSimilarityTestPlugin.java | 3 +- .../cohere/CohereServiceSettingsTests.java | 82 ++++++++++++++-- .../services/cohere/CohereServiceTests.java | 91 +++++++++++++++-- .../action/CohereActionCreatorTests.java | 64 +----------- .../action/CohereCompletionActionTests.java | 63 +----------- .../action/CohereEmbeddingsActionTests.java | 45 ++++----- .../CohereCompletionModelTests.java | 3 +- .../CohereCompletionServiceSettingsTests.java | 31 +++++- .../CohereEmbeddingsModelTests.java | 55 ++++++++++- .../CohereEmbeddingsServiceSettingsTests.java | 91 +++++++++++++---- .../v1/CohereV1RerankRequestTests.java | 3 +- .../v2/CohereV2RerankRequestTests.java | 3 +- .../CohereRerankServiceSettingsTests.java | 22 ++++- 26 files changed, 658 insertions(+), 253 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 78aa977c3342e..fb1fa6a9b0362 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -202,7 +202,7 @@ static TransportVersion def(int id) { public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55); public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56); - + public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_57); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -311,6 +311,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00); public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_105_0_00); + public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_106_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java index f1becc026a932..9782d4881ac61 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java @@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name; +import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.http.MockRequest; @@ -25,6 +26,7 @@ import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasSize; @@ -37,13 +39,14 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase { // TODO: replace with proper test features private static final String COHERE_EMBEDDINGS_ADDED_TEST_FEATURE = "gte_v8.13.0"; private static final String COHERE_RERANK_ADDED_TEST_FEATURE = "gte_v8.14.0"; - private static final String V2_API = "gte_v8.19.0"; + private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2"; private static MockWebServer cohereEmbeddingsServer; private static MockWebServer cohereRerankServer; private enum ApiVersion { - V1, V2 + V1, + V2 } public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) { @@ -71,7 +74,7 @@ public void testCohereEmbeddings() throws IOException { assumeTrue("Cohere embedding service supported", embeddingsSupported); String oldClusterEndpointIdentifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models"; - ApiVersion oldClusterApiVersion = oldClusterHasFeature(V2_API) ? ApiVersion.V2 : ApiVersion.V1; + ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1; final String oldClusterIdInt8 = "old-cluster-embeddings-int8"; final String oldClusterIdFloat = "old-cluster-embeddings-float"; @@ -186,6 +189,26 @@ public void testCohereEmbeddings() throws IOException { assertVersionInPath(cohereEmbeddingsServer.requests().getLast(), "embed", ApiVersion.V2); delete(upgradedClusterIdFloat); } + { + // new endpoints use the V2 API which require the model to be set + final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id"; + var jsonBody = Strings.format(""" + { + "service": "cohere", + "service_settings": { + "url": "%s", + "api_key": "XXXX", + "embedding_type": "int8" + } + } + """, getUrl(cohereEmbeddingsServer)); + + var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType)); + assertThat( + e.getMessage(), + containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.") + ); + } delete(oldClusterIdFloat); delete(oldClusterIdInt8); @@ -223,7 +246,7 @@ public void testRerank() throws IOException { assumeTrue("Cohere rerank service supported", rerankSupported); String old_cluster_endpoint_identifier = oldClusterHasFeature(MODELS_RENAMED_TO_ENDPOINTS_FEATURE) ? "endpoints" : "models"; - ApiVersion oldClusterApiVersion = oldClusterHasFeature(V2_API) ? ApiVersion.V2 : ApiVersion.V1; + ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1; final String oldClusterId = "old-cluster-rerank"; final String upgradedClusterId = "upgraded-cluster-rerank"; @@ -268,6 +291,26 @@ public void testRerank() throws IOException { assertRerank(upgradedClusterId); assertVersionInPath(cohereRerankServer.requests().getLast(), "rerank", ApiVersion.V2); + { + // new endpoints use the V2 API which require the model_id to be set + final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id"; + var jsonBody = Strings.format(""" + { + "service": "cohere", + "service_settings": { + "url": "%s", + "api_key": "XXXX" + } + } + """, getUrl(cohereEmbeddingsServer)); + + var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType)); + assertThat( + e.getMessage(), + containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.") + ); + } + delete(oldClusterId); delete(upgradedClusterId); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index ba1694d472181..2bc481cc484d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -37,6 +37,7 @@ public class InferenceFeatures implements FeatureSpecification { "test_rule_retriever.with_indices_that_dont_return_rank_docs" ); private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter"); + private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2"); @Override public Set getTestFeatures() { @@ -64,7 +65,8 @@ public Set getTestFeatures() { SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG, SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER, SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS, - SEMANTIC_TEXT_INDEX_OPTIONS + SEMANTIC_TEXT_INDEX_OPTIONS, + COHERE_V2_API ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 45e178a93f27d..850c96160dc44 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -41,6 +41,10 @@ public QueryAndDocsInputs( this.topN = topN; } + public QueryAndDocsInputs(String query, List chunks) { + this(query, chunks, null, null, false); + } + public String getQuery() { return query; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java index 116655a598809..2457fb31c9c6e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java @@ -25,6 +25,7 @@ import java.util.Objects; public abstract class CohereModel extends RateLimitGroupingModel { + private final SecureString apiKey; private final CohereRateLimitServiceSettings rateLimitServiceSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java index 365f7e8ae4acb..5b9fa3376a4f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java @@ -14,5 +14,7 @@ public interface CohereRateLimitServiceSettings { RateLimitSettings rateLimitSettings(); + CohereServiceSettings.CohereApiVersion apiVersion(); + URI uri(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index b4b565ef70560..c2f1221763165 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -314,7 +314,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { embeddingSize, serviceSettings.getCommonSettings().maxInputTokens(), serviceSettings.getCommonSettings().modelId(), - serviceSettings.getCommonSettings().rateLimitSettings() + serviceSettings.getCommonSettings().rateLimitSettings(), + serviceSettings.getCommonSettings().apiVersion() ), serviceSettings.getEmbeddingType() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java index a1943b339a561..9dc7f212b4d45 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -20,11 +20,14 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; import java.net.URI; +import java.util.EnumSet; +import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -43,6 +46,18 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser public static final String NAME = "cohere_service_settings"; public static final String OLD_MODEL_ID_FIELD = "model"; public static final String MODEL_ID = "model_id"; + public static final String API_VERSION = "api_version"; + public static final String MODEL_REQUIRED_FOR_V2_API = "The [service_settings.model_id] field is required for the Cohere V2 API."; + + public enum CohereApiVersion { + V1, + V2; + + public static CohereApiVersion fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + } + private static final Logger logger = LogManager.getLogger(CohereServiceSettings.class); // Production key rate limits for all endpoints: https://docs.cohere.com/docs/going-live#production-key-specifications // 10K requests a minute @@ -72,11 +87,53 @@ public static CohereServiceSettings fromMap(Map map, Configurati logger.info("The cohere [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead."); } + var resolvedModelId = modelId(oldModelId, modelId); + var apiVersion = apiVersionFromMap(map, context, validationException); + if (apiVersion == CohereApiVersion.V2) { + if (resolvedModelId == null) { + validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API); + } + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, modelId(oldModelId, modelId), rateLimitSettings); + return new CohereServiceSettings( + uri, + similarity, + dims, + maxInputTokens, + modelId(oldModelId, modelId), + rateLimitSettings, + apiVersion + ); + } + + public static CohereApiVersion apiVersionFromMap( + Map map, + ConfigurationParseContext context, + ValidationException validationException + ) { + return switch (context) { + case REQUEST -> CohereApiVersion.V2; // new endpoints all use the V2 API. + case PERSISTENT -> { + var apiVersion = ServiceUtils.extractOptionalEnum( + map, + API_VERSION, + ModelConfigurations.SERVICE_SETTINGS, + CohereApiVersion::fromString, + EnumSet.allOf(CohereApiVersion.class), + validationException + ); + + if (apiVersion == null) { + yield CohereApiVersion.V1; // If the API version is not persisted then it must be V1 + } else { + yield apiVersion; + } + } + }; } private static String modelId(@Nullable String model, @Nullable String modelId) { @@ -89,6 +146,7 @@ private static String modelId(@Nullable String model, @Nullable String modelId) private final Integer maxInputTokens; private final String modelId; private final RateLimitSettings rateLimitSettings; + private final CohereApiVersion apiVersion; public CohereServiceSettings( @Nullable URI uri, @@ -96,7 +154,8 @@ public CohereServiceSettings( @Nullable Integer dimensions, @Nullable Integer maxInputTokens, @Nullable String modelId, - @Nullable RateLimitSettings rateLimitSettings + @Nullable RateLimitSettings rateLimitSettings, + CohereApiVersion apiVersion ) { this.uri = uri; this.similarity = similarity; @@ -104,6 +163,7 @@ public CohereServiceSettings( this.maxInputTokens = maxInputTokens; this.modelId = modelId; this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.apiVersion = apiVersion; } public CohereServiceSettings( @@ -112,9 +172,10 @@ public CohereServiceSettings( @Nullable Integer dimensions, @Nullable Integer maxInputTokens, @Nullable String modelId, - @Nullable RateLimitSettings rateLimitSettings + @Nullable RateLimitSettings rateLimitSettings, + CohereApiVersion apiVersion ) { - this(createOptionalUri(url), similarity, dimensions, maxInputTokens, modelId, rateLimitSettings); + this(createOptionalUri(url), similarity, dimensions, maxInputTokens, modelId, rateLimitSettings, apiVersion); } public CohereServiceSettings(StreamInput in) throws IOException { @@ -129,11 +190,17 @@ public CohereServiceSettings(StreamInput in) throws IOException { } else { rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; } + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION) + || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) { + this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class); + } else { + this.apiVersion = CohereServiceSettings.CohereApiVersion.V1; + } } // should only be used for testing, public because it's accessed outside of the package - public CohereServiceSettings() { - this((URI) null, null, null, null, null, null); + public CohereServiceSettings(CohereApiVersion apiVersion) { + this((URI) null, null, null, null, null, null, apiVersion); } @Override @@ -141,6 +208,11 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + @Override + public CohereApiVersion apiVersion() { + return apiVersion; + } + public URI uri() { return uri; } @@ -172,15 +244,14 @@ public String getWriteableName() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - toXContentFragment(builder, params); - builder.endObject(); return builder; } public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { - return toXContentFragmentOfExposedFields(builder, params); + toXContentFragmentOfExposedFields(builder, params); + return builder.field(API_VERSION, apiVersion); // API version is persisted but not exposed to the user } @Override @@ -222,6 +293,10 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { rateLimitSettings.writeTo(out); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION) + || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) { + out.writeEnum(apiVersion); + } } @Override @@ -234,11 +309,12 @@ public boolean equals(Object o) { && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) && Objects.equals(modelId, that.modelId) - && Objects.equals(rateLimitSettings, that.rateLimitSettings); + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && apiVersion == that.apiVersion; } @Override public int hashCode() { - return Objects.hash(uri, similarity, dimensions, maxInputTokens, modelId, rateLimitSettings); + return Objects.hash(uri, similarity, dimensions, maxInputTokens, modelId, rateLimitSettings, apiVersion); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index f16fd70e64d60..f195ed8db7a74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.cohere.action; +import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; @@ -16,13 +17,18 @@ import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.cohere.CohereResponseHandler; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest; import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1EmbeddingsRequest; import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1RerankRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v2.CohereV2CompletionRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v2.CohereV2EmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v2.CohereV2RerankRequest; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.response.CohereCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity; @@ -30,6 +36,7 @@ import java.util.Map; import java.util.Objects; +import java.util.function.Function; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; @@ -68,54 +75,81 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { @Override public ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings) { var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings); + + Function requestCreator = inferenceInputs -> { + var requestInputType = InputType.isSpecified(inferenceInputs.getInputType()) + ? inferenceInputs.getInputType() + : overriddenModel.getTaskSettings().getInputType(); + + return switch (overriddenModel.getServiceSettings().getCommonSettings().apiVersion()) { + case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); + case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); + }; + }; + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings"); - var requestCreator = new GenericRequestManager<>( + var requestManager = new GenericRequestManager<>( serviceComponents.threadPool(), model, EMBEDDINGS_HANDLER, - (inferenceInputs -> new CohereV1EmbeddingsRequest( - inferenceInputs.getStringInputs(), - inferenceInputs.getInputType(), - overriddenModel - )), + requestCreator, EmbeddingsInput.class ); - return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); } @Override public ExecutableAction create(CohereRerankModel model, Map taskSettings) { var overriddenModel = CohereRerankModel.of(model, taskSettings); - var requestCreator = new GenericRequestManager<>( - serviceComponents.threadPool(), - overriddenModel, - RERANK_HANDLER, - (inferenceInputs -> new CohereV1RerankRequest( + + Function requestCreator = inferenceInputs -> switch (overriddenModel.getServiceSettings() + .apiVersion()) { + case V1 -> new CohereV1RerankRequest( inferenceInputs.getQuery(), inferenceInputs.getChunks(), inferenceInputs.getReturnDocuments(), inferenceInputs.getTopN(), overriddenModel - )), + ); + case V2 -> new CohereV2RerankRequest( + inferenceInputs.getQuery(), + inferenceInputs.getChunks(), + inferenceInputs.getReturnDocuments(), + inferenceInputs.getTopN(), + overriddenModel + ); + }; + + var requestManager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + requestCreator, QueryAndDocsInputs.class ); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere rerank"); - return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); } @Override public ExecutableAction create(CohereCompletionModel model, Map taskSettings) { // no overridden model as task settings are always empty for cohere completion model - var requestCreator = new GenericRequestManager<>( + + Function requestCreator = completionInput -> switch (model.getServiceSettings().apiVersion()) { + case V1 -> new CohereV1CompletionRequest(completionInput.getInputs(), model, completionInput.stream()); + case V2 -> new CohereV2CompletionRequest(completionInput.getInputs(), model, completionInput.stream()); + }; + + var requestManager = new GenericRequestManager<>( serviceComponents.threadPool(), model, COMPLETION_HANDLER, - (completionInput) -> new CohereV1CompletionRequest(completionInput.getInputs(), model, completionInput.stream()), + requestCreator, ChatCompletionInput.class ); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); - return new SingleInputSenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); + return new SingleInputSenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java index be241f3aaa7fc..efe58ed19a008 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -32,6 +33,9 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.API_VERSION; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.MODEL_REQUIRED_FOR_V2_API; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.apiVersionFromMap; public class CohereCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings { @@ -54,34 +58,56 @@ public static CohereCompletionServiceSettings fromMap(Map map, C context ); String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var apiVersion = apiVersionFromMap(map, context, validationException); + if (apiVersion == CohereServiceSettings.CohereApiVersion.V2) { + if (modelId == null) { + validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API); + } + } if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new CohereCompletionServiceSettings(uri, modelId, rateLimitSettings); + return new CohereCompletionServiceSettings(uri, modelId, rateLimitSettings, apiVersion); } private final URI uri; - private final String modelId; - private final RateLimitSettings rateLimitSettings; - - public CohereCompletionServiceSettings(@Nullable URI uri, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { + private final CohereServiceSettings.CohereApiVersion apiVersion; + + public CohereCompletionServiceSettings( + @Nullable URI uri, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings, + CohereServiceSettings.CohereApiVersion apiVersion + ) { this.uri = uri; this.modelId = modelId; this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.apiVersion = apiVersion; } - public CohereCompletionServiceSettings(@Nullable String url, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { - this(createOptionalUri(url), modelId, rateLimitSettings); + public CohereCompletionServiceSettings( + @Nullable String url, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings, + CohereServiceSettings.CohereApiVersion apiVersion + ) { + this(createOptionalUri(url), modelId, rateLimitSettings, apiVersion); } public CohereCompletionServiceSettings(StreamInput in) throws IOException { uri = createOptionalUri(in.readOptionalString()); modelId = in.readOptionalString(); rateLimitSettings = new RateLimitSettings(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION) + || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) { + this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class); + } else { + this.apiVersion = CohereServiceSettings.CohereApiVersion.V1; + } } @Override @@ -89,6 +115,11 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + @Override + public CohereServiceSettings.CohereApiVersion apiVersion() { + return apiVersion; + } + public URI uri() { return uri; } @@ -102,6 +133,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); + builder.field(API_VERSION, apiVersion); // API version is persisted but not exposed to the user builder.endObject(); return builder; @@ -123,6 +155,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(uriToWrite); out.writeOptionalString(modelId); rateLimitSettings.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION) + || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) { + out.writeEnum(apiVersion); + } } @Override @@ -146,11 +182,12 @@ public boolean equals(Object object) { CohereCompletionServiceSettings that = (CohereCompletionServiceSettings) object; return Objects.equals(uri, that.uri) && Objects.equals(modelId, that.modelId) - && Objects.equals(rateLimitSettings, that.rateLimitSettings); + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && apiVersion == that.apiVersion; } @Override public int hashCode() { - return Objects.hash(uri, modelId, rateLimitSettings); + return Objects.hash(uri, modelId, rateLimitSettings, apiVersion); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java index b25b9fc8fd351..11cd6c2bcd75d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java @@ -9,6 +9,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -183,6 +184,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(CohereEmbeddingType.translateToVersion(embeddingType, out.getTransportVersion())); } + @Override + public String toString() { + return Strings.toString(this); + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java index cb60469735dc3..6fb8eb5bec7b8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; import java.util.Objects; +import java.util.Optional; public class CohereV2EmbeddingsRequest extends CohereRequest { @@ -37,7 +38,7 @@ public CohereV2EmbeddingsRequest(List input, InputType inputType, Cohere ); this.input = Objects.requireNonNull(input); - this.inputType = Objects.requireNonNull(inputType); // inputType is required in v2 + this.inputType = Optional.ofNullable(inputType).orElse(InputType.SEARCH); // inputType is required in v2 taskSettings = embeddingsModel.getTaskSettings(); embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java index 874fc5dc906ff..a17fff7f165cf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -37,7 +38,10 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.API_VERSION; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.MODEL_REQUIRED_FOR_V2_API; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.apiVersionFromMap; public class CohereRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings { public static final String NAME = "cohere_rerank_service_settings"; @@ -62,27 +66,44 @@ public static CohereRerankServiceSettings fromMap(Map map, Confi context ); + var apiVersion = apiVersionFromMap(map, context, validationException); + if (apiVersion == CohereServiceSettings.CohereApiVersion.V2) { + if (modelId == null) { + validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API); + } + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new CohereRerankServiceSettings(uri, modelId, rateLimitSettings); + return new CohereRerankServiceSettings(uri, modelId, rateLimitSettings, apiVersion); } private final URI uri; - private final String modelId; - private final RateLimitSettings rateLimitSettings; - - public CohereRerankServiceSettings(@Nullable URI uri, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { + private final CohereServiceSettings.CohereApiVersion apiVersion; + + public CohereRerankServiceSettings( + @Nullable URI uri, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings, + CohereServiceSettings.CohereApiVersion apiVersion + ) { this.uri = uri; this.modelId = modelId; this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.apiVersion = apiVersion; } - public CohereRerankServiceSettings(@Nullable String url, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { - this(createOptionalUri(url), modelId, rateLimitSettings); + public CohereRerankServiceSettings( + @Nullable String url, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings, + CohereServiceSettings.CohereApiVersion apiVersion + ) { + this(createOptionalUri(url), modelId, rateLimitSettings, apiVersion); } public CohereRerankServiceSettings(StreamInput in) throws IOException { @@ -102,6 +123,13 @@ public CohereRerankServiceSettings(StreamInput in) throws IOException { } else { this.rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; } + + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION) + || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) { + this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class); + } else { + this.apiVersion = CohereServiceSettings.CohereApiVersion.V1; + } } @Override @@ -119,6 +147,11 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + @Override + public CohereServiceSettings.CohereApiVersion apiVersion() { + return apiVersion; + } + @Override public String getWriteableName() { return NAME; @@ -129,6 +162,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); + builder.field(API_VERSION, apiVersion); // API version is persisted but not exposed to the user builder.endObject(); return builder; @@ -172,6 +206,10 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { rateLimitSettings.writeTo(out); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION) + || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) { + out.writeEnum(apiVersion); + } } @Override @@ -181,11 +219,12 @@ public boolean equals(Object object) { CohereRerankServiceSettings that = (CohereRerankServiceSettings) object; return Objects.equals(uri, that.uri) && Objects.equals(modelId, that.modelId) - && Objects.equals(rateLimitSettings, that.rateLimitSettings); + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && apiVersion == that.apiVersion; } @Override public int hashCode() { - return Objects.hash(uri, modelId, rateLimitSettings); + return Objects.hash(uri, modelId, rateLimitSettings, apiVersion); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index dc0e2cc10501d..f8563aebe0764 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; @@ -132,7 +133,7 @@ private void handleGetInferenceModelActionRequ request.getInferenceEntityId(), request.getTaskType(), CohereService.NAME, - new CohereRerankServiceSettings("uri", "model", null), + new CohereRerankServiceSettings("uri", "model", null, CohereServiceSettings.CohereApiVersion.V2), topN == null ? new EmptyTaskSettings() : new CohereRerankTaskSettings(topN, null, null) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java index f4dad7546c8a2..0ce016956cda0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java @@ -61,7 +61,8 @@ private static CohereServiceSettings createRandom(String url) { dims, maxInputTokens, model, - RateLimitSettingsTests.createRandom() + RateLimitSettingsTests.createRandom(), + randomFrom(CohereServiceSettings.CohereApiVersion.values()) ); } @@ -91,7 +92,17 @@ public void testFromMap() { MatcherAssert.assertThat( serviceSettings, - is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model, null)) + is( + new CohereServiceSettings( + ServiceUtils.createUri(url), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + model, + null, + CohereServiceSettings.CohereApiVersion.V2 + ) + ) ); } @@ -130,7 +141,8 @@ public void testFromMap_WithRateLimit() { dims, maxInputTokens, model, - new RateLimitSettings(3) + new RateLimitSettings(3), + CohereServiceSettings.CohereApiVersion.V2 ) ) ); @@ -154,7 +166,9 @@ public void testFromMap_WhenUsingModelId() { ServiceFields.MAX_INPUT_TOKENS, maxInputTokens, CohereServiceSettings.MODEL_ID, - model + model, + CohereServiceSettings.API_VERSION, + CohereServiceSettings.CohereApiVersion.V1.toString() ) ), ConfigurationParseContext.PERSISTENT @@ -162,7 +176,41 @@ public void testFromMap_WhenUsingModelId() { MatcherAssert.assertThat( serviceSettings, - is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model, null)) + is( + new CohereServiceSettings( + ServiceUtils.createUri(url), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + model, + null, + CohereServiceSettings.CohereApiVersion.V1 + ) + ) + ); + } + + public void testFromMap_MissingModelId() { + var e = expectThrows( + ValidationException.class, + () -> CohereServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512 + ) + ), + ConfigurationParseContext.REQUEST + ) + ); + + assertThat( + e.validationErrors().getFirst(), + containsString("The [service_settings.model_id] field is required for the Cohere V2 API.") ); } @@ -194,7 +242,17 @@ public void testFromMap_PrefersModelId_OverModel() { MatcherAssert.assertThat( serviceSettings, - is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model, null)) + is( + new CohereServiceSettings( + ServiceUtils.createUri(url), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + model, + null, + CohereServiceSettings.CohereApiVersion.V1 + ) + ) ); } @@ -255,14 +313,22 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { } public void testXContent_WritesModelId() throws IOException { - var entity = new CohereServiceSettings((String) null, null, null, null, "modelId", new RateLimitSettings(1)); + var entity = new CohereServiceSettings( + (String) null, + null, + null, + null, + "modelId", + new RateLimitSettings(1), + CohereServiceSettings.CohereApiVersion.V2 + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" - {"model_id":"modelId","rate_limit":{"requests_per_minute":1}}""")); + {"model_id":"modelId","rate_limit":{"requests_per_minute":1},"api_version":"V2"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index fabf87151644b..2f2b43dc5f2e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -263,8 +263,11 @@ private static ActionListener getModelListenerForException(Class excep public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createCohereService()) { + var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); + serviceSettings.put(CohereServiceSettings.MODEL_ID, "foo"); + var config = getRequestConfigMap( - CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), + serviceSettings, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ); @@ -318,8 +321,10 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var secretSettingsMap = getSecretSettingsMap("secret"); secretSettingsMap.put("extra_key", "value"); + var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); + serviceSettings.put(CohereServiceSettings.MODEL_ID, "foo"); var config = getRequestConfigMap( - CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), + serviceSettings, getTaskSettingsMapEmpty(), secretSettingsMap ); @@ -343,11 +348,13 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, (e) -> fail("Model parsing should have succeeded " + e.getMessage())); + var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); + serviceSettings.put(CohereServiceSettings.MODEL_ID, "foo"); service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null), + serviceSettings, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret") ), @@ -953,7 +960,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, 1024, 1024, - null, + "coheremodel", null ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1127,7 +1134,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs } } - public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { + public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v1API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { @@ -1166,7 +1173,8 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec 1024, 1024, "model", - null + null, + CohereServiceSettings.CohereApiVersion.V1 ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1201,6 +1209,73 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec } } + public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v2API() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings(null, null), + 1024, + 1024, + "model", + null, + CohereServiceSettings.CohereApiVersion.V2 + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float"), "input_type", "search_query")) + ); + } + } + public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { var model = CohereEmbeddingsModelTests.createModel( getUrl(webServer), @@ -1316,7 +1391,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("float"))) + is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("float"), "input_type", "search_query")) ); } } @@ -1415,7 +1490,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("int8"))) + is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("int8"), "input_type", "search_query")) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index b56a19c0af0f1..88d26d5d7eef1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -94,7 +94,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { }, "meta": { "api_version": { - "version": "1" + "version": "2" }, "billed_units": { "input_tokens": 1 @@ -209,67 +209,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); - } - } - - public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var sender = createSender(senderFactory)) { - sender.start(); - - String responseJson = """ - { - "response_id": "some id", - "text": "result", - "generation_id": "some id", - "chat_history": [ - { - "role": "USER", - "message": "input" - }, - { - "role": "CHATBOT", - "message": "result" - } - ], - "finish_reason": "COMPLETE", - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 4, - "output_tokens": 191 - }, - "tokens": { - "input_tokens": 70, - "output_tokens": 191 - } - } - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", null); - var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model, Map.of()); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc"))); + assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false))); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java index 6fbb16c607e85..78b8b7bdeaf3e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java @@ -132,68 +132,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); - } - } - - public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); - - String responseJson = """ - { - "response_id": "some id", - "text": "result", - "generation_id": "some id", - "chat_history": [ - { - "role": "USER", - "message": "input" - }, - { - "role": "CHATBOT", - "message": "result" - } - ], - "finish_reason": "COMPLETE", - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 4, - "output_tokens": 191 - }, - "tokens": { - "input_tokens": 70, - "output_tokens": 191 - } - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var action = createAction(getUrl(webServer), "secret", null, sender); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - assertThat( - webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER), - equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE) - ); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc"))); + assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false))); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index 7adba4fb4b787..b9f7c30478624 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -54,6 +54,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -123,7 +124,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { ); PlainActionFuture listener = new PlainActionFuture<>(); - var inputType = InputTypeTests.randomWithNull(); + InputType inputType = InputTypeTests.randomWithNull(); action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -142,31 +143,25 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - if (inputType != null && inputType != InputType.UNSPECIFIED) { - var cohereInputType = CohereUtils.inputTypeToString(inputType); - MatcherAssert.assertThat( - requestMap, - is( - Map.of( - "texts", - List.of("abc"), - "model", - "model", - "input_type", - cohereInputType, - "embedding_types", - List.of("float"), - "truncate", - "start" - ) + var expectedInputType = InputType.isSpecified(inputType) ? inputType : InputType.SEARCH; + var cohereInputType = CohereUtils.inputTypeToString(expectedInputType); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + cohereInputType, + "embedding_types", + List.of("float"), + "truncate", + "start" ) - ); - } else { - MatcherAssert.assertThat( - requestMap, - is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float"), "truncate", "start")) - ); - } + ) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java index fefd4323d0f38..6ae12b96741ed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.util.HashMap; @@ -36,7 +37,7 @@ public void testCreateModel_AlwaysWithEmptyTaskSettings() { public static CohereCompletionModel createModel(String url, String apiKey, @Nullable String model) { return new CohereCompletionModel( "id", - new CohereCompletionServiceSettings(url, model, null), + new CohereCompletionServiceSettings(url, model, null, CohereServiceSettings.CohereApiVersion.V2), EmptyTaskSettings.INSTANCE, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java index ed8bc90d32140..06ebdd158b92c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; @@ -27,7 +28,12 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { public static CohereCompletionServiceSettings createRandom() { - return new CohereCompletionServiceSettings(randomAlphaOfLength(8), randomAlphaOfLength(8), RateLimitSettingsTests.createRandom()); + return new CohereCompletionServiceSettings( + randomAlphaOfLength(8), + randomAlphaOfLength(8), + RateLimitSettingsTests.createRandom(), + randomFrom(CohereServiceSettings.CohereApiVersion.values()) + ); } public void testFromMap_WithRateLimitSettingsNull() { @@ -39,7 +45,7 @@ public void testFromMap_WithRateLimitSettingsNull() { ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null))); + assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null, CohereServiceSettings.CohereApiVersion.V1))); } public void testFromMap_WithRateLimitSettings() { @@ -61,18 +67,33 @@ public void testFromMap_WithRateLimitSettings() { ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute)))); + assertThat( + serviceSettings, + is( + new CohereCompletionServiceSettings( + url, + model, + new RateLimitSettings(requestsPerMinute), + CohereServiceSettings.CohereApiVersion.V1 + ) + ) + ); } public void testToXContent_WritesAllValues() throws IOException { - var serviceSettings = new CohereCompletionServiceSettings("url", "model", new RateLimitSettings(3)); + var serviceSettings = new CohereCompletionServiceSettings( + "url", + "model", + new RateLimitSettings(3), + CohereServiceSettings.CohereApiVersion.V1 + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" - {"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}""")); + {"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3},"api_version":"V1"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index 294d33f7d9fac..aa19fcdf791ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -121,7 +121,15 @@ public static CohereEmbeddingsModel createModel( return new CohereEmbeddingsModel( "id", new CohereEmbeddingsServiceSettings( - new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null), + new CohereServiceSettings( + url, + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit, + model, + null, + CohereServiceSettings.CohereApiVersion.V2 + ), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) ), taskSettings, @@ -136,13 +144,44 @@ public static CohereEmbeddingsModel createModel( CohereEmbeddingsTaskSettings taskSettings, @Nullable Integer tokenLimit, @Nullable Integer dimensions, - @Nullable String model, + String model, @Nullable CohereEmbeddingType embeddingType + ) + { + return createModel( + url, + apiKey, + taskSettings, + tokenLimit, + dimensions, + model, + embeddingType, + CohereServiceSettings.CohereApiVersion.V2 + ); + } + + public static CohereEmbeddingsModel createModel( + String url, + String apiKey, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model, + @Nullable CohereEmbeddingType embeddingType, + CohereServiceSettings.CohereApiVersion apiVersion ) { return new CohereEmbeddingsModel( "id", new CohereEmbeddingsServiceSettings( - new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null), + new CohereServiceSettings( + url, + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit, + model, + null, + apiVersion + ), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) ), taskSettings, @@ -164,7 +203,15 @@ public static CohereEmbeddingsModel createModel( return new CohereEmbeddingsModel( "id", new CohereEmbeddingsServiceSettings( - new CohereServiceSettings(url, similarityMeasure, dimensions, tokenLimit, model, null), + new CohereServiceSettings( + url, + similarityMeasure, + dimensions, + tokenLimit, + model, + null, + CohereServiceSettings.CohereApiVersion.V2 + ), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) ), taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java index 544676cfa7cc7..b033bfa0db6e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java @@ -81,7 +81,8 @@ public void testFromMap() { dims, maxInputTokens, model, - null + null, + CohereServiceSettings.CohereApiVersion.V1 ), CohereEmbeddingType.BYTE ) @@ -125,7 +126,8 @@ public void testFromMap_WithModelId() { dims, maxInputTokens, model, - null + null, + CohereServiceSettings.CohereApiVersion.V2 ), CohereEmbeddingType.INT8 ) @@ -155,7 +157,9 @@ public void testFromMap_PrefersModelId_OverModel() { CohereServiceSettings.MODEL_ID, model, CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, - CohereEmbeddingType.BYTE.toString() + CohereEmbeddingType.BYTE.toString(), + CohereServiceSettings.API_VERSION, + CohereServiceSettings.CohereApiVersion.V1.toString() ) ), ConfigurationParseContext.PERSISTENT @@ -171,7 +175,8 @@ public void testFromMap_PrefersModelId_OverModel() { dims, maxInputTokens, model, - null + null, + CohereServiceSettings.CohereApiVersion.V1 ), CohereEmbeddingType.BYTE ) @@ -188,7 +193,7 @@ public void testFromMap_EmptyEmbeddingType_ThrowsError() { var thrownException = expectThrows( ValidationException.class, () -> CohereEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "")), + new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "", CohereServiceSettings.MODEL_ID, "model")), ConfigurationParseContext.REQUEST ) ); @@ -208,7 +213,7 @@ public void testFromMap_InvalidEmbeddingType_ThrowsError_ForRequest() { var thrownException = expectThrows( ValidationException.class, () -> CohereEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "abc")), + new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "abc", CohereServiceSettings.MODEL_ID, "model")), ConfigurationParseContext.REQUEST ) ); @@ -265,7 +270,12 @@ public void testFromMap_ConvertsElementTypeByte_ToCohereEmbeddingTypeByte() { new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, DenseVectorFieldMapper.ElementType.BYTE.toString())), ConfigurationParseContext.PERSISTENT ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.BYTE)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(CohereServiceSettings.CohereApiVersion.V1), + CohereEmbeddingType.BYTE + ) + ) ); } @@ -275,7 +285,12 @@ public void testFromMap_ConvertsElementTypeFloat_ToCohereEmbeddingTypeFloat() { new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, DenseVectorFieldMapper.ElementType.FLOAT.toString())), ConfigurationParseContext.PERSISTENT ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.FLOAT)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(CohereServiceSettings.CohereApiVersion.V1), + CohereEmbeddingType.FLOAT + ) + ) ); } @@ -283,29 +298,58 @@ public void testFromMap_ConvertsInt8_ToCohereEmbeddingTypeInt8() { assertThat( CohereEmbeddingsServiceSettings.fromMap( new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.INT8.toString())), - ConfigurationParseContext.REQUEST + ConfigurationParseContext.PERSISTENT ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.INT8)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(CohereServiceSettings.CohereApiVersion.V1), + CohereEmbeddingType.INT8 + ) + ) ); } public void testFromMap_ConvertsBit_ToCohereEmbeddingTypeBit() { assertThat( CohereEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.BIT.toString())), + new HashMap<>( + Map.of( + CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, + CohereEmbeddingType.BIT.toString(), + CohereServiceSettings.MODEL_ID, + "model" + ) + ), ConfigurationParseContext.REQUEST ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.BIT)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings((String) null, null, null, null, "model", null, CohereServiceSettings.CohereApiVersion.V2), + CohereEmbeddingType.BIT + ) + ) ); } public void testFromMap_PreservesEmbeddingTypeFloat() { assertThat( CohereEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.FLOAT.toString())), + new HashMap<>( + Map.of( + CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, + CohereEmbeddingType.FLOAT.toString(), + CohereServiceSettings.MODEL_ID, + "model" + ) + ), ConfigurationParseContext.REQUEST ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.FLOAT)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings((String) null, null, null, null, "model", null, CohereServiceSettings.CohereApiVersion.V2), + CohereEmbeddingType.FLOAT + ) + ) ); } @@ -315,7 +359,12 @@ public void testFromMap_PersistentReadsInt8() { new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "int8")), ConfigurationParseContext.PERSISTENT ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.INT8)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(CohereServiceSettings.CohereApiVersion.V1), + CohereEmbeddingType.INT8 + ) + ) ); } @@ -331,7 +380,15 @@ public void testFromCohereOrDenseVectorEnumValues() { public void testToXContent_WritesAllValues() throws IOException { var serviceSettings = new CohereEmbeddingsServiceSettings( - new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)), + new CohereServiceSettings( + "url", + SimilarityMeasure.COSINE, + 5, + 10, + "model_id", + new RateLimitSettings(3), + CohereServiceSettings.CohereApiVersion.V2 + ), CohereEmbeddingType.INT8 ); @@ -340,7 +397,7 @@ public void testToXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" {"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """ - "rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}""")); + "rate_limit":{"requests_per_minute":3},"api_version":"V2","embedding_type":"byte"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java index a9d73c78af9e5..62ff4d599f6e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; @@ -125,7 +126,7 @@ public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOExcep private CohereRerankModel createModel(String modelId, CohereRerankTaskSettings taskSettings) { return new CohereRerankModel( "inference_id", - new CohereRerankServiceSettings((URI) null, modelId, null), + new CohereRerankServiceSettings((URI) null, modelId, null, CohereServiceSettings.CohereApiVersion.V2), taskSettings, new DefaultSecretSettings(new SecureString("secret".toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java index 722a32f172d82..34cf019b6010a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; @@ -123,7 +124,7 @@ public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOExcep private CohereRerankModel createModel(String modelId, String uri, CohereRerankTaskSettings taskSettings) { return new CohereRerankModel( "inference_id", - new CohereRerankServiceSettings(uri, modelId, null), + new CohereRerankServiceSettings(uri, modelId, null, CohereServiceSettings.CohereApiVersion.V2), taskSettings, new DefaultSecretSettings(new SecureString("secret".toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java index e3401b74017f2..27a9fc38f392c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java @@ -36,7 +36,8 @@ public static CohereRerankServiceSettings createRandom(@Nullable RateLimitSettin return new CohereRerankServiceSettings( randomFrom(new String[] { null, Strings.format("http://%s.com", randomAlphaOfLength(8)) }), randomFrom(new String[] { null, randomAlphaOfLength(10) }), - rateLimitSettings + rateLimitSettings, + CohereServiceSettings.CohereApiVersion.V2 ); } @@ -44,7 +45,7 @@ public void testToXContent_WritesAllValues() throws IOException { var url = "http://www.abc.com"; var model = "model"; - var serviceSettings = new CohereRerankServiceSettings(url, model, null); + var serviceSettings = new CohereRerankServiceSettings(url, model, null, CohereServiceSettings.CohereApiVersion.V2); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); @@ -56,7 +57,8 @@ public void testToXContent_WritesAllValues() throws IOException { "model_id":"model", "rate_limit": { "requests_per_minute": 10000 - } + }, + "api_version": "V2" } """)); } @@ -80,7 +82,19 @@ protected CohereRerankServiceSettings mutateInstance(CohereRerankServiceSettings protected CohereRerankServiceSettings mutateInstanceForVersion(CohereRerankServiceSettings instance, TransportVersion version) { if (version.before(TransportVersions.V_8_15_0)) { // We always default to the same rate limit settings, if a node is on a version before rate limits were introduced - return new CohereRerankServiceSettings(instance.uri(), instance.modelId(), CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS); + return new CohereRerankServiceSettings( + instance.uri(), + instance.modelId(), + CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS, + CohereServiceSettings.CohereApiVersion.V1 + ); + } else if (version.before(TransportVersions.ML_INFERENCE_COHERE_API_VERSION)) { + return new CohereRerankServiceSettings( + instance.uri(), + instance.modelId(), + instance.rateLimitSettings(), + CohereServiceSettings.CohereApiVersion.V1 + ); } return instance; } From 8860a6160258c799ec09e3ed0b47158f9870255f Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 23 Jun 2025 22:45:47 +0100 Subject: [PATCH 07/12] Update docs/changelog/129884.yaml --- docs/changelog/129884.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/129884.yaml diff --git a/docs/changelog/129884.yaml b/docs/changelog/129884.yaml new file mode 100644 index 0000000000000..a3ae373f2dbd0 --- /dev/null +++ b/docs/changelog/129884.yaml @@ -0,0 +1,5 @@ +pr: 129884 +summary: Move to the Cohere V2 API for new inference endpoints +area: Machine Learning +type: enhancement +issues: [] From fe3fed6fdfa9c7a35ac1dbb3349c6630c96077bc Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 23 Jun 2025 21:54:44 +0000 Subject: [PATCH 08/12] [CI] Auto commit changes from spotless --- .../cohere/action/CohereActionCreator.java | 1 - .../services/cohere/CohereServiceTests.java | 18 +++--------------- .../action/CohereEmbeddingsActionTests.java | 1 - .../embeddings/CohereEmbeddingsModelTests.java | 13 ++----------- 4 files changed, 5 insertions(+), 28 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index f195ed8db7a74..777ddc348bda6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.cohere.CohereResponseHandler; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 2f2b43dc5f2e8..52e4f904a4de0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -266,11 +266,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); serviceSettings.put(CohereServiceSettings.MODEL_ID, "foo"); - var config = getRequestConfigMap( - serviceSettings, - getTaskSettingsMapEmpty(), - getSecretSettingsMap("secret") - ); + var config = getRequestConfigMap(serviceSettings, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")); config.put("extra_key", "value"); var failureListener = getModelListenerForException( @@ -323,11 +319,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); serviceSettings.put(CohereServiceSettings.MODEL_ID, "foo"); - var config = getRequestConfigMap( - serviceSettings, - getTaskSettingsMapEmpty(), - secretSettingsMap - ); + var config = getRequestConfigMap(serviceSettings, getTaskSettingsMapEmpty(), secretSettingsMap); var failureListener = getModelListenerForException( ElasticsearchStatusException.class, @@ -353,11 +345,7 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - serviceSettings, - getTaskSettingsMapEmpty(), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(serviceSettings, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")), modelListener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index b9f7c30478624..05d69bae4903e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -54,7 +54,6 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index aa19fcdf791ba..fd380b8fd973d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -146,8 +146,7 @@ public static CohereEmbeddingsModel createModel( @Nullable Integer dimensions, String model, @Nullable CohereEmbeddingType embeddingType - ) - { + ) { return createModel( url, apiKey, @@ -173,15 +172,7 @@ public static CohereEmbeddingsModel createModel( return new CohereEmbeddingsModel( "id", new CohereEmbeddingsServiceSettings( - new CohereServiceSettings( - url, - SimilarityMeasure.DOT_PRODUCT, - dimensions, - tokenLimit, - model, - null, - apiVersion - ), + new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null, apiVersion), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) ), taskSettings, From 7ef5119cbee8dc0dad45eb4090a19655cb9934b5 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 24 Jun 2025 16:36:22 +0100 Subject: [PATCH 09/12] Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java Co-authored-by: Pat Whelan --- .../xpack/inference/services/cohere/CohereServiceSettings.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java index 9dc7f212b4d45..94d7b0a6dda74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -104,7 +104,7 @@ public static CohereServiceSettings fromMap(Map map, Configurati similarity, dims, maxInputTokens, - modelId(oldModelId, modelId), + resolvedModelId, rateLimitSettings, apiVersion ); From 2b2eac58eb73d9531b949a836f76e27cbb998e83 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 24 Jun 2025 15:46:17 +0000 Subject: [PATCH 10/12] [CI] Auto commit changes from spotless --- .../services/cohere/CohereServiceSettings.java | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java index 94d7b0a6dda74..1767653fd1a51 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -99,15 +99,7 @@ public static CohereServiceSettings fromMap(Map map, Configurati throw validationException; } - return new CohereServiceSettings( - uri, - similarity, - dims, - maxInputTokens, - resolvedModelId, - rateLimitSettings, - apiVersion - ); + return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, resolvedModelId, rateLimitSettings, apiVersion); } public static CohereApiVersion apiVersionFromMap( From 90f25c131ca1d84d6aa33f3e9cef29bc4f711c21 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 24 Jun 2025 18:49:35 +0000 Subject: [PATCH 11/12] [CI] Auto commit changes from spotless --- server/src/main/java/org/elasticsearch/TransportVersions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index cb6034adbb861..0f8e54a947d2c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -319,7 +319,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = def(9_106_0_00); public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS = def(9_107_0_00); public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00); - public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_109_0_00); + public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_109_0_00); /* * STOP! READ THIS FIRST! No, really, From 954cd0d83f53066e46086d526714e3917ba6d996 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 24 Jun 2025 20:16:02 +0000 Subject: [PATCH 12/12] [CI] Auto commit changes from spotless --- server/src/main/java/org/elasticsearch/TransportVersions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index a9716ab5bbfc0..1bc23a24fc5ab 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -321,7 +321,7 @@ static TransportVersion def(int id) { public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS = def(9_107_0_00); public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00); public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_109_00_0); - public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_110_0_00); + public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_110_0_00); /* * STOP! READ THIS FIRST! No, really,