diff --git a/parea/api_client.py b/parea/api_client.py index c26ddb9b..cefed5d4 100644 --- a/parea/api_client.py +++ b/parea/api_client.py @@ -1,7 +1,43 @@ -from typing import Any, Optional +from typing import Any, Callable, Optional + +import asyncio +import time +from functools import wraps import httpx +MAX_RETRIES = 5 +BACKOFF_FACTOR = 0.5 + + +def retry_on_502(func: Callable[..., Any]) -> Callable[..., Any]: + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def wrapper(*args, **kwargs): + for retry in range(MAX_RETRIES): + try: + return await func(*args, **kwargs) + except httpx.HTTPStatusError as e: + if e.response.status_code != 502 or retry == MAX_RETRIES - 1: + raise + await asyncio.sleep(BACKOFF_FACTOR * (2**retry)) + + return wrapper + else: + + @wraps(func) + def wrapper(*args, **kwargs): + for retry in range(MAX_RETRIES): + try: + return func(*args, **kwargs) + except httpx.HTTPStatusError as e: + if e.response.status_code != 502 or retry == MAX_RETRIES - 1: + raise + time.sleep(BACKOFF_FACTOR * (2**retry)) + + return wrapper + class HTTPClient: _instance = None @@ -18,6 +54,7 @@ def __new__(cls, *args, **kwargs): def set_api_key(self, api_key: str): self.api_key = api_key + @retry_on_502 def request( self, method: str, @@ -34,6 +71,7 @@ def request( response.raise_for_status() return response + @retry_on_502 async def request_async( self, method: str, diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 010aa935..a5a3f498 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -4,15 +4,16 @@ import inspect import json import os -import time from attrs import define, field from dotenv import load_dotenv from tqdm import tqdm +from tqdm.asyncio import tqdm_asyncio from parea.client import Parea from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, TraceStatsSchema +from parea.utils.trace_utils import thread_ids_running_evals def calculate_avg_as_string(values: List[float]) -> str: @@ -50,7 +51,7 @@ def async_wrapper(fn, **kwargs): return asyncio.run(fn(**kwargs)) -def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema: +async def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema: """Creates an experiment and runs the function on the data iterator.""" load_dotenv() @@ -62,12 +63,29 @@ def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentSta experiment_uuid = experiment_schema.uuid os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid - for data_input in tqdm(data): - if inspect.iscoroutinefunction(func): - asyncio.run(func(**data_input)) - else: + max_parallel_calls = 10 + sem = asyncio.Semaphore(max_parallel_calls) + + async def limit_concurrency(data_input): + async with sem: + return await func(**data_input) + + if inspect.iscoroutinefunction(func): + tasks = [limit_concurrency(data_input) for data_input in data] + for result in tqdm_asyncio(tasks): + await result + else: + for data_input in tqdm(data): func(**data_input) - time.sleep(5) # wait for any evaluation to finish which is executed in the background + + total_evals = len(thread_ids_running_evals.get()) + with tqdm(total=total_evals, dynamic_ncols=True) as pbar: + while thread_ids_running_evals.get(): + pbar.set_description(f"Waiting for evaluations to finish") + pbar.update(total_evals - len(thread_ids_running_evals.get())) + total_evals = len(thread_ids_running_evals.get()) + await asyncio.sleep(0.5) + experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid) stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}\n\n") @@ -90,4 +108,4 @@ def __attrs_post_init__(self): _experiments.append(self) def run(self): - self.experiment_stats = experiment(self.name, self.data, self.func) + self.experiment_stats = asyncio.run(experiment(self.name, self.data, self.func)) diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index 17193f43..2ca63e37 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -25,6 +25,9 @@ # A dictionary to hold trace data for each trace trace_data = contextvars.ContextVar("trace_data", default={}) +# Context variable to maintain running evals in thread +thread_ids_running_evals = contextvars.ContextVar("thread_ids_running_evals", default=[]) + def merge(old, new): if isinstance(old, dict) and isinstance(new, dict): @@ -191,6 +194,7 @@ def logger_all_possible(trace_id: str): def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None): data = trace_data.get()[trace_id] + thread_ids_running_evals.get().append(trace_id) try: if eval_funcs and data.status == "success": if access_output_of_func: @@ -215,6 +219,8 @@ def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, a data.output = output_old except Exception as e: logger.exception(f"Error occurred in when trying to evaluate output, {e}", exc_info=e) + finally: + thread_ids_running_evals.get().remove(trace_id) parea_logger.default_log(data=data) diff --git a/pyproject.toml b/pyproject.toml index d3d700c5..402f9fe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.34" +version = "0.2.35" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]