From 44b890ff1e952de4ab5182e8173e1a9ad270ef5d Mon Sep 17 00:00:00 2001 From: Gleb Sizov Date: Wed, 22 Jan 2025 22:42:50 +0100 Subject: [PATCH 1/2] Adds model, temperature and maxToken as configurable parameters to OpenAI component. --- .../java/ai/vespa/llm/clients/OpenAI.java | 27 ++++++++++++++--- .../configdefinitions/llm-client.def | 14 +++++++++ .../java/ai/vespa/llm/clients/OpenAITest.java | 20 ++++++++++++- .../ai/vespa/llm/InferenceParameters.java | 18 ++++++++++-- .../vespa/llm/client/openai/OpenAiClient.java | 7 +++-- .../ai/vespa/llm/InferenceParametersTest.java | 29 +++++++++++++++++++ 6 files changed, 105 insertions(+), 10 deletions(-) create mode 100644 vespajlib/src/test/java/ai/vespa/llm/InferenceParametersTest.java diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java index 28af52babdf2..bf2be8ff61f5 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -9,40 +9,59 @@ import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; /** * A configurable OpenAI client. * - * @author lesters + * @author lesters glebashnik */ @Beta public class OpenAI extends ConfigurableLanguageModel { - private final OpenAiClient client; + private final Map configOptions; @Inject public OpenAI(LlmClientConfig config, Secrets secretStore) { super(config, secretStore); client = new OpenAiClient(); + + configOptions = new HashMap<>(); + + if (!config.model().isBlank()) { + configOptions.put(OpenAiClient.OPTION_MODEL, config.model()); + } + + if (config.temperature() >= 0) { + configOptions.put(OpenAiClient.OPTION_TEMPERATURE, String.valueOf(config.temperature())); + } + + if (config.maxTokens() >= 0) { + configOptions.put(OpenAiClient.OPTION_MAX_TOKENS, String.valueOf(config.maxTokens())); + } + } @Override public List complete(Prompt prompt, InferenceParameters parameters) { + var combinedParameters = parameters.withDefaultOptions(configOptions::get); setApiKey(parameters); setEndpoint(parameters); - return client.complete(prompt, parameters); + return client.complete(prompt, combinedParameters); } @Override public CompletableFuture completeAsync(Prompt prompt, InferenceParameters parameters, Consumer consumer) { + var combinedParameters = parameters.withDefaultOptions(configOptions::get); setApiKey(parameters); setEndpoint(parameters); - return client.completeAsync(prompt, parameters, consumer); + return client.completeAsync(prompt, combinedParameters, consumer); } } diff --git a/model-integration/src/main/resources/configdefinitions/llm-client.def b/model-integration/src/main/resources/configdefinitions/llm-client.def index 0866459166ab..dda0e2faf3bf 100755 --- a/model-integration/src/main/resources/configdefinitions/llm-client.def +++ b/model-integration/src/main/resources/configdefinitions/llm-client.def @@ -6,3 +6,17 @@ apiKeySecretName string default="" # Endpoint for LLM client - if not set reverts to default for client endpoint string default="" + +# LLM APIs often provide several models, e.g. gpt-4o and gpt-4o-mini from OpenAI. +# Specify a model to use by default when not overwritten by inference parameters. +# Empty means use default for the API. +model string default="" + +# Sampling temperature, between 0 and 1. +# Higher values make the output more random, while lower values are more focused and deterministic. +# -1 means use default for the API. +temperature double default=-1 + +# The maximum number of tokens that can be generated in the completion. +# -1 means use default for the API. +maxTokens int default=-1 \ No newline at end of file diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java index 63f99876de9e..2b7b53c8d7d0 100644 --- a/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java +++ b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java @@ -3,11 +3,12 @@ import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.StringPrompt; -import com.yahoo.container.jdisc.SecretStoreProvider; import com.yahoo.container.jdisc.SecretsProvider; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Map; public class OpenAITest { @@ -33,4 +34,21 @@ public void testOpenAIGeneration() { future.join(); } + @Test + @Disabled + public void testComplete() { + var config = new LlmClientConfig.Builder().maxTokens(10).build(); + var openai = new OpenAI(config, new SecretsProvider().get()); + var options = Map.of( + "model", "gpt-4o-mini" + ); + var prompt = StringPrompt.from("Explain why ducks better than cats in 20 words?"); + var completions = openai.complete(prompt, new InferenceParameters(apiKey, options::get)); + assertFalse(completions.isEmpty()); + + // Token is smaller than word. + // Splitting by space is a poor tokenizer but it is good enough for this test. + assertTrue(completions.get(0).text().split(" ").length <= 10); + } + } diff --git a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java index a942e5090e58..910317c78891 100755 --- a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java +++ b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java @@ -24,11 +24,16 @@ public class InferenceParameters { private final Function options; public InferenceParameters(Function options) { - this(null, options); + this(null, null, options); } public InferenceParameters(String apiKey, Function options) { + this(apiKey, null, options); + } + + public InferenceParameters(String apiKey, String endpoint, Function options) { this.apiKey = apiKey; + this.endpoint = endpoint; this.options = Objects.requireNonNull(options); } @@ -71,6 +76,15 @@ public Optional getInt(String option) { public void ifPresent(String option, Consumer func) { get(option).ifPresent(func); } - + + // Creates a new InferenceParameters object with default values for options, + // i.e. a value in the given default options is used when a corresponding value in the current options is null. + public InferenceParameters withDefaultOptions(Function defaultOptions) { + Function prependedOptions = key -> { + var afterValue = options.apply(key); + return afterValue != null ? afterValue : defaultOptions.apply(key); + }; + return new InferenceParameters(apiKey, endpoint, prependedOptions); + } } diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java index cc22ee696842..c842229b7192 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -47,9 +47,10 @@ public class OpenAiClient implements LanguageModel { private static final int MAX_RETRIES = 3; private static final long RETRY_DELAY_MS = 250; - private static final String OPTION_MODEL = "model"; - private static final String OPTION_TEMPERATURE = "temperature"; - private static final String OPTION_MAX_TOKENS = "maxTokens"; + // These are public so that they can be used to set corresponding InferenceParameters outside of this class. + public static final String OPTION_MODEL = "model"; + public static final String OPTION_TEMPERATURE = "temperature"; + public static final String OPTION_MAX_TOKENS = "maxTokens"; private final HttpClient httpClient; diff --git a/vespajlib/src/test/java/ai/vespa/llm/InferenceParametersTest.java b/vespajlib/src/test/java/ai/vespa/llm/InferenceParametersTest.java new file mode 100644 index 000000000000..90fddb35026e --- /dev/null +++ b/vespajlib/src/test/java/ai/vespa/llm/InferenceParametersTest.java @@ -0,0 +1,29 @@ +package ai.vespa.llm; + + +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author glebashnik + */ +public class InferenceParametersTest { + @Test + public void testWithDefaultOptions() { + var params = new InferenceParameters("testapikey", "testhost", Map.of("a", "a1", "b", "b1")::get); + var combinedParams = params.withDefaultOptions(Map.of("b", "b2", "c", "c2", "d", "d2")::get); + + assertEquals(combinedParams.getApiKey(), Optional.of("testapikey")); + assertEquals(combinedParams.getEndpoint(), Optional.of("testhost")); + assertEquals(combinedParams.get("a"), Optional.of("a1")); + assertEquals(combinedParams.get("b"), Optional.of("b1")); + assertEquals(combinedParams.get("c"), Optional.of("c2")); + assertEquals(combinedParams.get("d"), Optional.of("d2")); + assertTrue(combinedParams.get("e").isEmpty()); + } +} From ea54f7e962699c0a7d09515637a64f20c1bcb767 Mon Sep 17 00:00:00 2001 From: Gleb Sizov Date: Thu, 23 Jan 2025 13:01:13 +0100 Subject: [PATCH 2/2] Updated abi spec for LLM client config --- model-integration/abi-spec.json | 8 +++++++- vespajlib/abi-spec.json | 10 ++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index c151eac86850..e0222fdb265b 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -32,6 +32,9 @@ "public void (ai.vespa.llm.clients.LlmClientConfig)", "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)", "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder model(java.lang.String)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder temperature(double)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder maxTokens(int)", "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", "public final java.lang.String getDefMd5()", "public final java.lang.String getDefName()", @@ -70,7 +73,10 @@ "public static java.lang.String getDefNamespace()", "public void (ai.vespa.llm.clients.LlmClientConfig$Builder)", "public java.lang.String apiKeySecretName()", - "public java.lang.String endpoint()" + "public java.lang.String endpoint()", + "public java.lang.String model()", + "public double temperature()", + "public int maxTokens()" ], "fields" : [ "public static final java.lang.String CONFIG_DEF_MD5", diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index de020bf85471..efd1cee3c5ee 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -4139,6 +4139,7 @@ "methods" : [ "public void (java.util.function.Function)", "public void (java.lang.String, java.util.function.Function)", + "public void (java.lang.String, java.lang.String, java.util.function.Function)", "public void setApiKey(java.lang.String)", "public java.util.Optional getApiKey()", "public void setEndpoint(java.lang.String)", @@ -4146,7 +4147,8 @@ "public java.util.Optional get(java.lang.String)", "public java.util.Optional getDouble(java.lang.String)", "public java.util.Optional getInt(java.lang.String)", - "public void ifPresent(java.lang.String, java.util.function.Consumer)" + "public void ifPresent(java.lang.String, java.util.function.Consumer)", + "public ai.vespa.llm.InferenceParameters withDefaultOptions(java.util.function.Function)" ], "fields" : [ ] }, @@ -4189,7 +4191,11 @@ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" ], - "fields" : [ ] + "fields" : [ + "public static final java.lang.String OPTION_MODEL", + "public static final java.lang.String OPTION_TEMPERATURE", + "public static final java.lang.String OPTION_MAX_TOKENS" + ] }, "ai.vespa.llm.completion.Completion$FinishReason" : { "superClass" : "java.lang.Enum",