Skip to content

Commit 879016e

Browse files
authored
Merge pull request #256 from bespokelabsai/ryanm/rate-limit-cool-down
Cool down when hitting rate limit with online processors
2 parents 3096a29 + 9a3da98 commit 879016e

File tree

5 files changed

+46
-17
lines changed

5 files changed

+46
-17
lines changed

src/bespokelabs/curator/request_processor/base_online_request_processor.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DEFAULT_MAX_REQUESTS_PER_MINUTE = 100
2626
DEFAULT_MAX_TOKENS_PER_MINUTE = 100_000
2727
DEFAULT_MAX_RETRIES = 10
28+
SECONDS_TO_PAUSE_ON_RATE_LIMIT = 10
2829

2930

3031
@dataclass
@@ -46,7 +47,9 @@ class StatusTracker:
4647
max_tokens_per_minute: int = 0
4748
pbar: tqdm = field(default=None)
4849
response_cost: float = 0
49-
time_of_last_rate_limit_error: float = field(default=None)
50+
time_of_last_rate_limit_error: float = field(
51+
default=time.time() - SECONDS_TO_PAUSE_ON_RATE_LIMIT
52+
)
5053

5154
def __str__(self):
5255
return (
@@ -254,7 +257,7 @@ async def process_requests_from_file(
254257
completed_request_ids = set()
255258
if os.path.exists(save_filepath):
256259
if resume:
257-
logger.debug(f"Resuming progress from existing file: {save_filepath}")
260+
logger.info(f"Resuming progress by reading existing file: {save_filepath}")
258261
logger.debug(
259262
f"Removing all failed requests from {save_filepath} so they can be retried"
260263
)
@@ -355,6 +358,19 @@ async def process_requests_from_file(
355358
while not status_tracker.has_capacity(token_estimate):
356359
await asyncio.sleep(0.1)
357360

361+
# Wait for rate limits cool down if needed
362+
seconds_since_rate_limit_error = (
363+
time.time() - status_tracker.time_of_last_rate_limit_error
364+
)
365+
if seconds_since_rate_limit_error < SECONDS_TO_PAUSE_ON_RATE_LIMIT:
366+
remaining_seconds_to_pause = (
367+
SECONDS_TO_PAUSE_ON_RATE_LIMIT - seconds_since_rate_limit_error
368+
)
369+
await asyncio.sleep(remaining_seconds_to_pause)
370+
logger.warn(
371+
f"Pausing to cool down for {int(remaining_seconds_to_pause)} seconds"
372+
)
373+
358374
# Consume capacity before making request
359375
status_tracker.consume_capacity(token_estimate)
360376

src/bespokelabs/curator/request_processor/base_request_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def create_request_files(
116116
return request_files
117117

118118
# Create new requests file
119+
logger.info(f"Preparing request file(s) in {working_dir}")
119120
request_file = f"{working_dir}/requests_0.jsonl"
120121
request_files = [request_file]
121122

src/bespokelabs/curator/request_processor/litellm_online_request_processor.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from bespokelabs.curator.request_processor.generic_request import GenericRequest
1414
from bespokelabs.curator.request_processor.generic_response import TokenUsage, GenericResponse
1515
from pydantic import BaseModel
16+
import time
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -236,18 +237,29 @@ async def call_single_request(
236237
GenericResponse: The response from LiteLLM
237238
"""
238239
# Get response directly without extra logging
239-
if request.generic_request.response_format:
240-
response, completion_obj = await self.client.chat.completions.create_with_completion(
241-
**request.api_specific_request,
242-
response_model=request.prompt_formatter.response_format,
243-
timeout=60.0,
244-
)
245-
response_message = (
246-
response.model_dump() if hasattr(response, "model_dump") else response
247-
)
248-
else:
249-
completion_obj = await litellm.acompletion(**request.api_specific_request, timeout=60.0)
250-
response_message = completion_obj["choices"][0]["message"]["content"]
240+
try:
241+
if request.generic_request.response_format:
242+
response, completion_obj = (
243+
await self.client.chat.completions.create_with_completion(
244+
**request.api_specific_request,
245+
response_model=request.prompt_formatter.response_format,
246+
timeout=60.0,
247+
)
248+
)
249+
response_message = (
250+
response.model_dump() if hasattr(response, "model_dump") else response
251+
)
252+
else:
253+
completion_obj = await litellm.acompletion(
254+
**request.api_specific_request, timeout=60.0
255+
)
256+
response_message = completion_obj["choices"][0]["message"]["content"]
257+
except litellm.RateLimitError as e:
258+
status_tracker.time_of_last_rate_limit_error = time.time()
259+
status_tracker.num_rate_limit_errors += 1
260+
# because handle_single_request_with_retries will double count otherwise
261+
status_tracker.num_api_errors -= 1
262+
raise e
251263

252264
# Extract token usage
253265
usage = completion_obj.usage if hasattr(completion_obj, "usage") else {}

src/bespokelabs/curator/request_processor/openai_batch_request_processor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,7 @@ def generic_response_file_from_responses(
200200
)
201201
generic_response = GenericResponse(
202202
response_message=None,
203-
response_errors=[
204-
f"Request {generic_request} failed with status code {raw_response['response']['status_code']}"
205-
],
203+
response_errors=[raw_response["response"]["status_code"]],
206204
raw_response=raw_response,
207205
raw_request=None,
208206
generic_request=generic_request,

src/bespokelabs/curator/request_processor/openai_online_request_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@ async def call_single_request(
283283
status_tracker.time_of_last_rate_limit_error = time.time()
284284
status_tracker.num_rate_limit_errors += 1
285285
status_tracker.num_api_errors -= 1
286+
# because handle_single_request_with_retries will double count otherwise
287+
status_tracker.num_other_errors -= 1
286288
raise Exception(f"API error: {error}")
287289

288290
if response_obj.status != 200:

0 commit comments

Comments
 (0)