Skip to content

Commit

Permalink
fix batch pbar on resume
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMarten committed Dec 4, 2024
1 parent 84eae2d commit 7cfbfef
Showing 1 changed file with 33 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def __init__(
self.tracker = BatchStatusTracker()
self.tracker.n_submitted_batches = len(self.batch_ids)
self.tracker.n_submitted_requests = n_submitted_requests
self.pbar = None
self.remaining_batch_ids = set(self.batch_ids)
self.prompt_formatter = prompt_formatter
self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_BATCH_OPERATIONS)
Expand Down Expand Up @@ -447,33 +448,42 @@ async def check_batch_status(self, batch_id: str) -> Batch | None:
self.tracker.n_failed_in_progress_requests += n_failed_requests
return None

async def resume(self, batch_id: str, all_response_files: set[str]):
request_file = self.batch_id_to_request_file_name[batch_id]
request_file_idx = request_file.split("/")[-1].split("_", 1)[1]
response_file = f"{self.working_dir}/responses_{request_file_idx}"
if response_file in all_response_files:
logger.info(
f"File {response_file} found for batch {batch_id}, skipping status check and download."
)
self.remaining_batch_ids.remove(batch_id)
# Update progress bar based on completed responses found in files
async with aiofiles.open(response_file, "r") as f:
n_completed = 0
async for _ in f:
n_completed += 1
self.tracker.n_completed_returned_requests += n_completed
self.tracker.n_returned_batches += 1
self.tracker.n_completed_batches += 1 # asumming completed instead of failed

async def watch(self) -> None:
"""Monitor the status of batches until all are completed (includes successfully, failed, expired or cancelled)."""
# progress bar for completed requests
pbar = tqdm(
self.pbar = tqdm(
total=self.tracker.n_submitted_requests,
desc="Completed OpenAI requests in batches",
unit="request",
)

# resume from already downloaded batches
response_files_found = 0
all_response_files = set(glob.glob(f"{self.working_dir}/responses_*.jsonl"))
for batch_id, request_file in self.batch_id_to_request_file_name.items():
request_file_idx = request_file.split("/")[-1].split("_", 1)[1]
response_file = f"{self.working_dir}/responses_{request_file_idx}"
if response_file in all_response_files:
logger.info(
f"File {response_file} found for batch {batch_id}, skipping status check and download."
)
self.remaining_batch_ids.remove(batch_id)
response_files_found += 1
if response_files_found > 0:
tasks = [self.resume(batch_id, all_response_files) for batch_id in self.remaining_batch_ids]
await asyncio.gather(*tasks)

if self.tracker.n_returned_batches > 0:
logger.info(
f"Found {response_files_found} out of {self.tracker.n_submitted_batches} completed batches, resuming polling for the remaining {len(self.remaining_batch_ids)} batches."
f"Found {self.tracker.n_returned_batches} out of {self.tracker.n_submitted_batches} completed batches, resuming polling for the remaining {len(self.remaining_batch_ids)} batches."
)
self.tracker.n_completed_batches += response_files_found # here we are assuming they are completed, but they also could have failed
self.tracker.n_returned_batches += response_files_found
all_response_files = list(all_response_files)

# loop until all batches have been returned
Expand All @@ -491,12 +501,12 @@ async def watch(self) -> None:
batches_to_download = filter(None, batches_to_download)

# update progress bar
pbar.n = 0
pbar.n += self.tracker.n_completed_returned_requests
pbar.n += self.tracker.n_failed_returned_requests
pbar.n += self.tracker.n_completed_in_progress_requests
pbar.n += self.tracker.n_failed_in_progress_requests
pbar.refresh()
self.pbar.n = 0
self.pbar.n += self.tracker.n_completed_returned_requests
self.pbar.n += self.tracker.n_failed_returned_requests
self.pbar.n += self.tracker.n_completed_in_progress_requests
self.pbar.n += self.tracker.n_failed_in_progress_requests
self.pbar.refresh()

download_tasks = [
self.download_batch_to_generic_responses_file(batch)
Expand All @@ -508,12 +518,12 @@ async def watch(self) -> None:
if self.tracker.n_returned_batches < self.tracker.n_submitted_batches:
logger.debug(
f"Batches returned: {self.tracker.n_returned_batches}/{self.tracker.n_submitted_batches} "
f"Requests completed: {pbar.n}/{self.tracker.n_submitted_requests}"
f"Requests completed: {self.pbar.n}/{self.tracker.n_submitted_requests}"
)
logger.debug(f"Sleeping for {self.check_interval} seconds...")
await asyncio.sleep(self.check_interval)

pbar.close()
self.pbar.close()
response_files = filter(None, all_response_files)
if self.tracker.n_completed_batches == 0 or not response_files:
raise RuntimeError(
Expand Down

0 comments on commit 7cfbfef

Please sign in to comment.