diff --git a/src/bespokelabs/curator/prompter/base_prompter.py b/src/bespokelabs/curator/prompter/base_prompter.py index 5ecae89e..f965e7d0 100644 --- a/src/bespokelabs/curator/prompter/base_prompter.py +++ b/src/bespokelabs/curator/prompter/base_prompter.py @@ -43,7 +43,8 @@ def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> Any: from pydantic import BaseModel -T = TypeVar('T') +T = TypeVar("T") + class BasePrompter(ABC): """Abstract base class for prompter implementations. @@ -86,7 +87,10 @@ def __init__( self.frequency_penalty = frequency_penalty @abstractmethod - def prompt_func(self, row: Optional[Union[Dict[str, Any], BaseModel]] = None) -> Dict[str, str]: + def prompt_func( + self, + row: Optional[Union[Dict[str, Any], BaseModel]] = None, + ) -> Dict[str, str]: """Override this method to define how prompts are generated. Args: @@ -98,8 +102,11 @@ def prompt_func(self, row: Optional[Union[Dict[str, Any], BaseModel]] = None) -> """ pass - def parse_func(self, row: Union[Dict[str, Any], BaseModel], - response: Union[Dict[str, Any], BaseModel]) -> T: + def parse_func( + self, + row: Union[Dict[str, Any], BaseModel], + response: Union[Dict[str, Any], BaseModel], + ) -> T: """Override this method to define how responses are parsed. Args: diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 6080013e..a9c5e27d 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -110,13 +110,15 @@ def parse_func(self, row, response): prompter = CustomPrompter(model_name="gpt-4") """ - _prompt_func: Optional[Callable[[Optional[Union[Dict[str, Any], BaseModel]]], Dict[str, str]]] + _prompt_func: Optional[ + Callable[ + [Optional[Union[Dict[str, Any], BaseModel]]], + Dict[str, str], + ] + ] _parse_func: Optional[ Callable[ - [ - Union[Dict[str, Any], BaseModel], - Union[Dict[str, Any], BaseModel], - ], + [Union[Dict[str, Any], BaseModel], Union[Dict[str, Any], BaseModel]], T, ] ] @@ -222,44 +224,6 @@ def __init__( frequency_penalty=frequency_penalty, ) - if parse_func is not None: - parse_sig = inspect.signature(parse_func) - if len(parse_sig.parameters) != 2: - raise ValueError( - f"parse_func must take exactly 2 arguments, got {len(parse_sig.parameters)}" - ) - - self.prompt_formatter = PromptFormatter( - model_name, self.prompt_func, self.parse_func, response_format - ) - self.batch_mode = batch - if batch: - if batch_size is None: - batch_size = 1_000 - logger.info( - f"batch=True but no batch_size provided, using default batch_size of {batch_size:,}" - ) - self._request_processor = OpenAIBatchRequestProcessor( - model=model_name, - batch_size=batch_size, - temperature=temperature, - top_p=top_p, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - ) - else: - if batch_size is not None: - logger.warning( - f"Prompter argument `batch_size` {batch_size} is ignored because `batch` is False" - ) - self._request_processor = OpenAIOnlineRequestProcessor( - model=model_name, - temperature=temperature, - top_p=top_p, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - ) - def __call__(self, dataset: Optional[Iterable] = None, working_dir: str = None) -> Dataset: """ Run completions on a dataset.