From 345e06445b0dc66d3d6b108a743f7cca7152b6be Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Thu, 28 Dec 2023 14:59:34 -0500 Subject: [PATCH 01/10] feat: add experiments --- parea/benchmark.py | 76 +++++++++++++++++---- parea/client.py | 35 +++++++++- parea/constants.py | 1 + parea/schemas/models.py | 35 +++++++++- parea/utils/trace_integrations/langchain.py | 5 ++ parea/utils/trace_utils.py | 3 + parea/wrapper/wrapper.py | 3 + 7 files changed, 142 insertions(+), 16 deletions(-) create mode 100644 parea/constants.py diff --git a/parea/benchmark.py b/parea/benchmark.py index caea82a1..ae17ecb0 100644 --- a/parea/benchmark.py +++ b/parea/benchmark.py @@ -4,17 +4,21 @@ import csv import importlib import inspect +import json import os import sys import time from importlib import util +from math import sqrt +from typing import Dict, List -from attr import asdict, fields_dict from tqdm import tqdm +from parea import Parea from parea.cache.redis import RedisCache +from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.helpers import write_trace_logs_to_csv -from parea.schemas.models import TraceLog +from parea.schemas.models import TraceLog, CreateExperimentRequest, Experiment, ExperimentStatsSchema, TraceStatsSchema def load_from_path(module_path, attr_name): @@ -49,12 +53,45 @@ def async_wrapper(fn, **kwargs): return asyncio.run(fn(**kwargs)) +def calculate_avg_std_as_string(values: List[float]) -> str: + if not values: + return "N/A" + values = [x for x in values if x is not None] + avg = sum(values) / len(values) + std = sqrt(sum((x - avg) ** 2 for x in values) / len(values)) + return f"{avg:.2f} ± {std:.2f}" + + +def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> Dict[str, str]: + trace_stats: List[TraceStatsSchema] = experiment_stats.parent_trace_stats + latency_values = [trace_stat.latency for trace_stat in trace_stats] + input_tokens_values = [trace_stat.input_tokens for trace_stat in trace_stats] + output_tokens_values = [trace_stat.output_tokens for trace_stat in trace_stats] + total_tokens_values = [trace_stat.total_tokens for trace_stat in trace_stats] + cost_values = [trace_stat.cost for trace_stat in trace_stats] + score_name_to_values: Dict[str, List[float]] = {} + for trace_stat in trace_stats: + if trace_stat.scores: + for score in trace_stat.scores: + score_name_to_values[score.name] = score_name_to_values.get(score.name, []) + [score.score] + + return { + "latency": calculate_avg_std_as_string(latency_values), + "input_tokens": calculate_avg_std_as_string(input_tokens_values), + "output_tokens": calculate_avg_std_as_string(output_tokens_values), + "total_tokens": calculate_avg_std_as_string(total_tokens_values), + "cost": calculate_avg_std_as_string(cost_values), + **{score_name: calculate_avg_std_as_string(score_values) for score_name, score_values in score_name_to_values.items()} + } + + def run_benchmark(args): parser = argparse.ArgumentParser() + parser.add_argument("--name", help="Name of the experiment", type=str, required=True) parser.add_argument("--func", help="Function to test e.g., path/to/my_code.py:argument_chain", type=str, required=True) parser.add_argument("--csv_path", help="Path to the input CSV file", type=str, required=True) - parser.add_argument("--redis_host", help="Redis host", type=str, default=os.getenv("REDIS_HOST", "localhost")) - parser.add_argument("--redis_port", help="Redis port", type=int, default=int(os.getenv("REDIS_PORT", 6379))) + parser.add_argument("--redis_host", help="Redis host", type=str, default=None) + parser.add_argument("--redis_port", help="Redis port", type=int, default=None) parser.add_argument("--redis_password", help="Redis password", type=str, default=None) parsed_args = parser.parse_args(args) @@ -62,8 +99,16 @@ def run_benchmark(args): data_inputs = read_input_file(parsed_args.csv_path) - redis_logs_key = f"parea-trace-logs-{int(time.time())}" - os.putenv("_parea_redis_logs_key", redis_logs_key) + if parea_api_key := os.getenv("PAREA_API_KEY") is None: + raise ValueError("Please set the PAREA_API_KEY environment variable") + p = Parea(api_key=parea_api_key) + + experiment: Experiment = p.create_experiment(CreateExperimentRequest(name=parsed_args.name)) + os.putenv(PAREA_OS_ENV_EXPERIMENT_UUID, experiment.uuid) + + if is_using_redis := parsed_args.redis_host and parsed_args.redis_port: + redis_logs_key = f"parea-trace-logs-{int(time.time())}" + os.putenv("_parea_redis_logs_key", redis_logs_key) with concurrent.futures.ProcessPoolExecutor() as executor: if inspect.iscoroutinefunction(fn): @@ -72,12 +117,15 @@ def run_benchmark(args): futures = [executor.submit(fn, **data_input) for data_input in data_inputs] for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): pass - print(f"Done with {len(futures)} inputs") - - redis_cache = RedisCache(key_logs=redis_logs_key, host=args.redis_host, port=args.redis_port, password=args.redis_password) - # write to csv - path_csv = f"trace_logs-{int(time.time())}.csv" - trace_logs: list[TraceLog] = redis_cache.read_logs() - write_trace_logs_to_csv(path_csv, trace_logs) - print(f"Wrote CSV of results to: {path_csv}") + experiment_stats: ExperimentStatsSchema = p.get_experiment_stats(experiment.uuid) + stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) + print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}") + + if is_using_redis: + redis_cache = RedisCache(key_logs=redis_logs_key, host=args.redis_host, port=args.redis_port, password=args.redis_password) + # write to csv + path_csv = f"trace_logs-{int(time.time())}.csv" + trace_logs: list[TraceLog] = redis_cache.read_logs() + write_trace_logs_to_csv(path_csv, trace_logs) + print(f"Wrote CSV of results to: {path_csv}") diff --git a/parea/client.py b/parea/client.py index 2387740e..4fe20fd0 100644 --- a/parea/client.py +++ b/parea/client.py @@ -11,13 +11,16 @@ from parea.cache.cache import Cache from parea.helpers import gen_trace_id from parea.parea_logger import parea_logger -from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, UseDeployedPromptResponse +from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, \ + UseDeployedPromptResponse, CreateExperimentRequest, Experiment, ExperimentStatsSchema from parea.utils.trace_utils import get_current_trace_id, logger_all_possible, logger_record_log, trace_data from parea.wrapper import OpenAIWrapper COMPLETION_ENDPOINT = "/completion" DEPLOYED_PROMPT_ENDPOINT = "/deployed-prompt" RECORD_FEEDBACK_ENDPOINT = "/feedback" +EXPERIMENT_ENDPOINT = "/experiment" +EXPERIMENT_STATS_ENDPOINT = "/experiment-stats/{experiment_uuid}" @define @@ -98,6 +101,36 @@ async def arecord_feedback(self, data: FeedbackRequest) -> None: data=asdict(data), ) + def create_experiment(self, data: CreateExperimentRequest) -> Experiment: + r = self._client.request( + "POST", + EXPERIMENT_ENDPOINT, + data=asdict(data), + ) + return Experiment(**r.json()) + + async def acreate_experiment(self, data: CreateExperimentRequest) -> Experiment: + r = await self._client.request_async( + "POST", + EXPERIMENT_ENDPOINT, + data=asdict(data), + ) + return Experiment(**r.json()) + + def get_experiment_stats(self, experiment_uuid: str) -> ExperimentStatsSchema: + r = self._client.request( + "GET", + EXPERIMENT_STATS_ENDPOINT.format(experiment_uuid=experiment_uuid), + ) + return ExperimentStatsSchema(**r.json()) + + async def aget_experiment_stats(self, experiment_uuid: str) -> ExperimentStatsSchema: + r = await self._client.request_async( + "GET", + EXPERIMENT_STATS_ENDPOINT.format(experiment_uuid=experiment_uuid), + ) + return ExperimentStatsSchema(**r.json()) + _initialized_parea_wrapper = False diff --git a/parea/constants.py b/parea/constants.py new file mode 100644 index 00000000..04ff8615 --- /dev/null +++ b/parea/constants.py @@ -0,0 +1 @@ +PAREA_OS_ENV_EXPERIMENT_UUID = "_PAREA_EXPERIMENT_UUID" diff --git a/parea/schemas/models.py b/parea/schemas/models.py index f0dfc09a..332b83c2 100644 --- a/parea/schemas/models.py +++ b/parea/schemas/models.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, List from attrs import define, field, validators @@ -109,6 +109,7 @@ class TraceLog(Log): end_user_identifier: Optional[str] = None metadata: Optional[dict[str, Any]] = None tags: Optional[list[str]] = None + experiment_uuid: Optional[str] = None @define @@ -125,3 +126,35 @@ class CacheRequest: class UpdateLog: trace_id: str field_name_to_value_map: dict[str, Any] + + +@define +class CreateExperimentRequest: + name: str + + +@define +class Experiment(CreateExperimentRequest): + uuid: str + created_at: str + + +@define +class EvaluationScoreSchema(NamedEvaluationScore): + id: Optional[int] = None + + +@define +class TraceStatsSchema: + trace_id: str + latency: Optional[float] = 0.0 + input_tokens: Optional[int] = 0 + output_tokens: Optional[int] = 0 + total_tokens: Optional[int] = 0 + cost: Optional[float] = None + scores: Optional[List[EvaluationScoreSchema]] = None + + +@define +class ExperimentStatsSchema: + parent_trace_stats: List[TraceStatsSchema] diff --git a/parea/utils/trace_integrations/langchain.py b/parea/utils/trace_integrations/langchain.py index 3ed571cb..e1074008 100644 --- a/parea/utils/trace_integrations/langchain.py +++ b/parea/utils/trace_integrations/langchain.py @@ -1,3 +1,4 @@ +import os from typing import Union from uuid import UUID @@ -5,6 +6,7 @@ from langchain_core.tracers import BaseTracer from langchain_core.tracers.schemas import ChainRun, LLMRun, Run, ToolRun +from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.parea_logger import parea_logger from parea.schemas.log import TraceIntegrations @@ -15,6 +17,9 @@ class PareaAILangchainTracer(BaseTracer): def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: self.parent_trace_id = run.id # using .dict() since langchain Run class currently set to Pydantic v1 + data = run.dict() + if experiment_uuid := os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None): + data["experiment_uuid"] = experiment_uuid parea_logger.record_vendor_log(run.dict(), TraceIntegrations.LANGCHAIN) def get_parent_trace_id(self) -> UUID: diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index 35256923..e02cc2b9 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -1,3 +1,4 @@ +import os from typing import Any, Callable, Optional import contextvars @@ -9,6 +10,7 @@ from collections import ChainMap from functools import wraps +from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.helpers import gen_trace_id, to_date_and_time_string from parea.parea_logger import parea_logger from parea.schemas.models import NamedEvaluationScore, TraceLog @@ -79,6 +81,7 @@ def init_trace(func_name, args, kwargs, func) -> tuple[str, float]: target=target, tags=tags, inputs=inputs, + experiment_uuid=os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None) ) parent_trace_id = trace_context.get()[-2] if len(trace_context.get()) > 1 else None if parent_trace_id: diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index a6a94bd8..70e6de77 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -1,3 +1,4 @@ +import os from typing import Any, Callable, List, Tuple import functools @@ -7,6 +8,7 @@ from uuid import uuid4 from parea.cache.cache import Cache +from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.helpers import date_and_time_string_to_timestamp from parea.schemas.models import TraceLog from parea.utils.trace_utils import call_eval_funcs_then_log, to_date_and_time_string, trace_context, trace_data @@ -70,6 +72,7 @@ def _init_trace(self) -> tuple[str, float]: target=None, tags=None, inputs={}, + experiment_uuid=os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None) ) parent_trace_id = trace_context.get()[-2] if len(trace_context.get()) > 1 else None From 2fbc360c62fdc7550c3ae0ec8c65f8ecd6e313cb Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Thu, 4 Jan 2024 19:02:27 -0500 Subject: [PATCH 02/10] feat: use cattrs --- parea/__init__.py | 4 ++++ parea/benchmark.py | 8 ++++++-- parea/client.py | 19 ++++++++++--------- parea/schemas/models.py | 14 +++++++------- parea/utils/trace_utils.py | 2 ++ 5 files changed, 29 insertions(+), 18 deletions(-) diff --git a/parea/__init__.py b/parea/__init__.py index 1fc6c484..3e7a0a67 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -32,3 +32,7 @@ def main(): run_benchmark(args[1:]) else: print(f"Unknown command: '{args[0]}'") + + +if __name__ == "__main__": + main() diff --git a/parea/benchmark.py b/parea/benchmark.py index ae17ecb0..2dd922b2 100644 --- a/parea/benchmark.py +++ b/parea/benchmark.py @@ -12,9 +12,10 @@ from math import sqrt from typing import Dict, List +from dotenv import load_dotenv from tqdm import tqdm -from parea import Parea +from parea.client import Parea from parea.cache.redis import RedisCache from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID from parea.helpers import write_trace_logs_to_csv @@ -95,11 +96,13 @@ def run_benchmark(args): parser.add_argument("--redis_password", help="Redis password", type=str, default=None) parsed_args = parser.parse_args(args) + load_dotenv() + fn = load_from_path(*parsed_args.func.rsplit(":", 1)) data_inputs = read_input_file(parsed_args.csv_path) - if parea_api_key := os.getenv("PAREA_API_KEY") is None: + if not (parea_api_key := os.getenv("PAREA_API_KEY")): raise ValueError("Please set the PAREA_API_KEY environment variable") p = Parea(api_key=parea_api_key) @@ -118,6 +121,7 @@ def run_benchmark(args): for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): pass + time.sleep(5) # wait for all trace logs to be written to DB experiment_stats: ExperimentStatsSchema = p.get_experiment_stats(experiment.uuid) stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}") diff --git a/parea/client.py b/parea/client.py index 4fe20fd0..f5b06d92 100644 --- a/parea/client.py +++ b/parea/client.py @@ -5,6 +5,7 @@ import time from attrs import asdict, define, field +from cattrs import structure from parea.api_client import HTTPClient from parea.cache import InMemoryCache, RedisCache @@ -20,7 +21,7 @@ DEPLOYED_PROMPT_ENDPOINT = "/deployed-prompt" RECORD_FEEDBACK_ENDPOINT = "/feedback" EXPERIMENT_ENDPOINT = "/experiment" -EXPERIMENT_STATS_ENDPOINT = "/experiment-stats/{experiment_uuid}" +EXPERIMENT_STATS_ENDPOINT = "/experiment/stats/{experiment_uuid}" @define @@ -51,7 +52,7 @@ def completion(self, data: Completion) -> CompletionResponse: if parent_trace_id: trace_data.get()[parent_trace_id].children.append(inference_id) logger_record_log(parent_trace_id) - return CompletionResponse(**r.json()) + return structure(r.json(), CompletionResponse) async def acompletion(self, data: Completion) -> CompletionResponse: parent_trace_id = get_current_trace_id() @@ -67,7 +68,7 @@ async def acompletion(self, data: Completion) -> CompletionResponse: if parent_trace_id: trace_data.get()[parent_trace_id].children.append(inference_id) logger_record_log(parent_trace_id) - return CompletionResponse(**r.json()) + return structure(r.json(), CompletionResponse) def get_prompt(self, data: UseDeployedPrompt) -> UseDeployedPromptResponse: r = self._client.request( @@ -75,7 +76,7 @@ def get_prompt(self, data: UseDeployedPrompt) -> UseDeployedPromptResponse: DEPLOYED_PROMPT_ENDPOINT, data=asdict(data), ) - return UseDeployedPromptResponse(**r.json()) + return structure(r.json(), UseDeployedPromptResponse) async def aget_prompt(self, data: UseDeployedPrompt) -> UseDeployedPromptResponse: r = await self._client.request_async( @@ -83,7 +84,7 @@ async def aget_prompt(self, data: UseDeployedPrompt) -> UseDeployedPromptRespons DEPLOYED_PROMPT_ENDPOINT, data=asdict(data), ) - return UseDeployedPromptResponse(**r.json()) + return structure(r.json(), UseDeployedPromptResponse) def record_feedback(self, data: FeedbackRequest) -> None: time.sleep(2) # give logs time to update @@ -107,7 +108,7 @@ def create_experiment(self, data: CreateExperimentRequest) -> Experiment: EXPERIMENT_ENDPOINT, data=asdict(data), ) - return Experiment(**r.json()) + return structure(r.json(), Experiment) async def acreate_experiment(self, data: CreateExperimentRequest) -> Experiment: r = await self._client.request_async( @@ -115,21 +116,21 @@ async def acreate_experiment(self, data: CreateExperimentRequest) -> Experiment: EXPERIMENT_ENDPOINT, data=asdict(data), ) - return Experiment(**r.json()) + return structure(r.json(), Experiment) def get_experiment_stats(self, experiment_uuid: str) -> ExperimentStatsSchema: r = self._client.request( "GET", EXPERIMENT_STATS_ENDPOINT.format(experiment_uuid=experiment_uuid), ) - return ExperimentStatsSchema(**r.json()) + return structure(r.json(), ExperimentStatsSchema) async def aget_experiment_stats(self, experiment_uuid: str) -> ExperimentStatsSchema: r = await self._client.request_async( "GET", EXPERIMENT_STATS_ENDPOINT.format(experiment_uuid=experiment_uuid), ) - return ExperimentStatsSchema(**r.json()) + return structure(r.json(), ExperimentStatsSchema) _initialized_parea_wrapper = False diff --git a/parea/schemas/models.py b/parea/schemas/models.py index 332b83c2..8e494ed1 100644 --- a/parea/schemas/models.py +++ b/parea/schemas/models.py @@ -16,7 +16,7 @@ class Completion: deployment_id: Optional[str] = None name: Optional[str] = None metadata: Optional[dict] = None - tags: Optional[list[str]] = None + tags: Optional[list[str]] = field(factory=list) target: Optional[str] = None cache: bool = True log_omit_inputs: bool = False @@ -96,8 +96,8 @@ class TraceLog(Log): deployment_id: Optional[str] = None cache_hit: bool = False output_for_eval_metrics: Optional[str] = None - evaluation_metric_names: Optional[list[str]] = None - scores: Optional[list[NamedEvaluationScore]] = None + evaluation_metric_names: Optional[list[str]] = field(factory=list) + scores: Optional[list[NamedEvaluationScore]] = field(factory=list) feedback_score: Optional[float] = None # info filled from decorator @@ -108,13 +108,13 @@ class TraceLog(Log): end_timestamp: Optional[str] = None end_user_identifier: Optional[str] = None metadata: Optional[dict[str, Any]] = None - tags: Optional[list[str]] = None + tags: Optional[list[str]] = field(factory=list) experiment_uuid: Optional[str] = None @define class TraceLogTree(TraceLog): - children: Optional[list[TraceLog]] = None + children: Optional[list[TraceLog]] = field(factory=list) @define @@ -152,9 +152,9 @@ class TraceStatsSchema: output_tokens: Optional[int] = 0 total_tokens: Optional[int] = 0 cost: Optional[float] = None - scores: Optional[List[EvaluationScoreSchema]] = None + scores: Optional[List[EvaluationScoreSchema]] = field(factory=list) @define class ExperimentStatsSchema: - parent_trace_stats: List[TraceStatsSchema] + parent_trace_stats: List[TraceStatsSchema] = field(factory=list) diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index e02cc2b9..1a1149e9 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -161,6 +161,8 @@ def make_output(result, islist) -> str: if islist: json_list = [json_dumps(r) for r in result] return json_dumps(json_list) + elif isinstance(result, str): + return result else: return json_dumps(result) From cd556a19e1775ee3dfc0c46c99d811a2a5d2ebe4 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Thu, 4 Jan 2024 19:43:49 -0500 Subject: [PATCH 03/10] feat: expose as experiment function --- parea/__init__.py | 6 +- parea/benchmark.py | 135 --------------------------------- parea/client.py | 10 +-- parea/experiment/__init__.py | 0 parea/experiment/cli.py | 61 +++++++++++++++ parea/experiment/experiment.py | 75 ++++++++++++++++++ parea/schemas/models.py | 2 +- 7 files changed, 145 insertions(+), 144 deletions(-) delete mode 100644 parea/benchmark.py create mode 100644 parea/experiment/__init__.py create mode 100644 parea/experiment/cli.py create mode 100644 parea/experiment/experiment.py diff --git a/parea/__init__.py b/parea/__init__.py index 3e7a0a67..7980b954 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -11,7 +11,7 @@ import sys from importlib import metadata as importlib_metadata -from parea.benchmark import run_benchmark +from parea.experiment.cli import experiment from parea.cache import InMemoryCache, RedisCache from parea.client import Parea, init @@ -28,8 +28,8 @@ def get_version() -> str: def main(): args = sys.argv[1:] - if args[0] == "benchmark": - run_benchmark(args[1:]) + if args[0] == "experiment": + experiment(args[1:]) else: print(f"Unknown command: '{args[0]}'") diff --git a/parea/benchmark.py b/parea/benchmark.py deleted file mode 100644 index 2dd922b2..00000000 --- a/parea/benchmark.py +++ /dev/null @@ -1,135 +0,0 @@ -import argparse -import asyncio -import concurrent -import csv -import importlib -import inspect -import json -import os -import sys -import time -from importlib import util -from math import sqrt -from typing import Dict, List - -from dotenv import load_dotenv -from tqdm import tqdm - -from parea.client import Parea -from parea.cache.redis import RedisCache -from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID -from parea.helpers import write_trace_logs_to_csv -from parea.schemas.models import TraceLog, CreateExperimentRequest, Experiment, ExperimentStatsSchema, TraceStatsSchema - - -def load_from_path(module_path, attr_name): - # Ensure the directory of user-provided script is in the system path - dir_name = os.path.dirname(module_path) - if dir_name not in sys.path: - sys.path.insert(0, dir_name) - - module_name = os.path.basename(module_path) - # Add .py extension back in to allow import correctly - module_path_with_ext = f"{module_path}.py" - - spec = importlib.util.spec_from_file_location(module_name, module_path_with_ext) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - if spec.name not in sys.modules: - sys.modules[spec.name] = module - - fn = getattr(module, attr_name) - return fn - - -def read_input_file(file_path) -> list[dict]: - with open(file_path) as file: - reader = csv.DictReader(file) - inputs = list(reader) - return inputs - - -def async_wrapper(fn, **kwargs): - return asyncio.run(fn(**kwargs)) - - -def calculate_avg_std_as_string(values: List[float]) -> str: - if not values: - return "N/A" - values = [x for x in values if x is not None] - avg = sum(values) / len(values) - std = sqrt(sum((x - avg) ** 2 for x in values) / len(values)) - return f"{avg:.2f} ± {std:.2f}" - - -def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> Dict[str, str]: - trace_stats: List[TraceStatsSchema] = experiment_stats.parent_trace_stats - latency_values = [trace_stat.latency for trace_stat in trace_stats] - input_tokens_values = [trace_stat.input_tokens for trace_stat in trace_stats] - output_tokens_values = [trace_stat.output_tokens for trace_stat in trace_stats] - total_tokens_values = [trace_stat.total_tokens for trace_stat in trace_stats] - cost_values = [trace_stat.cost for trace_stat in trace_stats] - score_name_to_values: Dict[str, List[float]] = {} - for trace_stat in trace_stats: - if trace_stat.scores: - for score in trace_stat.scores: - score_name_to_values[score.name] = score_name_to_values.get(score.name, []) + [score.score] - - return { - "latency": calculate_avg_std_as_string(latency_values), - "input_tokens": calculate_avg_std_as_string(input_tokens_values), - "output_tokens": calculate_avg_std_as_string(output_tokens_values), - "total_tokens": calculate_avg_std_as_string(total_tokens_values), - "cost": calculate_avg_std_as_string(cost_values), - **{score_name: calculate_avg_std_as_string(score_values) for score_name, score_values in score_name_to_values.items()} - } - - -def run_benchmark(args): - parser = argparse.ArgumentParser() - parser.add_argument("--name", help="Name of the experiment", type=str, required=True) - parser.add_argument("--func", help="Function to test e.g., path/to/my_code.py:argument_chain", type=str, required=True) - parser.add_argument("--csv_path", help="Path to the input CSV file", type=str, required=True) - parser.add_argument("--redis_host", help="Redis host", type=str, default=None) - parser.add_argument("--redis_port", help="Redis port", type=int, default=None) - parser.add_argument("--redis_password", help="Redis password", type=str, default=None) - parsed_args = parser.parse_args(args) - - load_dotenv() - - fn = load_from_path(*parsed_args.func.rsplit(":", 1)) - - data_inputs = read_input_file(parsed_args.csv_path) - - if not (parea_api_key := os.getenv("PAREA_API_KEY")): - raise ValueError("Please set the PAREA_API_KEY environment variable") - p = Parea(api_key=parea_api_key) - - experiment: Experiment = p.create_experiment(CreateExperimentRequest(name=parsed_args.name)) - os.putenv(PAREA_OS_ENV_EXPERIMENT_UUID, experiment.uuid) - - if is_using_redis := parsed_args.redis_host and parsed_args.redis_port: - redis_logs_key = f"parea-trace-logs-{int(time.time())}" - os.putenv("_parea_redis_logs_key", redis_logs_key) - - with concurrent.futures.ProcessPoolExecutor() as executor: - if inspect.iscoroutinefunction(fn): - futures = [executor.submit(async_wrapper, fn, **data_input) for data_input in data_inputs] - else: - futures = [executor.submit(fn, **data_input) for data_input in data_inputs] - for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): - pass - - time.sleep(5) # wait for all trace logs to be written to DB - experiment_stats: ExperimentStatsSchema = p.get_experiment_stats(experiment.uuid) - stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) - print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}") - - if is_using_redis: - redis_cache = RedisCache(key_logs=redis_logs_key, host=args.redis_host, port=args.redis_port, password=args.redis_password) - # write to csv - path_csv = f"trace_logs-{int(time.time())}.csv" - trace_logs: list[TraceLog] = redis_cache.read_logs() - write_trace_logs_to_csv(path_csv, trace_logs) - print(f"Wrote CSV of results to: {path_csv}") diff --git a/parea/client.py b/parea/client.py index f5b06d92..5f512ca9 100644 --- a/parea/client.py +++ b/parea/client.py @@ -13,7 +13,7 @@ from parea.helpers import gen_trace_id from parea.parea_logger import parea_logger from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, \ - UseDeployedPromptResponse, CreateExperimentRequest, Experiment, ExperimentStatsSchema + UseDeployedPromptResponse, CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema from parea.utils.trace_utils import get_current_trace_id, logger_all_possible, logger_record_log, trace_data from parea.wrapper import OpenAIWrapper @@ -102,21 +102,21 @@ async def arecord_feedback(self, data: FeedbackRequest) -> None: data=asdict(data), ) - def create_experiment(self, data: CreateExperimentRequest) -> Experiment: + def create_experiment(self, data: CreateExperimentRequest) -> ExperimentSchema: r = self._client.request( "POST", EXPERIMENT_ENDPOINT, data=asdict(data), ) - return structure(r.json(), Experiment) + return structure(r.json(), ExperimentSchema) - async def acreate_experiment(self, data: CreateExperimentRequest) -> Experiment: + async def acreate_experiment(self, data: CreateExperimentRequest) -> ExperimentSchema: r = await self._client.request_async( "POST", EXPERIMENT_ENDPOINT, data=asdict(data), ) - return structure(r.json(), Experiment) + return structure(r.json(), ExperimentSchema) def get_experiment_stats(self, experiment_uuid: str) -> ExperimentStatsSchema: r = self._client.request( diff --git a/parea/experiment/__init__.py b/parea/experiment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/parea/experiment/cli.py b/parea/experiment/cli.py new file mode 100644 index 00000000..b2cae30f --- /dev/null +++ b/parea/experiment/cli.py @@ -0,0 +1,61 @@ +import argparse +import csv +import importlib +import os +import sys +import traceback +from importlib import util + + +from .experiment import experiment as experiment_orig + + +def load_from_path(module_path, attr_name): + # Ensure the directory of user-provided script is in the system path + dir_name = os.path.dirname(module_path) + if dir_name not in sys.path: + sys.path.insert(0, dir_name) + + module_name = os.path.basename(module_path) + # Add .py extension back in to allow import correctly + module_path_with_ext = f"{module_path}.py" + + spec = importlib.util.spec_from_file_location(module_name, module_path_with_ext) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if spec.name not in sys.modules: + sys.modules[spec.name] = module + + fn = getattr(module, attr_name) + return fn + + +def read_input_file(file_path) -> list[dict]: + with open(file_path) as file: + reader = csv.DictReader(file) + inputs = list(reader) + return inputs + + +def experiment(args): + parser = argparse.ArgumentParser() + parser.add_argument("--name", help="Name of the experiment", type=str, required=True) + parser.add_argument("--func", help="Function to test e.g., path/to/my_code.py:argument_chain", type=str, required=True) + parser.add_argument("--csv_path", help="Path to the input CSV file", type=str, required=True) + parsed_args = parser.parse_args(args) + + try: + func = load_from_path(*parsed_args.func.rsplit(":", 1)) + except Exception as e: + print(f"Error loading function: {e}\n", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + try: + data = read_input_file(parsed_args.csv_path) + except Exception as e: + print(f"Error reading input file: {e}\n", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + + experiment_orig(name=parsed_args.name, func=func, data=data) diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py new file mode 100644 index 00000000..6dd5838a --- /dev/null +++ b/parea/experiment/experiment.py @@ -0,0 +1,75 @@ +import asyncio +import concurrent +import inspect +import json +import os +import time +from math import sqrt +from typing import Dict, List, Callable, Iterator + +from dotenv import load_dotenv +from tqdm import tqdm + +from parea.client import Parea +from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID +from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, TraceStatsSchema + + +def calculate_avg_std_as_string(values: List[float]) -> str: + if not values: + return "N/A" + values = [x for x in values if x is not None] + avg = sum(values) / len(values) + std = sqrt(sum((x - avg) ** 2 for x in values) / len(values)) + return f"{avg:.2f} ± {std:.2f}" + + +def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> Dict[str, str]: + trace_stats: List[TraceStatsSchema] = experiment_stats.parent_trace_stats + latency_values = [trace_stat.latency for trace_stat in trace_stats] + input_tokens_values = [trace_stat.input_tokens for trace_stat in trace_stats] + output_tokens_values = [trace_stat.output_tokens for trace_stat in trace_stats] + total_tokens_values = [trace_stat.total_tokens for trace_stat in trace_stats] + cost_values = [trace_stat.cost for trace_stat in trace_stats] + score_name_to_values: Dict[str, List[float]] = {} + for trace_stat in trace_stats: + if trace_stat.scores: + for score in trace_stat.scores: + score_name_to_values[score.name] = score_name_to_values.get(score.name, []) + [score.score] + + return { + "latency": calculate_avg_std_as_string(latency_values), + "input_tokens": calculate_avg_std_as_string(input_tokens_values), + "output_tokens": calculate_avg_std_as_string(output_tokens_values), + "total_tokens": calculate_avg_std_as_string(total_tokens_values), + "cost": calculate_avg_std_as_string(cost_values), + **{score_name: calculate_avg_std_as_string(score_values) for score_name, score_values in score_name_to_values.items()} + } + + +def async_wrapper(fn, **kwargs): + return asyncio.run(fn(**kwargs)) + + +def experiment(name: str, data: Iterator, func: Callable): + load_dotenv() + + if not (parea_api_key := os.getenv("PAREA_API_KEY")): + raise ValueError("Please set the PAREA_API_KEY environment variable") + p = Parea(api_key=parea_api_key) + + experiment_schema: ExperimentSchema = p.create_experiment(CreateExperimentRequest(name=name)) + os.putenv(PAREA_OS_ENV_EXPERIMENT_UUID, experiment_schema.uuid) + + with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: + if inspect.iscoroutinefunction(func): + futures = [executor.submit(async_wrapper, func, **data_input) for data_input in data] + else: + futures = [executor.submit(func, **data_input) for data_input in data] + for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): + pass + + time.sleep(5) # wait for all trace logs to be written to DB + experiment_stats: ExperimentStatsSchema = p.get_experiment_stats(experiment_schema.uuid) + stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) + print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}") diff --git a/parea/schemas/models.py b/parea/schemas/models.py index 8e494ed1..661d9468 100644 --- a/parea/schemas/models.py +++ b/parea/schemas/models.py @@ -134,7 +134,7 @@ class CreateExperimentRequest: @define -class Experiment(CreateExperimentRequest): +class ExperimentSchema(CreateExperimentRequest): uuid: str created_at: str From 8d42dc692183f770f6d0af53f6b0180082d54df8 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Fri, 5 Jan 2024 14:41:33 -0500 Subject: [PATCH 04/10] feat: use max cpu cores --- parea/client.py | 2 +- parea/experiment/experiment.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parea/client.py b/parea/client.py index 5f512ca9..2cf9ca0a 100644 --- a/parea/client.py +++ b/parea/client.py @@ -21,7 +21,7 @@ DEPLOYED_PROMPT_ENDPOINT = "/deployed-prompt" RECORD_FEEDBACK_ENDPOINT = "/feedback" EXPERIMENT_ENDPOINT = "/experiment" -EXPERIMENT_STATS_ENDPOINT = "/experiment/stats/{experiment_uuid}" +EXPERIMENT_STATS_ENDPOINT = "/experiment/{experiment_uuid}/stats" @define diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 6dd5838a..f638aeda 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -61,7 +61,7 @@ def experiment(name: str, data: Iterator, func: Callable): experiment_schema: ExperimentSchema = p.create_experiment(CreateExperimentRequest(name=name)) os.putenv(PAREA_OS_ENV_EXPERIMENT_UUID, experiment_schema.uuid) - with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: + with concurrent.futures.ProcessPoolExecutor(max_workers=os.cpu_count() or 1) as executor: if inspect.iscoroutinefunction(func): futures = [executor.submit(async_wrapper, func, **data_input) for data_input in data] else: From d9197de4199fbfa5f5226fd5a03754d51ff8e326 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Sat, 6 Jan 2024 21:18:11 -0500 Subject: [PATCH 05/10] feat: add experiment uuid --- parea/experiment/experiment.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index f638aeda..72d575db 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -59,7 +59,8 @@ def experiment(name: str, data: Iterator, func: Callable): p = Parea(api_key=parea_api_key) experiment_schema: ExperimentSchema = p.create_experiment(CreateExperimentRequest(name=name)) - os.putenv(PAREA_OS_ENV_EXPERIMENT_UUID, experiment_schema.uuid) + experiment_uuid = experiment_schema.uuid + os.putenv(PAREA_OS_ENV_EXPERIMENT_UUID, experiment_uuid) with concurrent.futures.ProcessPoolExecutor(max_workers=os.cpu_count() or 1) as executor: if inspect.iscoroutinefunction(func): @@ -70,6 +71,7 @@ def experiment(name: str, data: Iterator, func: Callable): pass time.sleep(5) # wait for all trace logs to be written to DB - experiment_stats: ExperimentStatsSchema = p.get_experiment_stats(experiment_schema.uuid) + experiment_stats: ExperimentStatsSchema = p.get_experiment_stats(experiment_uuid) stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) - print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}") + print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}\n\n") + print(f"View experiment & its traces at: https://app.parea.ai/experiments/{experiment_uuid}\n") From c24d3089ca05489a7f7ebe8fbd94c4f7d0e30adf Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Sat, 6 Jan 2024 22:22:33 -0500 Subject: [PATCH 06/10] feat: add cattrs dependency --- poetry.lock | 67 +++++++++++++++++++++++++++++++++----------------- pyproject.toml | 1 + 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/poetry.lock b/poetry.lock index 29db8de4..a05e1868 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -612,6 +612,31 @@ files = [ {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, ] +[[package]] +name = "cattrs" +version = "23.2.3" +description = "Composable complex class support for attrs and dataclasses." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cattrs-23.2.3-py3-none-any.whl", hash = "sha256:0341994d94971052e9ee70662542699a3162ea1e0c62f7ce1b4a57f563685108"}, + {file = "cattrs-23.2.3.tar.gz", hash = "sha256:a934090d95abaa9e911dac357e3a8699e0b4b14f8529bcc7d2b1ad9d51672b9f"}, +] + +[package.dependencies] +attrs = ">=23.1.0" +exceptiongroup = {version = ">=1.1.1", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.1.0,<4.6.3 || >4.6.3", markers = "python_version < \"3.11\""} + +[package.extras] +bson = ["pymongo (>=4.4.0)"] +cbor2 = ["cbor2 (>=5.4.6)"] +msgpack = ["msgpack (>=1.0.5)"] +orjson = ["orjson (>=3.9.2)"] +pyyaml = ["pyyaml (>=6.0)"] +tomlkit = ["tomlkit (>=0.11.8)"] +ujson = ["ujson (>=5.7.0)"] + [[package]] name = "cerberus" version = "1.3.5" @@ -4946,6 +4971,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4953,8 +4979,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4971,6 +5004,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4978,6 +5012,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -5657,51 +5692,37 @@ python-versions = ">=3.6" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d92f81886165cb14d7b067ef37e142256f1c6a90a65cd156b063a43da1708cfd"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b5edda50e5e9e15e54a6a8a0070302b00c518a9d32accc2346ad6c984aacd279"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:7048c338b6c86627afb27faecf418768acb6331fc24cfa56c93e8c9780f815fa"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:1758ce7d8e1a29d23de54a16ae867abd370f01b5a69e1a3ba75223eaa3ca1a1b"}, {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3fcc54cb0c8b811ff66082de1680b4b14cf8a81dce0d4fbf665c2265a81e07a1"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:665f58bfd29b167039f714c6998178d27ccd83984084c286110ef26b230f259f"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:9eb5dee2772b0f704ca2e45b1713e4e5198c18f515b52743576d196348f374d3"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"}, {file = "ruamel.yaml.clib-0.2.8.tar.gz", hash = "sha256:beb2e0404003de9a4cab9753a8805a8fe9320ee6673136ed7f04255fe60bb512"}, @@ -5922,7 +5943,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} typing-extensions = ">=4.2.0" [package.extras] @@ -7195,4 +7216,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "bd83b05cfd6d0d53d8a84f86f65e6101ce98d59e5ff6d82079ed411d99adc2d6" +content-hash = "000930533d7e75e4d52b33d8e3ea64e35c7491a94f20b6001d12e184c77e9e2c" diff --git a/pyproject.toml b/pyproject.toml index 4cb351f6..792728bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ contextvars = "^2.4" openai = "*" redis = "^5.0.1" pysbd = "^0.3.4" +cattrs = ">=22.1.0" [tool.poetry.dev-dependencies] bandit = "^1.7.1" From 2159a39e3dc607d8e306c2e37371dfa45d6735a1 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Sat, 6 Jan 2024 22:25:40 -0500 Subject: [PATCH 07/10] style --- parea/__init__.py | 2 +- parea/client.py | 12 ++++++++++-- parea/experiment/cli.py | 1 - parea/experiment/experiment.py | 5 +++-- parea/schemas/models.py | 2 +- parea/utils/trace_integrations/langchain.py | 2 +- parea/utils/trace_utils.py | 4 ++-- parea/wrapper/wrapper.py | 6 +++--- 8 files changed, 21 insertions(+), 13 deletions(-) diff --git a/parea/__init__.py b/parea/__init__.py index 7980b954..4204a5dc 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -11,9 +11,9 @@ import sys from importlib import metadata as importlib_metadata -from parea.experiment.cli import experiment from parea.cache import InMemoryCache, RedisCache from parea.client import Parea, init +from parea.experiment.cli import experiment def get_version() -> str: diff --git a/parea/client.py b/parea/client.py index 2cf9ca0a..49db8d85 100644 --- a/parea/client.py +++ b/parea/client.py @@ -12,8 +12,16 @@ from parea.cache.cache import Cache from parea.helpers import gen_trace_id from parea.parea_logger import parea_logger -from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, \ - UseDeployedPromptResponse, CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema +from parea.schemas.models import ( + Completion, + CompletionResponse, + CreateExperimentRequest, + ExperimentSchema, + ExperimentStatsSchema, + FeedbackRequest, + UseDeployedPrompt, + UseDeployedPromptResponse, +) from parea.utils.trace_utils import get_current_trace_id, logger_all_possible, logger_record_log, trace_data from parea.wrapper import OpenAIWrapper diff --git a/parea/experiment/cli.py b/parea/experiment/cli.py index b2cae30f..515daf93 100644 --- a/parea/experiment/cli.py +++ b/parea/experiment/cli.py @@ -6,7 +6,6 @@ import traceback from importlib import util - from .experiment import experiment as experiment_orig diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 72d575db..76fe16d0 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -1,3 +1,5 @@ +from typing import Callable, Dict, Iterator, List + import asyncio import concurrent import inspect @@ -5,7 +7,6 @@ import os import time from math import sqrt -from typing import Dict, List, Callable, Iterator from dotenv import load_dotenv from tqdm import tqdm @@ -43,7 +44,7 @@ def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> "output_tokens": calculate_avg_std_as_string(output_tokens_values), "total_tokens": calculate_avg_std_as_string(total_tokens_values), "cost": calculate_avg_std_as_string(cost_values), - **{score_name: calculate_avg_std_as_string(score_values) for score_name, score_values in score_name_to_values.items()} + **{score_name: calculate_avg_std_as_string(score_values) for score_name, score_values in score_name_to_values.items()}, } diff --git a/parea/schemas/models.py b/parea/schemas/models.py index 661d9468..52b9d4c8 100644 --- a/parea/schemas/models.py +++ b/parea/schemas/models.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List +from typing import Any, List, Optional from attrs import define, field, validators diff --git a/parea/utils/trace_integrations/langchain.py b/parea/utils/trace_integrations/langchain.py index e1074008..3a1b34e9 100644 --- a/parea/utils/trace_integrations/langchain.py +++ b/parea/utils/trace_integrations/langchain.py @@ -1,6 +1,6 @@ -import os from typing import Union +import os from uuid import UUID from langchain_core.tracers import BaseTracer diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index 1a1149e9..b6f85fb4 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -1,10 +1,10 @@ -import os from typing import Any, Callable, Optional import contextvars import inspect import json import logging +import os import threading import time from collections import ChainMap @@ -81,7 +81,7 @@ def init_trace(func_name, args, kwargs, func) -> tuple[str, float]: target=target, tags=tags, inputs=inputs, - experiment_uuid=os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None) + experiment_uuid=os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None), ) parent_trace_id = trace_context.get()[-2] if len(trace_context.get()) > 1 else None if parent_trace_id: diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index 70e6de77..029a943b 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -1,8 +1,8 @@ -import os -from typing import Any, Callable, List, Tuple +from typing import Any, Callable import functools import inspect +import os import time from collections.abc import AsyncIterator, Iterator from uuid import uuid4 @@ -72,7 +72,7 @@ def _init_trace(self) -> tuple[str, float]: target=None, tags=None, inputs={}, - experiment_uuid=os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None) + experiment_uuid=os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None), ) parent_trace_id = trace_context.get()[-2] if len(trace_context.get()) > 1 else None From c40706f4aa3de43f421c7bc64581fa2b2c578598 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Sun, 7 Jan 2024 18:16:10 -0500 Subject: [PATCH 08/10] feat: expose as class --- parea/__init__.py | 5 +-- parea/experiment/cli.py | 28 +++++---------- parea/experiment/experiment.py | 63 +++++++++++++++++++++------------- parea/utils/trace_utils.py | 2 +- 4 files changed, 51 insertions(+), 47 deletions(-) diff --git a/parea/__init__.py b/parea/__init__.py index 4204a5dc..921ea025 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -13,7 +13,8 @@ from parea.cache import InMemoryCache, RedisCache from parea.client import Parea, init -from parea.experiment.cli import experiment +from parea.experiment.experiment import Experiment +from parea.experiment.cli import experiment as experiment_cli def get_version() -> str: @@ -29,7 +30,7 @@ def get_version() -> str: def main(): args = sys.argv[1:] if args[0] == "experiment": - experiment(args[1:]) + experiment_cli(args[1:]) else: print(f"Unknown command: '{args[0]}'") diff --git a/parea/experiment/cli.py b/parea/experiment/cli.py index 515daf93..6c9e701e 100644 --- a/parea/experiment/cli.py +++ b/parea/experiment/cli.py @@ -6,29 +6,23 @@ import traceback from importlib import util -from .experiment import experiment as experiment_orig +from .experiment import _experiments -def load_from_path(module_path, attr_name): +def load_from_path(module_path): # Ensure the directory of user-provided script is in the system path dir_name = os.path.dirname(module_path) if dir_name not in sys.path: sys.path.insert(0, dir_name) module_name = os.path.basename(module_path) - # Add .py extension back in to allow import correctly - module_path_with_ext = f"{module_path}.py" - - spec = importlib.util.spec_from_file_location(module_name, module_path_with_ext) + spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) if spec.name not in sys.modules: sys.modules[spec.name] = module - fn = getattr(module, attr_name) - return fn - def read_input_file(file_path) -> list[dict]: with open(file_path) as file: @@ -39,22 +33,16 @@ def read_input_file(file_path) -> list[dict]: def experiment(args): parser = argparse.ArgumentParser() - parser.add_argument("--name", help="Name of the experiment", type=str, required=True) - parser.add_argument("--func", help="Function to test e.g., path/to/my_code.py:argument_chain", type=str, required=True) - parser.add_argument("--csv_path", help="Path to the input CSV file", type=str, required=True) + parser.add_argument("file", help="Path to the experiment", type=str) + parsed_args = parser.parse_args(args) try: - func = load_from_path(*parsed_args.func.rsplit(":", 1)) + load_from_path(parsed_args.file) except Exception as e: print(f"Error loading function: {e}\n", file=sys.stderr) traceback.print_exc() sys.exit(1) - try: - data = read_input_file(parsed_args.csv_path) - except Exception as e: - print(f"Error reading input file: {e}\n", file=sys.stderr) - traceback.print_exc() - sys.exit(1) - experiment_orig(name=parsed_args.name, func=func, data=data) + for _experiment in _experiments: + _experiment.run() diff --git a/parea/experiment/experiment.py b/parea/experiment/experiment.py index 76fe16d0..4e2ee4e6 100644 --- a/parea/experiment/experiment.py +++ b/parea/experiment/experiment.py @@ -1,13 +1,12 @@ from typing import Callable, Dict, Iterator, List import asyncio -import concurrent import inspect import json import os import time -from math import sqrt +from attrs import define, field from dotenv import load_dotenv from tqdm import tqdm @@ -16,13 +15,12 @@ from parea.schemas.models import CreateExperimentRequest, ExperimentSchema, ExperimentStatsSchema, TraceStatsSchema -def calculate_avg_std_as_string(values: List[float]) -> str: +def calculate_avg_as_string(values: List[float]) -> str: if not values: return "N/A" values = [x for x in values if x is not None] avg = sum(values) / len(values) - std = sqrt(sum((x - avg) ** 2 for x in values) / len(values)) - return f"{avg:.2f} ± {std:.2f}" + return f"{avg:.2f}" def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> Dict[str, str]: @@ -39,12 +37,12 @@ def calculate_avg_std_for_experiment(experiment_stats: ExperimentStatsSchema) -> score_name_to_values[score.name] = score_name_to_values.get(score.name, []) + [score.score] return { - "latency": calculate_avg_std_as_string(latency_values), - "input_tokens": calculate_avg_std_as_string(input_tokens_values), - "output_tokens": calculate_avg_std_as_string(output_tokens_values), - "total_tokens": calculate_avg_std_as_string(total_tokens_values), - "cost": calculate_avg_std_as_string(cost_values), - **{score_name: calculate_avg_std_as_string(score_values) for score_name, score_values in score_name_to_values.items()}, + "latency": calculate_avg_as_string(latency_values), + "input_tokens": calculate_avg_as_string(input_tokens_values), + "output_tokens": calculate_avg_as_string(output_tokens_values), + "total_tokens": calculate_avg_as_string(total_tokens_values), + "cost": calculate_avg_as_string(cost_values), + **{score_name: calculate_avg_as_string(score_values) for score_name, score_values in score_name_to_values.items()}, } @@ -52,7 +50,8 @@ def async_wrapper(fn, **kwargs): return asyncio.run(fn(**kwargs)) -def experiment(name: str, data: Iterator, func: Callable): +def experiment(name: str, data: Iterator, func: Callable) -> ExperimentStatsSchema: + """Creates an experiment and runs the function on the data iterator.""" load_dotenv() if not (parea_api_key := os.getenv("PAREA_API_KEY")): @@ -61,18 +60,34 @@ def experiment(name: str, data: Iterator, func: Callable): experiment_schema: ExperimentSchema = p.create_experiment(CreateExperimentRequest(name=name)) experiment_uuid = experiment_schema.uuid - os.putenv(PAREA_OS_ENV_EXPERIMENT_UUID, experiment_uuid) + os.environ[PAREA_OS_ENV_EXPERIMENT_UUID] = experiment_uuid - with concurrent.futures.ProcessPoolExecutor(max_workers=os.cpu_count() or 1) as executor: + for data_input in tqdm(data): if inspect.iscoroutinefunction(func): - futures = [executor.submit(async_wrapper, func, **data_input) for data_input in data] + asyncio.run(func(**data_input)) else: - futures = [executor.submit(func, **data_input) for data_input in data] - for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): - pass - - time.sleep(5) # wait for all trace logs to be written to DB - experiment_stats: ExperimentStatsSchema = p.get_experiment_stats(experiment_uuid) - stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) - print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}\n\n") - print(f"View experiment & its traces at: https://app.parea.ai/experiments/{experiment_uuid}\n") + func(**data_input) + time.sleep(5) # wait for all trace logs to be written to DB + experiment_stats: ExperimentStatsSchema = p.get_experiment_stats(experiment_uuid) + stat_name_to_avg_std = calculate_avg_std_for_experiment(experiment_stats) + print(f"Experiment stats:\n{json.dumps(stat_name_to_avg_std, indent=2)}\n\n") + print(f"View experiment & its traces at: https://app.parea.ai/experiments/{experiment_uuid}\n") + return experiment_stats + + +_experiments = [] + + +@define +class Experiment: + name: str = field(init=True) + data: Iterator[Dict] = field(init=True) + func: Callable = field(init=True) + experiment_stats: ExperimentStatsSchema = field(init=False, default=None) + + def __attrs_post_init__(self): + global _experiments + _experiments.append(self) + + def run(self): + self.experiment_stats = experiment(self.name, self.data, self.func) diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index b6f85fb4..ea467388 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -81,7 +81,7 @@ def init_trace(func_name, args, kwargs, func) -> tuple[str, float]: target=target, tags=tags, inputs=inputs, - experiment_uuid=os.getenv(PAREA_OS_ENV_EXPERIMENT_UUID, None), + experiment_uuid=os.environ.get(PAREA_OS_ENV_EXPERIMENT_UUID, None), ) parent_trace_id = trace_context.get()[-2] if len(trace_context.get()) > 1 else None if parent_trace_id: From a52ed1925f5476c30ab276dd227c52f4d5d74275 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Sun, 7 Jan 2024 18:16:53 -0500 Subject: [PATCH 09/10] style --- parea/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parea/__init__.py b/parea/__init__.py index 921ea025..cfc4dc40 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -13,8 +13,8 @@ from parea.cache import InMemoryCache, RedisCache from parea.client import Parea, init -from parea.experiment.experiment import Experiment from parea.experiment.cli import experiment as experiment_cli +from parea.experiment.experiment import Experiment def get_version() -> str: From 2ad29d8beb29f2ae1f7634196b78f2d023a0b536 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Sun, 7 Jan 2024 18:29:52 -0500 Subject: [PATCH 10/10] chore: bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 792728bd..29202c2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.26a0" +version = "0.2.26" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]