Skip to content

Commit

Permalink
Merge pull request #935 from parea-ai/PAI-1264-finished-experiment-ev…
Browse files Browse the repository at this point in the history
…en-on-trace-fail-flag-to-fail-on-first

Pai 1264 finished experiment even on trace fail flag to fail on first
  • Loading branch information
joschkabraun committed Jun 14, 2024
2 parents e3f0c63 + 84a9654 commit b0741ad
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
3 changes: 3 additions & 0 deletions parea/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def experiment(
metadata: Optional[Dict[str, str]] = None,
dataset_level_evals: Optional[List[Callable]] = None,
n_workers: int = 10,
stop_on_error: bool = False,
):
"""
:param data: If your dataset is defined locally it should be an iterable of k/v
Expand All @@ -363,6 +364,7 @@ def experiment(
:param metadata: Optional metadata to attach to the experiment.
:param dataset_level_evals: Optional list of functions to run on the dataset level. Each function should accept a list of EvaluatedLog objects and return a float or an EvaluationResult object
:param n_workers: The number of workers to use for running the experiment.
:param stop_on_error: If True, the experiment will stop on the first exception. If False, the experiment will continue running the remaining samples.
"""
from parea import Experiment

Expand All @@ -375,6 +377,7 @@ def experiment(
metadata=metadata,
dataset_level_evals=dataset_level_evals,
n_workers=n_workers,
stop_on_error=stop_on_error,
)

def _update_data_and_trace(self, data: Completion) -> Completion:
Expand Down
27 changes: 22 additions & 5 deletions parea/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def experiment(
n_trials: int = 1,
dataset_level_evals: Optional[List[Callable]] = None,
n_workers: int = 10,
stop_on_error: bool = False,
) -> ExperimentStatsSchema:
"""Creates an experiment and runs the function on the data iterator.
param experiment_name: The name of the experiment. Used to organize experiments within a project.
Expand All @@ -107,6 +108,7 @@ async def experiment(
param dataset_level_evals: A list of functions to run on the dataset level. Each function should accept a list of EvaluatedLogs and return a float or a
EvaluationResult. If a float is returned, the name of the function will be used as the name of the evaluation.
param n_workers: The number of workers to use for running the experiment.
param stop_on_error: If True, the experiment will stop running if an exception is raised.
"""
if isinstance(data, (str, int)):
print(f"Fetching test collection: {data}")
Expand Down Expand Up @@ -152,10 +154,13 @@ def limit_concurrency_sync(sample):
await coro
pbar.update(1)
except Exception as e:
print(f"\nExperiment stopped due to an error: {str(e)}\n")
status = ExperimentStatus.FAILED
for task in tasks:
task.cancel()
if stop_on_error:
print(f"\nExperiment stopped due to an error: {str(e)}\n")
for task in tasks:
task.cancel()
else:
pbar.update(1)
except asyncio.CancelledError:
pass

Expand Down Expand Up @@ -220,6 +225,7 @@ class Experiment:
n_workers: int = field(default=10)
# The number of times to run the experiment on the same data.
n_trials: int = field(default=1)
stop_on_error: bool = field(default=False)

def __attrs_post_init__(self):
global _experiments
Expand Down Expand Up @@ -253,7 +259,18 @@ def run(self, run_name: Optional[str] = None) -> None:
experiment_schema: ExperimentSchema = self.p.create_experiment(CreateExperimentRequest(name=self.experiment_name, run_name=self.run_name, metadata=self.metadata))
self.experiment_uuid = experiment_schema.uuid
self.experiment_stats = asyncio.run(
experiment(self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers)
experiment(
self.experiment_name,
self.run_name,
self.data,
self.func,
self.p,
self.experiment_uuid,
self.n_trials,
self.dataset_level_evals,
self.n_workers,
self.stop_on_error,
)
)
except Exception as e:
import traceback
Expand All @@ -277,7 +294,7 @@ async def arun(self, run_name: Optional[str] = None) -> None:
)
self.experiment_uuid = experiment_schema.uuid
self.experiment_stats = await experiment(
self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers
self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers, self.stop_on_error
)
except Exception as e:
import traceback
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.175"
version = "0.2.176"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit b0741ad

Please sign in to comment.