Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ryanm/quick fix expired batches #228

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ line-length = 100

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,18 @@ async def submit_batch(self, requests: list[dict], metadata: dict) -> Batch:

return batch_object

def get_submitted_batch_ids(self) -> set[str]:
if not os.path.exists(self.submitted_batch_objects_file):
return set()
with open(self.submitted_batch_objects_file, "r") as f:
return {json.loads(line.strip())["id"] for line in f}

def get_downloaded_batch_ids(self) -> set[str]:
if not os.path.exists(self.downloaded_batch_objects_file):
return set()
with open(self.downloaded_batch_objects_file, "r") as f:
return {json.loads(line.strip())["id"] for line in f}

async def cancel_batches(self):
if not os.path.exists(self.submitted_batch_objects_file):
logger.warning("No batches to be cancelled, but cancel_batches=True.")
Expand Down Expand Up @@ -736,6 +748,10 @@ async def track_already_submitted_batches(self):
)
batch_object = await self.retrieve_batch(batch_object.id)

if batch_object.status in ["cancelled", "expired"]:
# don't mark as submitted, will submit a new batch for this request file
continue

# Edge case where the batch is still validating, and we need to know the total number of requests
if batch_object.status == "validating":
n_requests = len(open(request_file_name, "r").readlines())
Expand All @@ -745,9 +761,6 @@ async def track_already_submitted_batches(self):

if request_file_name in self.tracker.unsubmitted_request_files:
self.tracker.mark_as_submitted(request_file_name, batch_object, n_requests)
else:
# batch objects if not unsubmitted, should be downloaded
assert batch_object.id in self.tracker.downloaded_batches

if self.tracker.n_submitted_batches > 0:
logger.info(
Expand Down
168 changes: 168 additions & 0 deletions tests/batch/mock_test_resume_cancelled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import pytest
import time
import os
from unittest.mock import AsyncMock, patch
from openai.types import Batch, FileObject
from openai.types.batch import Errors
from openai.types.batch_request_counts import BatchRequestCounts
from tests.helpers import run_script
from tests.helpers import prepare_test_cache
from bespokelabs.curator.request_processor.openai_batch_request_processor import BatchManager

"""
USAGE:
pytest -s tests/batch/mock_test_resume_cancelled.py
"""

# https://platform.openai.com/docs/api-reference/batch/create
def create_mock_batch(
batch_id: str,
request_file_name: str,
status: str = "in_progress",
total_requests: int = 1,
) -> Batch:
"""Helper function to create mock Batch objects"""
return Batch(
id=batch_id,
created_at=1234567890,
error_file_id=None,
errors=None,
expires_at=None,
failed_at=None,
input_file_id="file-123",
output_file_id="file-456" if status == "completed" else None,
status=status,
request_counts=BatchRequestCounts(
completed=total_requests if status == "completed" else 0,
failed=0,
total=total_requests,
),
metadata={"request_file_name": request_file_name},
completion_window="24h",
endpoint="/v1/chat/completions",
object="batch",
in_progress_at=None,
finalizing_at=None,
completed_at=None,
expired_at=None,
cancelling_at=None,
cancelled_at=None,
)


@pytest.mark.cache_dir(
os.path.expanduser("~/.cache/curator-tests/mock-test-batch-resume-cancelled")
)
@pytest.mark.usefixtures("prepare_test_cache")
@patch("openai.AsyncOpenAI")
async def test_batch_resume(mock_openai):
# Setup mock responses
mock_client = AsyncMock()
mock_openai.return_value = mock_client

# Mock batch creation
mock_client.batches.create.return_value = create_mock_batch(
"batch_" + "a" * 32, request_file_name="requests.jsonl"
)

# Mock file creation
mock_client.files.create.return_value = FileObject(
id="file-123",
bytes=1000,
created_at=1234567890,
filename="test.jsonl",
object="file",
purpose="batch",
status="processed",
)

# Mock file processing wait
mock_client.files.wait_for_processing.return_value = FileObject(
id="file-123",
bytes=1000,
created_at=1234567890,
filename="test.jsonl",
object="file",
purpose="batch",
status="processed",
)

# Setup batch retrieval sequence
mock_batch_sequence = [
create_mock_batch(
"batch_" + "a" * 32, "requests.jsonl", status="in_progress"
), # First check
create_mock_batch(
"batch_" + "a" * 32, "requests.jsonl", status="cancelled"
), # After cancellation
]
mock_client.batches.retrieve.side_effect = mock_batch_sequence

script = [
"python",
"tests/batch/simple_batch.py",
"--log-level",
"DEBUG",
"--n-requests",
"2",
"--batch-size",
"1",
"--batch-check-interval",
"10",
]

env = os.environ.copy()

print("FIRST RUN")
stop_line_pattern = r"Marked batch ID batch_[a-f0-9]{32} as downloaded"
output1, _ = run_script(script, stop_line_pattern, env=env)
print(output1)

# Small delay to ensure files are written
time.sleep(1)

# cache_dir = os.getenv("CURATOR_CACHE_DIR")
# child_folder = os.listdir(cache_dir)[0]
# working_dir = os.path.join(cache_dir, child_folder)
# print(f"CANCELLING BATCHES in {working_dir}")
# batch_manager = BatchManager(
# working_dir,
# delete_successful_batch_files=True,
# delete_failed_batch_files=True,
# )
# submitted_batch_ids = batch_manager.get_submitted_batch_ids()
# downloaded_batch_ids = batch_manager.get_downloaded_batch_ids()
# not_downloaded_batch_id = list(submitted_batch_ids - downloaded_batch_ids)[0]
# print(f"Submitted batch IDs: {submitted_batch_ids}")
# print(f"Downloaded batch IDs: {downloaded_batch_ids}")
# print(f"Not downloaded batch ID: {not_downloaded_batch_id}")

# Mock batch cancellation
# mock_client.batches.cancel.return_value = None

# # Reset batch retrieval sequence for second run
# mock_batch_sequence = [
# create_mock_batch(
# "batch_" + "a" * 32, "requests.jsonl", status="cancelled"
# ), # Initial check
# create_mock_batch(
# "batch_" + "b" * 32, "requests.jsonl", status="completed", total_requests=2
# ), # New batch
# ]
# mock_client.batches.retrieve.side_effect = mock_batch_sequence

# batch_manager.cancel_batch(not_downloaded_batch_id)
# batch_object = batch_manager.retrieve_batch(not_downloaded_batch_id)
# # takes a while for the batch to be cancelled
# while batch_object.status != "cancelled":
# time.sleep(10)
# batch_object = batch_manager.retrieve_batch(not_downloaded_batch_id)

# # Second run should process the remaining batch, and resubmit the cancelled batch
# print("SECOND RUN")
# output2, _ = run_script(script, env=env)
# print(output2)

# # checks
# assert "1 out of 2 batches already downloaded." in output2
# assert "0 out of 1 remaining batches are already submitted." in output2
69 changes: 69 additions & 0 deletions tests/batch/test_resume_cancelled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
import time
import os
from tests.helpers import run_script
from tests.helpers import prepare_test_cache
from bespokelabs.curator.request_processor.openai_batch_request_processor import BatchManager

"""
USAGE:
pytest -s tests/batch/test_resume_cancelled.py
"""


@pytest.mark.cache_dir(os.path.expanduser("~/.cache/curator-tests/test-batch-resume-cancelled"))
@pytest.mark.usefixtures("prepare_test_cache")
def test_batch_resume():
script = [
"python",
"tests/batch/simple_batch.py",
"--log-level",
"DEBUG",
"--n-requests",
"2",
"--batch-size",
"1",
"--batch-check-interval",
"10",
]

env = os.environ.copy()

print("FIRST RUN")
stop_line_pattern = r"Marked batch ID batch_[a-f0-9]{32} as downloaded"
output1, _ = run_script(script, stop_line_pattern, env=env)
print(output1)

# Small delay to ensure files are written
time.sleep(1)

cache_dir = os.getenv("CURATOR_CACHE_DIR")
child_folder = os.listdir(cache_dir)[0]
working_dir = os.path.join(cache_dir, child_folder)
print(f"CANCELLING BATCHES in {working_dir}")
batch_manager = BatchManager(
working_dir,
delete_successful_batch_files=True,
delete_failed_batch_files=True,
)
submitted_batch_ids = batch_manager.get_submitted_batch_ids()
downloaded_batch_ids = batch_manager.get_downloaded_batch_ids()
not_downloaded_batch_id = list(submitted_batch_ids - downloaded_batch_ids)[0]
print(f"Submitted batch IDs: {submitted_batch_ids}")
print(f"Downloaded batch IDs: {downloaded_batch_ids}")
print(f"Not downloaded batch ID: {not_downloaded_batch_id}")
batch_manager.cancel_batch(not_downloaded_batch_id)
batch_object = batch_manager.retrieve_batch(not_downloaded_batch_id)
# takes a while for the batch to be cancelled
while batch_object.status != "cancelled":
time.sleep(10)
batch_object = batch_manager.retrieve_batch(not_downloaded_batch_id)

# Second run should process the remaining batch, and resubmit the cancelled batch
print("SECOND RUN")
output2, _ = run_script(script, env=env)
print(output2)

# checks
assert "1 out of 2 batches already downloaded." in output2
assert "0 out of 1 remaining batches are already submitted." in output2
Loading