Skip to content

Commit

Permalink
Merge pull request #256 from bespokelabsai/ryanm/rate-limit-cool-down
Browse files Browse the repository at this point in the history
Cool down when hitting rate limit with online processors
  • Loading branch information
RyanMarten authored Dec 14, 2024
2 parents 3096a29 + 9a3da98 commit 879016e
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DEFAULT_MAX_REQUESTS_PER_MINUTE = 100
DEFAULT_MAX_TOKENS_PER_MINUTE = 100_000
DEFAULT_MAX_RETRIES = 10
SECONDS_TO_PAUSE_ON_RATE_LIMIT = 10


@dataclass
Expand All @@ -46,7 +47,9 @@ class StatusTracker:
max_tokens_per_minute: int = 0
pbar: tqdm = field(default=None)
response_cost: float = 0
time_of_last_rate_limit_error: float = field(default=None)
time_of_last_rate_limit_error: float = field(
default=time.time() - SECONDS_TO_PAUSE_ON_RATE_LIMIT
)

def __str__(self):
return (
Expand Down Expand Up @@ -254,7 +257,7 @@ async def process_requests_from_file(
completed_request_ids = set()
if os.path.exists(save_filepath):
if resume:
logger.debug(f"Resuming progress from existing file: {save_filepath}")
logger.info(f"Resuming progress by reading existing file: {save_filepath}")
logger.debug(
f"Removing all failed requests from {save_filepath} so they can be retried"
)
Expand Down Expand Up @@ -355,6 +358,19 @@ async def process_requests_from_file(
while not status_tracker.has_capacity(token_estimate):
await asyncio.sleep(0.1)

# Wait for rate limits cool down if needed
seconds_since_rate_limit_error = (
time.time() - status_tracker.time_of_last_rate_limit_error
)
if seconds_since_rate_limit_error < SECONDS_TO_PAUSE_ON_RATE_LIMIT:
remaining_seconds_to_pause = (
SECONDS_TO_PAUSE_ON_RATE_LIMIT - seconds_since_rate_limit_error
)
await asyncio.sleep(remaining_seconds_to_pause)
logger.warn(
f"Pausing to cool down for {int(remaining_seconds_to_pause)} seconds"
)

# Consume capacity before making request
status_tracker.consume_capacity(token_estimate)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def create_request_files(
return request_files

# Create new requests file
logger.info(f"Preparing request file(s) in {working_dir}")
request_file = f"{working_dir}/requests_0.jsonl"
request_files = [request_file]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from bespokelabs.curator.request_processor.generic_request import GenericRequest
from bespokelabs.curator.request_processor.generic_response import TokenUsage, GenericResponse
from pydantic import BaseModel
import time

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -236,18 +237,29 @@ async def call_single_request(
GenericResponse: The response from LiteLLM
"""
# Get response directly without extra logging
if request.generic_request.response_format:
response, completion_obj = await self.client.chat.completions.create_with_completion(
**request.api_specific_request,
response_model=request.prompt_formatter.response_format,
timeout=60.0,
)
response_message = (
response.model_dump() if hasattr(response, "model_dump") else response
)
else:
completion_obj = await litellm.acompletion(**request.api_specific_request, timeout=60.0)
response_message = completion_obj["choices"][0]["message"]["content"]
try:
if request.generic_request.response_format:
response, completion_obj = (
await self.client.chat.completions.create_with_completion(
**request.api_specific_request,
response_model=request.prompt_formatter.response_format,
timeout=60.0,
)
)
response_message = (
response.model_dump() if hasattr(response, "model_dump") else response
)
else:
completion_obj = await litellm.acompletion(
**request.api_specific_request, timeout=60.0
)
response_message = completion_obj["choices"][0]["message"]["content"]
except litellm.RateLimitError as e:
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
# because handle_single_request_with_retries will double count otherwise
status_tracker.num_api_errors -= 1
raise e

# Extract token usage
usage = completion_obj.usage if hasattr(completion_obj, "usage") else {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,7 @@ def generic_response_file_from_responses(
)
generic_response = GenericResponse(
response_message=None,
response_errors=[
f"Request {generic_request} failed with status code {raw_response['response']['status_code']}"
],
response_errors=[raw_response["response"]["status_code"]],
raw_response=raw_response,
raw_request=None,
generic_request=generic_request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ async def call_single_request(
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
status_tracker.num_api_errors -= 1
# because handle_single_request_with_retries will double count otherwise
status_tracker.num_other_errors -= 1
raise Exception(f"API error: {error}")

if response_obj.status != 200:
Expand Down

0 comments on commit 879016e

Please sign in to comment.