Skip to content

Commit

Permalink
Rename prompt.py to llm.py. Simplify prompt_formatter and add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
madiator committed Dec 12, 2024
1 parent 23c74cb commit 5a4a4d7
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 49 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class Poems(BaseModel):
poems_list: List[Poem] = Field(description="A list of poems.")


# We define a Prompter that generates poems which gets applied to the topics dataset.
poet = curator.Prompter(
# We define an `LLM` object that generates poems which gets applied to the topics dataset.
poet = curator.LLM(
# `prompt_func` takes a row of the dataset as input.
# `row` is a dictionary with a single key 'topic' in this case.
prompt_func=lambda row: f"Write two poems about {row['topic']}.",
Expand All @@ -97,7 +97,7 @@ print(poem.to_pandas())
# 2 Beauty of Bespoke Labs's Curator library In whispers of design and crafted grace,\nBesp...
# 3 Beauty of Bespoke Labs's Curator library In the hushed breath of parchment and ink,\nBe...
```
Note that `topics` can be created with `curator.Prompter` as well,
Note that `topics` can be created with `curator.LLM` as well,
and we can scale this up to create tens of thousands of diverse poems.
You can see a more detailed example in the [examples/poem.py](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples/poem.py) file,
and other examples in the [examples](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples) directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class Poems(BaseModel):
poems_list: List[Poem] = Field(description="A list of poems.")
# We define a Prompter that generates poems which gets applied to the topics dataset.
poet = curator.Prompter(
# We define an LLM object that generates poems which gets applied to the topics dataset.
poet = curator.LLM(
# prompt_func takes a row of the dataset as input.
# row is a dictionary with a single key 'topic' in this case.
prompt_func=lambda row: f"Write two poems about {row['topic']}.",
Expand Down
2 changes: 1 addition & 1 deletion examples/poem.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Poems(BaseModel):
poems_list: List[str] = Field(description="A list of poems.")


# We define a prompter that generates poems which gets applied to the topics dataset.
# We define an `LLM` object that generates poems which gets applied to the topics dataset.
poet = curator.LLM(
# The prompt_func takes a row of the dataset as input.
# The row is a dictionary with a single key 'topic' in this case.
Expand Down
2 changes: 1 addition & 1 deletion src/bespokelabs/curator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .dataset import Dataset
from .prompter.prompter import LLM
from .prompter.llm import LLM
File renamed without changes.
76 changes: 42 additions & 34 deletions src/bespokelabs/curator/prompter/prompt_formatter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import inspect
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union

Expand All @@ -6,44 +7,48 @@
from bespokelabs.curator.request_processor.generic_request import GenericRequest

T = TypeVar("T")
_DictOrBaseModel = Union[Dict[str, Any], BaseModel]


def _validate_messages(messages: list[dict]) -> None:
"""Validates that messages conform to the expected chat format.
Args:
messages: A list of message dictionaries to validate.
Raises:
ValueError: If messages don't meet the required format:
- Must be a list of dictionaries
- Each message must have 'role' and 'content' keys
- Role must be one of: 'system', 'user', 'assistant'
"""
valid_roles = {'system', 'user', 'assistant'}

for msg in messages:
if not isinstance(msg, dict):
raise ValueError(
"In the return value (a list) of the prompt_func, each "
"message must be a dictionary")

if 'role' not in msg or 'content' not in msg:
raise ValueError(
"In the return value (a list) of the prompt_func, each "
"message must contain 'role' and 'content' keys")

if msg['role'] not in valid_roles:
raise ValueError(f"In the return value (a list) of the prompt_func, "
f"each message role must be one of: {', '.join(sorted(valid_roles))}")


@dataclasses.dataclass
class PromptFormatter:
model_name: str
prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]]
parse_func: Optional[
Callable[
[
Union[Dict[str, Any], BaseModel],
Union[Dict[str, Any], BaseModel],
],
T,
]
] = None
prompt_func: Callable[[_DictOrBaseModel], Dict[str, str]]
parse_func: Optional[Callable[[_DictOrBaseModel, _DictOrBaseModel], T]] = None
response_format: Optional[Type[BaseModel]] = None

def __init__(
self,
model_name: str,
prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]],
parse_func: Optional[
Callable[
[
Union[Dict[str, Any], BaseModel],
Union[Dict[str, Any], BaseModel],
],
T,
]
] = None,
response_format: Optional[Type[BaseModel]] = None,
):
self.model_name = model_name
self.prompt_func = prompt_func
self.parse_func = parse_func
self.response_format = response_format

def create_generic_request(self, row: Dict[str, Any] | BaseModel, idx: int) -> GenericRequest:
"""Format the request object based off Prompter attributes."""
def create_generic_request(self, row: _DictOrBaseModel, idx: int) -> GenericRequest:
"""Format the request object based off of `LLM` attributes."""
sig = inspect.signature(self.prompt_func)
if len(sig.parameters) == 0:
prompts = self.prompt_func()
Expand All @@ -54,9 +59,12 @@ def create_generic_request(self, row: Dict[str, Any] | BaseModel, idx: int) -> G

if isinstance(prompts, str):
messages = [{"role": "user", "content": prompts}]
else:
# TODO(Ryan): Add validation here
elif isinstance(prompts, list):
_validate_messages(prompts)
messages = prompts
else:
raise ValueError(
"The return value of the prompt_func must be a list of dictionaries.")

# Convert BaseModel to dict for serialization
if isinstance(row, BaseModel):
Expand Down
83 changes: 83 additions & 0 deletions src/bespokelabs/curator/prompter/prompt_formatter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
from pydantic import BaseModel

from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter, _validate_messages


def test_validate_messages_valid():
"""Tests that valid message formats pass validation."""
valid_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
]
# Should not raise any exceptions
_validate_messages(valid_messages)


def test_validate_messages_invalid_format():
"""Tests that invalid message formats raise appropriate errors."""
# Test non-dict message
with pytest.raises(ValueError, match="must be a dictionary"):
_validate_messages([["role", "content"]])

# Test missing required keys
with pytest.raises(ValueError, match="must contain 'role' and 'content' keys"):
_validate_messages([{"role": "user"}])

# Test invalid role
with pytest.raises(ValueError, match="must be one of: assistant, system, user"):
_validate_messages([{"role": "invalid", "content": "test"}])


class TestResponse(BaseModel):
text: str


def test_prompt_formatter_create_generic_request():
"""Tests that PromptFormatter correctly creates GenericRequest objects."""
# Test with string prompt
formatter = PromptFormatter(
model_name="test-model",
prompt_func=lambda x: "Hello",
response_format=TestResponse
)
request = formatter.create_generic_request({"input": "test"}, 0)

assert request.model == "test-model"
assert request.messages == [{"role": "user", "content": "Hello"}]
assert request.original_row == {"input": "test"}
assert request.original_row_idx == 0
assert request.response_format is not None

# Test with message list prompt
formatter = PromptFormatter(
model_name="test-model",
prompt_func=lambda x: [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hi"}
]
)
request = formatter.create_generic_request({"input": "test"}, 1)

assert len(request.messages) == 2
assert request.messages[0]["role"] == "system"
assert request.messages[1]["role"] == "user"
assert request.original_row_idx == 1


def test_prompt_formatter_invalid_prompt_func():
"""Tests that PromptFormatter raises errors for invalid prompt functions."""
# Test prompt function with too many parameters
with pytest.raises(ValueError, match="must have 0 or 1 arguments"):
PromptFormatter(
model_name="test",
prompt_func=lambda x, y: "test"
).create_generic_request({}, 0)

# Test invalid prompt function return type
with pytest.raises(ValueError, match="must be a list of dictionaries"):
PromptFormatter(
model_name="test",
prompt_func=lambda x: {"invalid": "format"}
).create_generic_request({}, 0)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from bespokelabs.curator.dataset import Dataset
from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor
from bespokelabs.curator.prompter.prompter import PromptFormatter
from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.generic_request import GenericRequest
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
from bespokelabs.curator.request_processor.generic_response import GenericResponse
Expand Down
14 changes: 7 additions & 7 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datasets import Dataset

from bespokelabs.curator import LLM
from bespokelabs import curator


def test_same_value_caching(tmp_path):
Expand All @@ -13,7 +13,7 @@ def test_same_value_caching(tmp_path):
def prompt_func():
return f"Say '1'. Do not explain."

prompter = LLM(
prompter = curator.LLM(
prompt_func=prompt_func,
model_name="gpt-4o-mini",
)
Expand All @@ -36,7 +36,7 @@ def test_different_values_caching(tmp_path):
def prompt_func():
return f"Say '{x}'. Do not explain."

prompter = LLM(
prompter = curator.LLM(
prompt_func=prompt_func,
model_name="gpt-4o-mini",
)
Expand All @@ -52,7 +52,7 @@ def prompt_func():
def test_same_dataset_caching(tmp_path):
"""Test that using the same dataset multiple times uses cache."""
dataset = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}])
prompter = LLM(
prompter = curator.LLM(
prompt_func=lambda x: x["instruction"],
model_name="gpt-4o-mini",
)
Expand All @@ -72,7 +72,7 @@ def test_different_dataset_caching(tmp_path):
"""Test that using different datasets creates different cache entries."""
dataset1 = Dataset.from_list([{"instruction": "Say '1'. Do not explain."}])
dataset2 = Dataset.from_list([{"instruction": "Say '2'. Do not explain."}])
prompter = LLM(
prompter = curator.LLM(
prompt_func=lambda x: x["instruction"],
model_name="gpt-4o-mini",
)
Expand All @@ -97,7 +97,7 @@ def value_generator():
def prompt_func():
return f"Say '{value_generator()}'. Do not explain."

prompter = LLM(
prompter = curator.LLM(
prompt_func=prompt_func,
model_name="gpt-4o-mini",
)
Expand All @@ -123,7 +123,7 @@ def test_function_hash_dir_change():
import tempfile
from pathlib import Path

from bespokelabs.curator.prompter.prompter import _get_function_hash
from bespokelabs.curator.prompter.llm import _get_function_hash

# Set up logging to write to a file in the current directory
debug_log = Path("function_debug.log")
Expand Down

0 comments on commit 5a4a4d7

Please sign in to comment.