diff --git a/parea/client.py b/parea/client.py index 2eb592a0..a62c76a1 100644 --- a/parea/client.py +++ b/parea/client.py @@ -1,8 +1,9 @@ -from typing import Callable +from typing import Callable, Dict import asyncio import os import time +from collections.abc import Iterable from attrs import asdict, define, field from cattrs import structure @@ -11,6 +12,7 @@ from parea.api_client import HTTPClient from parea.cache import InMemoryCache, RedisCache from parea.cache.cache import Cache +from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.helpers import gen_trace_id from parea.parea_logger import parea_logger from parea.schemas.models import ( @@ -60,6 +62,9 @@ def completion(self, data: Completion) -> CompletionResponse: data.inference_id = inference_id data.parent_trace_id = parent_trace_id or inference_id + if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None): + data.experiment_uuid = experiment_uuid + r = self._client.request( "POST", COMPLETION_ENDPOINT, @@ -67,6 +72,7 @@ def completion(self, data: Completion) -> CompletionResponse: ) if parent_trace_id: trace_data.get()[parent_trace_id].children.append(inference_id) + trace_data.get()[parent_trace_id].experiment_uuid = experiment_uuid logger_record_log(parent_trace_id) return structure(r.json(), CompletionResponse) @@ -76,6 +82,9 @@ async def acompletion(self, data: Completion) -> CompletionResponse: data.inference_id = inference_id data.parent_trace_id = parent_trace_id or inference_id + if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None): + data.experiment_uuid = experiment_uuid + r = await self._client.request_async( "POST", COMPLETION_ENDPOINT, @@ -83,6 +92,7 @@ async def acompletion(self, data: Completion) -> CompletionResponse: ) if parent_trace_id: trace_data.get()[parent_trace_id].children.append(inference_id) + trace_data.get()[parent_trace_id].experiment_uuid = experiment_uuid logger_record_log(parent_trace_id) return structure(r.json(), CompletionResponse) @@ -162,6 +172,11 @@ async def afinish_experiment(self, experiment_uuid: str) -> ExperimentSchema: ) return structure(r.json(), ExperimentStatsSchema) + def experiment(self, name: str, data: Iterable[dict], func: Callable): + from parea import Experiment + + return Experiment(name=name, data=data, func=func, p=self) + _initialized_parea_wrapper = False diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index a5a3f498..a589779d 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -1,22 +1,22 @@ -from typing import Callable, Dict, Iterable, List +from typing import Callable, Dict, List import asyncio import inspect import json import os +from collections.abc import Iterable from attrs import define, field -from dotenv import load_dotenv from tqdm import tqdm from tqdm.asyncio import tqdm_asyncio -from parea.client import Parea +from parea import Parea from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, TraceStatsSchema from parea.utils.trace_utils import thread_ids_running_evals -def calculate_avg_as_string(values: List[float]) -> str: +def calculate_avg_as_string(values: list[float]) -> str: if not values: return "N/A" values = [x for x in values if x is not None] @@ -24,14 +24,14 @@ def calculate_avg_as_string(values: List[float]) -> str: return f"{avg:.2f}" -def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> Dict[str, str]: - trace_stats: List[TraceStatsSchema] = experiment_stats.parent_trace_stats +def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> dict[str, str]: + trace_stats: list[TraceStatsSchema] = experiment_stats.parent_trace_stats latency_values = [trace_stat.latency for trace_stat in trace_stats] input_tokens_values = [trace_stat.input_tokens for trace_stat in trace_stats] output_tokens_values = [trace_stat.output_tokens for trace_stat in trace_stats] total_tokens_values = [trace_stat.total_tokens for trace_stat in trace_stats] cost_values = [trace_stat.cost for trace_stat in trace_stats] - score_name_to_values: Dict[str, List[float]] = {} + score_name_to_values: dict[str, list[float]] = {} for trace_stat in trace_stats: if trace_stat.scores: for score in trace_stat.scores: @@ -51,14 +51,8 @@ def async_wrapper(fn, **kwargs): return asyncio.run(fn(**kwargs)) -async def experiment(name: str, data: Iterable[Dict], func: Callable) -> ExperimentStatsSchema: +async def experiment(name: str, data: Iterable[dict], func: Callable, p: Parea) -> ExperimentStatsSchema: """Creates an experiment and runs the function on the data iterator.""" - load_dotenv() - - if not (parea_api_key := os.getenv("PAREA_API_KEY")): - raise ValueError("Please set the PAREA_API_KEY environment variable") - p = Parea(api_key=parea_api_key) - experiment_schema: ExperimentSchema = p.create_experiment(CreateExperimentRequest(name=name)) experiment_uuid = experiment_schema.uuid os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid @@ -98,14 +92,15 @@ async def limit_concurrency(data_input): @define class Experiment: - name: str = field(init=True) - data: Iterable[Dict] = field(init=True) - func: Callable = field(init=True) + name: str = field() + data: Iterable[dict] = field() + func: Callable = field() experiment_stats: ExperimentStatsSchema = field(init=False, default=None) + p: Parea = field(default=None) def __attrs_post_init__(self): global _experiments _experiments.append(self) def run(self): - self.experiment_stats = asyncio.run(experiment(self.name, self.data, self.func)) + self.experiment_stats = asyncio.run(experiment(self.name, self.data, self.func, self.p)) diff --git a/parea/schemas/models.py b/parea/schemas/models.py index 52b9d4c8..5ae75e8a 100644 --- a/parea/schemas/models.py +++ b/parea/schemas/models.py @@ -1,4 +1,6 @@ -from typing import Any, List, Optional +from typing import Any, Optional + +from enum import Enum from attrs import define, field, validators @@ -22,6 +24,7 @@ class Completion: log_omit_inputs: bool = False log_omit_outputs: bool = False log_omit: bool = False + experiment_uuid: Optional[str] = None @define @@ -152,9 +155,15 @@ class TraceStatsSchema: output_tokens: Optional[int] = 0 total_tokens: Optional[int] = 0 cost: Optional[float] = None - scores: Optional[List[EvaluationScoreSchema]] = field(factory=list) + scores: Optional[list[EvaluationScoreSchema]] = field(factory=list) @define class ExperimentStatsSchema: - parent_trace_stats: List[TraceStatsSchema] = field(factory=list) + parent_trace_stats: list[TraceStatsSchema] = field(factory=list) + + +class UpdateTraceScenario(str, Enum): + RESULT: str = "result" + ERROR: str = "error" + CHAIN: str = "chain" diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index 2ca63e37..71ad914b 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -13,7 +13,7 @@ from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.helpers import gen_trace_id, to_date_and_time_string from parea.parea_logger import parea_logger -from parea.schemas.models import NamedEvaluationScore, TraceLog +from parea.schemas.models import NamedEvaluationScore, TraceLog, UpdateLog, UpdateTraceScenario from parea.utils.universal_encoder import json_dumps logger = logging.getLogger() @@ -29,6 +29,11 @@ thread_ids_running_evals = contextvars.ContextVar("thread_ids_running_evals", default=[]) +def log_in_thread(target_func: Callable, data: dict[str, Any]): + logging_thread = threading.Thread(target=target_func, kwargs=data) + logging_thread.start() + + def merge(old, new): if isinstance(old, dict) and isinstance(new, dict): return dict(ChainMap(new, old)) @@ -37,6 +42,26 @@ def merge(old, new): return new +def check_multiple_return_values(func) -> bool: + specs = inspect.getfullargspec(func) + try: + r = specs.annotations.get("return") + if r and r.__origin__ == tuple: + return len(r.__args__) > 1 + except Exception: + return False + + +def make_output(result, islist) -> str: + if islist: + json_list = [json_dumps(r) for r in result] + return json_dumps(json_list) + elif isinstance(result, str): + return result + else: + return json_dumps(result) + + def get_current_trace_id() -> str: stack = trace_context.get() if stack: @@ -53,6 +78,19 @@ def trace_insert(data: dict[str, Any]): current_trace_data.__setattr__(key, merge(existing_value, new_value) if existing_value else new_value) +def fill_trace_data(trace_id: str, data: dict[str, Any], scenario: UpdateTraceScenario): + if scenario == UpdateTraceScenario.RESULT: + trace_data.get()[trace_id].output = make_output(data["result"], data.get("output_as_list")) + trace_data.get()[trace_id].status = "success" + trace_data.get()[trace_id].evaluation_metric_names = data.get("eval_funcs_names") + elif scenario == UpdateTraceScenario.ERROR: + trace_data.get()[trace_id].error = data["error"] + trace_data.get()[trace_id].status = "error" + elif scenario == UpdateTraceScenario.CHAIN: + trace_data.get()[trace_id].parent_trace_id = data["parent_trace_id"] + trace_data.get()[data["parent_trace_id"]].children.append(trace_id) + + def trace( name: Optional[str] = None, tags: Optional[list[str]] = None, @@ -94,8 +132,7 @@ def init_trace(func_name, args, kwargs, func) -> tuple[str, float]: ) parent_trace_id = trace_context.get()[-2] if len(trace_context.get()) > 1 else None if parent_trace_id: - trace_data.get()[trace_id].parent_trace_id = parent_trace_id - trace_data.get()[parent_trace_id].children.append(trace_id) + fill_trace_data(trace_id, {"parent_trace_id": parent_trace_id}, UpdateTraceScenario.CHAIN) return trace_id, start_time @@ -113,13 +150,10 @@ async def async_wrapper(*args, **kwargs): output_as_list = check_multiple_return_values(func) try: result = await func(*args, **kwargs) - trace_data.get()[trace_id].output = make_output(result, output_as_list) - trace_data.get()[trace_id].status = "success" - trace_data.get()[trace_id].evaluation_metric_names = eval_funcs_names + fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT) except Exception as e: logger.exception(f"Error occurred in function {func.__name__}, {e}") - trace_data.get()[trace_id].error = str(e) - trace_data.get()[trace_id].status = "error" + fill_trace_data(trace_id, {"error": str(e)}, UpdateTraceScenario.ERROR) raise e finally: cleanup_trace(trace_id, start_time) @@ -131,13 +165,10 @@ def wrapper(*args, **kwargs): output_as_list = check_multiple_return_values(func) try: result = func(*args, **kwargs) - trace_data.get()[trace_id].output = make_output(result, output_as_list) - trace_data.get()[trace_id].status = "success" - trace_data.get()[trace_id].evaluation_metric_names = eval_funcs_names + fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT) except Exception as e: logger.exception(f"Error occurred in function {func.__name__}, {e}") - trace_data.get()[trace_id].error = str(e) - trace_data.get()[trace_id].status = "error" + fill_trace_data(trace_id, {"error": str(e)}, UpdateTraceScenario.ERROR) raise e finally: cleanup_trace(trace_id, start_time) @@ -156,77 +187,43 @@ def wrapper(*args, **kwargs): return decorator -def check_multiple_return_values(func) -> bool: - specs = inspect.getfullargspec(func) - try: - r = specs.annotations.get("return", None) - if r and r.__origin__ == tuple: - return len(r.__args__) > 1 - except Exception: - return False +def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None): + data = trace_data.get()[trace_id] + parea_logger.default_log(data=data) + if eval_funcs and data.status == "success": + thread_ids_running_evals.get().append(trace_id) + if access_output_of_func: + try: + output = json.loads(data.output) + output = access_output_of_func(output) + output_for_eval_metrics = json_dumps(output) + except Exception as e: + logger.exception(f"Error accessing output of func with output: {data.output}. Error: {e}", exc_info=e) + return + else: + output_for_eval_metrics = data.output -def make_output(result, islist) -> str: - if islist: - json_list = [json_dumps(r) for r in result] - return json_dumps(json_list) - elif isinstance(result, str): - return result - else: - return json_dumps(result) + data.output = output_for_eval_metrics + scores = [] + for func in eval_funcs: + try: + scores.append(NamedEvaluationScore(name=func.__name__, score=func(data))) + except Exception as e: + logger.exception(f"Error occurred calling evaluation function '{func.__name__}', {e}", exc_info=e) -def logger_record_log(trace_id: str): - logging_thread = threading.Thread( - target=parea_logger.record_log, - kwargs={"data": trace_data.get()[trace_id]}, - ) - logging_thread.start() + parea_logger.update_log(data=UpdateLog(trace_id=trace_id, field_name_to_value_map={"scores": scores})) + thread_ids_running_evals.get().remove(trace_id) -def logger_all_possible(trace_id: str): - logging_thread = threading.Thread( - target=parea_logger.default_log, - kwargs={"data": trace_data.get()[trace_id]}, - ) - logging_thread.start() +def logger_record_log(trace_id: str): + log_in_thread(parea_logger.record_log, {"data": trace_data.get()[trace_id]}) -def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None): - data = trace_data.get()[trace_id] - thread_ids_running_evals.get().append(trace_id) - try: - if eval_funcs and data.status == "success": - if access_output_of_func: - output = json.loads(data.output) - output = access_output_of_func(output) - output_for_eval_metrics = json_dumps(output) - else: - output_for_eval_metrics = data.output - - data.output_for_eval_metrics = output_for_eval_metrics - output_old = data.output - data.output = data.output_for_eval_metrics - data.scores = [] - - for func in eval_funcs: - try: - score = func(data) - data.scores.append(NamedEvaluationScore(name=func.__name__, score=score)) - except Exception as e: - logger.exception(f"Error occurred calling evaluation function '{func.__name__}', {e}", exc_info=e) - - data.output = output_old - except Exception as e: - logger.exception(f"Error occurred in when trying to evaluate output, {e}", exc_info=e) - finally: - thread_ids_running_evals.get().remove(trace_id) - parea_logger.default_log(data=data) +def logger_all_possible(trace_id: str): + log_in_thread(parea_logger.default_log, {"data": trace_data.get()[trace_id]}) def thread_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, access_output_of_func: Callable = None): - logging_thread = threading.Thread( - target=call_eval_funcs_then_log, - kwargs={"trace_id": trace_id, "eval_funcs": eval_funcs, "access_output_of_func": access_output_of_func}, - ) - logging_thread.start() + log_in_thread(call_eval_funcs_then_log, {"trace_id": trace_id, "eval_funcs": eval_funcs, "access_output_of_func": access_output_of_func}) diff --git a/pyproject.toml b/pyproject.toml index e9cc4568..5dc2b15f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.38a0" +version = "0.2.39" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]