Skip to content

Commit

Permalink
Merge pull request #898 from parea-ai/PAI-1165-docs-sub-step-testing
Browse files Browse the repository at this point in the history
feat: substep testing
  • Loading branch information
joschkabraun authored May 23, 2024
2 parents b450d72 + 6ecc3c1 commit ff205ab
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 5 deletions.
72 changes: 72 additions & 0 deletions parea/cookbook/experiment_test_substeps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Union

import json
import os

from dotenv import load_dotenv

from parea import Parea, trace
from parea.evals.general.levenshtein import levenshtein_distance
from parea.schemas import Log

load_dotenv()

p = Parea(api_key=os.getenv("PAREA_API_KEY"))


# evaluation function for the substep
def eval_choose_greeting(log: Log) -> Union[float, None]:
if not (target := log.target):
return None

target_substep = json.loads(target)["substep"] # log.target is a string
output = log.output
return levenshtein_distance(target_substep, output)


# sub-step
@trace(eval_funcs=[eval_choose_greeting])
def choose_greeting(name: str) -> str:
return "Hello"


# end-to-end evaluation function
def eval_greet(log: Log) -> Union[float, None]:
if not (target := log.target):
return None

target_overall = json.loads(target)["overall"]
output = log.output
return levenshtein_distance(target_overall, output)


@trace(eval_funcs=[eval_greet])
def greet(name: str) -> str:
greeting = choose_greeting(name)
return f"{greeting} {name}"


data = [
{
"name": "Foo",
"target": {
"overall": "Hi Foo",
"substep": "Hi",
},
},
{
"name": "Bar",
"target": {
"overall": "Hello Bar",
"substep": "Hello",
},
},
]


if __name__ == "__main__":
p.experiment(
name="greeting",
data=data,
func=greet,
).run()
5 changes: 4 additions & 1 deletion parea/evals/general/levenshtein.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ def levenshtein(log: Log) -> Union[float, None]:
if (target := log.target) is None:
return None

output, target = str(output), str(target)
return levenshtein_distance(str(output), str(target))


def levenshtein_distance(output: str, target: str) -> float:
max_len = max(len(x) for x in [output, target])

score = 1
Expand Down
12 changes: 11 additions & 1 deletion parea/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,22 @@ def limit_concurrency_sync(sample):
_experiments = []


def data_converter(data: Union[str, int, Iterable[dict]]) -> Union[str, int, Iterable[dict]]:
if isinstance(data, (str, int)):
return data
else:
for sample in data:
if "target" in sample and isinstance(sample["target"], dict):
sample["target"] = json_dumps(sample["target"])
return data


@define
class Experiment:
# If your dataset is defined locally it should be an iterable of k/v
# pairs matching the expected inputs of your function. To reference a dataset you
# have saved on Parea, use the dataset name as a string or the id as an int.
data: Union[str, int, Iterable[dict]]
data: Union[str, int, Iterable[dict]] = field(converter=data_converter)
# The function to run. This function should accept inputs that match the keys of the data field.
func: Callable = field()
experiment_stats: ExperimentStatsSchema = field(init=False, default=None)
Expand Down
7 changes: 5 additions & 2 deletions parea/utils/trace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def init_trace(func_name, _parea_target_field, args, kwargs, func) -> Tuple[str,
execution_order = counters[root_trace_id]
counters[root_trace_id] += 1

parent_trace_id = new_trace_context[-2] if len(new_trace_context) > 1 else None
parent_target = trace_data.get()[parent_trace_id].target if parent_trace_id else None

trace_data.get()[trace_id] = TraceLog(
trace_id=trace_id,
parent_trace_id=trace_id,
Expand All @@ -201,7 +204,7 @@ def init_trace(func_name, _parea_target_field, args, kwargs, func) -> Tuple[str,
end_user_identifier=end_user_identifier,
session_id=session_id,
metadata=metadata,
target=_parea_target_field,
target=_parea_target_field or parent_target,
tags=tags,
inputs={} if log_omit_inputs else inputs,
experiment_uuid=os.environ.get(PAREA_OS_ENV_EXPERIMENT_UUID, None),
Expand All @@ -210,7 +213,7 @@ def init_trace(func_name, _parea_target_field, args, kwargs, func) -> Tuple[str,
depth=depth,
execution_order=execution_order,
)
parent_trace_id = new_trace_context[-2] if len(new_trace_context) > 1 else None

if parent_trace_id:
fill_trace_data(trace_id, {"parent_trace_id": parent_trace_id}, UpdateTraceScenario.CHAIN)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.158"
version = "0.2.159"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit ff205ab

Please sign in to comment.