-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rename prompt.py to llm.py. Simplify prompt_formatter and add test.
- Loading branch information
Showing
9 changed files
with
140 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .dataset import Dataset | ||
from .prompter.prompter import LLM | ||
from .prompter.llm import LLM |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import pytest | ||
from pydantic import BaseModel | ||
|
||
from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter, _validate_messages | ||
|
||
|
||
def test_validate_messages_valid(): | ||
"""Tests that valid message formats pass validation.""" | ||
valid_messages = [ | ||
{"role": "system", "content": "You are a helpful assistant"}, | ||
{"role": "user", "content": "Hello"}, | ||
{"role": "assistant", "content": "Hi there!"} | ||
] | ||
# Should not raise any exceptions | ||
_validate_messages(valid_messages) | ||
|
||
|
||
def test_validate_messages_invalid_format(): | ||
"""Tests that invalid message formats raise appropriate errors.""" | ||
# Test non-dict message | ||
with pytest.raises(ValueError, match="must be a dictionary"): | ||
_validate_messages([["role", "content"]]) | ||
|
||
# Test missing required keys | ||
with pytest.raises(ValueError, match="must contain 'role' and 'content' keys"): | ||
_validate_messages([{"role": "user"}]) | ||
|
||
# Test invalid role | ||
with pytest.raises(ValueError, match="must be one of: assistant, system, user"): | ||
_validate_messages([{"role": "invalid", "content": "test"}]) | ||
|
||
|
||
class TestResponse(BaseModel): | ||
text: str | ||
|
||
|
||
def test_prompt_formatter_create_generic_request(): | ||
"""Tests that PromptFormatter correctly creates GenericRequest objects.""" | ||
# Test with string prompt | ||
formatter = PromptFormatter( | ||
model_name="test-model", | ||
prompt_func=lambda x: "Hello", | ||
response_format=TestResponse | ||
) | ||
request = formatter.create_generic_request({"input": "test"}, 0) | ||
|
||
assert request.model == "test-model" | ||
assert request.messages == [{"role": "user", "content": "Hello"}] | ||
assert request.original_row == {"input": "test"} | ||
assert request.original_row_idx == 0 | ||
assert request.response_format is not None | ||
|
||
# Test with message list prompt | ||
formatter = PromptFormatter( | ||
model_name="test-model", | ||
prompt_func=lambda x: [ | ||
{"role": "system", "content": "You are helpful"}, | ||
{"role": "user", "content": "Hi"} | ||
] | ||
) | ||
request = formatter.create_generic_request({"input": "test"}, 1) | ||
|
||
assert len(request.messages) == 2 | ||
assert request.messages[0]["role"] == "system" | ||
assert request.messages[1]["role"] == "user" | ||
assert request.original_row_idx == 1 | ||
|
||
|
||
def test_prompt_formatter_invalid_prompt_func(): | ||
"""Tests that PromptFormatter raises errors for invalid prompt functions.""" | ||
# Test prompt function with too many parameters | ||
with pytest.raises(ValueError, match="must have 0 or 1 arguments"): | ||
PromptFormatter( | ||
model_name="test", | ||
prompt_func=lambda x, y: "test" | ||
).create_generic_request({}, 0) | ||
|
||
# Test invalid prompt function return type | ||
with pytest.raises(ValueError, match="must be a list of dictionaries"): | ||
PromptFormatter( | ||
model_name="test", | ||
prompt_func=lambda x: {"invalid": "format"} | ||
).create_generic_request({}, 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters