From 533eb6b1d1b919af9f98c7b280757f196998f47e Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 23 Jan 2024 17:21:29 -0500 Subject: [PATCH 1/6] feat: optimize experiments for async --- parea/experiment/experiment.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 010aa935..2a5c1249 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -9,6 +9,7 @@ from attrs import define, field from dotenv import load_dotenv from tqdm import tqdm +from tqdm.asyncio import tqdm_asyncio from parea.client import Parea from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID @@ -50,7 +51,7 @@ def async_wrapper(fn, **kwargs): return asyncio.run(fn(**kwargs)) -def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema: +async def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema: """Creates an experiment and runs the function on the data iterator.""" load_dotenv() @@ -62,11 +63,20 @@ def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentSta experiment_uuid = experiment_schema.uuid os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid - for data_input in tqdm(data): - if inspect.iscoroutinefunction(func): - asyncio.run(func(**data_input)) - else: + sem = asyncio.Semaphore(10) + + async def limit_concurrency(data_input): + async with sem: + return await func(**data_input) + + if inspect.iscoroutinefunction(func): + tasks = [limit_concurrency(data_input) for data_input in data] + for result in tqdm_asyncio(tasks): + await result + else: + for data_input in tqdm(data): func(**data_input) + time.sleep(5) # wait for any evaluation to finish which is executed in the background experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid) stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) @@ -90,4 +100,4 @@ def __attrs_post_init__(self): _experiments.append(self) def run(self): - self.experiment_stats = experiment(self.name, self.data, self.func) + self.experiment_stats = asyncio.run(experiment(self.name, self.data, self.func)) From f4a185b51fed0551b9f92aa2d7ff0bd1b3fab676 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Wed, 24 Jan 2024 12:59:39 -0500 Subject: [PATCH 2/6] feat: add progress bar for eval funcs which are executed in the background --- parea/experiment/experiment.py | 11 +++++++++-- parea/utils/trace_utils.py | 6 ++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 2a5c1249..d67307f2 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -4,7 +4,6 @@ import inspect import json import os -import time from attrs import define, field from dotenv import load_dotenv @@ -14,6 +13,7 @@ from parea.client import Parea from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, TraceStatsSchema +from parea.utils.trace_utils import thread_ids_running_evals def calculate_avg_as_string(values: List[float]) -> str: @@ -77,7 +77,14 @@ async def limit_concurrency(data_input): for data_input in tqdm(data): func(**data_input) - time.sleep(5) # wait for any evaluation to finish which is executed in the background + total_evals = len(thread_ids_running_evals.get()) + with tqdm(total=total_evals, dynamic_ncols=True) as pbar: + while thread_ids_running_evals.get(): + pbar.set_description(f"Waiting for evaluations to finish") + pbar.update(total_evals - len(thread_ids_running_evals.get())) + total_evals = len(thread_ids_running_evals.get()) + await asyncio.sleep(0.5) + experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid) stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}\n\n") diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index 17193f43..2ca63e37 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -25,6 +25,9 @@ # A dictionary to hold trace data for each trace trace_data = contextvars.ContextVar("trace_data", default={}) +# Context variable to maintain running evals in thread +thread_ids_running_evals = contextvars.ContextVar("thread_ids_running_evals", default=[]) + def merge(old, new): if isinstance(old, dict) and isinstance(new, dict): @@ -191,6 +194,7 @@ def logger_all_possible(trace_id: str): def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None): data = trace_data.get()[trace_id] + thread_ids_running_evals.get().append(trace_id) try: if eval_funcs and data.status == "success": if access_output_of_func: @@ -215,6 +219,8 @@ def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, a data.output = output_old except Exception as e: logger.exception(f"Error occurred in when trying to evaluate output, {e}", exc_info=e) + finally: + thread_ids_running_evals.get().remove(trace_id) parea_logger.default_log(data=data) From 8e80c83303ed7c9e6d59d6f7ab15a70358d88e09 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Wed, 24 Jan 2024 13:32:30 -0500 Subject: [PATCH 3/6] feat: add retry for bad gateway errors --- parea/api_client.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/parea/api_client.py b/parea/api_client.py index c26ddb9b..82cd5445 100644 --- a/parea/api_client.py +++ b/parea/api_client.py @@ -1,7 +1,39 @@ -from typing import Any, Optional +import asyncio +import time +from functools import wraps +from typing import Any, Optional, Callable import httpx +MAX_RETRIES = 5 +BACKOFF_FACTOR = 0.5 + + +def retry_on_502(func: Callable[..., Any]) -> Callable[..., Any]: + if asyncio.iscoroutinefunction(func): + @wraps(func) + async def wrapper(*args, **kwargs): + for retry in range(MAX_RETRIES): + try: + return await func(*args, **kwargs) + except httpx.HTTPStatusError as e: + if e.response.status_code != 502 or retry == MAX_RETRIES - 1: + raise + await asyncio.sleep(BACKOFF_FACTOR * (2 ** retry)) + return wrapper + else: + @wraps(func) + def wrapper(*args, **kwargs): + for retry in range(MAX_RETRIES): + try: + return func(*args, **kwargs) + except httpx.HTTPStatusError as e: + if e.response.status_code != 502 or retry == MAX_RETRIES - 1: + raise + time.sleep(BACKOFF_FACTOR * (2 ** retry)) + return wrapper + + class HTTPClient: _instance = None @@ -18,6 +50,7 @@ def __new__(cls, *args, **kwargs): def set_api_key(self, api_key: str): self.api_key = api_key + @retry_on_502 def request( self, method: str, @@ -34,6 +67,7 @@ def request( response.raise_for_status() return response + @retry_on_502 async def request_async( self, method: str, From 475b64b5f80cd867f63411dcdf5643d80ccf3cf8 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Wed, 24 Jan 2024 17:03:52 -0500 Subject: [PATCH 4/6] feat: parallelize sync functions in experiments --- parea/experiment/experiment.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index d67307f2..34b3f832 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -1,3 +1,5 @@ +from concurrent.futures import ThreadPoolExecutor +from functools import partial from typing import Callable, Dict, Iterable, List import asyncio @@ -63,7 +65,9 @@ async def experiment(name: str, data: Iterable[Dict], func: Callable) -> Experim experiment_uuid = experiment_schema.uuid os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid - sem = asyncio.Semaphore(10) + max_parallel_calls = 10 + executor = ThreadPoolExecutor(max_workers=max_parallel_calls) + sem = asyncio.Semaphore(max_parallel_calls) async def limit_concurrency(data_input): async with sem: @@ -74,8 +78,10 @@ async def limit_concurrency(data_input): for result in tqdm_asyncio(tasks): await result else: - for data_input in tqdm(data): - func(**data_input) + loop = asyncio.get_event_loop() + tasks = [loop.run_in_executor(executor, partial(func, **data_input)) for data_input in data] + for future in tqdm_asyncio.as_completed(tasks): + await future total_evals = len(thread_ids_running_evals.get()) with tqdm(total=total_evals, dynamic_ncols=True) as pbar: From 295589ae0fe48902ffd472d13f970487d14f4734 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Wed, 24 Jan 2024 18:16:01 -0500 Subject: [PATCH 5/6] feat: revert sync changes --- parea/experiment/experiment.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 34b3f832..a5a3f498 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -1,5 +1,3 @@ -from concurrent.futures import ThreadPoolExecutor -from functools import partial from typing import Callable, Dict, Iterable, List import asyncio @@ -66,7 +64,6 @@ async def experiment(name: str, data: Iterable[Dict], func: Callable) -> Experim os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid max_parallel_calls = 10 - executor = ThreadPoolExecutor(max_workers=max_parallel_calls) sem = asyncio.Semaphore(max_parallel_calls) async def limit_concurrency(data_input): @@ -78,10 +75,8 @@ async def limit_concurrency(data_input): for result in tqdm_asyncio(tasks): await result else: - loop = asyncio.get_event_loop() - tasks = [loop.run_in_executor(executor, partial(func, **data_input)) for data_input in data] - for future in tqdm_asyncio.as_completed(tasks): - await future + for data_input in tqdm(data): + func(**data_input) total_evals = len(thread_ids_running_evals.get()) with tqdm(total=total_evals, dynamic_ncols=True) as pbar: From feda0f282b4de18634e1664cde95e9a4deb0cc43 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Thu, 25 Jan 2024 12:16:29 -0500 Subject: [PATCH 6/6] feat: bump version --- parea/api_client.py | 12 ++++++++---- pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/parea/api_client.py b/parea/api_client.py index 82cd5445..cefed5d4 100644 --- a/parea/api_client.py +++ b/parea/api_client.py @@ -1,7 +1,8 @@ +from typing import Any, Callable, Optional + import asyncio import time from functools import wraps -from typing import Any, Optional, Callable import httpx @@ -11,6 +12,7 @@ def retry_on_502(func: Callable[..., Any]) -> Callable[..., Any]: if asyncio.iscoroutinefunction(func): + @wraps(func) async def wrapper(*args, **kwargs): for retry in range(MAX_RETRIES): @@ -19,9 +21,11 @@ async def wrapper(*args, **kwargs): except httpx.HTTPStatusError as e: if e.response.status_code != 502 or retry == MAX_RETRIES - 1: raise - await asyncio.sleep(BACKOFF_FACTOR * (2 ** retry)) + await asyncio.sleep(BACKOFF_FACTOR * (2**retry)) + return wrapper else: + @wraps(func) def wrapper(*args, **kwargs): for retry in range(MAX_RETRIES): @@ -30,9 +34,9 @@ def wrapper(*args, **kwargs): except httpx.HTTPStatusError as e: if e.response.status_code != 502 or retry == MAX_RETRIES - 1: raise - time.sleep(BACKOFF_FACTOR * (2 ** retry)) - return wrapper + time.sleep(BACKOFF_FACTOR * (2**retry)) + return wrapper class HTTPClient: diff --git a/pyproject.toml b/pyproject.toml index d3d700c5..402f9fe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.34" +version = "0.2.35" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]