From 405d73df837469286c386dea21d3205f8fbe1bb4 Mon Sep 17 00:00:00 2001 From: izellevy Date: Sun, 10 Dec 2023 21:12:46 +0200 Subject: [PATCH] Add last message query generator (#210) * Add last message query generator * Fix lint * Fix lint * Improve docs * Fix docstring * Take only user messages * Fix lint * Fix behavior * Fix sample messages * [config] Changed Anyscale config to use LastMessageQueryGenerator Since Anyscale don't support function calling for now * [config] Update anyscale.yaml - The new QueryGenerator doesn't need an LLM Otherwise it will error out --------- Co-authored-by: ilai --- config/anyscale.yaml | 12 +----- config/config.yaml | 2 +- .../chat_engine/query_generator/__init__.py | 1 + .../query_generator/last_message.py | 36 ++++++++++++++++ .../test_last_message_query_generator.py | 41 +++++++++++++++++++ 5 files changed, 80 insertions(+), 12 deletions(-) create mode 100644 src/canopy/chat_engine/query_generator/last_message.py create mode 100644 tests/unit/query_generators/test_last_message_query_generator.py diff --git a/config/anyscale.yaml b/config/anyscale.yaml index 816a0ffc..7f5f28ef 100644 --- a/config/anyscale.yaml +++ b/config/anyscale.yaml @@ -24,14 +24,4 @@ chat_engine: # The query builder is responsible for generating textual queries given user message history. # -------------------------------------------------------------------- query_builder: - type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator] - params: - prompt: *query_builder_prompt # The query builder's system prompt for calling the LLM - function_description: # A function description passed to the LLM's `function_calling` API - Query search engine for relevant information - - llm: # The LLM that the query builder will use to generate queries. - #Use OpenAI for function call for now - type: OpenAILLM - params: - model_name: gpt-3.5-turbo + type: LastMessageQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator] \ No newline at end of file diff --git a/config/config.yaml b/config/config.yaml index 0d3576cd..8fc78572 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -59,7 +59,7 @@ chat_engine: # The query builder is responsible for generating textual queries given user message history. # -------------------------------------------------------------------- query_builder: - type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator] + type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator] params: prompt: *query_builder_prompt # The query builder's system prompt for calling the LLM function_description: # A function description passed to the LLM's `function_calling` API diff --git a/src/canopy/chat_engine/query_generator/__init__.py b/src/canopy/chat_engine/query_generator/__init__.py index 13ffd0d0..9005d02b 100644 --- a/src/canopy/chat_engine/query_generator/__init__.py +++ b/src/canopy/chat_engine/query_generator/__init__.py @@ -1,2 +1,3 @@ from .base import QueryGenerator from .function_calling import FunctionCallingQueryGenerator +from .last_message import LastMessageQueryGenerator diff --git a/src/canopy/chat_engine/query_generator/last_message.py b/src/canopy/chat_engine/query_generator/last_message.py new file mode 100644 index 00000000..74e15661 --- /dev/null +++ b/src/canopy/chat_engine/query_generator/last_message.py @@ -0,0 +1,36 @@ +from typing import List + +from canopy.chat_engine.query_generator import QueryGenerator +from canopy.models.data_models import Messages, Query, Role + + +class LastMessageQueryGenerator(QueryGenerator): + """ + Returns the last message as a query without running any LLMs. This can be + considered as the most basic query generation. Please use other query generators + for more accurate results. + """ + + def generate(self, + messages: Messages, + max_prompt_tokens: int) -> List[Query]: + """ + max_prompt_token is dismissed since we do not consume any token for + generating the queries. + """ + + if len(messages) == 0: + raise ValueError("Passed chat history does not contain any messages. " + "Please include at least one message in the history.") + + last_message = messages[-1] + + if last_message.role != Role.USER: + raise ValueError(f"Expected a UserMessage, got {type(last_message)}.") + + return [Query(text=last_message.content)] + + async def agenerate(self, + messages: Messages, + max_prompt_tokens: int) -> List[Query]: + return self.generate(messages, max_prompt_tokens) diff --git a/tests/unit/query_generators/test_last_message_query_generator.py b/tests/unit/query_generators/test_last_message_query_generator.py new file mode 100644 index 00000000..308c1b4d --- /dev/null +++ b/tests/unit/query_generators/test_last_message_query_generator.py @@ -0,0 +1,41 @@ +import pytest + +from canopy.chat_engine.query_generator import LastMessageQueryGenerator +from canopy.models.data_models import UserMessage, Query, AssistantMessage + + +@pytest.fixture +def sample_messages(): + return [ + UserMessage(content="What is photosynthesis?") + ] + + +@pytest.fixture +def query_generator(): + return LastMessageQueryGenerator() + + +def test_generate(query_generator, sample_messages): + expected = [Query(text=sample_messages[-1].content)] + actual = query_generator.generate(sample_messages, 0) + assert actual == expected + + +@pytest.mark.asyncio +async def test_agenerate(query_generator, sample_messages): + expected = [Query(text=sample_messages[-1].content)] + actual = await query_generator.agenerate(sample_messages, 0) + assert actual == expected + + +def test_generate_fails_with_empty_history(query_generator): + with pytest.raises(ValueError): + query_generator.generate([], 0) + + +def test_generate_fails_with_no_user_message(query_generator): + with pytest.raises(ValueError): + query_generator.generate([ + AssistantMessage(content="Hi! How can I help you?") + ], 0)