diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index ea9184e8..398bf0b5 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -6,8 +6,7 @@ import resource from abc import ABC, abstractmethod from math import ceil -from pathlib import Path -from typing import Optional, List +from typing import Optional import aiofiles import pyarrow @@ -77,51 +76,6 @@ def run( """ pass - def _verify_existing_request_files(self, working_dir: str, dataset: Optional[Dataset]) -> List[int]: - """ - Verify integrity of the cache (each request file has associated metadata, and the number of rows is correct), - and return the indices of request files that need to be regenerated (so that no work is repeated). - - Args: - working_dir (str): Working directory where cache files are expected to be (requests.jsonl, metadata.json) - dataset (Optional[Dataset]): The dataset that we want to create requests from - - Returns: - List[int]: Indices of missing files - """ - - if self.batch_size is not None and dataset is not None: - expected_num_files = ceil(len(dataset) / self.batch_size) - else: - expected_num_files = 1 - - try: - incomplete_files = [] - for i in range(expected_num_files): - req_f = os.path.join(working_dir, f"requests_{i}.jsonl") - meta_f = os.path.join(working_dir, f"metadata_{i}.json") - if not os.path.exists(req_f) or not os.path.exists(meta_f): - incomplete_files.append(i) - else: - with open(req_f, "r") as f: - data = f.read() - num_jobs = len(data.splitlines()) - - with open(meta_f, "r") as f: - metadata = json.load(f) - - expected_num_jobs = metadata["num_jobs"] - if num_jobs != expected_num_jobs: - incomplete_files.append(i) - - logger.info(f"Cache missing {len(incomplete_files)} complete request files - regenerating missing ones.") - return incomplete_files - - except: - logger.info("Cache verification failed for unexpected reasons - regenerating all request files.") - incomplete_files = list(range(expected_num_files)) - return incomplete_files - def create_request_files( self, dataset: Optional[Dataset], @@ -142,9 +96,7 @@ def create_request_files( request_files = glob.glob(f"{working_dir}/requests_*.jsonl") # By default use existing requests in working_dir - incomplete_files = self._verify_existing_request_files(working_dir, dataset) - - if len(incomplete_files) == 0: + if len(request_files) > 0: logger.info(f"Using cached requests. {CACHE_MSG}") # count existing jobs in file and print first job with open(request_files[0], "r") as f: @@ -168,23 +120,15 @@ def create_request_files( request_file = f"{working_dir}/requests_0.jsonl" request_files = [request_file] - metadata_file = f"{working_dir}/metadata_0.json" - metadata_files = [metadata_file] - if dataset is None: with open(request_file, "w") as f: generic_request = prompt_formatter.create_generic_request(dict(), 0) f.write(json.dumps(generic_request.model_dump(), default=str) + "\n") - - metadata_dict = {"num_jobs": 1} - with open(metadata_file, "w") as f: - f.write(json.dumps(metadata_dict, indent=4) + "\n") return request_files if self.batch_size: num_batches = ceil(len(dataset) / self.batch_size) request_files = [f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)] - metadata_files = [f"{working_dir}/metadata_{i}.json" for i in range(num_batches)] async def create_all_request_files(): tasks = [ @@ -192,16 +136,15 @@ async def create_all_request_files(): dataset, prompt_formatter, request_files[i], - metadata_files[i], start_idx=i * self.batch_size, ) - for i in range(num_batches) if i in incomplete_files + for i in range(num_batches) ] await asyncio.gather(*tasks) run_in_event_loop(create_all_request_files()) else: - run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file, metadata_file)) + run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file)) return request_files @@ -211,9 +154,8 @@ async def acreate_request_file( dataset: Dataset, prompt_formatter: PromptFormatter, request_file: str, - metadata_file: str, start_idx: int = 0, - ) -> None: + ) -> str: if self.batch_size is not None: end_idx = min(start_idx + self.batch_size, len(dataset)) dataset = dataset.select(range(start_idx, end_idx)) @@ -227,13 +169,7 @@ async def acreate_request_file( # Get the generic request from the map function request = prompt_formatter.create_generic_request(dataset_row, dataset_row_idx) await f.write(json.dumps(request.model_dump(), default=str) + "\n") - - num_requests = end_idx - start_idx - metadata_dict = {"num_jobs": num_requests} - async with aiofiles.open(metadata_file, "w") as f: - await f.write(json.dumps(metadata_dict, indent=4) + "\n") - - logger.info(f"Wrote {num_requests} requests to {request_file}.") + logger.info(f"Wrote {end_idx - start_idx} requests to {request_file}.") def attempt_loading_cached_dataset( self, working_dir: str, parse_func_hash: str