Skip to content

Commit

Permalink
Merge pull request #348 from parea-ai/PAI-583-experiment-in-typescrip…
Browse files Browse the repository at this point in the history
…t-easy-auto-evals-py-sdk

sdk code updates
  • Loading branch information
jalexanderII committed Jan 29, 2024
2 parents 8518848 + 7453cef commit f737aa5
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 99 deletions.
17 changes: 16 additions & 1 deletion parea/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Callable
from typing import Callable, Dict

import asyncio
import os
import time
from collections.abc import Iterable

from attrs import asdict, define, field
from cattrs import structure
Expand All @@ -11,6 +12,7 @@
from parea.api_client import HTTPClient
from parea.cache import InMemoryCache, RedisCache
from parea.cache.cache import Cache
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.helpers import gen_trace_id
from parea.parea_logger import parea_logger
from parea.schemas.models import (
Expand Down Expand Up @@ -60,13 +62,17 @@ def completion(self, data: Completion) -> CompletionResponse:
data.inference_id = inference_id
data.parent_trace_id = parent_trace_id or inference_id

if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None):
data.experiment_uuid = experiment_uuid

r = self._client.request(
"POST",
COMPLETION_ENDPOINT,
data=asdict(data),
)
if parent_trace_id:
trace_data.get()[parent_trace_id].children.append(inference_id)
trace_data.get()[parent_trace_id].experiment_uuid = experiment_uuid
logger_record_log(parent_trace_id)
return structure(r.json(), CompletionResponse)

Expand All @@ -76,13 +82,17 @@ async def acompletion(self, data: Completion) -> CompletionResponse:
data.inference_id = inference_id
data.parent_trace_id = parent_trace_id or inference_id

if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None):
data.experiment_uuid = experiment_uuid

r = await self._client.request_async(
"POST",
COMPLETION_ENDPOINT,
data=asdict(data),
)
if parent_trace_id:
trace_data.get()[parent_trace_id].children.append(inference_id)
trace_data.get()[parent_trace_id].experiment_uuid = experiment_uuid
logger_record_log(parent_trace_id)
return structure(r.json(), CompletionResponse)

Expand Down Expand Up @@ -162,6 +172,11 @@ async def afinish_experiment(self, experiment_uuid: str) -> ExperimentSchema:
)
return structure(r.json(), ExperimentStatsSchema)

def experiment(self, name: str, data: Iterable[dict], func: Callable):
from parea import Experiment

return Experiment(name=name, data=data, func=func, p=self)


_initialized_parea_wrapper = False

Expand Down
31 changes: 13 additions & 18 deletions parea/experiment/experiment.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
from typing import Callable, Dict, Iterable, List
from typing import Callable, Dict, List

import asyncio
import inspect
import json
import os
from collections.abc import Iterable

from attrs import define, field
from dotenv import load_dotenv
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from parea.client import Parea
from parea import Parea
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, TraceStatsSchema
from parea.utils.trace_utils import thread_ids_running_evals


def calculate_avg_as_string(values: List[float]) -> str:
def calculate_avg_as_string(values: list[float]) -> str:
if not values:
return "N/A"
values = [x for x in values if x is not None]
avg = sum(values) / len(values)
return f"{avg:.2f}"


def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> Dict[str, str]:
trace_stats: List[TraceStatsSchema] = experiment_stats.parent_trace_stats
def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> dict[str, str]:
trace_stats: list[TraceStatsSchema] = experiment_stats.parent_trace_stats
latency_values = [trace_stat.latency for trace_stat in trace_stats]
input_tokens_values = [trace_stat.input_tokens for trace_stat in trace_stats]
output_tokens_values = [trace_stat.output_tokens for trace_stat in trace_stats]
total_tokens_values = [trace_stat.total_tokens for trace_stat in trace_stats]
cost_values = [trace_stat.cost for trace_stat in trace_stats]
score_name_to_values: Dict[str, List[float]] = {}
score_name_to_values: dict[str, list[float]] = {}
for trace_stat in trace_stats:
if trace_stat.scores:
for score in trace_stat.scores:
Expand All @@ -51,14 +51,8 @@ def async_wrapper(fn, **kwargs):
return asyncio.run(fn(**kwargs))


async def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema:
async def experiment(name: str, data: Iterable[dict], func: Callable, p: Parea) -> ExperimentStatsSchema:
"""Creates an experiment and runs the function on the data iterator."""
load_dotenv()

if not (parea_api_key := os.getenv("PAREA_API_KEY")):
raise ValueError("Please set the PAREA_API_KEY environment variable")
p = Parea(api_key=parea_api_key)

experiment_schema: ExperimentSchema = p.create_experiment(CreateExperimentRequest(name=name))
experiment_uuid = experiment_schema.uuid
os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid
Expand Down Expand Up @@ -98,14 +92,15 @@ async def limit_concurrency(data_input):

@define
class Experiment:
name: str = field(init=True)
data: Iterable[Dict] = field(init=True)
func: Callable = field(init=True)
name: str = field()
data: Iterable[dict] = field()
func: Callable = field()
experiment_stats: ExperimentStatsSchema = field(init=False, default=None)
p: Parea = field(default=None)

def __attrs_post_init__(self):
global _experiments
_experiments.append(self)

def run(self):
self.experiment_stats = asyncio.run(experiment(self.name, self.data, self.func))
self.experiment_stats = asyncio.run(experiment(self.name, self.data, self.func, self.p))
15 changes: 12 additions & 3 deletions parea/schemas/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, List, Optional
from typing import Any, Optional

from enum import Enum

from attrs import define, field, validators

Expand All @@ -22,6 +24,7 @@ class Completion:
log_omit_inputs: bool = False
log_omit_outputs: bool = False
log_omit: bool = False
experiment_uuid: Optional[str] = None


@define
Expand Down Expand Up @@ -152,9 +155,15 @@ class TraceStatsSchema:
output_tokens: Optional[int] = 0
total_tokens: Optional[int] = 0
cost: Optional[float] = None
scores: Optional[List[EvaluationScoreSchema]] = field(factory=list)
scores: Optional[list[EvaluationScoreSchema]] = field(factory=list)


@define
class ExperimentStatsSchema:
parent_trace_stats: List[TraceStatsSchema] = field(factory=list)
parent_trace_stats: list[TraceStatsSchema] = field(factory=list)


class UpdateTraceScenario(str, Enum):
RESULT: str = "result"
ERROR: str = "error"
CHAIN: str = "chain"
149 changes: 73 additions & 76 deletions parea/utils/trace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.helpers import gen_trace_id, to_date_and_time_string
from parea.parea_logger import parea_logger
from parea.schemas.models import NamedEvaluationScore, TraceLog
from parea.schemas.models import NamedEvaluationScore, TraceLog, UpdateLog, UpdateTraceScenario
from parea.utils.universal_encoder import json_dumps

logger = logging.getLogger()
Expand All @@ -29,6 +29,11 @@
thread_ids_running_evals = contextvars.ContextVar("thread_ids_running_evals", default=[])


def log_in_thread(target_func: Callable, data: dict[str, Any]):
logging_thread = threading.Thread(target=target_func, kwargs=data)
logging_thread.start()


def merge(old, new):
if isinstance(old, dict) and isinstance(new, dict):
return dict(ChainMap(new, old))
Expand All @@ -37,6 +42,26 @@ def merge(old, new):
return new


def check_multiple_return_values(func) -> bool:
specs = inspect.getfullargspec(func)
try:
r = specs.annotations.get("return")
if r and r.__origin__ == tuple:
return len(r.__args__) > 1
except Exception:
return False


def make_output(result, islist) -> str:
if islist:
json_list = [json_dumps(r) for r in result]
return json_dumps(json_list)
elif isinstance(result, str):
return result
else:
return json_dumps(result)


def get_current_trace_id() -> str:
stack = trace_context.get()
if stack:
Expand All @@ -53,6 +78,19 @@ def trace_insert(data: dict[str, Any]):
current_trace_data.__setattr__(key, merge(existing_value, new_value) if existing_value else new_value)


def fill_trace_data(trace_id: str, data: dict[str, Any], scenario: UpdateTraceScenario):
if scenario == UpdateTraceScenario.RESULT:
trace_data.get()[trace_id].output = make_output(data["result"], data.get("output_as_list"))
trace_data.get()[trace_id].status = "success"
trace_data.get()[trace_id].evaluation_metric_names = data.get("eval_funcs_names")
elif scenario == UpdateTraceScenario.ERROR:
trace_data.get()[trace_id].error = data["error"]
trace_data.get()[trace_id].status = "error"
elif scenario == UpdateTraceScenario.CHAIN:
trace_data.get()[trace_id].parent_trace_id = data["parent_trace_id"]
trace_data.get()[data["parent_trace_id"]].children.append(trace_id)


def trace(
name: Optional[str] = None,
tags: Optional[list[str]] = None,
Expand Down Expand Up @@ -94,8 +132,7 @@ def init_trace(func_name, args, kwargs, func) -> tuple[str, float]:
)
parent_trace_id = trace_context.get()[-2] if len(trace_context.get()) > 1 else None
if parent_trace_id:
trace_data.get()[trace_id].parent_trace_id = parent_trace_id
trace_data.get()[parent_trace_id].children.append(trace_id)
fill_trace_data(trace_id, {"parent_trace_id": parent_trace_id}, UpdateTraceScenario.CHAIN)

return trace_id, start_time

Expand All @@ -113,13 +150,10 @@ async def async_wrapper(*args, **kwargs):
output_as_list = check_multiple_return_values(func)
try:
result = await func(*args, **kwargs)
trace_data.get()[trace_id].output = make_output(result, output_as_list)
trace_data.get()[trace_id].status = "success"
trace_data.get()[trace_id].evaluation_metric_names = eval_funcs_names
fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT)
except Exception as e:
logger.exception(f"Error occurred in function {func.__name__}, {e}")
trace_data.get()[trace_id].error = str(e)
trace_data.get()[trace_id].status = "error"
fill_trace_data(trace_id, {"error": str(e)}, UpdateTraceScenario.ERROR)
raise e
finally:
cleanup_trace(trace_id, start_time)
Expand All @@ -131,13 +165,10 @@ def wrapper(*args, **kwargs):
output_as_list = check_multiple_return_values(func)
try:
result = func(*args, **kwargs)
trace_data.get()[trace_id].output = make_output(result, output_as_list)
trace_data.get()[trace_id].status = "success"
trace_data.get()[trace_id].evaluation_metric_names = eval_funcs_names
fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT)
except Exception as e:
logger.exception(f"Error occurred in function {func.__name__}, {e}")
trace_data.get()[trace_id].error = str(e)
trace_data.get()[trace_id].status = "error"
fill_trace_data(trace_id, {"error": str(e)}, UpdateTraceScenario.ERROR)
raise e
finally:
cleanup_trace(trace_id, start_time)
Expand All @@ -156,77 +187,43 @@ def wrapper(*args, **kwargs):
return decorator


def check_multiple_return_values(func) -> bool:
specs = inspect.getfullargspec(func)
try:
r = specs.annotations.get("return", None)
if r and r.__origin__ == tuple:
return len(r.__args__) > 1
except Exception:
return False
def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None):
data = trace_data.get()[trace_id]
parea_logger.default_log(data=data)

if eval_funcs and data.status == "success":
thread_ids_running_evals.get().append(trace_id)
if access_output_of_func:
try:
output = json.loads(data.output)
output = access_output_of_func(output)
output_for_eval_metrics = json_dumps(output)
except Exception as e:
logger.exception(f"Error accessing output of func with output: {data.output}. Error: {e}", exc_info=e)
return
else:
output_for_eval_metrics = data.output

def make_output(result, islist) -> str:
if islist:
json_list = [json_dumps(r) for r in result]
return json_dumps(json_list)
elif isinstance(result, str):
return result
else:
return json_dumps(result)
data.output = output_for_eval_metrics
scores = []

for func in eval_funcs:
try:
scores.append(NamedEvaluationScore(name=func.__name__, score=func(data)))
except Exception as e:
logger.exception(f"Error occurred calling evaluation function '{func.__name__}', {e}", exc_info=e)

def logger_record_log(trace_id: str):
logging_thread = threading.Thread(
target=parea_logger.record_log,
kwargs={"data": trace_data.get()[trace_id]},
)
logging_thread.start()
parea_logger.update_log(data=UpdateLog(trace_id=trace_id, field_name_to_value_map={"scores": scores}))
thread_ids_running_evals.get().remove(trace_id)


def logger_all_possible(trace_id: str):
logging_thread = threading.Thread(
target=parea_logger.default_log,
kwargs={"data": trace_data.get()[trace_id]},
)
logging_thread.start()
def logger_record_log(trace_id: str):
log_in_thread(parea_logger.record_log, {"data": trace_data.get()[trace_id]})


def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None):
data = trace_data.get()[trace_id]
thread_ids_running_evals.get().append(trace_id)
try:
if eval_funcs and data.status == "success":
if access_output_of_func:
output = json.loads(data.output)
output = access_output_of_func(output)
output_for_eval_metrics = json_dumps(output)
else:
output_for_eval_metrics = data.output

data.output_for_eval_metrics = output_for_eval_metrics
output_old = data.output
data.output = data.output_for_eval_metrics
data.scores = []

for func in eval_funcs:
try:
score = func(data)
data.scores.append(NamedEvaluationScore(name=func.__name__, score=score))
except Exception as e:
logger.exception(f"Error occurred calling evaluation function '{func.__name__}', {e}", exc_info=e)

data.output = output_old
except Exception as e:
logger.exception(f"Error occurred in when trying to evaluate output, {e}", exc_info=e)
finally:
thread_ids_running_evals.get().remove(trace_id)
parea_logger.default_log(data=data)
def logger_all_possible(trace_id: str):
log_in_thread(parea_logger.default_log, {"data": trace_data.get()[trace_id]})


def thread_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None):
logging_thread = threading.Thread(
target=call_eval_funcs_then_log,
kwargs={"trace_id": trace_id, "eval_funcs": eval_funcs, "access_output_of_func": access_output_of_func},
)
logging_thread.start()
log_in_thread(call_eval_funcs_then_log, {"trace_id": trace_id, "eval_funcs": eval_funcs, "access_output_of_func": access_output_of_func})
Loading

0 comments on commit f737aa5

Please sign in to comment.