Skip to content
This repository was archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Add last message query generator (#210)
Browse files Browse the repository at this point in the history
* Add last message query generator

* Fix lint

* Fix lint

* Improve docs

* Fix docstring

* Take only user messages

* Fix lint

* Fix behavior

* Fix sample messages

* [config] Changed Anyscale config to use LastMessageQueryGenerator

Since Anyscale don't support function calling for now

* [config] Update anyscale.yaml - The new QueryGenerator doesn't need an LLM

Otherwise it will error out

---------

Co-authored-by: ilai <[email protected]>
  • Loading branch information
izellevy and igiloh-pinecone authored Dec 10, 2023
1 parent b129a05 commit 405d73d
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 12 deletions.
12 changes: 1 addition & 11 deletions config/anyscale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,4 @@ chat_engine:
# The query builder is responsible for generating textual queries given user message history.
# --------------------------------------------------------------------
query_builder:
type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator]
params:
prompt: *query_builder_prompt # The query builder's system prompt for calling the LLM
function_description: # A function description passed to the LLM's `function_calling` API
Query search engine for relevant information

llm: # The LLM that the query builder will use to generate queries.
#Use OpenAI for function call for now
type: OpenAILLM
params:
model_name: gpt-3.5-turbo
type: LastMessageQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator]
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ chat_engine:
# The query builder is responsible for generating textual queries given user message history.
# --------------------------------------------------------------------
query_builder:
type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator]
type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator]
params:
prompt: *query_builder_prompt # The query builder's system prompt for calling the LLM
function_description: # A function description passed to the LLM's `function_calling` API
Expand Down
1 change: 1 addition & 0 deletions src/canopy/chat_engine/query_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import QueryGenerator
from .function_calling import FunctionCallingQueryGenerator
from .last_message import LastMessageQueryGenerator
36 changes: 36 additions & 0 deletions src/canopy/chat_engine/query_generator/last_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List

from canopy.chat_engine.query_generator import QueryGenerator
from canopy.models.data_models import Messages, Query, Role


class LastMessageQueryGenerator(QueryGenerator):
"""
Returns the last message as a query without running any LLMs. This can be
considered as the most basic query generation. Please use other query generators
for more accurate results.
"""

def generate(self,
messages: Messages,
max_prompt_tokens: int) -> List[Query]:
"""
max_prompt_token is dismissed since we do not consume any token for
generating the queries.
"""

if len(messages) == 0:
raise ValueError("Passed chat history does not contain any messages. "
"Please include at least one message in the history.")

last_message = messages[-1]

if last_message.role != Role.USER:
raise ValueError(f"Expected a UserMessage, got {type(last_message)}.")

return [Query(text=last_message.content)]

async def agenerate(self,
messages: Messages,
max_prompt_tokens: int) -> List[Query]:
return self.generate(messages, max_prompt_tokens)
41 changes: 41 additions & 0 deletions tests/unit/query_generators/test_last_message_query_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from canopy.chat_engine.query_generator import LastMessageQueryGenerator
from canopy.models.data_models import UserMessage, Query, AssistantMessage


@pytest.fixture
def sample_messages():
return [
UserMessage(content="What is photosynthesis?")
]


@pytest.fixture
def query_generator():
return LastMessageQueryGenerator()


def test_generate(query_generator, sample_messages):
expected = [Query(text=sample_messages[-1].content)]
actual = query_generator.generate(sample_messages, 0)
assert actual == expected


@pytest.mark.asyncio
async def test_agenerate(query_generator, sample_messages):
expected = [Query(text=sample_messages[-1].content)]
actual = await query_generator.agenerate(sample_messages, 0)
assert actual == expected


def test_generate_fails_with_empty_history(query_generator):
with pytest.raises(ValueError):
query_generator.generate([], 0)


def test_generate_fails_with_no_user_message(query_generator):
with pytest.raises(ValueError):
query_generator.generate([
AssistantMessage(content="Hi! How can I help you?")
], 0)

0 comments on commit 405d73d

Please sign in to comment.