Skip to content

Commit

Permalink
Merge pull request #33158 from vespa-engine/glebashnik/openai-config-…
Browse files Browse the repository at this point in the history
…model-temp-maxtokens

Adds model, temperature and maxToken as configurable parameters to LLM client
  • Loading branch information
bratseth authored Jan 24, 2025
2 parents 49344f5 + ea54f7e commit da1e0a0
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 13 deletions.
8 changes: 7 additions & 1 deletion model-integration/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
"public void <init>(ai.vespa.llm.clients.LlmClientConfig)",
"public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)",
"public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)",
"public ai.vespa.llm.clients.LlmClientConfig$Builder model(java.lang.String)",
"public ai.vespa.llm.clients.LlmClientConfig$Builder temperature(double)",
"public ai.vespa.llm.clients.LlmClientConfig$Builder maxTokens(int)",
"public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
"public final java.lang.String getDefMd5()",
"public final java.lang.String getDefName()",
Expand Down Expand Up @@ -70,7 +73,10 @@
"public static java.lang.String getDefNamespace()",
"public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)",
"public java.lang.String apiKeySecretName()",
"public java.lang.String endpoint()"
"public java.lang.String endpoint()",
"public java.lang.String model()",
"public double temperature()",
"public int maxTokens()"
],
"fields" : [
"public static final java.lang.String CONFIG_DEF_MD5",
Expand Down
27 changes: 23 additions & 4 deletions model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,59 @@
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;

/**
* A configurable OpenAI client.
*
* @author lesters
* @author lesters glebashnik
*/
@Beta
public class OpenAI extends ConfigurableLanguageModel {

private final OpenAiClient client;
private final Map<String, String> configOptions;

@Inject
public OpenAI(LlmClientConfig config, Secrets secretStore) {
super(config, secretStore);
client = new OpenAiClient();

configOptions = new HashMap<>();

if (!config.model().isBlank()) {
configOptions.put(OpenAiClient.OPTION_MODEL, config.model());
}

if (config.temperature() >= 0) {
configOptions.put(OpenAiClient.OPTION_TEMPERATURE, String.valueOf(config.temperature()));
}

if (config.maxTokens() >= 0) {
configOptions.put(OpenAiClient.OPTION_MAX_TOKENS, String.valueOf(config.maxTokens()));
}

}

@Override
public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
var combinedParameters = parameters.withDefaultOptions(configOptions::get);
setApiKey(parameters);
setEndpoint(parameters);
return client.complete(prompt, parameters);
return client.complete(prompt, combinedParameters);
}

@Override
public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
InferenceParameters parameters,
Consumer<Completion> consumer) {
var combinedParameters = parameters.withDefaultOptions(configOptions::get);
setApiKey(parameters);
setEndpoint(parameters);
return client.completeAsync(prompt, parameters, consumer);
return client.completeAsync(prompt, combinedParameters, consumer);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,17 @@ apiKeySecretName string default=""

# Endpoint for LLM client - if not set reverts to default for client
endpoint string default=""

# LLM APIs often provide several models, e.g. gpt-4o and gpt-4o-mini from OpenAI.
# Specify a model to use by default when not overwritten by inference parameters.
# Empty means use default for the API.
model string default=""

# Sampling temperature, between 0 and 1.
# Higher values make the output more random, while lower values are more focused and deterministic.
# -1 means use default for the API.
temperature double default=-1

# The maximum number of tokens that can be generated in the completion.
# -1 means use default for the API.
maxTokens int default=-1
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.container.jdisc.SecretStoreProvider;
import com.yahoo.container.jdisc.SecretsProvider;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.Map;

public class OpenAITest {
Expand All @@ -33,4 +34,21 @@ public void testOpenAIGeneration() {
future.join();
}

@Test
@Disabled
public void testComplete() {
var config = new LlmClientConfig.Builder().maxTokens(10).build();
var openai = new OpenAI(config, new SecretsProvider().get());
var options = Map.of(
"model", "gpt-4o-mini"
);
var prompt = StringPrompt.from("Explain why ducks better than cats in 20 words?");
var completions = openai.complete(prompt, new InferenceParameters(apiKey, options::get));
assertFalse(completions.isEmpty());

// Token is smaller than word.
// Splitting by space is a poor tokenizer but it is good enough for this test.
assertTrue(completions.get(0).text().split(" ").length <= 10);
}

}
10 changes: 8 additions & 2 deletions vespajlib/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -4140,14 +4140,16 @@
"methods" : [
"public void <init>(java.util.function.Function)",
"public void <init>(java.lang.String, java.util.function.Function)",
"public void <init>(java.lang.String, java.lang.String, java.util.function.Function)",
"public void setApiKey(java.lang.String)",
"public java.util.Optional getApiKey()",
"public void setEndpoint(java.lang.String)",
"public java.util.Optional getEndpoint()",
"public java.util.Optional get(java.lang.String)",
"public java.util.Optional getDouble(java.lang.String)",
"public java.util.Optional getInt(java.lang.String)",
"public void ifPresent(java.lang.String, java.util.function.Consumer)"
"public void ifPresent(java.lang.String, java.util.function.Consumer)",
"public ai.vespa.llm.InferenceParameters withDefaultOptions(java.util.function.Function)"
],
"fields" : [ ]
},
Expand Down Expand Up @@ -4190,7 +4192,11 @@
"public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
"public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
],
"fields" : [ ]
"fields" : [
"public static final java.lang.String OPTION_MODEL",
"public static final java.lang.String OPTION_TEMPERATURE",
"public static final java.lang.String OPTION_MAX_TOKENS"
]
},
"ai.vespa.llm.completion.Completion$FinishReason" : {
"superClass" : "java.lang.Enum",
Expand Down
18 changes: 16 additions & 2 deletions vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ public class InferenceParameters {
private final Function<String, String> options;

public InferenceParameters(Function<String, String> options) {
this(null, options);
this(null, null, options);
}

public InferenceParameters(String apiKey, Function<String, String> options) {
this(apiKey, null, options);
}

public InferenceParameters(String apiKey, String endpoint, Function<String, String> options) {
this.apiKey = apiKey;
this.endpoint = endpoint;
this.options = Objects.requireNonNull(options);
}

Expand Down Expand Up @@ -71,6 +76,15 @@ public Optional<Integer> getInt(String option) {
public void ifPresent(String option, Consumer<String> func) {
get(option).ifPresent(func);
}


// Creates a new InferenceParameters object with default values for options,
// i.e. a value in the given default options is used when a corresponding value in the current options is null.
public InferenceParameters withDefaultOptions(Function<String, String> defaultOptions) {
Function<String, String> prependedOptions = key -> {
var afterValue = options.apply(key);
return afterValue != null ? afterValue : defaultOptions.apply(key);
};
return new InferenceParameters(apiKey, endpoint, prependedOptions);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ public class OpenAiClient implements LanguageModel {
private static final int MAX_RETRIES = 3;
private static final long RETRY_DELAY_MS = 250;

private static final String OPTION_MODEL = "model";
private static final String OPTION_TEMPERATURE = "temperature";
private static final String OPTION_MAX_TOKENS = "maxTokens";
// These are public so that they can be used to set corresponding InferenceParameters outside of this class.
public static final String OPTION_MODEL = "model";
public static final String OPTION_TEMPERATURE = "temperature";
public static final String OPTION_MAX_TOKENS = "maxTokens";

private final HttpClient httpClient;

Expand Down
29 changes: 29 additions & 0 deletions vespajlib/src/test/java/ai/vespa/llm/InferenceParametersTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package ai.vespa.llm;


import org.junit.jupiter.api.Test;

import java.util.Map;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

/**
* @author glebashnik
*/
public class InferenceParametersTest {
@Test
public void testWithDefaultOptions() {
var params = new InferenceParameters("testapikey", "testhost", Map.of("a", "a1", "b", "b1")::get);
var combinedParams = params.withDefaultOptions(Map.of("b", "b2", "c", "c2", "d", "d2")::get);

assertEquals(combinedParams.getApiKey(), Optional.of("testapikey"));
assertEquals(combinedParams.getEndpoint(), Optional.of("testhost"));
assertEquals(combinedParams.get("a"), Optional.of("a1"));
assertEquals(combinedParams.get("b"), Optional.of("b1"));
assertEquals(combinedParams.get("c"), Optional.of("c2"));
assertEquals(combinedParams.get("d"), Optional.of("d2"));
assertTrue(combinedParams.get("e").isEmpty());
}
}

0 comments on commit da1e0a0

Please sign in to comment.