Skip to content

Commit

Permalink
Merge pull request #195 from bespokelabsai/ryanm/delete-successful-ba…
Browse files Browse the repository at this point in the history
…tches

Delete input and output files for successful batches
  • Loading branch information
RyanMarten authored Dec 4, 2024
2 parents 84eae2d + 919c9f7 commit 44d5b72
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/bespokelabs/curator/prompter/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(
top_p: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
delete_successful_batch_files: bool = True,
delete_failed_batch_files: bool = False, # To allow users to debug failed batches
):
"""Initialize a Prompter.
Expand Down Expand Up @@ -99,6 +101,8 @@ def __init__(
top_p=top_p,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
delete_successful_batch_files=delete_successful_batch_files,
delete_failed_batch_files=delete_failed_batch_files,
)
else:
if batch_size is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(
self,
batch_size: int,
model: str,
delete_successful_batch_files: bool,
delete_failed_batch_files: bool,
temperature: float | None = None,
top_p: float | None = None,
check_interval: int = 10,
Expand All @@ -63,6 +65,8 @@ def __init__(
self.presence_penalty: float | None = presence_penalty
self.frequency_penalty: float | None = frequency_penalty
self._file_lock = asyncio.Lock()
self.delete_successful_batch_files: bool = delete_successful_batch_files
self.delete_failed_batch_files: bool = delete_failed_batch_files

def get_rate_limits(self) -> dict:
"""
Expand Down Expand Up @@ -324,6 +328,8 @@ async def watch_batches():
check_interval=self.check_interval,
n_submitted_requests=n_submitted_requests,
prompt_formatter=prompt_formatter,
delete_successful_batch_files=self.delete_successful_batch_files,
delete_failed_batch_files=self.delete_failed_batch_files,
)
await batch_watcher.watch()
await batch_watcher.close_client()
Expand Down Expand Up @@ -362,6 +368,8 @@ def __init__(
check_interval: int,
prompt_formatter: PromptFormatter,
n_submitted_requests: int,
delete_successful_batch_files: bool,
delete_failed_batch_files: bool,
) -> None:
"""Initialize BatchWatcher with batch objects file and check interval.
Expand All @@ -386,6 +394,8 @@ def __init__(
self.remaining_batch_ids = set(self.batch_ids)
self.prompt_formatter = prompt_formatter
self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_BATCH_OPERATIONS)
self.delete_successful_batch_files = delete_successful_batch_files
self.delete_failed_batch_files = delete_failed_batch_files

async def close_client(self):
await self.client.close()
Expand Down Expand Up @@ -521,9 +531,28 @@ async def watch(self) -> None:
"Please check the logs above and https://platform.openai.com/batches for errors."
)

async def delete_file(self, file_id: str, semaphore: asyncio.Semaphore):
"""
Delete a file by its ID.
Args:
file_id (str): The ID of the file to delete.
semaphore (asyncio.Semaphore): Semaphore to limit concurrent operations.
"""
async with semaphore:
delete_response = await self.client.files.delete(file_id)
if delete_response.deleted:
logger.info(f"Deleted file {file_id}")
else:
logger.warning(f"Failed to delete file {file_id}")

async def download_batch_to_generic_responses_file(self, batch: Batch) -> str | None:
"""Download the result of a completed batch to file.
To prevent an accumulation of files, we delete the batch input and output files
Without this the 100GB limit for files will be reached very quickly
The user can control this behavior with delete_successful_batch_files and delete_failed_batch_files
Args:
batch: The batch object to download results from.
Expand All @@ -537,16 +566,23 @@ async def download_batch_to_generic_responses_file(self, batch: Batch) -> str |
elif batch.status == "failed" and batch.error_file_id:
file_content = await self.client.files.content(batch.error_file_id)
logger.warning(f"Batch {batch.id} failed\n. Errors will be parsed below.")
if self.delete_failed_batch_files:
await self.delete_file(batch.input_file_id, self.semaphore)
await self.delete_file(batch.error_file_id, self.semaphore)
elif batch.status == "failed" and not batch.error_file_id:
errors = "\n".join([str(error) for error in batch.errors.data])
logger.error(
f"Batch {batch.id} failed and likely failed validation. "
f"Batch errors: {errors}. "
f"Check https://platform.openai.com/batches/{batch.id} for more details."
)
if self.delete_failed_batch_files:
await self.delete_file(batch.input_file_id, self.semaphore)
return None
elif batch.status == "cancelled" or batch.status == "expired":
logger.warning(f"Batch {batch.id} was cancelled or expired")
if self.delete_failed_batch_files:
await self.delete_file(batch.input_file_id, self.semaphore)
return None

# Naming is consistent with the request file (e.g. requests_0.jsonl -> responses_0.jsonl)
Expand Down Expand Up @@ -627,5 +663,11 @@ async def download_batch_to_generic_responses_file(self, batch: Batch) -> str |
response_cost=cost,
)
f.write(json.dumps(generic_response.model_dump(), default=str) + "\n")

logger.info(f"Batch {batch.id} written to {response_file}")

if self.delete_successful_batch_files:
await self.delete_file(batch.input_file_id, self.semaphore)
await self.delete_file(batch.output_file_id, self.semaphore)

return response_file

0 comments on commit 44d5b72

Please sign in to comment.