diff --git a/examples/camel.py b/examples/camel.py index ec3e7362..bffa0507 100644 --- a/examples/camel.py +++ b/examples/camel.py @@ -1,4 +1,3 @@ -import asyncio from typing import List from pydantic import BaseModel, Field diff --git a/examples/poem.py b/examples/poem.py index ffb8c5a5..e8e50d07 100644 --- a/examples/poem.py +++ b/examples/poem.py @@ -18,7 +18,7 @@ class Topics(BaseModel): # We define a prompter that generates topics. topic_generator = curator.Prompter( - prompt_func=lambda: f"Generate 10 diverse topics that are suitable for writing poems about.", + prompt_func=lambda: "Generate 10 diverse topics that are suitable for writing poems about.", model_name="gpt-4o-mini", response_format=Topics, parse_func=lambda _, topics: [{"topic": t} for t in topics.topics_list], diff --git a/poetry.lock b/poetry.lock index 13174123..0abbbcff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,14 +1,14 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiofiles" -version = "24.1.0" +version = "23.2.1" description = "File support for asyncio." optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, - {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, + {file = "aiofiles-23.2.1-py3-none-any.whl", hash = "sha256:19297512c647d4b27a2cf7c34caa7e405c0d60b5560618a29a9fe027b18b0107"}, + {file = "aiofiles-23.2.1.tar.gz", hash = "sha256:84ec2218d8419404abcb9f0c02df3f34c6e0a68ed41072acfb1cef5cbc29051a"}, ] [[package]] @@ -3063,42 +3063,47 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tiktoken" -version = "0.8.0" +version = "0.7.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" optional = false -python-versions = ">=3.9" +python-versions = ">=3.8" files = [ - {file = "tiktoken-0.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b07e33283463089c81ef1467180e3e00ab00d46c2c4bbcef0acab5f771d6695e"}, - {file = "tiktoken-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9269348cb650726f44dd3bbb3f9110ac19a8dcc8f54949ad3ef652ca22a38e21"}, - {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e13f37bc4ef2d012731e93e0fef21dc3b7aea5bb9009618de9a4026844e560"}, - {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f13d13c981511331eac0d01a59b5df7c0d4060a8be1e378672822213da51e0a2"}, - {file = "tiktoken-0.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6b2ddbc79a22621ce8b1166afa9f9a888a664a579350dc7c09346a3b5de837d9"}, - {file = "tiktoken-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d8c2d0e5ba6453a290b86cd65fc51fedf247e1ba170191715b049dac1f628005"}, - {file = "tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1"}, - {file = "tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a"}, - {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d"}, - {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47"}, - {file = "tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419"}, - {file = "tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99"}, - {file = "tiktoken-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:881839cfeae051b3628d9823b2e56b5cc93a9e2efb435f4cf15f17dc45f21586"}, - {file = "tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe9399bdc3f29d428f16a2f86c3c8ec20be3eac5f53693ce4980371c3245729b"}, - {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a58deb7075d5b69237a3ff4bb51a726670419db6ea62bdcd8bd80c78497d7ab"}, - {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2908c0d043a7d03ebd80347266b0e58440bdef5564f84f4d29fb235b5df3b04"}, - {file = "tiktoken-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:294440d21a2a51e12d4238e68a5972095534fe9878be57d905c476017bff99fc"}, - {file = "tiktoken-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d8f3192733ac4d77977432947d563d7e1b310b96497acd3c196c9bddb36ed9db"}, - {file = "tiktoken-0.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:02be1666096aff7da6cbd7cdaa8e7917bfed3467cd64b38b1f112e96d3b06a24"}, - {file = "tiktoken-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94ff53c5c74b535b2cbf431d907fc13c678bbd009ee633a2aca269a04389f9a"}, - {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b231f5e8982c245ee3065cd84a4712d64692348bc609d84467c57b4b72dcbc5"}, - {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4177faa809bd55f699e88c96d9bb4635d22e3f59d635ba6fd9ffedf7150b9953"}, - {file = "tiktoken-0.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5376b6f8dc4753cd81ead935c5f518fa0fbe7e133d9e25f648d8c4dabdd4bad7"}, - {file = "tiktoken-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:18228d624807d66c87acd8f25fc135665617cab220671eb65b50f5d70fa51f69"}, - {file = "tiktoken-0.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7e17807445f0cf1f25771c9d86496bd8b5c376f7419912519699f3cc4dc5c12e"}, - {file = "tiktoken-0.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:886f80bd339578bbdba6ed6d0567a0d5c6cfe198d9e587ba6c447654c65b8edc"}, - {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6adc8323016d7758d6de7313527f755b0fc6c72985b7d9291be5d96d73ecd1e1"}, - {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b591fb2b30d6a72121a80be24ec7a0e9eb51c5500ddc7e4c2496516dd5e3816b"}, - {file = "tiktoken-0.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:845287b9798e476b4d762c3ebda5102be87ca26e5d2c9854002825d60cdb815d"}, - {file = "tiktoken-0.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:1473cfe584252dc3fa62adceb5b1c763c1874e04511b197da4e6de51d6ce5a02"}, - {file = "tiktoken-0.8.0.tar.gz", hash = "sha256:9ccbb2740f24542534369c5635cfd9b2b3c2490754a78ac8831d99f89f94eeb2"}, + {file = "tiktoken-0.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485f3cc6aba7c6b6ce388ba634fbba656d9ee27f766216f45146beb4ac18b25f"}, + {file = "tiktoken-0.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e54be9a2cd2f6d6ffa3517b064983fb695c9a9d8aa7d574d1ef3c3f931a99225"}, + {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79383a6e2c654c6040e5f8506f3750db9ddd71b550c724e673203b4f6b4b4590"}, + {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d4511c52caacf3c4981d1ae2df85908bd31853f33d30b345c8b6830763f769c"}, + {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:13c94efacdd3de9aff824a788353aa5749c0faee1fbe3816df365ea450b82311"}, + {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8e58c7eb29d2ab35a7a8929cbeea60216a4ccdf42efa8974d8e176d50c9a3df5"}, + {file = "tiktoken-0.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:21a20c3bd1dd3e55b91c1331bf25f4af522c525e771691adbc9a69336fa7f702"}, + {file = "tiktoken-0.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:10c7674f81e6e350fcbed7c09a65bca9356eaab27fb2dac65a1e440f2bcfe30f"}, + {file = "tiktoken-0.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:084cec29713bc9d4189a937f8a35dbdfa785bd1235a34c1124fe2323821ee93f"}, + {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:811229fde1652fedcca7c6dfe76724d0908775b353556d8a71ed74d866f73f7b"}, + {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b6e7dc2e7ad1b3757e8a24597415bafcfb454cebf9a33a01f2e6ba2e663992"}, + {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1063c5748be36344c7e18c7913c53e2cca116764c2080177e57d62c7ad4576d1"}, + {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:20295d21419bfcca092644f7e2f2138ff947a6eb8cfc732c09cc7d76988d4a89"}, + {file = "tiktoken-0.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:959d993749b083acc57a317cbc643fb85c014d055b2119b739487288f4e5d1cb"}, + {file = "tiktoken-0.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:71c55d066388c55a9c00f61d2c456a6086673ab7dec22dd739c23f77195b1908"}, + {file = "tiktoken-0.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09ed925bccaa8043e34c519fbb2f99110bd07c6fd67714793c21ac298e449410"}, + {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03c6c40ff1db0f48a7b4d2dafeae73a5607aacb472fa11f125e7baf9dce73704"}, + {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20b5c6af30e621b4aca094ee61777a44118f52d886dbe4f02b70dfe05c15350"}, + {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d427614c3e074004efa2f2411e16c826f9df427d3c70a54725cae860f09e4bf4"}, + {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8c46d7af7b8c6987fac9b9f61041b452afe92eb087d29c9ce54951280f899a97"}, + {file = "tiktoken-0.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:0bc603c30b9e371e7c4c7935aba02af5994a909fc3c0fe66e7004070858d3f8f"}, + {file = "tiktoken-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2398fecd38c921bcd68418675a6d155fad5f5e14c2e92fcf5fe566fa5485a858"}, + {file = "tiktoken-0.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8f5f6afb52fb8a7ea1c811e435e4188f2bef81b5e0f7a8635cc79b0eef0193d6"}, + {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:861f9ee616766d736be4147abac500732b505bf7013cfaf019b85892637f235e"}, + {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54031f95c6939f6b78122c0aa03a93273a96365103793a22e1793ee86da31685"}, + {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:fffdcb319b614cf14f04d02a52e26b1d1ae14a570f90e9b55461a72672f7b13d"}, + {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c72baaeaefa03ff9ba9688624143c858d1f6b755bb85d456d59e529e17234769"}, + {file = "tiktoken-0.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:131b8aeb043a8f112aad9f46011dced25d62629091e51d9dc1adbf4a1cc6aa98"}, + {file = "tiktoken-0.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cabc6dc77460df44ec5b879e68692c63551ae4fae7460dd4ff17181df75f1db7"}, + {file = "tiktoken-0.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8d57f29171255f74c0aeacd0651e29aa47dff6f070cb9f35ebc14c82278f3b25"}, + {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ee92776fdbb3efa02a83f968c19d4997a55c8e9ce7be821ceee04a1d1ee149c"}, + {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e215292e99cb41fbc96988ef62ea63bb0ce1e15f2c147a61acc319f8b4cbe5bf"}, + {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a81bac94769cab437dd3ab0b8a4bc4e0f9cf6835bcaa88de71f39af1791727a"}, + {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d6d73ea93e91d5ca771256dfc9d1d29f5a554b83821a1dc0891987636e0ae226"}, + {file = "tiktoken-0.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:2bcb28ddf79ffa424f171dfeef9a4daff61a94c631ca6813f43967cb263b83b9"}, + {file = "tiktoken-0.7.0.tar.gz", hash = "sha256:1077266e949c24e0291f6c350433c6f0971365ece2b173a23bc3b9f9defef6b6"}, ] [package.dependencies] @@ -3597,4 +3602,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "3604f19ac9d9dd28454528f2623f2b638bbd985d12810f4d99934d2bd11a3294" +content-hash = "12084fbca319156982a26115e12c65410ac08b2e18678a27c0b807d5f375866c" diff --git a/pyproject.toml b/pyproject.toml index 0e622361..ad77fa72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "bespokelabs-curator" -version = "0.1.9post1" +version = "0.1.10" description = "Bespoke Labs Curator" authors = ["Bespoke Labs "] readme = "README.md" @@ -29,12 +29,12 @@ pandas = "2.2.2" xxhash = "^3.5.0" tqdm = "^4.67.0" matplotlib = "^3.9.2" -aiofiles = "^24.1.0" -tiktoken = "^0.8.0" nest-asyncio = "^1.6.0" rich = "^13.7.0" litellm = "^1.52.11" isort = "^5.13.2" +tiktoken = ">=0.7.0,<0.8.0" +aiofiles = ">=22.0,<24.0" [tool.poetry.group.dev.dependencies] black = "^24.2.0" diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 6025b957..abfbdad8 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -203,6 +203,8 @@ def _completions( ) fingerprint = xxh64(fingerprint_str.encode("utf-8")).hexdigest() + logger.debug(f"Curator Cache Fingerprint: {fingerprint}") + metadata_db_path = os.path.join(curator_cache_dir, "metadata.db") metadata_db = MetadataDB(metadata_db_path) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 1b2254fb..dd2d095e 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -177,7 +177,7 @@ def create_dataset_files( working_dir: str, parse_func_hash: str, prompt_formatter: PromptFormatter, - ) -> None: + ) -> Dataset: """ Creates the request files if they don't already exist or use existing. A single request file (requests_0.jsonl) or multiple request files @@ -217,7 +217,7 @@ def create_dataset_files( return output_dataset error_help = ( - f"Please check your `parse_func` is returning a valid row (dict) " + "Please check your `parse_func` is returning a valid row (dict) " "or list of rows (list of dicts) and re-run. " "Dataset will be regenerated from cached LLM responses." ) @@ -314,9 +314,7 @@ def create_dataset_files( writer.finalize() - output_dataset = Dataset.from_file(dataset_file) - - return output_dataset + return Dataset.from_file(dataset_file) def parse_response_message( diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index e6289ed2..e200e505 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -3,9 +3,11 @@ import json import logging import os +import resource from dataclasses import dataclass import aiofiles +import glob import litellm from openai import AsyncOpenAI from openai.types import Batch @@ -27,6 +29,10 @@ MAX_REQUESTS_PER_BATCH = 50_000 MAX_BYTES_PER_BATCH = 200 * 1024 * 1024 +# NOTE(Ryan): This allows us to stay under the rate limit when submitting ~1,000 batches at a time +# When submitting >1,000 batches the batch submission and batch download operations get rate limited +MAX_CONCURRENT_BATCH_OPERATIONS = 100 + class OpenAIBatchRequestProcessor(BaseRequestProcessor): def __init__( @@ -55,6 +61,7 @@ def __init__( self.top_p: float | None = top_p self.presence_penalty: float | None = presence_penalty self.frequency_penalty: float | None = frequency_penalty + self._file_lock = asyncio.Lock() def get_rate_limits(self) -> dict: """ @@ -150,59 +157,91 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: return request - async def asubmit_batch(self, batch_file: str) -> dict: - async_client = AsyncOpenAI() - # Create a list to store API-specific requests - api_specific_requests = [] - - line_count = 0 - async with aiofiles.open(batch_file, "r") as file: - file_content = await file.read() - for line in file_content.splitlines(): - request = GenericRequest.model_validate_json(line) - api_specific_request = self.create_api_specific_request(request) - api_specific_requests.append(json.dumps(api_specific_request)) - line_count += 1 - - if line_count > MAX_REQUESTS_PER_BATCH: - raise ValueError( - f"Batch file {batch_file} contains {line_count:,} requests, " - f"which is more than the maximum of {MAX_REQUESTS_PER_BATCH:,} requests per batch that OpenAI supports. " - f"Preventing batch submission." - ) + async def asubmit_batch( + self, batch_file: str, semaphore: asyncio.Semaphore | None = None + ) -> dict: + async with semaphore or asyncio.Semaphore(): # Use provided semaphore or dummy one + async_client = AsyncOpenAI() + # Create a list to store API-specific requests + api_specific_requests = [] + + line_count = 0 + async with aiofiles.open(batch_file, "r") as file: + file_content = await file.read() + for line in file_content.splitlines(): + request = GenericRequest.model_validate_json(line) + api_specific_request = self.create_api_specific_request(request) + api_specific_requests.append(json.dumps(api_specific_request)) + line_count += 1 + + if line_count > MAX_REQUESTS_PER_BATCH: + raise ValueError( + f"Batch file {batch_file} contains {line_count:,} requests, " + f"which is more than the maximum of {MAX_REQUESTS_PER_BATCH:,} requests per batch that OpenAI supports. " + f"Preventing batch submission." + ) - # Join requests with newlines and encode to bytes for upload - file_content = "\n".join(api_specific_requests).encode() - file_content_size = len(file_content) - logger.debug( - f"Batch file content size: {file_content_size / (1024*1024):.2f} MB ({file_content_size:,} bytes)" - ) - if file_content_size > MAX_BYTES_PER_BATCH: - raise ValueError( - f"Batch file content size {file_content_size:,} bytes " - f"is greater than the maximum of {MAX_BYTES_PER_BATCH:,} bytes per batch that OpenAI supports. " - f"Please reduce your batch size or request content size (via prompt_func and response_format)." + # Join requests with newlines and encode to bytes for upload + file_content = "\n".join(api_specific_requests).encode() + file_content_size = len(file_content) + logger.debug( + f"Batch file content size: {file_content_size / (1024*1024):.2f} MB ({file_content_size:,} bytes)" ) + if file_content_size > MAX_BYTES_PER_BATCH: + raise ValueError( + f"Batch file content size {file_content_size:,} bytes " + f"is greater than the maximum of {MAX_BYTES_PER_BATCH:,} bytes per batch that OpenAI supports. " + f"Please reduce your batch size or request content size (via prompt_func and response_format)." + ) - # this let's you upload a file that is larger than 200MB and won't error, so we catch it above - batch_file_upload = await async_client.files.create(file=file_content, purpose="batch") + try: + # this let's you upload a file that is larger than 200MB and won't error, so we catch it above + batch_file_upload = await async_client.files.create( + file=file_content, purpose="batch" + ) + except Exception as e: + logger.error(f"Error uploading batch file: {e}") + raise e - logger.info(f"File uploaded: {batch_file_upload}") + # When submitting a file, sometimes the file is not ready immediately for status checking + await asyncio.sleep(1) - batch_object = await async_client.batches.create( - input_file_id=batch_file_upload.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={ - "request_file_name": batch_file - }, # for downloading the batch to similarly named responses file - ) - logger.info(f"Batch request submitted, received batch object: {batch_object}") - # Explicitly close the client. Otherwise we get something like - # future: > - await async_client.close() + try: + batch_file_upload = await async_client.files.wait_for_processing( + batch_file_upload.id + ) + except Exception as e: + logger.error(f"Error waiting for batch file to be processed: {e}") + raise e + + logger.info(f"File uploaded with id {batch_file_upload.id}") + try: + batch_object = await async_client.batches.create( + input_file_id=batch_file_upload.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "request_file_name": batch_file + }, # for downloading the batch to similarly named responses file + ) + logger.info(f"Batch submitted with id {batch_object.id}") + except Exception as e: + logger.error(f"Error submitting batch: {e}") + raise e + + # NOTE(Ryan): we can also store the request_file_name in this object here, instead of in the metadata during batch submission. + # If we create "generic batch object" maybe we can find a nice abstraction across other batch APIs (e.g. claude) + async with self._file_lock: + async with aiofiles.open( + f"{os.path.dirname(batch_file)}/batch_objects.jsonl", "a" + ) as f: + await f.write(json.dumps(batch_object.model_dump(), default=str) + "\n") + await f.flush() + + # Explicitly close the client + await async_client.close() - return batch_object + return batch_object def run( self, @@ -223,28 +262,47 @@ def run( Returns: Dataset: Completed dataset """ - requests_files = self.create_request_files(dataset, working_dir, prompt_formatter) + # Increase the number of open file descriptors to avoid "Too many open files" errors + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + desired_limit = min(1_000_000, hard) + logger.info( + f"Adjusting file descriptor limit from {soft} to {desired_limit} (hard limit: {hard})" + ) + resource.setrlimit(resource.RLIMIT_NOFILE, (desired_limit, hard)) + + requests_files = set(self.create_request_files(dataset, working_dir, prompt_formatter)) batch_objects_file = f"{working_dir}/batch_objects.jsonl" - # TODO(Ryan): we should have an easy way to cancel all batches in batch_objects.jsonl if the user realized they made a mistake if os.path.exists(batch_objects_file): - logger.warning( - f"Batch objects file already exists, skipping batch submission and resuming: {batch_objects_file}" - ) - else: - # upload requests files and submit batches - # asyncio gather preserves order - async def submit_all_batches(): - tasks = [self.asubmit_batch(requests_files[i]) for i in range(len(requests_files))] - return await asyncio.gather(*tasks) + # Read existing batch objects from file + already_submitted = 0 + with open(batch_objects_file, "r") as f: + for line in f: + batch_object = json.loads(line) + request_file_name = batch_object["metadata"]["request_file_name"] + logger.info(f"Batch job already submitted for request file {request_file_name}") + requests_files.remove(request_file_name) + already_submitted += 1 + + if len(requests_files) == 0: + logger.info( + f"All {already_submitted:,} batches completed, skipping batch submission" + ) + else: + logger.info( + f"{already_submitted:,} out of {len(requests_files)+already_submitted:,} batch jobs already submitted. Submitting the remaining {len(requests_files):,} batch jobs." + ) - batch_objects = run_in_event_loop(submit_all_batches()) + # Limit concurrent batch submissions to 100 so avoid overwhelming the API + async def submit_all_batches(): + semaphore = asyncio.Semaphore(MAX_CONCURRENT_BATCH_OPERATIONS) + tasks = [ + self.asubmit_batch(requests_file, semaphore) for requests_file in requests_files + ] + return await asyncio.gather(*tasks) - with open(batch_objects_file, "w") as f: - # NOTE(Ryan): we can also store the request_file_name in this object here, instead of in the metadata during batch submission. Can find a nice abstraction across other batch APIs (e.g. claude) - for obj in batch_objects: - f.write(json.dumps(obj.model_dump(), default=str) + "\n") - logger.info(f"Batch objects written to {batch_objects_file}") + run_in_event_loop(submit_all_batches()) + logger.info(f"All batch objects submitted and written to {batch_objects_file}") # TODO(Ryan): Actually do accounting for tokens, so rate limits enforced locally. # NOTE(Ryan): Although this isn't really practical since the limits are for an entire day and an entire organization. Maybe skip this and just recognize what a rate limit error for batching looks like (need to try this on a low tier account). @@ -267,8 +325,6 @@ async def watch_batches(): prompt_formatter=prompt_formatter, ) await batch_watcher.watch() - # Explicitly close the client. Otherwise we get something like - # future: > await batch_watcher.close_client() run_in_event_loop(watch_batches()) @@ -328,6 +384,7 @@ def __init__( self.tracker.n_submitted_requests = n_submitted_requests self.remaining_batch_ids = set(self.batch_ids) self.prompt_formatter = prompt_formatter + self.semaphore = asyncio.Semaphore(MAX_CONCURRENT_BATCH_OPERATIONS) async def close_client(self): await self.client.close() @@ -341,52 +398,53 @@ async def check_batch_status(self, batch_id: str) -> Batch | None: Returns: Batch: The batch object. None if the batch has not returned yet. """ - batch = await self.client.batches.retrieve(batch_id) - assert batch.id == batch_id - - n_completed_requests = batch.request_counts.completed - n_failed_requests = batch.request_counts.failed - n_total_requests = batch.request_counts.total - - logger.debug( - f"Batch {batch.id} status: {batch.status} requests: " - f"{n_completed_requests}/{n_failed_requests}/{n_total_requests} " - "completed/failed/total" - ) + async with self.semaphore: + batch = await self.client.batches.retrieve(batch_id) + assert batch.id == batch_id + + n_completed_requests = batch.request_counts.completed + n_failed_requests = batch.request_counts.failed + n_total_requests = batch.request_counts.total + + logger.debug( + f"Batch {batch.id} status: {batch.status} requests: " + f"{n_completed_requests}/{n_failed_requests}/{n_total_requests} " + "completed/failed/total" + ) - batch_returned = False - if batch.status == "completed": - self.tracker.n_completed_batches += 1 - batch_returned = True - elif batch.status == "failed": - self.tracker.n_failed_batches += 1 - batch_returned = True - elif batch.status == "expired": - self.tracker.n_expired_batches += 1 - batch_returned = True - elif batch.status == "cancelled": - self.tracker.n_cancelled_batches += 1 - batch_returned = True - else: - if batch.status not in [ - "validating", - "finalizing", - "cancelling", - "in_progress", - ]: - logger.warning(f"Unknown batch status: {batch.status}") - - if batch_returned: - logger.info(f"Batch {batch.id} returned with status: {batch.status}") - self.tracker.n_returned_batches += 1 - self.tracker.n_completed_returned_requests += n_completed_requests - self.tracker.n_failed_returned_requests += n_failed_requests - self.remaining_batch_ids.remove(batch.id) - return batch - else: - self.tracker.n_completed_in_progress_requests += n_completed_requests - self.tracker.n_failed_in_progress_requests += n_failed_requests - return None + batch_returned = False + if batch.status == "completed": + self.tracker.n_completed_batches += 1 + batch_returned = True + elif batch.status == "failed": + self.tracker.n_failed_batches += 1 + batch_returned = True + elif batch.status == "expired": + self.tracker.n_expired_batches += 1 + batch_returned = True + elif batch.status == "cancelled": + self.tracker.n_cancelled_batches += 1 + batch_returned = True + else: + if batch.status not in [ + "validating", + "finalizing", + "cancelling", + "in_progress", + ]: + logger.warning(f"Unknown batch status: {batch.status}") + + if batch_returned: + logger.info(f"Batch {batch.id} returned with status: {batch.status}") + self.tracker.n_returned_batches += 1 + self.tracker.n_completed_returned_requests += n_completed_requests + self.tracker.n_failed_returned_requests += n_failed_requests + self.remaining_batch_ids.remove(batch.id) + return batch + else: + self.tracker.n_completed_in_progress_requests += n_completed_requests + self.tracker.n_failed_in_progress_requests += n_failed_requests + return None async def watch(self) -> None: """Monitor the status of batches until all are completed (includes successfully, failed, expired or cancelled).""" @@ -396,7 +454,26 @@ async def watch(self) -> None: desc="Completed OpenAI requests in batches", unit="request", ) - all_response_files = [] + + # resume from already downloaded batches + response_files_found = 0 + all_response_files = set(glob.glob(f"{self.working_dir}/responses_*.jsonl")) + for batch_id, request_file in self.batch_id_to_request_file_name.items(): + request_file_idx = request_file.split("/")[-1].split("_", 1)[1] + response_file = f"{self.working_dir}/responses_{request_file_idx}" + if response_file in all_response_files: + logger.info( + f"File {response_file} found for batch {batch_id}, skipping status check and download." + ) + self.remaining_batch_ids.remove(batch_id) + response_files_found += 1 + if response_files_found > 0: + logger.info( + f"Found {response_files_found} out of {self.tracker.n_submitted_batches} completed batches, resuming polling for the remaining {len(self.remaining_batch_ids)} batches." + ) + self.tracker.n_completed_batches += response_files_found # here we are assuming they are completed, but they also could have failed + self.tracker.n_returned_batches += response_files_found + all_response_files = list(all_response_files) # loop until all batches have been returned while self.remaining_batch_ids: @@ -452,95 +529,102 @@ async def download_batch_to_generic_responses_file(self, batch: Batch) -> str | Returns: str: Path to the downloaded result file. """ - if batch.status == "completed" and batch.output_file_id: - file_content = await self.client.files.content(batch.output_file_id) - elif batch.status == "failed" and batch.error_file_id: - file_content = await self.client.files.content(batch.error_file_id) - logger.warning(f"Batch {batch.id} failed\n. Errors will be parsed below.") - elif batch.status == "failed" and not batch.error_file_id: - errors = "\n".join([str(error) for error in batch.errors.data]) - logger.error( - f"Batch {batch.id} failed and likely failed validation. " - f"Batch errors: {errors}. " - f"Check https://platform.openai.com/batches/{batch.id} for more details." - ) - return None - elif batch.status == "cancelled" or batch.status == "expired": - logger.warning(f"Batch {batch.id} was cancelled or expired") - return None - - # Naming is consistent with the request file (e.g. requests_0.jsonl -> responses_0.jsonl) - request_file = self.batch_id_to_request_file_name[batch.id] - request_file_idx = request_file.split("/")[-1].split("_", 1)[1] - response_file = f"{self.working_dir}/responses_{request_file_idx}" - - generic_request_map = {} - request_creation_times = {} # Track creation times for requests - with open(request_file, "r") as f: - for line in f: - generic_request = GenericRequest.model_validate_json(line) - generic_request_map[generic_request.original_row_idx] = generic_request - request_creation_times[generic_request.original_row_idx] = datetime.datetime.now() - - with open(response_file, "w") as f: - for raw_response in file_content.text.splitlines(): - raw_response = json.loads(raw_response) - request_idx = int(raw_response["custom_id"]) - generic_request = generic_request_map[request_idx] - - # TODO(Ryan): Add more specific error handling - if raw_response["response"]["status_code"] != 200: - logger.warning( - f"Request {generic_request} failed with status code {raw_response['response']['status_code']}" - ) - generic_response = GenericResponse( - response_message=None, - response_errors=[ - f"Request {generic_request} failed with status code {raw_response['response']['status_code']}" - ], - raw_response=raw_response, - raw_request=None, - generic_request=generic_request, - created_at=request_creation_times[request_idx], - finished_at=datetime.datetime.now(), - token_usage=None, - response_cost=None, - ) - else: - response_body = raw_response["response"]["body"] - choices = response_body["choices"] - usage = response_body.get("usage", {}) - - token_usage = TokenUsage( - prompt_tokens=usage.get("prompt_tokens", 0), - completion_tokens=usage.get("completion_tokens", 0), - total_tokens=usage.get("total_tokens", 0), - ) - - # Calculate cost using litellm - cost = litellm.completion_cost( - model=generic_request.model, - prompt=str( - generic_request.messages - ), # Convert messages to string for cost calculation - completion=choices[0]["message"]["content"], + async with self.semaphore: + if batch.status == "completed" and batch.output_file_id: + file_content = await self.client.files.content(batch.output_file_id) + logger.info(f"Batch {batch.id} completed and downloaded") + elif batch.status == "failed" and batch.error_file_id: + file_content = await self.client.files.content(batch.error_file_id) + logger.warning(f"Batch {batch.id} failed\n. Errors will be parsed below.") + elif batch.status == "failed" and not batch.error_file_id: + errors = "\n".join([str(error) for error in batch.errors.data]) + logger.error( + f"Batch {batch.id} failed and likely failed validation. " + f"Batch errors: {errors}. " + f"Check https://platform.openai.com/batches/{batch.id} for more details." + ) + return None + elif batch.status == "cancelled" or batch.status == "expired": + logger.warning(f"Batch {batch.id} was cancelled or expired") + return None + + # Naming is consistent with the request file (e.g. requests_0.jsonl -> responses_0.jsonl) + request_file = self.batch_id_to_request_file_name[batch.id] + request_file_idx = request_file.split("/")[-1].split("_", 1)[1] + response_file = f"{self.working_dir}/responses_{request_file_idx}" + + generic_request_map = {} + request_creation_times = {} # Track creation times for requests + with open(request_file, "r") as f: + for line in f: + generic_request = GenericRequest.model_validate_json(line) + generic_request_map[generic_request.original_row_idx] = generic_request + request_creation_times[generic_request.original_row_idx] = ( + datetime.datetime.now() ) - response_message = choices[0]["message"]["content"] - response_message, response_errors = parse_response_message( - response_message, self.prompt_formatter.response_format - ) + with open(response_file, "w") as f: + for raw_response in file_content.text.splitlines(): + raw_response = json.loads(raw_response) + request_idx = int(raw_response["custom_id"]) + generic_request = generic_request_map[request_idx] - generic_response = GenericResponse( - response_message=response_message, - response_errors=response_errors, - raw_response=raw_response, - raw_request=None, - generic_request=generic_request, - created_at=request_creation_times[request_idx], - finished_at=datetime.datetime.now(), - token_usage=token_usage, - response_cost=cost, - ) - f.write(json.dumps(generic_response.model_dump(), default=str) + "\n") - return response_file + # TODO(Ryan): Add more specific error handling + if raw_response["response"]["status_code"] != 200: + logger.warning( + f"Request {generic_request} failed with status code {raw_response['response']['status_code']}" + ) + generic_response = GenericResponse( + response_message=None, + response_errors=[ + f"Request {generic_request} failed with status code {raw_response['response']['status_code']}" + ], + raw_response=raw_response, + raw_request=None, + generic_request=generic_request, + created_at=request_creation_times[request_idx], + finished_at=datetime.datetime.now(), + token_usage=None, + response_cost=None, + ) + else: + response_body = raw_response["response"]["body"] + choices = response_body["choices"] + usage = response_body.get("usage", {}) + + token_usage = TokenUsage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ) + + # Calculate cost using litellm + cost = litellm.completion_cost( + model=generic_request.model, + prompt=str( + generic_request.messages + ), # Convert messages to string for cost calculation + completion=choices[0]["message"]["content"], + ) + # Batch requests are 50% off + cost = cost * 0.5 + + response_message = choices[0]["message"]["content"] + response_message, response_errors = parse_response_message( + response_message, self.prompt_formatter.response_format + ) + + generic_response = GenericResponse( + response_message=response_message, + response_errors=response_errors, + raw_response=raw_response, + raw_request=None, + generic_request=generic_request, + created_at=request_creation_times[request_idx], + finished_at=datetime.datetime.now(), + token_usage=token_usage, + response_cost=cost, + ) + f.write(json.dumps(generic_response.model_dump(), default=str) + "\n") + logger.info(f"Batch {batch.id} written to {response_file}") + return response_file diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 54368c8f..f387ab2d 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -194,7 +194,7 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: - Applies optional parameters (temperature, top_p, etc.) - Maintains compatibility with both chat and completion endpoints """ - request = { + request: dict[str, Any] = { "model": generic_request.model, "messages": generic_request.messages, }