Skip to content

Commit

Permalink
Merge pull request #934 from parea-ai/PAI-1264-finished-experiment-ev…
Browse files Browse the repository at this point in the history
…en-on-trace-fail-flag-to-fail-on-first

Pai 1264 finished experiment even on trace fail flag to fail on first
  • Loading branch information
joschkabraun authored Jun 14, 2024
2 parents 94848cf + cecbc3e commit e3f0c63
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def func(lang: str, framework: str) -> str:

if __name__ == "__main__":
p.experiment(
name="Hello World Example", # this is the name of the experiment
data="Hello World Example", # this is the name of your Dataset in Parea (Dataset page)
func=func,
).run(name="hello-world-example")
).run()

# Or use a dataset using its ID instead of the name
# p.experiment(
Expand Down
25 changes: 19 additions & 6 deletions parea/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from parea.experiment.dvc import save_results_to_dvc_if_init
from parea.helpers import duplicate_dicts, gen_random_name, is_logging_disabled
from parea.schemas import EvaluationResult
from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, FinishExperimentRequestSchema
from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, ExperimentStatus, FinishExperimentRequestSchema
from parea.utils.trace_utils import thread_ids_running_evals, trace_data
from parea.utils.universal_encoder import json_dumps

Expand Down Expand Up @@ -138,13 +138,26 @@ def limit_concurrency_sync(sample):
return func(_parea_target_field=target, **sample_copy)

if inspect.iscoroutinefunction(func):
tasks = [limit_concurrency(sample) for sample in data]
tasks = [asyncio.ensure_future(limit_concurrency(sample)) for sample in data]
else:
executor = ThreadPoolExecutor(max_workers=n_workers)
loop = asyncio.get_event_loop()
tasks = [loop.run_in_executor(executor, partial(limit_concurrency_sync, sample)) for sample in data]
for _task in tqdm_asyncio.as_completed(tasks, total=len_test_cases):
await _task
tasks = [asyncio.ensure_future(loop.run_in_executor(executor, partial(limit_concurrency_sync, sample))) for sample in data]

status = ExperimentStatus.COMPLETED
with tqdm(total=len(tasks), desc="Running samples", unit="sample") as pbar:
try:
for coro in asyncio.as_completed(tasks):
try:
await coro
pbar.update(1)
except Exception as e:
print(f"\nExperiment stopped due to an error: {str(e)}\n")
status = ExperimentStatus.FAILED
for task in tasks:
task.cancel()
except asyncio.CancelledError:
pass

await asyncio.sleep(0.2)
total_evals = len(thread_ids_running_evals.get())
Expand All @@ -162,7 +175,7 @@ def limit_concurrency_sync(sample):
else:
dataset_level_eval_results = []

experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid, FinishExperimentRequestSchema(dataset_level_stats=dataset_level_eval_results))
experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid, FinishExperimentRequestSchema(status=status, dataset_level_stats=dataset_level_eval_results))
stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats)
if dataset_level_eval_results:
stat_name_to_avg_std.update(**{eval_result.name: eval_result.score for eval_result in dataset_level_eval_results})
Expand Down
15 changes: 8 additions & 7 deletions parea/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,16 @@ class CreateTestCaseCollection(CreateTestCases):
column_names: List[str] = field(factory=list)


class ExperimentStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"


@define
class FinishExperimentRequestSchema:
status: ExperimentStatus
dataset_level_stats: Optional[List[EvaluationResult]] = field(factory=list)


Expand All @@ -343,13 +351,6 @@ class ListExperimentUUIDsFilters:
experiment_uuids: Optional[List[str]] = None


class ExperimentStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"


class StatisticOperation(str, Enum):
MEAN = "mean"
MEDIAN = "median"
Expand Down
5 changes: 4 additions & 1 deletion parea/utils/trace_integrations/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ def __call__(
for key in ["max_retries", "response_model", "validation_context", "mode", "args"]:
if kwargs.get(key):
metadata[key] = kwargs[key]
trace_name = "instructor"
if "response_model" in kwargs and kwargs["response_model"] and hasattr(kwargs["response_model"], "__name__"):
trace_name = kwargs["response_model"].__name__
return trace(
name="instructor",
name=trace_name,
overwrite_trace_id=trace_id,
overwrite_inputs=inputs,
metadata=metadata,
Expand Down

0 comments on commit e3f0c63

Please sign in to comment.