Skip to content

Commit

Permalink
Merge pull request #342 from parea-ai/feat-async-experiments
Browse files Browse the repository at this point in the history
feat: optimize experiments for async
  • Loading branch information
joschkabraun committed Jan 25, 2024
2 parents 2f6b04b + feda0f2 commit 3f38f32
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 10 deletions.
40 changes: 39 additions & 1 deletion parea/api_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,43 @@
from typing import Any, Optional
from typing import Any, Callable, Optional

import asyncio
import time
from functools import wraps

import httpx

MAX_RETRIES = 5
BACKOFF_FACTOR = 0.5


def retry_on_502(func: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(func):

@wraps(func)
async def wrapper(*args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return await func(*args, **kwargs)
except httpx.HTTPStatusError as e:
if e.response.status_code != 502 or retry == MAX_RETRIES - 1:
raise
await asyncio.sleep(BACKOFF_FACTOR * (2**retry))

return wrapper
else:

@wraps(func)
def wrapper(*args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return func(*args, **kwargs)
except httpx.HTTPStatusError as e:
if e.response.status_code != 502 or retry == MAX_RETRIES - 1:
raise
time.sleep(BACKOFF_FACTOR * (2**retry))

return wrapper


class HTTPClient:
_instance = None
Expand All @@ -18,6 +54,7 @@ def __new__(cls, *args, **kwargs):
def set_api_key(self, api_key: str):
self.api_key = api_key

@retry_on_502
def request(
self,
method: str,
Expand All @@ -34,6 +71,7 @@ def request(
response.raise_for_status()
return response

@retry_on_502
async def request_async(
self,
method: str,
Expand Down
34 changes: 26 additions & 8 deletions parea/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import inspect
import json
import os
import time

from attrs import define, field
from dotenv import load_dotenv
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from parea.client import Parea
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, TraceStatsSchema
from parea.utils.trace_utils import thread_ids_running_evals


def calculate_avg_as_string(values: List[float]) -> str:
Expand Down Expand Up @@ -50,7 +51,7 @@ def async_wrapper(fn, **kwargs):
return asyncio.run(fn(**kwargs))


def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema:
async def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema:
"""Creates an experiment and runs the function on the data iterator."""
load_dotenv()

Expand All @@ -62,12 +63,29 @@ def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentSta
experiment_uuid = experiment_schema.uuid
os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid

for data_input in tqdm(data):
if inspect.iscoroutinefunction(func):
asyncio.run(func(**data_input))
else:
max_parallel_calls = 10
sem = asyncio.Semaphore(max_parallel_calls)

async def limit_concurrency(data_input):
async with sem:
return await func(**data_input)

if inspect.iscoroutinefunction(func):
tasks = [limit_concurrency(data_input) for data_input in data]
for result in tqdm_asyncio(tasks):
await result
else:
for data_input in tqdm(data):
func(**data_input)
time.sleep(5) # wait for any evaluation to finish which is executed in the background

total_evals = len(thread_ids_running_evals.get())
with tqdm(total=total_evals, dynamic_ncols=True) as pbar:
while thread_ids_running_evals.get():
pbar.set_description(f"Waiting for evaluations to finish")
pbar.update(total_evals - len(thread_ids_running_evals.get()))
total_evals = len(thread_ids_running_evals.get())
await asyncio.sleep(0.5)

experiment_stats: ExperimentStatsSchema = p.finish_experiment(experiment_uuid)
stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats)
print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}\n\n")
Expand All @@ -90,4 +108,4 @@ def __attrs_post_init__(self):
_experiments.append(self)

def run(self):
self.experiment_stats = experiment(self.name, self.data, self.func)
self.experiment_stats = asyncio.run(experiment(self.name, self.data, self.func))
6 changes: 6 additions & 0 deletions parea/utils/trace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
# A dictionary to hold trace data for each trace
trace_data = contextvars.ContextVar("trace_data", default={})

# Context variable to maintain running evals in thread
thread_ids_running_evals = contextvars.ContextVar("thread_ids_running_evals", default=[])


def merge(old, new):
if isinstance(old, dict) and isinstance(new, dict):
Expand Down Expand Up @@ -191,6 +194,7 @@ def logger_all_possible(trace_id: str):

def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None):
data = trace_data.get()[trace_id]
thread_ids_running_evals.get().append(trace_id)
try:
if eval_funcs and data.status == "success":
if access_output_of_func:
Expand All @@ -215,6 +219,8 @@ def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, a
data.output = output_old
except Exception as e:
logger.exception(f"Error occurred in when trying to evaluate output, {e}", exc_info=e)
finally:
thread_ids_running_evals.get().remove(trace_id)
parea_logger.default_log(data=data)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.34"
version = "0.2.35"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit 3f38f32

Please sign in to comment.