Skip to content

Commit

Permalink
Merge pull request #924 from parea-ai/PAI-1225-get-scores-from-trace-…
Browse files Browse the repository at this point in the history
…id-endpoint-sdkp

Pai 1225 get scores from trace id endpoint sdkp
  • Loading branch information
jalexanderII committed Jun 5, 2024
2 parents 2d1d331 + 2e35f91 commit 9f21a61
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 10 deletions.
42 changes: 38 additions & 4 deletions parea/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from parea.experiment.datasets import create_test_cases, create_test_collection
from parea.helpers import gen_trace_id, serialize_metadata_values, structure_trace_log_from_api, structure_trace_logs_from_api
from parea.parea_logger import parea_logger
from parea.schemas import EvaluationResult
from parea.schemas.models import (
Completion,
CompletionResponse,
Expand All @@ -33,6 +34,7 @@
TestCaseCollection,
TraceLog,
TraceLogFilters,
TraceLogTree,
UseDeployedPrompt,
UseDeployedPromptResponse,
)
Expand Down Expand Up @@ -397,11 +399,30 @@ def _update_data_and_trace(self, data: Completion) -> Completion:

return data

def get_trace_log(self, trace_id: str) -> TraceLog:
def get_trace_log(self, trace_id: str) -> TraceLogTree:
response = self._client.request("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
return structure_trace_log_from_api(response.json())

async def aget_trace_log(self, trace_id: str) -> TraceLog:
def get_trace_log_scores(self, trace_id: str, check_context: bool = True) -> List[EvaluationResult]:
"""
Get the scores from the trace log. If the scores are not present in the trace log, fetch them from the DB.
Args:
trace_id: The trace id to get the scores for.
check_context: If True, will check the context for the scores first before fetching from the DB.
Returns: A list of EvaluationResult objects.
"""
# try to get trace_id scores from context
if check_context:
if scores := (trace_data.get()[trace_id].scores or []):
print("Scores from context", scores)
return scores

response = self._client.request("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
tree: TraceLogTree = structure_trace_log_from_api(response.json())
return extract_scores(tree)

async def aget_trace_log(self, trace_id: str) -> TraceLogTree:
response = await self._client.request_async("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
return structure_trace_log_from_api(response.json())

Expand All @@ -413,11 +434,11 @@ async def alist_experiments(self, filter_conditions: Optional[ListExperimentUUID
response = await self._client.request_async("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
return structure(response.json(), List[ExperimentWithPinnedStatsSchema])

def get_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLog]:
def get_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLogTree]:
response = self._client.request("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))
return structure_trace_logs_from_api(response.json())

async def aget_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLog]:
async def aget_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLogTree]:
response = await self._client.request_async("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))
return structure_trace_logs_from_api(response.json())

Expand Down Expand Up @@ -446,3 +467,16 @@ def new_init(self, *args, **kwargs):
subclass = type(openai_client.__name__, (openai_client,), {"__init__": new_init})

return subclass


def extract_scores(tree: TraceLogTree) -> List[EvaluationResult]:
scores: List[EvaluationResult] = []

def traverse(node: TraceLogTree):
if node.scores:
scores.extend(node.scores or [])
for child in node.children_logs:
traverse(child)

traverse(tree)
return scores
106 changes: 106 additions & 0 deletions parea/cookbook/openai/dynamic_few_shot_injection_with_evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import List, Optional

import os
import random

from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel

from parea import Parea, get_current_trace_id, trace, trace_insert
from parea.schemas import Log, TestCase

load_dotenv()

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
p = Parea(api_key=os.getenv("PAREA_API_KEY"))
p.wrap_openai_client(client)

NUM_INTERACTIONS = 5


class Person(BaseModel):
name: str
email: str


class Email(BaseModel):
contact: Person
email_sent: str


mock_DB: dict[str, Email] = {}


def call_llm(messages: List[dict], model: str = "gpt-4o", temperature: float = 0.0) -> str:
return client.chat.completions.create(model=model, temperature=temperature, messages=messages).choices[0].message.content


def eval_func(log: Log) -> float:
return random.uniform(0, 1)


# Imitate collecting few shot examples from prod based on user feedback
@trace(eval_funcs=[eval_func])
def email_writer(main_objective: str, contact: Person, few_shot_examples: Optional[List[str]] = None) -> str:
trace_insert({"end_user_identifier": contact.name, "metadata": {"has_few_shot_examples": bool(few_shot_examples)}})

few_shot_examples_prompt = ("\nHere are some examples of good emails\n" + "\n".join(few_shot_examples)) if few_shot_examples else ""
messages = [
{
"role": "system",
"content": f"You are an AI who performs an email writing task based on the following objective: {main_objective}",
},
{
"role": "user",
"content": f"""
Your email is from: {contact.model_dump()}
{few_shot_examples_prompt if few_shot_examples else ""}
Email:
""",
},
]
response = call_llm(messages)
trace_id = get_current_trace_id()
# insert into mock_DB
mock_DB[trace_id] = Email(contact=contact, email_sent=response)
return response


def mimic_prod(few_shot_limit: int = 3):
contact = Person(name="John Doe", email="[email protected]")
dataset = p.get_collection("Good_Email_Examples")
selected_few_shot_examples = None
if dataset:
testcases: list[TestCase] = list(dataset.test_cases.values())
few_shot_examples = [case.inputs["email"] for case in testcases if case.inputs["user"] == contact.name]
# This is simply taking most recent n examples. You can imagine adding additional logic to the dataset
# that allows you to rank the examples based on some criteria
selected_few_shot_examples = few_shot_examples[-few_shot_limit:] if few_shot_examples else None
for interaction in range(NUM_INTERACTIONS):
email = email_writer("Convincing email to gym to cancel membership early.", contact, selected_few_shot_examples)
print(email)


def add_good_email_example_to_dataset(user_name, email):
# Note: if the test case collection doesn't exist, we will create a new collection with the provided name and data
p.add_test_cases([{"user": user_name, "email": email}], name="Good_Email_Examples")


def mimic_prod_checking_eval_scores():
# imagine the trace_id of the email is stored in state in the UI, so when the user provides feedback, we can use it
trace_ids = mock_DB.keys()
for trace_id in trace_ids:
scores = p.get_trace_log_scores(trace_id)
for score in scores:
if score.name == "eval_func" and score.score >= 0.5:
add_good_email_example_to_dataset(mock_DB[trace_id].contact.name, mock_DB[trace_id].email_sent)
break


if __name__ == "__main__":
mimic_prod()
mimic_prod_checking_eval_scores()
# future llm calls will now have few-shot examples from the feedback collection
mimic_prod()
print("Done")
8 changes: 4 additions & 4 deletions parea/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from cattrs import GenConverter

from parea.constants import ADJECTIVES, NOUNS, TURN_OFF_PAREA_LOGGING
from parea.schemas.models import Completion, TraceLog, UpdateLog
from parea.schemas.models import Completion, TraceLog, TraceLogTree, UpdateLog
from parea.utils.universal_encoder import json_dumps


Expand Down Expand Up @@ -81,7 +81,7 @@ def timezone_aware_now() -> datetime:
return datetime.now(pytz.utc)


def structure_trace_log_from_api(d: dict) -> TraceLog:
def structure_trace_log_from_api(d: dict) -> TraceLogTree:
def structure_union_type(obj: Any, cl: type) -> Any:
if isinstance(obj, str):
return obj
Expand All @@ -92,10 +92,10 @@ def structure_union_type(obj: Any, cl: type) -> Any:

converter = GenConverter()
converter.register_structure_hook(Union[str, Dict[str, str], None], structure_union_type)
return converter.structure(d, TraceLog)
return converter.structure(d, TraceLogTree)


def structure_trace_logs_from_api(data: List[dict]) -> List[TraceLog]:
def structure_trace_logs_from_api(data: List[dict]) -> List[TraceLogTree]:
return [structure_trace_log_from_api(d) for d in data]


Expand Down
2 changes: 1 addition & 1 deletion parea/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TraceLog(EvaluatedLog):

@define
class TraceLogTree(TraceLog):
children: Optional[List[TraceLog]] = field(factory=list)
children_logs: Optional[List[TraceLog]] = field(factory=list)


@define
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.167a0"
version = "0.2.167"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit 9f21a61

Please sign in to comment.