@@ -40,6 +40,8 @@ def __init__(
40
40
self ,
41
41
batch_size : int ,
42
42
model : str ,
43
+ delete_successful_batch_files : bool ,
44
+ delete_failed_batch_files : bool ,
43
45
temperature : float | None = None ,
44
46
top_p : float | None = None ,
45
47
check_interval : int = 10 ,
@@ -63,6 +65,8 @@ def __init__(
63
65
self .presence_penalty : float | None = presence_penalty
64
66
self .frequency_penalty : float | None = frequency_penalty
65
67
self ._file_lock = asyncio .Lock ()
68
+ self .delete_successful_batch_files : bool = delete_successful_batch_files
69
+ self .delete_failed_batch_files : bool = delete_failed_batch_files
66
70
67
71
def get_rate_limits (self ) -> dict :
68
72
"""
@@ -324,6 +328,8 @@ async def watch_batches():
324
328
check_interval = self .check_interval ,
325
329
n_submitted_requests = n_submitted_requests ,
326
330
prompt_formatter = prompt_formatter ,
331
+ delete_successful_batch_files = self .delete_successful_batch_files ,
332
+ delete_failed_batch_files = self .delete_failed_batch_files ,
327
333
)
328
334
await batch_watcher .watch ()
329
335
await batch_watcher .close_client ()
@@ -362,6 +368,8 @@ def __init__(
362
368
check_interval : int ,
363
369
prompt_formatter : PromptFormatter ,
364
370
n_submitted_requests : int ,
371
+ delete_successful_batch_files : bool ,
372
+ delete_failed_batch_files : bool ,
365
373
) -> None :
366
374
"""Initialize BatchWatcher with batch objects file and check interval.
367
375
@@ -386,6 +394,8 @@ def __init__(
386
394
self .remaining_batch_ids = set (self .batch_ids )
387
395
self .prompt_formatter = prompt_formatter
388
396
self .semaphore = asyncio .Semaphore (MAX_CONCURRENT_BATCH_OPERATIONS )
397
+ self .delete_successful_batch_files = delete_successful_batch_files
398
+ self .delete_failed_batch_files = delete_failed_batch_files
389
399
390
400
async def close_client (self ):
391
401
await self .client .close ()
@@ -521,9 +531,28 @@ async def watch(self) -> None:
521
531
"Please check the logs above and https://platform.openai.com/batches for errors."
522
532
)
523
533
534
+ async def delete_file (self , file_id : str , semaphore : asyncio .Semaphore ):
535
+ """
536
+ Delete a file by its ID.
537
+
538
+ Args:
539
+ file_id (str): The ID of the file to delete.
540
+ semaphore (asyncio.Semaphore): Semaphore to limit concurrent operations.
541
+ """
542
+ async with semaphore :
543
+ delete_response = await self .client .files .delete (file_id )
544
+ if delete_response .deleted :
545
+ logger .info (f"Deleted file { file_id } " )
546
+ else :
547
+ logger .warning (f"Failed to delete file { file_id } " )
548
+
524
549
async def download_batch_to_generic_responses_file (self , batch : Batch ) -> str | None :
525
550
"""Download the result of a completed batch to file.
526
551
552
+ To prevent an accumulation of files, we delete the batch input and output files
553
+ Without this the 100GB limit for files will be reached very quickly
554
+ The user can control this behavior with delete_successful_batch_files and delete_failed_batch_files
555
+
527
556
Args:
528
557
batch: The batch object to download results from.
529
558
@@ -537,16 +566,23 @@ async def download_batch_to_generic_responses_file(self, batch: Batch) -> str |
537
566
elif batch .status == "failed" and batch .error_file_id :
538
567
file_content = await self .client .files .content (batch .error_file_id )
539
568
logger .warning (f"Batch { batch .id } failed\n . Errors will be parsed below." )
569
+ if self .delete_failed_batch_files :
570
+ await self .delete_file (batch .input_file_id , self .semaphore )
571
+ await self .delete_file (batch .error_file_id , self .semaphore )
540
572
elif batch .status == "failed" and not batch .error_file_id :
541
573
errors = "\n " .join ([str (error ) for error in batch .errors .data ])
542
574
logger .error (
543
575
f"Batch { batch .id } failed and likely failed validation. "
544
576
f"Batch errors: { errors } . "
545
577
f"Check https://platform.openai.com/batches/{ batch .id } for more details."
546
578
)
579
+ if self .delete_failed_batch_files :
580
+ await self .delete_file (batch .input_file_id , self .semaphore )
547
581
return None
548
582
elif batch .status == "cancelled" or batch .status == "expired" :
549
583
logger .warning (f"Batch { batch .id } was cancelled or expired" )
584
+ if self .delete_failed_batch_files :
585
+ await self .delete_file (batch .input_file_id , self .semaphore )
550
586
return None
551
587
552
588
# Naming is consistent with the request file (e.g. requests_0.jsonl -> responses_0.jsonl)
@@ -627,5 +663,11 @@ async def download_batch_to_generic_responses_file(self, batch: Batch) -> str |
627
663
response_cost = cost ,
628
664
)
629
665
f .write (json .dumps (generic_response .model_dump (), default = str ) + "\n " )
666
+
630
667
logger .info (f"Batch { batch .id } written to { response_file } " )
668
+
669
+ if self .delete_successful_batch_files :
670
+ await self .delete_file (batch .input_file_id , self .semaphore )
671
+ await self .delete_file (batch .output_file_id , self .semaphore )
672
+
631
673
return response_file
0 commit comments