Skip to content

Commit

Permalink
fix: apply black formatting to prompter files
Browse files Browse the repository at this point in the history
- Remove duplicate initialization code in prompter.py
- Fix long type hints formatting
- Improve method signature formatting in base_prompter.py
  • Loading branch information
devin-ai-integration[bot] committed Nov 24, 2024
1 parent c426e94 commit 5a10b6b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 47 deletions.
15 changes: 11 additions & 4 deletions src/bespokelabs/curator/prompter/base_prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def parse_func(self, row: Dict[str, Any], response: Dict[str, Any]) -> Any:

from pydantic import BaseModel

T = TypeVar('T')
T = TypeVar("T")


class BasePrompter(ABC):
"""Abstract base class for prompter implementations.
Expand Down Expand Up @@ -86,7 +87,10 @@ def __init__(
self.frequency_penalty = frequency_penalty

@abstractmethod
def prompt_func(self, row: Optional[Union[Dict[str, Any], BaseModel]] = None) -> Dict[str, str]:
def prompt_func(
self,
row: Optional[Union[Dict[str, Any], BaseModel]] = None,
) -> Dict[str, str]:
"""Override this method to define how prompts are generated.
Args:
Expand All @@ -98,8 +102,11 @@ def prompt_func(self, row: Optional[Union[Dict[str, Any], BaseModel]] = None) ->
"""
pass

def parse_func(self, row: Union[Dict[str, Any], BaseModel],
response: Union[Dict[str, Any], BaseModel]) -> T:
def parse_func(
self,
row: Union[Dict[str, Any], BaseModel],
response: Union[Dict[str, Any], BaseModel],
) -> T:
"""Override this method to define how responses are parsed.
Args:
Expand Down
50 changes: 7 additions & 43 deletions src/bespokelabs/curator/prompter/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,15 @@ def parse_func(self, row, response):
prompter = CustomPrompter(model_name="gpt-4")
"""

_prompt_func: Optional[Callable[[Optional[Union[Dict[str, Any], BaseModel]]], Dict[str, str]]]
_prompt_func: Optional[
Callable[
[Optional[Union[Dict[str, Any], BaseModel]]],
Dict[str, str],
]
]
_parse_func: Optional[
Callable[
[
Union[Dict[str, Any], BaseModel],
Union[Dict[str, Any], BaseModel],
],
[Union[Dict[str, Any], BaseModel], Union[Dict[str, Any], BaseModel]],
T,
]
]
Expand Down Expand Up @@ -222,44 +224,6 @@ def __init__(
frequency_penalty=frequency_penalty,
)

if parse_func is not None:
parse_sig = inspect.signature(parse_func)
if len(parse_sig.parameters) != 2:
raise ValueError(
f"parse_func must take exactly 2 arguments, got {len(parse_sig.parameters)}"
)

self.prompt_formatter = PromptFormatter(
model_name, self.prompt_func, self.parse_func, response_format
)
self.batch_mode = batch
if batch:
if batch_size is None:
batch_size = 1_000
logger.info(
f"batch=True but no batch_size provided, using default batch_size of {batch_size:,}"
)
self._request_processor = OpenAIBatchRequestProcessor(
model=model_name,
batch_size=batch_size,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
else:
if batch_size is not None:
logger.warning(
f"Prompter argument `batch_size` {batch_size} is ignored because `batch` is False"
)
self._request_processor = OpenAIOnlineRequestProcessor(
model=model_name,
temperature=temperature,
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)

def __call__(self, dataset: Optional[Iterable] = None, working_dir: str = None) -> Dataset:
"""
Run completions on a dataset.
Expand Down

0 comments on commit 5a10b6b

Please sign in to comment.