Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Prompter to support class-based approach #173

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
83 changes: 83 additions & 0 deletions examples/class_based_prompter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Example of using the class-based Prompter approach."""

from typing import Dict, Any, Optional
from pydantic import BaseModel

from bespokelabs.curator import Prompter
from bespokelabs.curator.prompter.base_prompter import BasePrompter


class ResponseFormat(BaseModel):
"""Example response format."""

answer: str
confidence: float


class MathPrompter(BasePrompter):
"""Example custom prompter for math problems."""

def prompt_func(self, row: Optional[Dict[str, Any]] = None) -> Dict[str, str]:
"""Generate prompts for math problems.

Args:
row: Optional dictionary containing 'question' key.
If None, generates a default prompt.

Returns:
Dict containing user_prompt and system_prompt.
"""
if row is None:
return {
"user_prompt": "What is 2 + 2?",
"system_prompt": "You are a math tutor. Provide clear, step-by-step solutions.",
}

return {
"user_prompt": row["question"],
"system_prompt": "You are a math tutor. Provide clear, step-by-step solutions.",
}

def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> ResponseFormat:
"""Parse LLM response into structured format.

Args:
row: Input row that generated the response
response: Raw response from the LLM

Returns:
ResponseFormat containing answer and confidence
"""
# Extract answer and add confidence score
return ResponseFormat(
answer=response["message"],
confidence=0.95 if "step" in response["message"].lower() else 0.7,
)


def main():
"""Example usage of class-based prompter."""
# Create instance of custom prompter
math_prompter = MathPrompter(
model_name="gpt-4o-mini",
response_format=ResponseFormat,
)

# Single completion without input
result = math_prompter()
print(f"Single completion result: {result}")

# Process multiple questions
questions = [
{"question": "What is 5 * 3?"},
{"question": "Solve x + 2 = 7"},
]
results = math_prompter(questions)
print("\nBatch processing results:")
for q, r in zip(questions, results):
print(f"Q: {q['question']}")
print(f"A: {r.answer} (confidence: {r.confidence})")


if __name__ == "__main__":
main()
119 changes: 119 additions & 0 deletions src/bespokelabs/curator/prompter/base_prompter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Base class for Prompter implementations.

This module provides the abstract base class for implementing custom prompters.
The BasePrompter class defines the interface that all prompter implementations
must follow.

Example:
Creating a custom prompter:
```python
class CustomPrompter(BasePrompter):
def prompt_func(self, row: Optional[Dict[str, Any]] = None) -> Dict[str, str]:
# Generate prompts for your specific use case
if row is None:
return {
"user_prompt": "Default prompt",
"system_prompt": "System instructions",
}
return {
"user_prompt": f"Process input: {row['data']}",
"system_prompt": "System instructions",
}

def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> Any:
# Optional: Override to customize response parsing
return response

# Usage
prompter = CustomPrompter(
model_name="gpt-4",
response_format=MyResponseFormat,
)
result = prompter(dataset) # Process dataset
single_result = prompter() # Single completion
```

For simpler use cases where you don't need a full class implementation,
you can use the function-based approach with the Prompter class directly.
See the Prompter class documentation for details.
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type, TypeVar, Union

from pydantic import BaseModel

T = TypeVar("T")


class BasePrompter(ABC):
"""Abstract base class for prompter implementations.

This class defines the interface for prompter implementations. Subclasses must
implement prompt_func and may optionally override parse_func.
"""

def __init__(
self,
model_name: str,
response_format: Optional[Type[BaseModel]] = None,
batch: bool = False,
batch_size: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
):
"""Initialize a BasePrompter.

Args:
model_name (str): The name of the LLM to use
response_format (Optional[Type[BaseModel]]): A Pydantic model specifying the
response format from the LLM.
batch (bool): Whether to use batch processing
batch_size (Optional[int]): The size of the batch to use, only used if batch is True
temperature (Optional[float]): The temperature to use for the LLM, only used if batch is False
top_p (Optional[float]): The top_p to use for the LLM, only used if batch is False
presence_penalty (Optional[float]): The presence_penalty to use for the LLM, only used if batch is False
frequency_penalty (Optional[float]): The frequency_penalty to use for the LLM, only used if batch is False
"""
self.model_name = model_name
self.response_format = response_format
self.batch_mode = batch
self.batch_size = batch_size
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty

@abstractmethod
def prompt_func(
self,
row: Optional[Union[Dict[str, Any], BaseModel]] = None,
) -> Dict[str, str]:
"""Override this method to define how prompts are generated.

Args:
row (Optional[Union[Dict[str, Any], BaseModel]]): The input row to generate a prompt for.
If None, generate a prompt without input data.

Returns:
Dict[str, str]: A dictionary containing the prompt components (e.g., user_prompt, system_prompt).
"""
pass

def parse_func(
self,
row: Union[Dict[str, Any], BaseModel],
response: Union[Dict[str, Any], BaseModel],
) -> T:
"""Override this method to define how responses are parsed.

Args:
row (Union[Dict[str, Any], BaseModel]): The input row that generated the response.
response (Union[Dict[str, Any], BaseModel]): The response from the LLM.

Returns:
T: The parsed response in the desired format.
"""
return response
131 changes: 121 additions & 10 deletions src/bespokelabs/curator/prompter/prompter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,63 @@
"""Curator: Bespoke Labs Synthetic Data Generation Library."""
"""Curator: Bespoke Labs Synthetic Data Generation Library.

This module provides the Prompter class for interacting with LLMs. It supports
both function-based and class-based approaches to defining prompt generation
and response parsing logic.

Examples:
Function-based approach (simple use cases):
```python
from bespokelabs.curator import Prompter

# Define prompt and parse functions
def prompt_func(row=None):
if row is None:
return {
"user_prompt": "Default prompt",
"system_prompt": "You are a helpful assistant.",
}
return {
"user_prompt": f"Process: {row['data']}",
"system_prompt": "You are a helpful assistant.",
}

def parse_func(row, response):
return {"result": response.message}

# Create prompter instance
prompter = Prompter(
model_name="gpt-4",
prompt_func=prompt_func,
parse_func=parse_func,
response_format=MyResponseFormat,
)

# Use the prompter
result = prompter(dataset) # Process dataset
single_result = prompter() # Single completion
```

Class-based approach (complex use cases):
```python
from bespokelabs.curator.prompter.base_prompter import BasePrompter

class CustomPrompter(BasePrompter):
def prompt_func(self, row=None):
# Your custom prompt generation logic
return {
"user_prompt": "...",
"system_prompt": "...",
}

def parse_func(self, row, response):
# Your custom response parsing logic
return response

prompter = CustomPrompter(model_name="gpt-4")
```

For more examples, see the examples/ directory in the repository.
"""

import inspect
import logging
Expand All @@ -13,6 +72,7 @@
from xxhash import xxh64

from bespokelabs.curator.db import MetadataDB
from bespokelabs.curator.prompter.base_prompter import BasePrompter
from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor
from bespokelabs.curator.request_processor.openai_batch_request_processor import (
Expand All @@ -28,13 +88,45 @@
logger = logging.getLogger(__name__)


class Prompter:
"""Interface for prompting LLMs."""
class Prompter(BasePrompter):
"""Interface for prompting LLMs.

This class supports both function-based and class-based approaches:

Function-based:
prompter = Prompter(
model_name="gpt-4",
prompt_func=lambda row: f"Process {row}",
parse_func=lambda row, response: response
)

Class-based:
class CustomPrompter(BasePrompter):
def prompt_func(self, row):
return f"Process {row}"
def parse_func(self, row, response):
return response

prompter = CustomPrompter(model_name="gpt-4")
"""

_prompt_func: Optional[
Callable[
[Optional[Union[Dict[str, Any], BaseModel]]],
Dict[str, str],
]
]
_parse_func: Optional[
Callable[
[Union[Dict[str, Any], BaseModel], Union[Dict[str, Any], BaseModel]],
T,
]
]

def __init__(
self,
model_name: str,
prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]],
prompt_func: Optional[Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]]] = None,
parse_func: Optional[
Callable[
[
Expand Down Expand Up @@ -69,11 +161,29 @@ def __init__(
presence_penalty (Optional[float]): The presence_penalty to use for the LLM, only used if batch is False
frequency_penalty (Optional[float]): The frequency_penalty to use for the LLM, only used if batch is False
"""
prompt_sig = inspect.signature(prompt_func)
if len(prompt_sig.parameters) > 1:
raise ValueError(
f"prompt_func must take one argument or less, got {len(prompt_sig.parameters)}"
)
# Call parent class constructor first
super().__init__(
model_name=model_name,
response_format=response_format,
batch=batch,
batch_size=batch_size,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)

# Store the provided functions
self._prompt_func = prompt_func
self._parse_func = parse_func

# Validate function signatures if provided
if prompt_func is not None:
prompt_sig = inspect.signature(prompt_func)
if len(prompt_sig.parameters) > 1:
raise ValueError(
f"prompt_func must take one argument or less, got {len(prompt_sig.parameters)}"
)

if parse_func is not None:
parse_sig = inspect.signature(parse_func)
Expand All @@ -83,8 +193,9 @@ def __init__(
)

self.prompt_formatter = PromptFormatter(
model_name, prompt_func, parse_func, response_format
model_name, self.prompt_func, self.parse_func, response_format
)

self.batch_mode = batch
if batch:
if batch_size is None:
Expand Down
Loading