Skip to content

Commit

Permalink
Revert "Add metadata dict + cache verification"
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMarten authored Dec 16, 2024
1 parent 0470c6e commit fecf33f
Showing 1 changed file with 6 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import resource
from abc import ABC, abstractmethod
from math import ceil
from pathlib import Path
from typing import Optional, List
from typing import Optional

import aiofiles
import pyarrow
Expand Down Expand Up @@ -77,51 +76,6 @@ def run(
"""
pass

def _verify_existing_request_files(self, working_dir: str, dataset: Optional[Dataset]) -> List[int]:
"""
Verify integrity of the cache (each request file has associated metadata, and the number of rows is correct),
and return the indices of request files that need to be regenerated (so that no work is repeated).
Args:
working_dir (str): Working directory where cache files are expected to be (requests.jsonl, metadata.json)
dataset (Optional[Dataset]): The dataset that we want to create requests from
Returns:
List[int]: Indices of missing files
"""

if self.batch_size is not None and dataset is not None:
expected_num_files = ceil(len(dataset) / self.batch_size)
else:
expected_num_files = 1

try:
incomplete_files = []
for i in range(expected_num_files):
req_f = os.path.join(working_dir, f"requests_{i}.jsonl")
meta_f = os.path.join(working_dir, f"metadata_{i}.json")
if not os.path.exists(req_f) or not os.path.exists(meta_f):
incomplete_files.append(i)
else:
with open(req_f, "r") as f:
data = f.read()
num_jobs = len(data.splitlines())

with open(meta_f, "r") as f:
metadata = json.load(f)

expected_num_jobs = metadata["num_jobs"]
if num_jobs != expected_num_jobs:
incomplete_files.append(i)

logger.info(f"Cache missing {len(incomplete_files)} complete request files - regenerating missing ones.")
return incomplete_files

except:
logger.info("Cache verification failed for unexpected reasons - regenerating all request files.")
incomplete_files = list(range(expected_num_files))
return incomplete_files

def create_request_files(
self,
dataset: Optional[Dataset],
Expand All @@ -142,9 +96,7 @@ def create_request_files(
request_files = glob.glob(f"{working_dir}/requests_*.jsonl")

# By default use existing requests in working_dir
incomplete_files = self._verify_existing_request_files(working_dir, dataset)

if len(incomplete_files) == 0:
if len(request_files) > 0:
logger.info(f"Using cached requests. {CACHE_MSG}")
# count existing jobs in file and print first job
with open(request_files[0], "r") as f:
Expand All @@ -168,40 +120,31 @@ def create_request_files(
request_file = f"{working_dir}/requests_0.jsonl"
request_files = [request_file]

metadata_file = f"{working_dir}/metadata_0.json"
metadata_files = [metadata_file]

if dataset is None:
with open(request_file, "w") as f:
generic_request = prompt_formatter.create_generic_request(dict(), 0)
f.write(json.dumps(generic_request.model_dump(), default=str) + "\n")

metadata_dict = {"num_jobs": 1}
with open(metadata_file, "w") as f:
f.write(json.dumps(metadata_dict, indent=4) + "\n")
return request_files

if self.batch_size:
num_batches = ceil(len(dataset) / self.batch_size)
request_files = [f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)]
metadata_files = [f"{working_dir}/metadata_{i}.json" for i in range(num_batches)]

async def create_all_request_files():
tasks = [
self.acreate_request_file(
dataset,
prompt_formatter,
request_files[i],
metadata_files[i],
start_idx=i * self.batch_size,
)
for i in range(num_batches) if i in incomplete_files
for i in range(num_batches)
]
await asyncio.gather(*tasks)

run_in_event_loop(create_all_request_files())
else:
run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file, metadata_file))
run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file))

return request_files

Expand All @@ -211,9 +154,8 @@ async def acreate_request_file(
dataset: Dataset,
prompt_formatter: PromptFormatter,
request_file: str,
metadata_file: str,
start_idx: int = 0,
) -> None:
) -> str:
if self.batch_size is not None:
end_idx = min(start_idx + self.batch_size, len(dataset))
dataset = dataset.select(range(start_idx, end_idx))
Expand All @@ -227,13 +169,7 @@ async def acreate_request_file(
# Get the generic request from the map function
request = prompt_formatter.create_generic_request(dataset_row, dataset_row_idx)
await f.write(json.dumps(request.model_dump(), default=str) + "\n")

num_requests = end_idx - start_idx
metadata_dict = {"num_jobs": num_requests}
async with aiofiles.open(metadata_file, "w") as f:
await f.write(json.dumps(metadata_dict, indent=4) + "\n")

logger.info(f"Wrote {num_requests} requests to {request_file}.")
logger.info(f"Wrote {end_idx - start_idx} requests to {request_file}.")

def attempt_loading_cached_dataset(
self, working_dir: str, parse_func_hash: str
Expand Down

0 comments on commit fecf33f

Please sign in to comment.