From e623468e5ac4aa2dc808f8b5fafe38e17b7f9c4a Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 24 Nov 2024 05:31:37 +0000 Subject: [PATCH 1/6] Refactor Prompter to support class-based approach - Add BasePrompter abstract base class - Modify Prompter to inherit from BasePrompter - Support both function-based and class-based approaches - Maintain backward compatibility - Add comprehensive documentation with examples --- .../curator/prompter/base_prompter.py | 74 +++++++++++++ src/bespokelabs/curator/prompter/prompter.py | 104 ++++++++++++++++-- 2 files changed, 170 insertions(+), 8 deletions(-) create mode 100644 src/bespokelabs/curator/prompter/base_prompter.py diff --git a/src/bespokelabs/curator/prompter/base_prompter.py b/src/bespokelabs/curator/prompter/base_prompter.py new file mode 100644 index 00000000..d266892a --- /dev/null +++ b/src/bespokelabs/curator/prompter/base_prompter.py @@ -0,0 +1,74 @@ +"""Base class for Prompter implementations.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Type, TypeVar, Union + +from pydantic import BaseModel + +T = TypeVar('T') + +class BasePrompter(ABC): + """Abstract base class for prompter implementations. + + This class defines the interface for prompter implementations. Subclasses must + implement prompt_func and may optionally override parse_func. + """ + + def __init__( + self, + model_name: str, + response_format: Optional[Type[BaseModel]] = None, + batch: bool = False, + batch_size: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + ): + """Initialize a BasePrompter. + + Args: + model_name (str): The name of the LLM to use + response_format (Optional[Type[BaseModel]]): A Pydantic model specifying the + response format from the LLM. + batch (bool): Whether to use batch processing + batch_size (Optional[int]): The size of the batch to use, only used if batch is True + temperature (Optional[float]): The temperature to use for the LLM, only used if batch is False + top_p (Optional[float]): The top_p to use for the LLM, only used if batch is False + presence_penalty (Optional[float]): The presence_penalty to use for the LLM, only used if batch is False + frequency_penalty (Optional[float]): The frequency_penalty to use for the LLM, only used if batch is False + """ + self.model_name = model_name + self.response_format = response_format + self.batch_mode = batch + self.batch_size = batch_size + self.temperature = temperature + self.top_p = top_p + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + + @abstractmethod + def prompt_func(self, row: Optional[Union[Dict[str, Any], BaseModel]] = None) -> Dict[str, str]: + """Override this method to define how prompts are generated. + + Args: + row (Optional[Union[Dict[str, Any], BaseModel]]): The input row to generate a prompt for. + If None, generate a prompt without input data. + + Returns: + Dict[str, str]: A dictionary containing the prompt components (e.g., user_prompt, system_prompt). + """ + pass + + def parse_func(self, row: Union[Dict[str, Any], BaseModel], + response: Union[Dict[str, Any], BaseModel]) -> T: + """Override this method to define how responses are parsed. + + Args: + row (Union[Dict[str, Any], BaseModel]): The input row that generated the response. + response (Union[Dict[str, Any], BaseModel]): The response from the LLM. + + Returns: + T: The parsed response in the desired format. + """ + return response diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index f0b5e9c9..9466516d 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -13,6 +13,7 @@ from xxhash import xxh64 from bespokelabs.curator.db import MetadataDB +from bespokelabs.curator.prompter.base_prompter import BasePrompter from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.base_request_processor import BaseRequestProcessor from bespokelabs.curator.request_processor.openai_batch_request_processor import ( @@ -28,13 +29,43 @@ logger = logging.getLogger(__name__) -class Prompter: - """Interface for prompting LLMs.""" +class Prompter(BasePrompter): + """Interface for prompting LLMs. + + This class supports both function-based and class-based approaches: + + Function-based: + prompter = Prompter( + model_name="gpt-4", + prompt_func=lambda row: f"Process {row}", + parse_func=lambda row, response: response + ) + + Class-based: + class CustomPrompter(BasePrompter): + def prompt_func(self, row): + return f"Process {row}" + def parse_func(self, row, response): + return response + + prompter = CustomPrompter(model_name="gpt-4") + """ + + _prompt_func: Optional[Callable[[Optional[Union[Dict[str, Any], BaseModel]]], Dict[str, str]]] + _parse_func: Optional[ + Callable[ + [ + Union[Dict[str, Any], BaseModel], + Union[Dict[str, Any], BaseModel], + ], + T, + ] + ] def __init__( self, model_name: str, - prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]], + prompt_func: Optional[Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]]] = None, parse_func: Optional[ Callable[ [ @@ -69,10 +100,67 @@ def __init__( presence_penalty (Optional[float]): The presence_penalty to use for the LLM, only used if batch is False frequency_penalty (Optional[float]): The frequency_penalty to use for the LLM, only used if batch is False """ - prompt_sig = inspect.signature(prompt_func) - if len(prompt_sig.parameters) > 1: - raise ValueError( - f"prompt_func must take one argument or less, got {len(prompt_sig.parameters)}" + # Call parent class constructor first + super().__init__( + model_name=model_name, + response_format=response_format, + batch=batch, + batch_size=batch_size, + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + ) + + # Store the provided functions + self._prompt_func = prompt_func + self._parse_func = parse_func + + # Validate function signatures if provided + if prompt_func is not None: + prompt_sig = inspect.signature(prompt_func) + if len(prompt_sig.parameters) > 1: + raise ValueError( + f"prompt_func must take one argument or less, got {len(prompt_sig.parameters)}" + ) + + if parse_func is not None: + parse_sig = inspect.signature(parse_func) + if len(parse_sig.parameters) != 2: + raise ValueError( + f"parse_func must take exactly 2 arguments, got {len(parse_sig.parameters)}" + ) + + self.prompt_formatter = PromptFormatter( + model_name, self.prompt_func, self.parse_func, response_format + ) + + self.batch_mode = batch + if batch: + if batch_size is None: + batch_size = 1_000 + logger.info( + f"batch=True but no batch_size provided, using default batch_size of {batch_size:,}" + ) + self._request_processor = OpenAIBatchRequestProcessor( + model=model_name, + batch_size=batch_size, + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + ) + else: + if batch_size is not None: + logger.warning( + f"Prompter argument `batch_size` {batch_size} is ignored because `batch` is False" + ) + self._request_processor = OpenAIOnlineRequestProcessor( + model=model_name, + temperature=temperature, + top_p=top_p, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, ) if parse_func is not None: @@ -83,7 +171,7 @@ def __init__( ) self.prompt_formatter = PromptFormatter( - model_name, prompt_func, parse_func, response_format + model_name, self.prompt_func, self.parse_func, response_format ) self.batch_mode = batch if batch: From 2ecd983f946159b082c8782b381b4a776142a18f Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 24 Nov 2024 05:32:43 +0000 Subject: [PATCH 2/6] Add comprehensive tests for class-based Prompter - Add TestClassBasedPrompter test suite - Test class-based implementation - Test mixed approach (class with function override) - Test invalid method signatures - Verify custom parsing behavior - Ensure backward compatibility --- tests/test_prompt.py | 106 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/tests/test_prompt.py b/tests/test_prompt.py index f1c327cd..1baf081b 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from bespokelabs.curator import Prompter +from bespokelabs.curator.prompter.base_prompter import BasePrompter class MockResponseFormat(BaseModel): @@ -134,3 +135,108 @@ def simple_prompt_func(): assert isinstance(result, MockResponseFormat) assert hasattr(result, "message") assert hasattr(result, "confidence") + + +class TestClassBasedPrompter: + """Test cases for class-based Prompter implementation.""" + + class CustomPrompter(BasePrompter): + """Test prompter implementation using class-based approach.""" + + def prompt_func(self, row=None): + if row is None: + return { + "user_prompt": "Write a test message", + "system_prompt": "You are a helpful assistant.", + } + return { + "user_prompt": f"Context: {row['context']} Answer this question: {row['question']}", + "system_prompt": "You are a helpful assistant.", + } + + def parse_func(self, row, response): + # Custom parsing that adds a prefix to the message + if isinstance(response, MockResponseFormat): + return MockResponseFormat( + message=f"Parsed: {response.message}", + confidence=response.confidence + ) + return response + + @pytest.fixture + def class_based_prompter(self): + """Create a class-based Prompter instance for testing.""" + return self.CustomPrompter( + model_name="gpt-4o-mini", + response_format=MockResponseFormat, + ) + + def test_class_based_completion(self, class_based_prompter, tmp_path): + """Test that class-based prompter processes a dataset correctly.""" + test_data = { + "context": ["Test context 1", "Test context 2"], + "question": ["What is 1+1?", "What is 2+2?"], + } + dataset = Dataset.from_dict(test_data) + + os.environ["BELLA_CACHE_DIR"] = str(tmp_path) + result_dataset = class_based_prompter(dataset) + result_dataset = result_dataset.to_huggingface() + + assert len(result_dataset) == len(dataset) + assert "message" in result_dataset.column_names + assert "confidence" in result_dataset.column_names + # Verify our custom parse_func was applied + assert all(msg.startswith("Parsed: ") for msg in result_dataset["message"]) + + def test_class_based_single_completion(self, class_based_prompter): + """Test that class-based prompter works for single completions.""" + result = class_based_prompter() + + assert isinstance(result, MockResponseFormat) + assert hasattr(result, "message") + assert hasattr(result, "confidence") + assert result.message.startswith("Parsed: ") + + def test_mixed_approach(self): + """Test using class-based prompter with function-based parse_func.""" + def custom_parse_func(row, response): + if isinstance(response, MockResponseFormat): + return MockResponseFormat( + message=f"Function parsed: {response.message}", + confidence=response.confidence + ) + return response + + prompter = self.CustomPrompter( + model_name="gpt-4o-mini", + response_format=MockResponseFormat, + parse_func=custom_parse_func # Override class-based parse_func + ) + + result = prompter() + assert isinstance(result, MockResponseFormat) + assert result.message.startswith("Function parsed: ") + + def test_invalid_prompt_func(self): + """Test that invalid prompt_func signature raises ValueError.""" + class InvalidPrompter(BasePrompter): + def prompt_func(self, row, extra_arg): # Invalid: too many parameters + return {"prompt": f"Process {row}"} + + with pytest.raises(ValueError) as exc_info: + InvalidPrompter(model_name="gpt-4o-mini") + assert "prompt_func must take one argument or less" in str(exc_info.value) + + def test_invalid_parse_func(self): + """Test that invalid parse_func signature raises ValueError.""" + class InvalidParsePrompter(BasePrompter): + def prompt_func(self, row=None): + return {"prompt": "test"} + + def parse_func(self, response): # Invalid: too few parameters + return response + + with pytest.raises(ValueError) as exc_info: + InvalidParsePrompter(model_name="gpt-4o-mini") + assert "parse_func must take exactly 2 arguments" in str(exc_info.value) From 5d880e55fa346b18b5ce0c6e3b7720bda65fe254 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 24 Nov 2024 05:33:57 +0000 Subject: [PATCH 3/6] Add example of class-based Prompter implementation - Add comprehensive example showing class-based approach - Include both single and batch processing examples - Demonstrate custom prompt generation and response parsing - Show proper type hints and documentation --- examples/class_based_prompter.py | 82 ++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 examples/class_based_prompter.py diff --git a/examples/class_based_prompter.py b/examples/class_based_prompter.py new file mode 100644 index 00000000..f4d6a4fa --- /dev/null +++ b/examples/class_based_prompter.py @@ -0,0 +1,82 @@ +"""Example of using the class-based Prompter approach.""" + +from typing import Dict, Any, Optional +from pydantic import BaseModel + +from bespokelabs.curator import Prompter +from bespokelabs.curator.prompter.base_prompter import BasePrompter + + +class ResponseFormat(BaseModel): + """Example response format.""" + answer: str + confidence: float + + +class MathPrompter(BasePrompter): + """Example custom prompter for math problems.""" + + def prompt_func(self, row: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + """Generate prompts for math problems. + + Args: + row: Optional dictionary containing 'question' key. + If None, generates a default prompt. + + Returns: + Dict containing user_prompt and system_prompt. + """ + if row is None: + return { + "user_prompt": "What is 2 + 2?", + "system_prompt": "You are a math tutor. Provide clear, step-by-step solutions.", + } + + return { + "user_prompt": row["question"], + "system_prompt": "You are a math tutor. Provide clear, step-by-step solutions.", + } + + def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> ResponseFormat: + """Parse LLM response into structured format. + + Args: + row: Input row that generated the response + response: Raw response from the LLM + + Returns: + ResponseFormat containing answer and confidence + """ + # Extract answer and add confidence score + return ResponseFormat( + answer=response["message"], + confidence=0.95 if "step" in response["message"].lower() else 0.7 + ) + + +def main(): + """Example usage of class-based prompter.""" + # Create instance of custom prompter + math_prompter = MathPrompter( + model_name="gpt-4o-mini", + response_format=ResponseFormat, + ) + + # Single completion without input + result = math_prompter() + print(f"Single completion result: {result}") + + # Process multiple questions + questions = [ + {"question": "What is 5 * 3?"}, + {"question": "Solve x + 2 = 7"}, + ] + results = math_prompter(questions) + print("\nBatch processing results:") + for q, r in zip(questions, results): + print(f"Q: {q['question']}") + print(f"A: {r.answer} (confidence: {r.confidence})") + + +if __name__ == "__main__": + main() From c426e945b34951799e73c842ec83637fb293a7f7 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 24 Nov 2024 05:34:10 +0000 Subject: [PATCH 4/6] Update documentation for class-based Prompter - Add comprehensive examples in docstrings - Document both function-based and class-based approaches - Include detailed usage examples - Improve method documentation --- .../curator/prompter/base_prompter.py | 40 +++++++++++- src/bespokelabs/curator/prompter/prompter.py | 61 ++++++++++++++++++- 2 files changed, 99 insertions(+), 2 deletions(-) diff --git a/src/bespokelabs/curator/prompter/base_prompter.py b/src/bespokelabs/curator/prompter/base_prompter.py index d266892a..5ecae89e 100644 --- a/src/bespokelabs/curator/prompter/base_prompter.py +++ b/src/bespokelabs/curator/prompter/base_prompter.py @@ -1,4 +1,42 @@ -"""Base class for Prompter implementations.""" +"""Base class for Prompter implementations. + +This module provides the abstract base class for implementing custom prompters. +The BasePrompter class defines the interface that all prompter implementations +must follow. + +Example: + Creating a custom prompter: + ```python + class CustomPrompter(BasePrompter): + def prompt_func(self, row: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + # Generate prompts for your specific use case + if row is None: + return { + "user_prompt": "Default prompt", + "system_prompt": "System instructions", + } + return { + "user_prompt": f"Process input: {row['data']}", + "system_prompt": "System instructions", + } + + def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> Any: + # Optional: Override to customize response parsing + return response + + # Usage + prompter = CustomPrompter( + model_name="gpt-4", + response_format=MyResponseFormat, + ) + result = prompter(dataset) # Process dataset + single_result = prompter() # Single completion + ``` + +For simpler use cases where you don't need a full class implementation, +you can use the function-based approach with the Prompter class directly. +See the Prompter class documentation for details. +""" from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Type, TypeVar, Union diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 9466516d..6080013e 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -1,4 +1,63 @@ -"""Curator: Bespoke Labs Synthetic Data Generation Library.""" +"""Curator: Bespoke Labs Synthetic Data Generation Library. + +This module provides the Prompter class for interacting with LLMs. It supports +both function-based and class-based approaches to defining prompt generation +and response parsing logic. + +Examples: + Function-based approach (simple use cases): + ```python + from bespokelabs.curator import Prompter + + # Define prompt and parse functions + def prompt_func(row=None): + if row is None: + return { + "user_prompt": "Default prompt", + "system_prompt": "You are a helpful assistant.", + } + return { + "user_prompt": f"Process: {row['data']}", + "system_prompt": "You are a helpful assistant.", + } + + def parse_func(row, response): + return {"result": response.message} + + # Create prompter instance + prompter = Prompter( + model_name="gpt-4", + prompt_func=prompt_func, + parse_func=parse_func, + response_format=MyResponseFormat, + ) + + # Use the prompter + result = prompter(dataset) # Process dataset + single_result = prompter() # Single completion + ``` + + Class-based approach (complex use cases): + ```python + from bespokelabs.curator.prompter.base_prompter import BasePrompter + + class CustomPrompter(BasePrompter): + def prompt_func(self, row=None): + # Your custom prompt generation logic + return { + "user_prompt": "...", + "system_prompt": "...", + } + + def parse_func(self, row, response): + # Your custom response parsing logic + return response + + prompter = CustomPrompter(model_name="gpt-4") + ``` + +For more examples, see the examples/ directory in the repository. +""" import inspect import logging From 5a10b6b1b27ff387e8ba549eb152137738181f2f Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 24 Nov 2024 05:43:01 +0000 Subject: [PATCH 5/6] fix: apply black formatting to prompter files - Remove duplicate initialization code in prompter.py - Fix long type hints formatting - Improve method signature formatting in base_prompter.py --- .../curator/prompter/base_prompter.py | 15 ++++-- src/bespokelabs/curator/prompter/prompter.py | 50 +++---------------- 2 files changed, 18 insertions(+), 47 deletions(-) diff --git a/src/bespokelabs/curator/prompter/base_prompter.py b/src/bespokelabs/curator/prompter/base_prompter.py index 5ecae89e..f965e7d0 100644 --- a/src/bespokelabs/curator/prompter/base_prompter.py +++ b/src/bespokelabs/curator/prompter/base_prompter.py @@ -43,7 +43,8 @@ def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> Any: from pydantic import BaseModel -T = TypeVar('T') +T = TypeVar("T") + class BasePrompter(ABC): """Abstract base class for prompter implementations. @@ -86,7 +87,10 @@ def __init__( self.frequency_penalty = frequency_penalty @abstractmethod - def prompt_func(self, row: Optional[Union[Dict[str, Any], BaseModel]] = None) -> Dict[str, str]: + def prompt_func( + self, + row: Optional[Union[Dict[str, Any], BaseModel]] = None, + ) -> Dict[str, str]: """Override this method to define how prompts are generated. Args: @@ -98,8 +102,11 @@ def prompt_func(self, row: Optional[Union[Dict[str, Any], BaseModel]] = None) -> """ pass - def parse_func(self, row: Union[Dict[str, Any], BaseModel], - response: Union[Dict[str, Any], BaseModel]) -> T: + def parse_func( + self, + row: Union[Dict[str, Any], BaseModel], + response: Union[Dict[str, Any], BaseModel], + ) -> T: """Override this method to define how responses are parsed. Args: diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 6080013e..a9c5e27d 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -110,13 +110,15 @@ def parse_func(self, row, response): prompter = CustomPrompter(model_name="gpt-4") """ - _prompt_func: Optional[Callable[[Optional[Union[Dict[str, Any], BaseModel]]], Dict[str, str]]] + _prompt_func: Optional[ + Callable[ + [Optional[Union[Dict[str, Any], BaseModel]]], + Dict[str, str], + ] + ] _parse_func: Optional[ Callable[ - [ - Union[Dict[str, Any], BaseModel], - Union[Dict[str, Any], BaseModel], - ], + [Union[Dict[str, Any], BaseModel], Union[Dict[str, Any], BaseModel]], T, ] ] @@ -222,44 +224,6 @@ def __init__( frequency_penalty=frequency_penalty, ) - if parse_func is not None: - parse_sig = inspect.signature(parse_func) - if len(parse_sig.parameters) != 2: - raise ValueError( - f"parse_func must take exactly 2 arguments, got {len(parse_sig.parameters)}" - ) - - self.prompt_formatter = PromptFormatter( - model_name, self.prompt_func, self.parse_func, response_format - ) - self.batch_mode = batch - if batch: - if batch_size is None: - batch_size = 1_000 - logger.info( - f"batch=True but no batch_size provided, using default batch_size of {batch_size:,}" - ) - self._request_processor = OpenAIBatchRequestProcessor( - model=model_name, - batch_size=batch_size, - temperature=temperature, - top_p=top_p, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - ) - else: - if batch_size is not None: - logger.warning( - f"Prompter argument `batch_size` {batch_size} is ignored because `batch` is False" - ) - self._request_processor = OpenAIOnlineRequestProcessor( - model=model_name, - temperature=temperature, - top_p=top_p, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - ) - def __call__(self, dataset: Optional[Iterable] = None, working_dir: str = None) -> Dataset: """ Run completions on a dataset. From c341eacade457332757fc160ade9192f33033075 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 24 Nov 2024 05:45:09 +0000 Subject: [PATCH 6/6] style: apply black formatting to example and test files --- examples/class_based_prompter.py | 3 ++- tests/test_prompt.py | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/class_based_prompter.py b/examples/class_based_prompter.py index f4d6a4fa..abbe456a 100644 --- a/examples/class_based_prompter.py +++ b/examples/class_based_prompter.py @@ -9,6 +9,7 @@ class ResponseFormat(BaseModel): """Example response format.""" + answer: str confidence: float @@ -50,7 +51,7 @@ def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> ResponseF # Extract answer and add confidence score return ResponseFormat( answer=response["message"], - confidence=0.95 if "step" in response["message"].lower() else 0.7 + confidence=0.95 if "step" in response["message"].lower() else 0.7, ) diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 1baf081b..ca9760f4 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -158,8 +158,7 @@ def parse_func(self, row, response): # Custom parsing that adds a prefix to the message if isinstance(response, MockResponseFormat): return MockResponseFormat( - message=f"Parsed: {response.message}", - confidence=response.confidence + message=f"Parsed: {response.message}", confidence=response.confidence ) return response @@ -200,18 +199,18 @@ def test_class_based_single_completion(self, class_based_prompter): def test_mixed_approach(self): """Test using class-based prompter with function-based parse_func.""" + def custom_parse_func(row, response): if isinstance(response, MockResponseFormat): return MockResponseFormat( - message=f"Function parsed: {response.message}", - confidence=response.confidence + message=f"Function parsed: {response.message}", confidence=response.confidence ) return response prompter = self.CustomPrompter( model_name="gpt-4o-mini", response_format=MockResponseFormat, - parse_func=custom_parse_func # Override class-based parse_func + parse_func=custom_parse_func, # Override class-based parse_func ) result = prompter() @@ -220,6 +219,7 @@ def custom_parse_func(row, response): def test_invalid_prompt_func(self): """Test that invalid prompt_func signature raises ValueError.""" + class InvalidPrompter(BasePrompter): def prompt_func(self, row, extra_arg): # Invalid: too many parameters return {"prompt": f"Process {row}"} @@ -230,6 +230,7 @@ def prompt_func(self, row, extra_arg): # Invalid: too many parameters def test_invalid_parse_func(self): """Test that invalid parse_func signature raises ValueError.""" + class InvalidParsePrompter(BasePrompter): def prompt_func(self, row=None): return {"prompt": "test"}