diff --git a/parea/client.py b/parea/client.py index f40e1d52..81242d17 100644 --- a/parea/client.py +++ b/parea/client.py @@ -352,6 +352,7 @@ def experiment( metadata: Optional[Dict[str, str]] = None, dataset_level_evals: Optional[List[Callable]] = None, n_workers: int = 10, + stop_on_error: bool = False, ): """ :param data: If your dataset is defined locally it should be an iterable of k/v @@ -363,6 +364,7 @@ def experiment( :param metadata: Optional metadata to attach to the experiment. :param dataset_level_evals: Optional list of functions to run on the dataset level. Each function should accept a list of EvaluatedLog objects and return a float or an EvaluationResult object :param n_workers: The number of workers to use for running the experiment. + :param stop_on_error: If True, the experiment will stop on the first exception. If False, the experiment will continue running the remaining samples. """ from parea import Experiment @@ -375,6 +377,7 @@ def experiment( metadata=metadata, dataset_level_evals=dataset_level_evals, n_workers=n_workers, + stop_on_error=stop_on_error, ) def _update_data_and_trace(self, data: Completion) -> Completion: diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 952731f9..d93ccda4 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -93,6 +93,7 @@ async def experiment( n_trials: int = 1, dataset_level_evals: Optional[List[Callable]] = None, n_workers: int = 10, + stop_on_error: bool = False, ) -> ExperimentStatsSchema: """Creates an experiment and runs the function on the data iterator. param experiment_name: The name of the experiment. Used to organize experiments within a project. @@ -107,6 +108,7 @@ async def experiment( param dataset_level_evals: A list of functions to run on the dataset level. Each function should accept a list of EvaluatedLogs and return a float or a EvaluationResult. If a float is returned, the name of the function will be used as the name of the evaluation. param n_workers: The number of workers to use for running the experiment. + param stop_on_error: If True, the experiment will stop running if an exception is raised. """ if isinstance(data, (str, int)): print(f"Fetching test collection: {data}") @@ -152,10 +154,13 @@ def limit_concurrency_sync(sample): await coro pbar.update(1) except Exception as e: - print(f"\nExperiment stopped due to an error: {str(e)}\n") status = ExperimentStatus.FAILED - for task in tasks: - task.cancel() + if stop_on_error: + print(f"\nExperiment stopped due to an error: {str(e)}\n") + for task in tasks: + task.cancel() + else: + pbar.update(1) except asyncio.CancelledError: pass @@ -220,6 +225,7 @@ class Experiment: n_workers: int = field(default=10) # The number of times to run the experiment on the same data. n_trials: int = field(default=1) + stop_on_error: bool = field(default=False) def __attrs_post_init__(self): global _experiments @@ -253,7 +259,18 @@ def run(self, run_name: Optional[str] = None) -> None: experiment_schema: ExperimentSchema = self.p.create_experiment(CreateExperimentRequest(name=self.experiment_name, run_name=self.run_name, metadata=self.metadata)) self.experiment_uuid = experiment_schema.uuid self.experiment_stats = asyncio.run( - experiment(self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers) + experiment( + self.experiment_name, + self.run_name, + self.data, + self.func, + self.p, + self.experiment_uuid, + self.n_trials, + self.dataset_level_evals, + self.n_workers, + self.stop_on_error, + ) ) except Exception as e: import traceback @@ -277,7 +294,7 @@ async def arun(self, run_name: Optional[str] = None) -> None: ) self.experiment_uuid = experiment_schema.uuid self.experiment_stats = await experiment( - self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers + self.experiment_name, self.run_name, self.data, self.func, self.p, self.experiment_uuid, self.n_trials, self.dataset_level_evals, self.n_workers, self.stop_on_error ) except Exception as e: import traceback diff --git a/pyproject.toml b/pyproject.toml index 8b606f7b..547a782c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.175" +version = "0.2.176" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]