Skip to content

Commit

Permalink
Merge pull request #460 from parea-ai/fix-remove-data-from-sample
Browse files Browse the repository at this point in the history
fix: copy dicts in experiments
  • Loading branch information
joschkabraun committed Feb 15, 2024
2 parents ed51034 + 58eed0b commit 4e6f2e2
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions parea/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
from collections import defaultdict
from collections.abc import Iterable
from copy import deepcopy

from attrs import define, field
from tqdm import tqdm
Expand Down Expand Up @@ -84,15 +85,19 @@ async def experiment(name: str, data: Union[str, Iterable[dict]], func: Callable

async def limit_concurrency(sample):
async with sem:
return await func(_parea_target_field=sample.pop("target", None), **sample)
sample_copy = deepcopy(sample)
target = sample_copy.pop("target", None)
return await func(_parea_target_field=target, **sample_copy)

if inspect.iscoroutinefunction(func):
tasks = [limit_concurrency(sample) for sample in data]
for result in tqdm_asyncio(tasks, total=len_test_cases):
await result
else:
for sample in tqdm(data, total=len_test_cases):
func(_parea_target_field=sample.pop("target", None), **sample)
sample_copy = deepcopy(sample)
target = sample_copy.pop("target", None)
func(_parea_target_field=target, **sample_copy)

total_evals = len(thread_ids_running_evals.get())
with tqdm(total=total_evals, dynamic_ncols=True) as pbar:
Expand Down

0 comments on commit 4e6f2e2

Please sign in to comment.