Skip to content

Commit

Permalink
Merge branch 'dev' into ryanm/anthropic-batches
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMarten committed Dec 18, 2024
2 parents 69109c9 + 5e8482a commit be55ca0
Show file tree
Hide file tree
Showing 20 changed files with 332 additions and 156 deletions.
8 changes: 8 additions & 0 deletions bespoke-dataset-viewer/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bespoke-dataset-viewer/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
},
"devDependencies": {
"@types/node": "^20",
"@types/prismjs": "^1.26.5",
"@types/react": "^18",
"@types/react-dom": "^18",
"eslint": "^8",
Expand Down
2 changes: 1 addition & 1 deletion build_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def nextjs_build():
def run_pytest():
print("Running pytest")
try:
run_command("pytest", cwd="tests")
run_command("pytest")
except subprocess.CalledProcessError:
print("Pytest failed. Aborting build.")
sys.exit(1)
Expand Down
6 changes: 0 additions & 6 deletions package-lock.json

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "bespokelabs-curator"
version = "0.1.11"
version = "0.1.12"
description = "Bespoke Labs Curator"
authors = ["Bespoke Labs <[email protected]>"]
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ def generic_response_file_from_responses(
)
generic_response = GenericResponse(
response_message=None,
response_errors=[
f"Request {generic_request} failed with status code {raw_response['response']['status_code']}"
],
response_errors=[raw_response["response"]["status_code"]],
raw_response=raw_response,
raw_request=None,
generic_request=generic_request,
Expand Down
2 changes: 1 addition & 1 deletion src/bespokelabs/curator/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
max_retries: Optional[int] = None,
require_all_responses: Optional[bool] = None,
require_all_responses: Optional[bool] = True,
):
"""Initialize a LLM.
Expand Down
58 changes: 57 additions & 1 deletion src/bespokelabs/curator/llm/prompt_formatter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import dataclasses
import inspect
import json
import logging
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union

from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from bespokelabs.curator.request_processor.generic_request import GenericRequest

T = TypeVar("T")
_DictOrBaseModel = Union[Dict[str, Any], BaseModel]
logger = logging.getLogger(__name__)


def _validate_messages(messages: list[dict]) -> None:
Expand Down Expand Up @@ -82,3 +85,56 @@ def create_generic_request(self, row: _DictOrBaseModel, idx: int) -> GenericRequ
self.response_format.model_json_schema() if self.response_format else None
),
)

def response_to_response_format(self, response_message: str | dict) -> Optional[dict | str]:
"""
Converts a response message to a specified Pydantic model format.
This method takes a response message (either as a string or dict) and validates/converts it
according to the provided Pydantic model format. If the response message is a string,
it first attempts to parse it as JSON. The resulting dict is then used to construct
an instance of the specified Pydantic model.
Args:
response_message (str | dict): The response message to convert, either as a JSON string
or a dictionary.
response_format (Optional[BaseModel]): The Pydantic model class that defines the
expected format of the response.
Returns:
Optional[dict | str]: The validated response message as a Pydantic model instance.
Raises:
json.JSONDecodeError: If the response_message is a string but cannot be parsed as valid JSON.
ValidationError: If the parsed response does not match the schema defined by response_format.
"""
# Response message is a string, which is converted to a dict
# The dict is then used to construct the response_format Pydantic model
if self.response_format is None:
return response_message

try:
# First try to parse the response message as JSON
if isinstance(response_message, str):
try:
response_dict = json.loads(response_message)
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse response message as JSON: {response_message}. "
f"The model likely returned an invalid JSON format."
)
raise e
else:
response_dict = response_message

# Then construct the Pydantic model from the parsed dict
response_message = self.response_format(**response_dict)
return response_message

except ValidationError as e:
schema_str = json.dumps(self.response_format.model_json_schema(), indent=2)
logger.warning(
f"Pydantic failed to parse response message {response_message} with `response_format` {schema_str}. "
f"The model likely returned a JSON that does not match the schema of the `response_format`."
)
raise e
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
DEFAULT_MAX_REQUESTS_PER_MINUTE = 100
DEFAULT_MAX_TOKENS_PER_MINUTE = 100_000
DEFAULT_MAX_RETRIES = 10
SECONDS_TO_PAUSE_ON_RATE_LIMIT = 10
DEFAULT_REQUEST_TIMEOUT = 10 * 60 # 10 minutes


@dataclass
Expand Down Expand Up @@ -69,6 +71,7 @@ def __init__(
self.max_retries = DEFAULT_MAX_RETRIES
else:
self.max_retries = max_retries
self.timeout = DEFAULT_REQUEST_TIMEOUT

@property
def max_requests_per_minute(self) -> int:
Expand Down Expand Up @@ -127,6 +130,11 @@ def run(
parse_func_hash: str,
prompt_formatter: PromptFormatter,
) -> Dataset:
# load from already completed dataset
output_dataset = self.attempt_loading_cached_dataset(working_dir, parse_func_hash)
if output_dataset is not None:
return output_dataset

"""Run completions using the online API with async processing."""
logger.info(f"Running {self.__class__.__name__} completions with model: {self.model}")

Expand All @@ -147,7 +155,6 @@ def run(
self.process_requests_from_file(
generic_request_filepath=request_file,
save_filepath=response_file,
max_attempts=self.max_retries,
resume=True,
)
)
Expand All @@ -158,7 +165,6 @@ async def process_requests_from_file(
self,
generic_request_filepath: str,
save_filepath: str,
max_attempts: int,
resume: bool,
resume_no_retry: bool = False,
) -> None:
Expand All @@ -182,7 +188,7 @@ async def process_requests_from_file(
completed_request_ids = set()
if os.path.exists(save_filepath):
if resume:
logger.debug(f"Resuming progress from existing file: {save_filepath}")
logger.info(f"Resuming progress by reading existing file: {save_filepath}")
logger.debug(
f"Removing all failed requests from {save_filepath} so they can be retried"
)
Expand All @@ -200,6 +206,11 @@ async def process_requests_from_file(
f"{response.response_errors}, removing from output and will retry"
)
num_previously_failed_requests += 1
if response.response_message is None:
logger.debug(
f"Request {response.generic_request.original_row_idx} previously failed due to no response, removing from output and will retry"
)
num_previously_failed_requests += 1
else:
completed_request_ids.add(response.generic_request.original_row_idx)
output_file.write(line)
Expand Down Expand Up @@ -273,7 +284,7 @@ async def process_requests_from_file(
task_id=status_tracker.num_tasks_started,
generic_request=generic_request,
api_specific_request=self.create_api_specific_request(generic_request),
attempts_left=max_attempts,
attempts_left=self.max_retries,
prompt_formatter=self.prompt_formatter,
)

Expand All @@ -283,6 +294,19 @@ async def process_requests_from_file(
while not status_tracker.has_capacity(token_estimate):
await asyncio.sleep(0.1)

# Wait for rate limits cool down if needed
seconds_since_rate_limit_error = (
time.time() - status_tracker.time_of_last_rate_limit_error
)
if seconds_since_rate_limit_error < SECONDS_TO_PAUSE_ON_RATE_LIMIT:
remaining_seconds_to_pause = (
SECONDS_TO_PAUSE_ON_RATE_LIMIT - seconds_since_rate_limit_error
)
await asyncio.sleep(remaining_seconds_to_pause)
logger.warn(
f"Pausing to cool down for {int(remaining_seconds_to_pause)} seconds"
)

# Consume capacity before making request
status_tracker.consume_capacity(token_estimate)

Expand Down Expand Up @@ -313,10 +337,10 @@ async def process_requests_from_file(
token_estimate = self.estimate_total_tokens(
retry_request.generic_request.messages
)
attempt_number = 1 + self.max_retries - retry_request.attempts_left
logger.info(
f"Processing retry for request {retry_request.task_id} "
f"(attempt #{attempt_number} of {self.max_retries}). "
attempt_number = self.max_retries - retry_request.attempts_left
logger.debug(
f"Retrying request {retry_request.task_id} "
f"(attempt #{attempt_number} of {self.max_retries})"
f"Previous errors: {retry_request.result}"
)

Expand Down Expand Up @@ -381,6 +405,9 @@ async def handle_single_request_with_retries(
status_tracker=status_tracker,
)

# Allows us to retry on responses that don't match the response format
self.prompt_formatter.response_to_response_format(generic_response.response_message)

# Save response in the base class
await self.append_generic_response(generic_response, save_filepath)

Expand All @@ -389,18 +416,15 @@ async def handle_single_request_with_retries(
status_tracker.pbar.update(1)

except Exception as e:
logger.warning(
f"Request {request.task_id} failed with Exception {e}, attempts left {request.attempts_left}"
)
status_tracker.num_other_errors += 1
request.result.append(e)

if request.attempts_left > 0:
request.attempts_left -= 1
# Add retry queue logging
logger.info(
f"Adding request {request.task_id} to retry queue. Will retry in next available slot. "
f"Attempts remaining: {request.attempts_left}"
logger.warning(
f"Encountered '{e.__class__.__name__}: {e}' during attempt "
f"{self.max_retries - request.attempts_left} of {self.max_retries} "
f"while processing request {request.task_id}"
)
retry_queue.put_nowait(request)
else:
Expand Down
Loading

0 comments on commit be55ca0

Please sign in to comment.