Skip to content

Commit

Permalink
Merge pull request #251 from bespokelabsai/ryanm/raise-on-failed-requ…
Browse files Browse the repository at this point in the history
…ests

Option to raise on failed responses
  • Loading branch information
RyanMarten authored Dec 13, 2024
2 parents ec7f965 + 19089cb commit 092d2d2
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 114 deletions.
14 changes: 14 additions & 0 deletions src/bespokelabs/curator/file_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# https://stackoverflow.com/questions/845058/how-to-get-the-line-count-of-a-large-file-cheaply-in-python
# https://stackoverflow.com/a/68385697
def _file_gen(reader):
b = reader(1024 * 1024)
while b:
yield b
b = reader(1024 * 1024)


# Instead of requiring counting lines, we can store metadata file that has the number of requests in each file
def count_lines(filename):
f = open(filename, "rb")
f_gen = _file_gen(f.raw.read)
return sum(buf.count(b"\n") for buf in f_gen)
83 changes: 83 additions & 0 deletions src/bespokelabs/curator/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,79 @@
class LLM:
"""Interface for prompting LLMs."""

def __init__(
self,
model_name: str,
prompt_func: Callable[[Union[Dict[str, Any], BaseModel]], Dict[str, str]],
parse_func: Optional[
Callable[
[
_DictOrBaseModel,
_DictOrBaseModel,
],
T,
]
] = None,
response_format: Optional[Type[BaseModel]] = None,
backend: Optional[str] = None,
max_requests_per_minute: Optional[int] = None,
max_tokens_per_minute: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
max_retries: Optional[int] = None,
require_all_responses: Optional[bool] = None,
):
"""Initialize a LLM.
Args:
model_name: The name of the LLM to use
prompt_func: A function that takes a single row
and returns either a string (assumed to be a user prompt) or messages list
parse_func: A function that takes the input row and
response object and returns the parsed output
response_format: A Pydantic model specifying the
response format from the LLM.
backend: The backend to use ("openai" or "litellm"). If None, will be auto-determined
max_requests_per_minute: Maximum requests per minute (not supported in batch mode)
max_tokens_per_minute: Maximum tokens per minute (not supported in batch mode)
temperature: The temperature to use for the LLM
top_p: The top_p to use for the LLM
presence_penalty: The presence_penalty to use for the LLM
frequency_penalty: The frequency_penalty to use for the LLM
max_retries: The maximum number of retries to use for the LLM. If 0, will only try a request once.
require_all_responses: Whether to require all responses
"""
self.prompt_formatter = PromptFormatter(
model_name, prompt_func, parse_func, response_format
)

# Initialize context manager state
self._batch_config = None
self._original_request_processor = None

# Store model parameters
self.temperature = temperature
self.top_p = top_p
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.model_name = model_name

# Auto-determine backend if not specified
if backend is not None:
self.backend = backend
else:
self.backend = self._determine_backend(model_name, response_format)

# Initialize request processor
self._setup_request_processor(
max_requests_per_minute=max_requests_per_minute,
max_tokens_per_minute=max_tokens_per_minute,
max_retries=max_retries,
require_all_responses=require_all_responses,
)

@staticmethod
def _determine_backend(
model_name: str, response_format: Optional[Type[BaseModel]] = None
Expand Down Expand Up @@ -91,6 +164,8 @@ def __init__(
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
max_retries: Optional[int] = None,
require_all_responses: Optional[bool] = None,
):
"""Initialize a LLM.
Expand All @@ -112,6 +187,8 @@ def __init__(
top_p: The top_p to use for the LLM, only used if batch is False
presence_penalty: The presence_penalty to use for the LLM, only used if batch is False
frequency_penalty: The frequency_penalty to use for the LLM, only used if batch is False
max_retries: The maximum number of retries to use for the LLM
require_all_responses: Whether to require all responses
"""
self.prompt_formatter = PromptFormatter(
model_name, prompt_func, parse_func, response_format
Expand Down Expand Up @@ -147,6 +224,8 @@ def __init__(
frequency_penalty=frequency_penalty,
delete_successful_batch_files=delete_successful_batch_files,
delete_failed_batch_files=delete_failed_batch_files,
max_retries=max_retries,
require_all_responses=require_all_responses,
)
else:
if batch_size is not None:
Expand All @@ -161,6 +240,8 @@ def __init__(
frequency_penalty=frequency_penalty,
max_requests_per_minute=max_requests_per_minute,
max_tokens_per_minute=max_tokens_per_minute,
max_retries=max_retries,
require_all_responses=require_all_responses,
)
elif self.backend == "litellm":
if batch:
Expand All @@ -175,6 +256,8 @@ def __init__(
frequency_penalty=frequency_penalty,
max_requests_per_minute=max_requests_per_minute,
max_tokens_per_minute=max_tokens_per_minute,
max_retries=max_retries,
require_all_responses=require_all_responses,
)
else:
raise ValueError(f"Unknown backend: {self.backend}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

DEFAULT_REQUESTS_PER_MINUTE = 100
DEFAULT_TOKENS_PER_MINUTE = 100_000
DEFAULT_MAX_REQUESTS_PER_MINUTE = 100
DEFAULT_MAX_TOKENS_PER_MINUTE = 100_000
DEFAULT_MAX_RETRIES = 10


@dataclass
Expand Down Expand Up @@ -124,53 +125,58 @@ def __init__(
frequency_penalty: Optional[float] = None,
max_requests_per_minute: Optional[int] = None,
max_tokens_per_minute: Optional[int] = None,
require_all_responses: bool = None,
max_retries: Optional[int] = None,
):
super().__init__(batch_size=None)
super().__init__(batch_size=None, require_all_responses=require_all_responses)
self.model: str = model
self.temperature: float | None = temperature
self.top_p: float | None = top_p
self.presence_penalty: float | None = presence_penalty
self.frequency_penalty: float | None = frequency_penalty
self.prompt_formatter: Optional[PromptFormatter] = None
self.max_requests_per_minute: Optional[int] = max_requests_per_minute
self.max_tokens_per_minute: Optional[int] = max_tokens_per_minute
self.DEFAULT_MAX_REQUESTS_PER_MINUTE = DEFAULT_REQUESTS_PER_MINUTE
self.DEFAULT_MAX_TOKENS_PER_MINUTE = DEFAULT_TOKENS_PER_MINUTE

def get_rate_limit(self, name, header_value):
"""Uses manual values if set, otherwise uses headers if available, and if not available uses defaults."""
manual_value = getattr(self, name)
default_value = getattr(self, f"DEFAULT_{name.upper()}")
if manual_value is not None:
logger.info(f"Manually set {name} to {manual_value}")
return manual_value
elif header_value != 0:
logger.info(f"Automatically set {name} to {header_value}")
return header_value
self.manual_max_requests_per_minute: Optional[int] = max_requests_per_minute
self.manual_max_tokens_per_minute: Optional[int] = max_tokens_per_minute
if max_retries is None:
self.max_retries = DEFAULT_MAX_RETRIES
else:
self.max_retries = max_retries

@property
def max_requests_per_minute(self) -> int:
if self.manual_max_requests_per_minute:
logger.info(
f"Manually set max_requests_per_minute to {self.manual_max_requests_per_minute}"
)
return self.manual_max_requests_per_minute
elif self.header_based_max_requests_per_minute:
logger.info(
f"Automatically set max_requests_per_minute to {self.header_based_max_requests_per_minute}"
)
return self.header_based_max_requests_per_minute
else:
logger.warning(
f"No manual {name} set, and headers based detection failed, using default value of {default_value}"
f"No manual max_requests_per_minute set, and headers based detection failed, using default value of {DEFAULT_MAX_REQUESTS_PER_MINUTE}"
)
return default_value

def get_rate_limits(self) -> dict:
"""Get rate limits for the API. Returns a dictionary with max_requests_per_minute and max_tokens_per_minute"""

# Get values from headers
header_based_rate_limits = self.get_header_based_rate_limits()
header_tpm = header_based_rate_limits["max_tokens_per_minute"]
header_rpm = header_based_rate_limits["max_requests_per_minute"]

# Determine final rate limit
tpm = self.get_rate_limit("max_tokens_per_minute", header_tpm)
rpm = self.get_rate_limit("max_requests_per_minute", header_rpm)

return {"max_requests_per_minute": rpm, "max_tokens_per_minute": tpm}
return DEFAULT_MAX_REQUESTS_PER_MINUTE

@abstractmethod
def get_header_based_rate_limits(self) -> dict:
"""Get rate limits for the API from headers. Returns a dictionary with max_requests_per_minute and max_tokens_per_minute"""
pass
@property
def max_tokens_per_minute(self) -> int:
if self.manual_max_tokens_per_minute:
logger.info(
f"Manually set max_tokens_per_minute to {self.manual_max_tokens_per_minute}"
)
return self.manual_max_tokens_per_minute
elif self.header_based_max_tokens_per_minute:
logger.info(
f"Automatically set max_tokens_per_minute to {self.header_based_max_tokens_per_minute}"
)
return self.header_based_max_tokens_per_minute
else:
logger.warning(
f"No manual max_tokens_per_minute set, and headers based detection failed, using default value of {DEFAULT_MAX_TOKENS_PER_MINUTE}"
)
return DEFAULT_MAX_TOKENS_PER_MINUTE

@abstractmethod
def estimate_total_tokens(self, messages: list) -> int:
Expand Down Expand Up @@ -213,7 +219,7 @@ def run(
self.process_requests_from_file(
generic_request_filepath=request_file,
save_filepath=response_file,
max_attempts=5,
max_attempts=self.max_retries,
resume=True,
)
)
Expand All @@ -235,9 +241,8 @@ async def process_requests_from_file(
status_tracker = StatusTracker()

# Get rate limits
rate_limits = self.get_rate_limits()
status_tracker.max_requests_per_minute = rate_limits["max_requests_per_minute"]
status_tracker.max_tokens_per_minute = rate_limits["max_tokens_per_minute"]
status_tracker.max_requests_per_minute = self.max_requests_per_minute
status_tracker.max_tokens_per_minute = self.max_tokens_per_minute

soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(
Expand Down Expand Up @@ -380,10 +385,10 @@ async def process_requests_from_file(
token_estimate = self.estimate_total_tokens(
retry_request.generic_request.messages
)
attempt_number = 6 - retry_request.attempts_left
attempt_number = 1 + self.max_retries - retry_request.attempts_left
logger.info(
f"Processing retry for request {retry_request.task_id} "
f"(attempt #{attempt_number} of 5). "
f"(attempt #{attempt_number} of {self.max_retries}). "
f"Previous errors: {retry_request.result}"
)

Expand Down Expand Up @@ -472,7 +477,7 @@ async def handle_single_request_with_retries(
retry_queue.put_nowait(request)
else:
logger.error(
f"Request {request.task_id} failed permanently after exhausting all 5 retry attempts. "
f"Request {request.task_id} failed permanently after exhausting all {self.max_retries} retry attempts. "
f"Errors: {[str(e) for e in request.result]}"
)
generic_response = GenericResponse(
Expand Down
46 changes: 29 additions & 17 deletions src/bespokelabs/curator/request_processor/base_request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from datasets.arrow_writer import ArrowWriter
from pydantic import BaseModel, ValidationError

from bespokelabs.curator.file_utilities import count_lines
from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
from bespokelabs.curator.request_processor.generic_request import GenericRequest
Expand All @@ -29,8 +30,9 @@ class BaseRequestProcessor(ABC):
Base class for all request processors.
"""

def __init__(self, batch_size: Optional[int] = None):
def __init__(self, batch_size: Optional[int] = None, require_all_responses: bool = True):
self.batch_size = batch_size
self.require_all_responses = require_all_responses
# Increase the number of open file descriptors to avoid "Too many open files" errors
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
desired_limit = min(10_000_000, hard)
Expand All @@ -39,16 +41,6 @@ def __init__(self, batch_size: Optional[int] = None):
)
resource.setrlimit(resource.RLIMIT_NOFILE, (desired_limit, hard))

@abstractmethod
def get_rate_limits(self) -> dict:
"""
Returns the rate limits for the API.
Returns:
dict: A dictionary containing the rate limit information.
"""
pass

@abstractmethod
def create_api_specific_request(self, generic_request: GenericRequest) -> dict:
"""
Expand Down Expand Up @@ -216,9 +208,6 @@ def create_dataset_files(
Returns:
Dataset: Completed dataset
"""
total_responses_count = 0
failed_responses_count = 0

responses_files = glob.glob(f"{working_dir}/responses_*.jsonl")
if len(responses_files) == 0:
raise ValueError(f"No responses files found in {working_dir}")
Expand All @@ -230,6 +219,8 @@ def create_dataset_files(
)

# Process all response files
total_responses_count = 0
failed_responses_count = 0
dataset_file = f"{working_dir}/{parse_func_hash}.arrow"
with ArrowWriter(path=dataset_file) as writer:
for responses_file in responses_files:
Expand Down Expand Up @@ -319,14 +310,35 @@ def create_dataset_files(

writer.write(row)

logger.info(f"Read {total_responses_count} responses, {failed_responses_count} failed")
logger.info("Finalizing writer")
writer.finalize()

logger.info(f"Read {total_responses_count} responses.")
if failed_responses_count == total_responses_count:
os.remove(dataset_file)
raise ValueError("All requests failed")

logger.info("Finalizing writer")
if failed_responses_count > 0:
logger.warning(f"{failed_responses_count} requests failed.")
if self.require_all_responses:
os.remove(dataset_file)
raise ValueError(f"Some requests failed and require_all_responses is True")

writer.finalize()
# number of responses matches number of requests
request_files = glob.glob(f"{working_dir}/requests_*.jsonl")
n_requests = 0
for request_file in request_files:
n_requests += count_lines(request_file)

if n_requests != total_responses_count:
logger.warning(
f"{n_requests - total_responses_count} requests do not have responses. n_requests is {n_requests} and n_responses is {total_responses_count}"
)
if self.require_all_responses:
os.remove(dataset_file)
raise ValueError(
f"Some requests do not have responses and require_all_responses is True."
)

return Dataset.from_file(dataset_file)

Expand Down
Loading

0 comments on commit 092d2d2

Please sign in to comment.