Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glebashnik/feed field generator #32842

Merged
merged 28 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b954ab2
wip to support generate indexing statement
glebashnik Oct 28, 2024
9461c66
Initial support for generate in indexing language
glebashnik Oct 29, 2024
98073fe
draft generate expression
glebashnik Nov 5, 2024
7a293db
Refactoring. LocallLLM and OpenAI implement Generator.
glebashnik Nov 5, 2024
82ff3a9
Fix build
lesters Nov 14, 2024
baccc99
Remove dependency on document in linguistics
lesters Nov 14, 2024
ac9357b
Fix tests
lesters Nov 14, 2024
8d91388
Set input and output types for generate expression
lesters Nov 14, 2024
8e56319
Merge branch 'master' into glebashnik/feed-field-generator
lesters Nov 14, 2024
420b943
Update GenerateExpression after merge with master branch
lesters Nov 14, 2024
cc0218a
Wire in generators to indexing processor
lesters Nov 20, 2024
261bfd6
Resolve conflicts with master
lesters Nov 20, 2024
69baaa4
Use Prompt insteaad of String in Generators
lesters Nov 21, 2024
0f0d6a9
Merge branch 'master' into glebashnik/feed-field-generator
lesters Nov 29, 2024
f4869ec
ConfigurableLanguageModel implements Generator
lesters Nov 29, 2024
40b4a19
Renamed Generator interface to TextGenerator, added support for array…
glebashnik Jan 3, 2025
7790a0a
Improved input/output type inference for generate expression.
glebashnik Jan 3, 2025
e75e566
Improved input/output type error messages
glebashnik Jan 3, 2025
46e3add
Descriptive (future-proof) names for text generator component classes.
glebashnik Jan 3, 2025
fb7aed2
Added prompt template to LanguageModelTextGenerator
glebashnik Jan 3, 2025
a5d466e
Added max length to LanguageModelTextGenerator
glebashnik Jan 6, 2025
9d87bbe
Added tests with a tiny LLM, fixed issue with inference parameters.
glebashnik Jan 6, 2025
74b89fb
Merging with master
glebashnik Jan 6, 2025
63b12fa
Fixed forgotten changes in jj and abi-spec after Generator to TextGen…
glebashnik Jan 6, 2025
69d0653
Renamed Generator to TextGenerator in IndexingParser.ccc
glebashnik Jan 6, 2025
85d7bc5
Added LanguageModelTextGenerator to PlatformBundles
glebashnik Jan 7, 2025
0ebd24f
Removed question in TODO
glebashnik Jan 10, 2025
c705878
Merged with master, resolved conflicts in indexing processor.
glebashnik Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public DataType setOutputType(DataType outputType, VerificationContext context)
throw new VerificationException(this, "Generate expression requires either a string or array<string> output type, but got "
+ outputType.getName());

super.setOutputType(null, outputType, null, context); // todo: Why not set actualOutput to outputType?
super.setOutputType(null, outputType, null, context); // TODO: Why not set actualOutput to outputType?
return outputType; // return input type the same as output type: string or array<string>
Copy link
Contributor Author

@glebashnik glebashnik Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Input-output type inference is straightforward same type on input as on output.
string to array is handled outside using split expression (see testGeneratorWithStringInputArrayOutput test)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import ai.vespa.llm.completion.Prompt;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.process.TextGenerator;
import de.kherud.llama.LlamaModel;
import de.kherud.llama.ModelParameters;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.language.process.TextGenerator;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.logging.Logger;

/**
Expand All @@ -18,17 +22,50 @@
public class LanguageModelTextGenerator extends AbstractComponent implements TextGenerator {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit verbose naming but quite precise.
I also considered LMTextGenerator.

private static final Logger logger = Logger.getLogger(LanguageModelTextGenerator.class.getName());
private final LanguageModel languageModel;

// Template usually contains {input} placeholder for the dynamic part of the prompt, replaced with the actual input.
// Templates without {input} are possible, which will ignore the input, making the prompt static.
// TODO: Consider not allowing templates without {input} to avoid costly errors when users forget to include {input}.
private static final String DEFAULT_PROMPT_TEMPLATE = "{input}";
private final String promptTemplate;

@Inject
public LanguageModelTextGenerator(TextGeneratorConfig config, ComponentRegistry<LanguageModel> languageModels) {
public LanguageModelTextGenerator(LanguageModelTextGeneratorConfig config, ComponentRegistry<LanguageModel> languageModels) {
this.languageModel = LanguageModelUtils.findLanguageModel(config.providerId(), languageModels, logger);
this.promptTemplate = loadPromptTemplate(config);
}

private String loadPromptTemplate(LanguageModelTextGeneratorConfig config) {
if (config.promptTemplate() != null && !config.promptTemplate().isEmpty()) {
return config.promptTemplate();
} else if (config.promptTemplateFile().isPresent()) {
Path path = config.promptTemplateFile().get();

try {
String promptTemplate = new String(Files.readAllBytes(path));

if (!promptTemplate.isEmpty()) { // TODO: Consider throwing an exception if the template is empty.
return promptTemplate;
}
} catch (IOException e) {
throw new IllegalArgumentException("Could not read prompt template file: " + path, e);
}
}

return DEFAULT_PROMPT_TEMPLATE;
}

@Override
public String generate(Prompt prompt, Context context) {
var finalPrompt = buildPrompt(prompt);
var options = new InferenceParameters(s -> "");
var completions = languageModel.complete(prompt, options);
var completions = languageModel.complete(finalPrompt, options);
var firstCompletion = completions.get(0);
return firstCompletion.text();
}

private Prompt buildPrompt(Prompt inputPrompt) {
String finalPrompt = promptTemplate.replace("{input}", inputPrompt.asString());
return StringPrompt.from(finalPrompt);
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package=ai.vespa.llm.generation

# The external LLM provider - the id of a LanguageModel component
providerId string default=""
# Id of a LanguageModel component specified in services.xml, e.g. OpenAI, LocalLLM
providerId string

# The default prompt to use if not overridden in query
prompt string default=""
# Prompt template, e.g. "Extract named entities from this text, output one entity per line: {input}"
promptTemplate string default=""

# The default prompt template file to use if not overridden in query. Above prompt has precedence if it is set.
promptTemplate path optional
# Prompt template in a file, above promptTemplate has precedence if it is set.
promptTemplateFile path optional
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,174 @@

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;

import com.yahoo.config.FileReference;
import com.yahoo.language.process.TextGenerator;
import org.junit.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;


public class LanguageModelTextGeneratorTest {

@Test
public void testGenerate() {
LanguageModel languageModel1 = new RepeaterMockLanguageModel(1);
LanguageModel languageModel2 = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("mock1", languageModel1, "mock2", languageModel2);
LanguageModel languageModel = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("languageModel", languageModel);

var config = new LanguageModelTextGeneratorConfig.Builder().providerId("languageModel").build();
var generator = createGenerator(config, languageModels);
var context = new TextGenerator.Context("schema.indexing");
var result = generator.generate(StringPrompt.from("hello"), context);
assertEquals("hello hello", result);
}

@Test
public void testGenerateWithTwoLanguageModel() {
LanguageModel languageModel1 = new RepeaterMockLanguageModel(2);
LanguageModel languageModel2 = new RepeaterMockLanguageModel(3);
var languageModels = Map.of("languageModel1", languageModel1, "languageModel2", languageModel2);

var config1 = new TextGeneratorConfig.Builder().providerId("mock1").build();
var config1 = new LanguageModelTextGeneratorConfig.Builder().providerId("languageModel1").build();
var generator1 = createGenerator(config1, languageModels);
var context = new TextGenerator.Context("schema.indexing");
var result1 = generator1.generate(StringPrompt.from("hello"), context);
assertEquals("hello", result1);
assertEquals("hello hello", result1);

var config2 = new TextGeneratorConfig.Builder().providerId("mock2").build();
var generator2 = createGenerator(config2, Map.of("mock1", languageModel1, "mock2", languageModel2));
var config2 = new LanguageModelTextGeneratorConfig.Builder().providerId("languageModel2").build();
var generator2 = createGenerator(config2, languageModels);
var result2 = generator2.generate(StringPrompt.from("hello"), context);
assertEquals("hello hello hello", result2);
}

@Test
public void testGenerateWithPromptTemplate() {
LanguageModel languageModel = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("languageModel", languageModel);

var config = new LanguageModelTextGeneratorConfig.Builder()
.providerId("languageModel")
.promptTemplate("hello {input}")
.build();
var generator = createGenerator(config, languageModels);
var context = new TextGenerator.Context("schema.indexing");

var result1 = generator.generate(StringPrompt.from("world"), context);
assertEquals("hello world hello world", result1);

var result2 = generator.generate(StringPrompt.from("there"), context);
assertEquals("hello there hello there", result2);
}

@Test
public void testGenerateWithEmptyPromptTemplate() {
LanguageModel languageModel = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("languageModel", languageModel);

var config = new LanguageModelTextGeneratorConfig.Builder()
.providerId("languageModel")
.promptTemplate("")
.build();
var generator = createGenerator(config, languageModels);
var context = new TextGenerator.Context("schema.indexing");

var result1 = generator.generate(StringPrompt.from("world"), context);
assertEquals("world world", result1);
}

@Test
public void testGenerateWithStaticPromptTemplate() {
LanguageModel languageModel = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("languageModel", languageModel);

var config = new LanguageModelTextGeneratorConfig.Builder()
.providerId("languageModel")
.promptTemplate("hello")
.build();
var generator = createGenerator(config, languageModels);
var context = new TextGenerator.Context("schema.indexing");

var result1 = generator.generate(StringPrompt.from("world"), context);
assertEquals("hello hello", result1);

var result2 = generator.generate(StringPrompt.from("there"), context);
assertEquals("hello hello", result2);
}

@Test
public void testGenerateWithPromptTemplateFile() {
LanguageModel languageModel = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("languageModel", languageModel);

var config = new LanguageModelTextGeneratorConfig.Builder()
.providerId("languageModel")
.promptTemplateFile(Optional.of(new FileReference("src/test/prompts/prompt_with_input.txt")))
.build();

var generator = createGenerator(config, languageModels);
var context = new TextGenerator.Context("schema.indexing");

var result1 = generator.generate(StringPrompt.from("world"), context);
assertEquals("hello world hello world", result1);
}

@Test
public void testGenerateWithEmptyTemplateFile() {
LanguageModel languageModel = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("languageModel", languageModel);

var config = new LanguageModelTextGeneratorConfig.Builder()
.providerId("languageModel")
.promptTemplateFile(Optional.of(new FileReference("src/test/prompts/empty_prompt.txt")))
.build();

var generator = createGenerator(config, languageModels);
var context = new TextGenerator.Context("schema.indexing");

var result1 = generator.generate(StringPrompt.from("world"), context);
assertEquals("world world", result1);
}

@Test
public void testGenerateWithMissingTemplateFile() {
LanguageModel languageModel = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("languageModel", languageModel);

var config = new LanguageModelTextGeneratorConfig.Builder()
.providerId("languageModel")
.promptTemplateFile(Optional.of(new FileReference("src/test/prompts/missing_prompt.txt")))
.build();

try {
createGenerator(config, languageModels);
}
catch (IllegalArgumentException e) {
assertEquals("Could not read prompt template file: src/test/prompts/missing_prompt.txt", e.getMessage());
}
}

@Test
public void testGenerateWithPromptTemplateOverridesPromptTemplateFile() {
LanguageModel languageModel = new RepeaterMockLanguageModel(2);
var languageModels = Map.of("languageModel", languageModel);

var config = new LanguageModelTextGeneratorConfig.Builder()
.providerId("languageModel")
.promptTemplate("bye {input}")
.promptTemplateFile(Optional.of(new FileReference("src/test/prompts/prompt_with_input.txt")))
.build();

var generator = createGenerator(config, languageModels);
var context = new TextGenerator.Context("schema.indexing");

var result1 = generator.generate(StringPrompt.from("world"), context);
assertEquals("bye world bye world", result1);
}

private static LanguageModelTextGenerator createGenerator(TextGeneratorConfig config, Map<String, LanguageModel> languageModels) {
private static LanguageModelTextGenerator createGenerator(LanguageModelTextGeneratorConfig config, Map<String, LanguageModel> languageModels) {
ComponentRegistry<LanguageModel> languageModelsRegistry = new ComponentRegistry<>();
languageModels.forEach((key, value) -> languageModelsRegistry.register(ComponentId.fromString(key), value));
languageModelsRegistry.freeze();
Expand Down
Empty file.
1 change: 1 addition & 0 deletions model-integration/src/test/prompts/prompt_with_input.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hello {input}
Loading