diff --git a/examples/class_based_prompter.py b/examples/class_based_prompter.py new file mode 100644 index 00000000..abbe456a --- /dev/null +++ b/examples/class_based_prompter.py @@ -0,0 +1,83 @@ +"""Example of using the class-based Prompter approach.""" + +from typing import Dict, Any, Optional +from pydantic import BaseModel + +from bespokelabs.curator import Prompter +from bespokelabs.curator.prompter.base_prompter import BasePrompter + + +class ResponseFormat(BaseModel): + """Example response format.""" + + answer: str + confidence: float + + +class MathPrompter(BasePrompter): + """Example custom prompter for math problems.""" + + def prompt_func(self, row: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + """Generate prompts for math problems. + + Args: + row: Optional dictionary containing 'question' key. + If None, generates a default prompt. + + Returns: + Dict containing user_prompt and system_prompt. + """ + if row is None: + return { + "user_prompt": "What is 2 + 2?", + "system_prompt": "You are a math tutor. Provide clear, step-by-step solutions.", + } + + return { + "user_prompt": row["question"], + "system_prompt": "You are a math tutor. Provide clear, step-by-step solutions.", + } + + def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> ResponseFormat: + """Parse LLM response into structured format. + + Args: + row: Input row that generated the response + response: Raw response from the LLM + + Returns: + ResponseFormat containing answer and confidence + """ + # Extract answer and add confidence score + return ResponseFormat( + answer=response["message"], + confidence=0.95 if "step" in response["message"].lower() else 0.7, + ) + + +def main(): + """Example usage of class-based prompter.""" + # Create instance of custom prompter + math_prompter = MathPrompter( + model_name="gpt-4o-mini", + response_format=ResponseFormat, + ) + + # Single completion without input + result = math_prompter() + print(f"Single completion result: {result}") + + # Process multiple questions + questions = [ + {"question": "What is 5 * 3?"}, + {"question": "Solve x + 2 = 7"}, + ] + results = math_prompter(questions) + print("\nBatch processing results:") + for q, r in zip(questions, results): + print(f"Q: {q['question']}") + print(f"A: {r.answer} (confidence: {r.confidence})") + + +if __name__ == "__main__": + main() diff --git a/src/bespokelabs/curator/prompter/base_prompter.py b/src/bespokelabs/curator/prompter/base_prompter.py new file mode 100644 index 00000000..f965e7d0 --- /dev/null +++ b/src/bespokelabs/curator/prompter/base_prompter.py @@ -0,0 +1,119 @@ +"""Base class for Prompter implementations. + +This module provides the abstract base class for implementing custom prompters. +The BasePrompter class defines the interface that all prompter implementations +must follow. + +Example: + Creating a custom prompter: + ```python + class CustomPrompter(BasePrompter): + def prompt_func(self, row: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + # Generate prompts for your specific use case + if row is None: + return { + "user_prompt": "Default prompt", + "system_prompt": "System instructions", + } + return { + "user_prompt": f"Process input: {row['data']}", + "system_prompt": "System instructions", + } + + def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> Any: + # Optional: Override to customize response parsing + return response + + # Usage + prompter = CustomPrompter( + model_name="gpt-4", + response_format=MyResponseFormat, + ) + result = prompter(dataset) # Process dataset + single_result = prompter() # Single completion + ``` + +For simpler use cases where you don't need a full class implementation, +you can use the function-based approach with the Prompter class directly. +See the Prompter class documentation for details. +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Type, TypeVar, Union + +from pydantic import BaseModel + +T = TypeVar("T") + + +class BasePrompter(ABC): + """Abstract base class for prompter implementations. + + This class defines the interface for prompter implementations. Subclasses must + implement prompt_func and may optionally override parse_func. + """ + + def __init__( + self, + model_name: str, + response_format: Optional[Type[BaseModel]] = None, + batch: bool = False, + batch_size: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + ): + """Initialize a BasePrompter. + + Args: + model_name (str): The name of the LLM to use + response_format (Optional[Type[BaseModel]]): A Pydantic model specifying the + response format from the LLM. + batch (bool): Whether to use batch processing + batch_size (Optional[int]): The size of the batch to use, only used if batch is True + temperature (Optional[float]): The temperature to use for the LLM, only used if batch is False + top_p (Optional[float]): The top_p to use for the LLM, only used if batch is False + presence_penalty (Optional[float]): The presence_penalty to use for the LLM, only used if batch is False + frequency_penalty (Optional[float]): The frequency_penalty to use for the LLM, only used if batch is False + """ + self.model_name = model_name + self.response_format = response_format + self.batch_mode = batch + self.batch_size = batch_size + self.temperature = temperature + self.top_p = top_p + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + + @abstractmethod + def prompt_func( + self, + row: Optional[Union[Dict[str, Any], BaseModel]] = None, + ) -> Dict[str, str]: + """Override this method to define how prompts are generated. + + Args: + row (Optional[Union[Dict[str, Any], BaseModel]]): The input row to generate a prompt for. + If None, generate a prompt without input data. + + Returns: + Dict[str, str]: A dictionary containing the prompt components (e.g., user_prompt, system_prompt). + """ + pass + + def parse_func( + self, + row: Union[Dict[str, Any], BaseModel], + response: Union[Dict[str, Any], BaseModel], + ) -> T: + """Override this method to define how responses are parsed. + + Args: + row (Union[Dict[str, Any], BaseModel]): The input row that generated the response. + response (Union[Dict[str, Any], BaseModel]): The response from the LLM. + + Returns: + T: The parsed response in the desired format. + """ + return response diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index f0b5e9c9..a9c5e27d 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -1,4 +1,63 @@ -"""Curator: Bespoke Labs Synthetic Data Generation Library.""" +"""Curator: Bespoke Labs Synthetic Data Generation Library. + +This module provides the Prompter class for interacting with LLMs. It supports +both function-based and class-based approaches to defining prompt generation +and response parsing logic. + +Examples: + Function-based approach (simple use cases): + ```python + from bespokelabs.curator import Prompter + + # Define prompt and parse functions + def prompt_func(row=None): + if row is None: + return { + "user_prompt": "Default prompt", + "system_prompt": "You are a helpful assistant.", + } + return { + "user_prompt": f"Process: {row['data']}", + "system_prompt": "You are a helpful assistant.", + } + + def parse_func(row, response): + return {"result": response.message} + + # Create prompter instance + prompter = Prompter( + model_name="gpt-4", + prompt_func=prompt_func, + parse_func=parse_func, + response_format=MyResponseFormat, + ) + + # Use the prompter + result = prompter(dataset) # Process dataset + single_result = prompter() # Single completion + ``` + + Class-based approach (complex use cases): + ```python + from bespokelabs.curator.prompter.base_prompter import BasePrompter + + class CustomPrompter(BasePrompter): + def prompt_func(self, row=None): + # Your custom prompt generation logic + return { + "user_prompt": "...", + "system_prompt": "...", + } + + def parse_func(self, row, response): + # Your custom response parsing logic + return response + + prompter = CustomPrompter(model_name="gpt-4") + ``` + +For more examples, see the examples/ directory in the repository. +""" import inspect import logging @@ -13,6 +72,7 @@ from xxhash import xxh64 from bespokelabs.curator.db import MetadataDB +from bespokelabs.curator.prompter.base_prompter import BasePrompter from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor from bespokelabs.curator.request_processor.openai_batch_request_processor import ( @@ -28,13 +88,45 @@ logger = logging.getLogger(__name__) -class Prompter: - """Interface for prompting LLMs.""" +class Prompter(BasePrompter): + """Interface for prompting LLMs. + + This class supports both function-based and class-based approaches: + + Function-based: + prompter = Prompter( + model_name="gpt-4", + prompt_func=lambda row: f"Process {row}", + parse_func=lambda row, response: response + ) + + Class-based: + class CustomPrompter(BasePrompter): + def prompt_func(self, row): + return f"Process {row}" + def parse_func(self, row, response): + return response + + prompter = CustomPrompter(model_name="gpt-4") + """ + + _prompt_func: Optional[ + Callable[ + [Optional[Union[Dict[str, Any], BaseModel]]], + Dict[str, str], + ] + ] + _parse_func: Optional[ + Callable[ + [Union[Dict[str, Any], BaseModel], Union[Dict[str, Any], BaseModel]], + T, + ] + ] def __init__( self, model_name: str, - prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]], + prompt_func: Optional[Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]]] = None, parse_func: Optional[ Callable[ [ @@ -69,11 +161,29 @@ def __init__( presence_penalty (Optional[float]): The presence_penalty to use for the LLM, only used if batch is False frequency_penalty (Optional[float]): The frequency_penalty to use for the LLM, only used if batch is False """ - prompt_sig = inspect.signature(prompt_func) - if len(prompt_sig.parameters) > 1: - raise ValueError( - f"prompt_func must take one argument or less, got {len(prompt_sig.parameters)}" - ) + # Call parent class constructor first + super().__init__( + model_name=model_name, + response_format=response_format, + batch=batch, + batch_size=batch_size, + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + ) + + # Store the provided functions + self._prompt_func = prompt_func + self._parse_func = parse_func + + # Validate function signatures if provided + if prompt_func is not None: + prompt_sig = inspect.signature(prompt_func) + if len(prompt_sig.parameters) > 1: + raise ValueError( + f"prompt_func must take one argument or less, got {len(prompt_sig.parameters)}" + ) if parse_func is not None: parse_sig = inspect.signature(parse_func) @@ -83,8 +193,9 @@ def __init__( ) self.prompt_formatter = PromptFormatter( - model_name, prompt_func, parse_func, response_format + model_name, self.prompt_func, self.parse_func, response_format ) + self.batch_mode = batch if batch: if batch_size is None: diff --git a/tests/test_prompt.py b/tests/test_prompt.py index f1c327cd..ca9760f4 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from bespokelabs.curator import Prompter +from bespokelabs.curator.prompter.base_prompter import BasePrompter class MockResponseFormat(BaseModel): @@ -134,3 +135,109 @@ def simple_prompt_func(): assert isinstance(result, MockResponseFormat) assert hasattr(result, "message") assert hasattr(result, "confidence") + + +class TestClassBasedPrompter: + """Test cases for class-based Prompter implementation.""" + + class CustomPrompter(BasePrompter): + """Test prompter implementation using class-based approach.""" + + def prompt_func(self, row=None): + if row is None: + return { + "user_prompt": "Write a test message", + "system_prompt": "You are a helpful assistant.", + } + return { + "user_prompt": f"Context: {row['context']} Answer this question: {row['question']}", + "system_prompt": "You are a helpful assistant.", + } + + def parse_func(self, row, response): + # Custom parsing that adds a prefix to the message + if isinstance(response, MockResponseFormat): + return MockResponseFormat( + message=f"Parsed: {response.message}", confidence=response.confidence + ) + return response + + @pytest.fixture + def class_based_prompter(self): + """Create a class-based Prompter instance for testing.""" + return self.CustomPrompter( + model_name="gpt-4o-mini", + response_format=MockResponseFormat, + ) + + def test_class_based_completion(self, class_based_prompter, tmp_path): + """Test that class-based prompter processes a dataset correctly.""" + test_data = { + "context": ["Test context 1", "Test context 2"], + "question": ["What is 1+1?", "What is 2+2?"], + } + dataset = Dataset.from_dict(test_data) + + os.environ["BELLA_CACHE_DIR"] = str(tmp_path) + result_dataset = class_based_prompter(dataset) + result_dataset = result_dataset.to_huggingface() + + assert len(result_dataset) == len(dataset) + assert "message" in result_dataset.column_names + assert "confidence" in result_dataset.column_names + # Verify our custom parse_func was applied + assert all(msg.startswith("Parsed: ") for msg in result_dataset["message"]) + + def test_class_based_single_completion(self, class_based_prompter): + """Test that class-based prompter works for single completions.""" + result = class_based_prompter() + + assert isinstance(result, MockResponseFormat) + assert hasattr(result, "message") + assert hasattr(result, "confidence") + assert result.message.startswith("Parsed: ") + + def test_mixed_approach(self): + """Test using class-based prompter with function-based parse_func.""" + + def custom_parse_func(row, response): + if isinstance(response, MockResponseFormat): + return MockResponseFormat( + message=f"Function parsed: {response.message}", confidence=response.confidence + ) + return response + + prompter = self.CustomPrompter( + model_name="gpt-4o-mini", + response_format=MockResponseFormat, + parse_func=custom_parse_func, # Override class-based parse_func + ) + + result = prompter() + assert isinstance(result, MockResponseFormat) + assert result.message.startswith("Function parsed: ") + + def test_invalid_prompt_func(self): + """Test that invalid prompt_func signature raises ValueError.""" + + class InvalidPrompter(BasePrompter): + def prompt_func(self, row, extra_arg): # Invalid: too many parameters + return {"prompt": f"Process {row}"} + + with pytest.raises(ValueError) as exc_info: + InvalidPrompter(model_name="gpt-4o-mini") + assert "prompt_func must take one argument or less" in str(exc_info.value) + + def test_invalid_parse_func(self): + """Test that invalid parse_func signature raises ValueError.""" + + class InvalidParsePrompter(BasePrompter): + def prompt_func(self, row=None): + return {"prompt": "test"} + + def parse_func(self, response): # Invalid: too few parameters + return response + + with pytest.raises(ValueError) as exc_info: + InvalidParsePrompter(model_name="gpt-4o-mini") + assert "parse_func must take exactly 2 arguments" in str(exc_info.value)