Skip to content

Commit

Permalink
Merge pull request #257 from GeorgiosSmyrnis/add_metadata
Browse files Browse the repository at this point in the history
Add metadata dict + cache verification
  • Loading branch information
RyanMarten authored Dec 16, 2024
2 parents a76ea42 + 59566c4 commit 0470c6e
Showing 1 changed file with 70 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import resource
from abc import ABC, abstractmethod
from math import ceil
from typing import Optional
from pathlib import Path
from typing import Optional, List

import aiofiles
import pyarrow
Expand Down Expand Up @@ -76,6 +77,51 @@ def run(
"""
pass

def _verify_existing_request_files(self, working_dir: str, dataset: Optional[Dataset]) -> List[int]:
"""
Verify integrity of the cache (each request file has associated metadata, and the number of rows is correct),
and return the indices of request files that need to be regenerated (so that no work is repeated).
Args:
working_dir (str): Working directory where cache files are expected to be (requests.jsonl, metadata.json)
dataset (Optional[Dataset]): The dataset that we want to create requests from
Returns:
List[int]: Indices of missing files
"""

if self.batch_size is not None and dataset is not None:
expected_num_files = ceil(len(dataset) / self.batch_size)
else:
expected_num_files = 1

try:
incomplete_files = []
for i in range(expected_num_files):
req_f = os.path.join(working_dir, f"requests_{i}.jsonl")
meta_f = os.path.join(working_dir, f"metadata_{i}.json")
if not os.path.exists(req_f) or not os.path.exists(meta_f):
incomplete_files.append(i)
else:
with open(req_f, "r") as f:
data = f.read()
num_jobs = len(data.splitlines())

with open(meta_f, "r") as f:
metadata = json.load(f)

expected_num_jobs = metadata["num_jobs"]
if num_jobs != expected_num_jobs:
incomplete_files.append(i)

logger.info(f"Cache missing {len(incomplete_files)} complete request files - regenerating missing ones.")
return incomplete_files

except:
logger.info("Cache verification failed for unexpected reasons - regenerating all request files.")
incomplete_files = list(range(expected_num_files))
return incomplete_files

def create_request_files(
self,
dataset: Optional[Dataset],
Expand All @@ -96,7 +142,9 @@ def create_request_files(
request_files = glob.glob(f"{working_dir}/requests_*.jsonl")

# By default use existing requests in working_dir
if len(request_files) > 0:
incomplete_files = self._verify_existing_request_files(working_dir, dataset)

if len(incomplete_files) == 0:
logger.info(f"Using cached requests. {CACHE_MSG}")
# count existing jobs in file and print first job
with open(request_files[0], "r") as f:
Expand All @@ -120,31 +168,40 @@ def create_request_files(
request_file = f"{working_dir}/requests_0.jsonl"
request_files = [request_file]

metadata_file = f"{working_dir}/metadata_0.json"
metadata_files = [metadata_file]

if dataset is None:
with open(request_file, "w") as f:
generic_request = prompt_formatter.create_generic_request(dict(), 0)
f.write(json.dumps(generic_request.model_dump(), default=str) + "\n")

metadata_dict = {"num_jobs": 1}
with open(metadata_file, "w") as f:
f.write(json.dumps(metadata_dict, indent=4) + "\n")
return request_files

if self.batch_size:
num_batches = ceil(len(dataset) / self.batch_size)
request_files = [f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches)]
metadata_files = [f"{working_dir}/metadata_{i}.json" for i in range(num_batches)]

async def create_all_request_files():
tasks = [
self.acreate_request_file(
dataset,
prompt_formatter,
request_files[i],
metadata_files[i],
start_idx=i * self.batch_size,
)
for i in range(num_batches)
for i in range(num_batches) if i in incomplete_files
]
await asyncio.gather(*tasks)

run_in_event_loop(create_all_request_files())
else:
run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file))
run_in_event_loop(self.acreate_request_file(dataset, prompt_formatter, request_file, metadata_file))

return request_files

Expand All @@ -154,8 +211,9 @@ async def acreate_request_file(
dataset: Dataset,
prompt_formatter: PromptFormatter,
request_file: str,
metadata_file: str,
start_idx: int = 0,
) -> str:
) -> None:
if self.batch_size is not None:
end_idx = min(start_idx + self.batch_size, len(dataset))
dataset = dataset.select(range(start_idx, end_idx))
Expand All @@ -169,7 +227,13 @@ async def acreate_request_file(
# Get the generic request from the map function
request = prompt_formatter.create_generic_request(dataset_row, dataset_row_idx)
await f.write(json.dumps(request.model_dump(), default=str) + "\n")
logger.info(f"Wrote {end_idx - start_idx} requests to {request_file}.")

num_requests = end_idx - start_idx
metadata_dict = {"num_jobs": num_requests}
async with aiofiles.open(metadata_file, "w") as f:
await f.write(json.dumps(metadata_dict, indent=4) + "\n")

logger.info(f"Wrote {num_requests} requests to {request_file}.")

def attempt_loading_cached_dataset(
self, working_dir: str, parse_func_hash: str
Expand Down

0 comments on commit 0470c6e

Please sign in to comment.