Skip to content

Commit

Permalink
add rate limit cool down
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMarten committed Dec 14, 2024
1 parent 45f1fd6 commit 081ddff
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DEFAULT_MAX_REQUESTS_PER_MINUTE = 100
DEFAULT_MAX_TOKENS_PER_MINUTE = 100_000
DEFAULT_MAX_RETRIES = 10
SECONDS_TO_PAUSE_ON_RATE_LIMIT = 10


@dataclass
Expand Down Expand Up @@ -355,6 +356,19 @@ async def process_requests_from_file(
while not status_tracker.has_capacity(token_estimate):
await asyncio.sleep(0.1)

# Wait for rate limits cool down if needed
seconds_since_rate_limit_error = (
time.time() - status_tracker.time_of_last_rate_limit_error
)
if seconds_since_rate_limit_error < SECONDS_TO_PAUSE_ON_RATE_LIMIT:
remaining_seconds_to_pause = (
SECONDS_TO_PAUSE_ON_RATE_LIMIT - seconds_since_rate_limit_error
)
await asyncio.sleep(remaining_seconds_to_pause)
logger.warn(
f"Pausing to cool down for {int(remaining_seconds_to_pause)} seconds"
)

# Consume capacity before making request
status_tracker.consume_capacity(token_estimate)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from bespokelabs.curator.request_processor.generic_request import GenericRequest
from bespokelabs.curator.request_processor.generic_response import TokenUsage, GenericResponse
from pydantic import BaseModel
import time

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -236,18 +237,29 @@ async def call_single_request(
GenericResponse: The response from LiteLLM
"""
# Get response directly without extra logging
if request.generic_request.response_format:
response, completion_obj = await self.client.chat.completions.create_with_completion(
**request.api_specific_request,
response_model=request.prompt_formatter.response_format,
timeout=60.0,
)
response_message = (
response.model_dump() if hasattr(response, "model_dump") else response
)
else:
completion_obj = await litellm.acompletion(**request.api_specific_request, timeout=60.0)
response_message = completion_obj["choices"][0]["message"]["content"]
try:
if request.generic_request.response_format:
response, completion_obj = (
await self.client.chat.completions.create_with_completion(
**request.api_specific_request,
response_model=request.prompt_formatter.response_format,
timeout=60.0,
)
)
response_message = (
response.model_dump() if hasattr(response, "model_dump") else response
)
else:
completion_obj = await litellm.acompletion(
**request.api_specific_request, timeout=60.0
)
response_message = completion_obj["choices"][0]["message"]["content"]
except litellm.RateLimitError as e:
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
# because handle_single_request_with_retries will double count otherwise
status_tracker.num_api_errors -= 1
raise e

# Extract token usage
usage = completion_obj.usage if hasattr(completion_obj, "usage") else {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ async def call_single_request(
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
status_tracker.num_api_errors -= 1
# because handle_single_request_with_retries will double count otherwise
status_tracker.num_other_errors -= 1
raise Exception(f"API error: {error}")

if response_obj.status != 200:
Expand Down

0 comments on commit 081ddff

Please sign in to comment.