This repository was archived by the owner on Nov 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add last message query generator (#210)
* Add last message query generator * Fix lint * Fix lint * Improve docs * Fix docstring * Take only user messages * Fix lint * Fix behavior * Fix sample messages * [config] Changed Anyscale config to use LastMessageQueryGenerator Since Anyscale don't support function calling for now * [config] Update anyscale.yaml - The new QueryGenerator doesn't need an LLM Otherwise it will error out --------- Co-authored-by: ilai <[email protected]>
- Loading branch information
1 parent
b129a05
commit 405d73d
Showing
5 changed files
with
80 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .base import QueryGenerator | ||
from .function_calling import FunctionCallingQueryGenerator | ||
from .last_message import LastMessageQueryGenerator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import List | ||
|
||
from canopy.chat_engine.query_generator import QueryGenerator | ||
from canopy.models.data_models import Messages, Query, Role | ||
|
||
|
||
class LastMessageQueryGenerator(QueryGenerator): | ||
""" | ||
Returns the last message as a query without running any LLMs. This can be | ||
considered as the most basic query generation. Please use other query generators | ||
for more accurate results. | ||
""" | ||
|
||
def generate(self, | ||
messages: Messages, | ||
max_prompt_tokens: int) -> List[Query]: | ||
""" | ||
max_prompt_token is dismissed since we do not consume any token for | ||
generating the queries. | ||
""" | ||
|
||
if len(messages) == 0: | ||
raise ValueError("Passed chat history does not contain any messages. " | ||
"Please include at least one message in the history.") | ||
|
||
last_message = messages[-1] | ||
|
||
if last_message.role != Role.USER: | ||
raise ValueError(f"Expected a UserMessage, got {type(last_message)}.") | ||
|
||
return [Query(text=last_message.content)] | ||
|
||
async def agenerate(self, | ||
messages: Messages, | ||
max_prompt_tokens: int) -> List[Query]: | ||
return self.generate(messages, max_prompt_tokens) |
41 changes: 41 additions & 0 deletions
41
tests/unit/query_generators/test_last_message_query_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import pytest | ||
|
||
from canopy.chat_engine.query_generator import LastMessageQueryGenerator | ||
from canopy.models.data_models import UserMessage, Query, AssistantMessage | ||
|
||
|
||
@pytest.fixture | ||
def sample_messages(): | ||
return [ | ||
UserMessage(content="What is photosynthesis?") | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def query_generator(): | ||
return LastMessageQueryGenerator() | ||
|
||
|
||
def test_generate(query_generator, sample_messages): | ||
expected = [Query(text=sample_messages[-1].content)] | ||
actual = query_generator.generate(sample_messages, 0) | ||
assert actual == expected | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_agenerate(query_generator, sample_messages): | ||
expected = [Query(text=sample_messages[-1].content)] | ||
actual = await query_generator.agenerate(sample_messages, 0) | ||
assert actual == expected | ||
|
||
|
||
def test_generate_fails_with_empty_history(query_generator): | ||
with pytest.raises(ValueError): | ||
query_generator.generate([], 0) | ||
|
||
|
||
def test_generate_fails_with_no_user_message(query_generator): | ||
with pytest.raises(ValueError): | ||
query_generator.generate([ | ||
AssistantMessage(content="Hi! How can I help you?") | ||
], 0) |