Skip to content

Commit

Permalink
Merge pull request #101 from bespokelabsai/trung/async-cleanup
Browse files Browse the repository at this point in the history
Explicitly close AsyncClient to avoid getting asyncio event loop is closed issues
  • Loading branch information
CharlieJCJ authored Nov 14, 2024
2 parents 3fc7b30 + 25c1ffe commit 4de0c76
Showing 1 changed file with 19 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def create_api_specific_request(
return request

async def asubmit_batch(self, batch_file: str) -> dict:
async_client = AsyncOpenAI()
# Create a list to store API-specific requests
api_specific_requests = []

Expand All @@ -147,13 +148,13 @@ async def asubmit_batch(self, batch_file: str) -> dict:
# Join requests with newlines and encode to bytes for upload
file_content = "\n".join(api_specific_requests).encode()

batch_file_upload = await self.async_client.files.create(
batch_file_upload = await async_client.files.create(
file=file_content, purpose="batch"
)

logger.info(f"File uploaded: {batch_file_upload}")

batch_object = await self.async_client.batches.create(
batch_object = await async_client.batches.create(
input_file_id=batch_file_upload.id,
endpoint="/v1/chat/completions",
completion_window="24h",
Expand All @@ -164,6 +165,9 @@ async def asubmit_batch(self, batch_file: str) -> dict:
logger.info(
f"Batch request submitted, received batch object: {batch_object}"
)
# Explicitly close the client. Otherwise we get something like
# future: <Task finished name='Task-46' coro=<AsyncClient.aclose() done ... >>
await async_client.close()

return batch_object

Expand Down Expand Up @@ -198,8 +202,6 @@ def run(
)
else:
# upload requests files and submit batches
self.async_client = AsyncOpenAI()

# asyncio gather preserves order
async def submit_all_batches():
tasks = [
Expand All @@ -226,18 +228,21 @@ async def submit_all_batches():
# TODO(Ryan): This creates responses_0.jsonl, responses_1.jsonl, etc. errors named same way? or errors_0.jsonl, errors_1.jsonl?
# TODO(Ryan): retries, resubmits on lagging batches - need to study this a little closer
# TODO(Ryan): likely can add some logic for smarter check_interval based on batch size and if the batch has started or not, fine to do a dumb ping for now
batch_watcher = BatchWatcher(
working_dir, check_interval=self.check_interval
)

# NOTE(Ryan): If we allow for multiple heterogeneous requests per dataset row, we will need to update this.
total_requests = 1 if dataset is None else len(dataset)

run_in_event_loop(
batch_watcher.watch(
async def watch_batches():
batch_watcher = BatchWatcher(
working_dir, check_interval=self.check_interval
)
await batch_watcher.watch(
prompt_formatter.response_format, total_requests
)
)
# Explicitly close the client. Otherwise we get something like
# future: <Task finished name='Task-46' coro=<AsyncClient.aclose() done ... >>
await batch_watcher.close_client()

run_in_event_loop(watch_batches())

dataset = self.create_dataset_files(
working_dir, parse_func_hash, prompt_formatter
Expand Down Expand Up @@ -266,6 +271,9 @@ def __init__(self, working_dir: str, check_interval) -> None:
self.check_interval = check_interval
self.working_dir = working_dir

async def close_client(self):
await self.client.close()

async def check_batch_status(self, batch_id: str) -> tuple[str, str]:
"""Check the status of a batch by its ID.
Expand Down

0 comments on commit 4de0c76

Please sign in to comment.