Skip to content

Commit

Permalink
Merge pull request #484 from parea-ai/re-build-tracelog-is-ill-formatted
Browse files Browse the repository at this point in the history
serialize metadata before sending to across wire
  • Loading branch information
jalexanderII authored Feb 20, 2024
2 parents ec1d7ec + 830fcaf commit d1fefed
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 33 deletions.
60 changes: 36 additions & 24 deletions parea/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,44 @@


def retry_on_502(func: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(func):
"""
A decorator to retry a function or coroutine on encountering a 502 error.
Parameters:
- func: The function or coroutine to be decorated.
Returns:
- A wrapper function that incorporates retry logic.
"""

@wraps(func)
async def async_wrapper(*args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return await func(*args, **kwargs)
except (httpx.HTTPStatusError, httpx.ConnectError) as e:
if not _should_retry(e, retry):
raise
await asyncio.sleep(BACKOFF_FACTOR * (2**retry))

@wraps(func)
def sync_wrapper(*args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return func(*args, **kwargs)
except (httpx.HTTPStatusError, httpx.ConnectError) as e:
if not _should_retry(e, retry):
raise
time.sleep(BACKOFF_FACTOR * (2**retry))

def _should_retry(error, current_retry):
"""Determines if the function should retry on error."""
is_502_error = isinstance(error, httpx.HTTPStatusError) and error.response.status_code == 502
is_last_retry = current_retry == MAX_RETRIES - 1
return not is_last_retry and (isinstance(error, httpx.ConnectError) or is_502_error)

@wraps(func)
async def wrapper(*args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return await func(*args, **kwargs)
except (httpx.HTTPStatusError, httpx.ConnectError) as e:
if (isinstance(e, httpx.HTTPStatusError) and e.response.status_code != 502) or retry == MAX_RETRIES - 1:
raise
await asyncio.sleep(BACKOFF_FACTOR * (2**retry))

return wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:

@wraps(func)
def wrapper(*args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return func(*args, **kwargs)
except (httpx.HTTPStatusError, httpx.ConnectError) as e:
if (isinstance(e, httpx.HTTPStatusError) and e.response.status_code != 502) or retry == MAX_RETRIES - 1:
raise
time.sleep(BACKOFF_FACTOR * (2**retry))

return wrapper
return sync_wrapper


class HTTPClient:
Expand Down
3 changes: 2 additions & 1 deletion parea/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from parea.cache.cache import Cache
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.experiment.datasets import create_test_cases, create_test_collection
from parea.helpers import gen_trace_id
from parea.helpers import gen_trace_id, serialize_metadata_values
from parea.parea_logger import parea_logger
from parea.schemas.models import (
Completion,
Expand Down Expand Up @@ -274,6 +274,7 @@ def experiment(self, data: Union[str, int, Iterable[dict]], func: Callable, n_tr
return Experiment(data=data, func=func, p=self, n_trials=n_trials, metadata=metadata)

def _update_data_and_trace(self, data: Completion) -> Completion:
data = serialize_metadata_values(data)
inference_id = gen_trace_id()
data.inference_id = inference_id
data.project_uuid = self._project.uuid
Expand Down
4 changes: 2 additions & 2 deletions parea/cookbook/tracing_with_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
p = Parea(api_key=os.getenv("PAREA_API_KEY"))

# Parea SDK makes it easy to use different LLMs with the same apis structure and standardized request/response schemas.
LLM_OPTIONS = [("gpt-3.5-turbo", "openai"), ("gpt-4", "openai"), ("claude-instant-1", "anthropic"), ("claude-2", "anthropic")]
LLM_OPTIONS = [("gpt-3.5-turbo-0125", "openai"), ("gpt-4-0125-preview", "openai"), ("claude-instant-1", "anthropic"), ("claude-2.1", "anthropic")]
LIMIT = 1


Expand Down Expand Up @@ -61,7 +61,7 @@ def generate_tasks(main_objective: str, expounded_initial_task: list[dict[str, s
task_expansion = dump_task(expounded_initial_task)
prompt = [
Message(
role=Role.system,
role=Role.user,
content=(
f"You are an AI who creates tasks based on the following MAIN OBJECTIVE: {main_objective}\n"
f"Create tasks pertaining directly to your previous research here:\n"
Expand Down
20 changes: 18 additions & 2 deletions parea/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional, Union

import csv
import random
Expand All @@ -10,7 +10,8 @@
from attr import asdict, fields_dict

from parea.constants import ADJECTIVES, NOUNS
from parea.schemas.models import TraceLog
from parea.schemas.models import Completion, TraceLog, UpdateLog
from parea.utils.universal_encoder import json_dumps


def gen_trace_id() -> str:
Expand Down Expand Up @@ -54,3 +55,18 @@ def calculate_avg_as_string(values: list[Optional[float]]) -> str:

def duplicate_dicts(data: Iterable[dict], n: int) -> Iterable[dict]:
return [deepcopy(item) for item in data for _ in range(n)]


def serialize_metadata_values(log_data: Union[TraceLog, UpdateLog, Completion]) -> Union[TraceLog, UpdateLog, Completion]:
def serialize_values(metadata: dict[str, Any]) -> dict[str, str]:
return {k: json_dumps(v) for k, v in metadata.items()}

if isinstance(log_data, UpdateLog) and log_data.field_name_to_value_map:
if "metadata" in log_data.field_name_to_value_map:
serialized_values = serialize_values(log_data.field_name_to_value_map["metadata"])
log_data.field_name_to_value_map["metadata"] = serialized_values
elif log_data.metadata:
serialized_values = serialize_values(log_data.metadata)
log_data.metadata = serialized_values

return log_data
5 changes: 5 additions & 0 deletions parea/parea_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from parea.api_client import HTTPClient
from parea.cache.redis import RedisCache
from parea.constants import PAREA_OS_ENV_EXPERIMENT_UUID
from parea.helpers import serialize_metadata_values
from parea.schemas.log import TraceIntegrations
from parea.schemas.models import TraceLog, UpdateLog
from parea.utils.universal_encoder import json_dumps
Expand All @@ -32,13 +33,15 @@ def set_project_uuid(self, project_uuid: str) -> None:
self._project_uuid = project_uuid

def update_log(self, data: UpdateLog) -> None:
data = serialize_metadata_values(data)
self._client.request(
"PUT",
LOG_ENDPOINT,
data=asdict(data),
)

def record_log(self, data: TraceLog) -> None:
data = serialize_metadata_values(data)
data.project_uuid = self._project_uuid
self._client.request(
"POST",
Expand All @@ -47,6 +50,7 @@ def record_log(self, data: TraceLog) -> None:
)

async def arecord_log(self, data: TraceLog) -> None:
data = serialize_metadata_values(data)
data.project_uuid = self._project_uuid
await self._client.request_async(
"POST",
Expand All @@ -55,6 +59,7 @@ async def arecord_log(self, data: TraceLog) -> None:
)

def write_log(self, data: TraceLog) -> None:
data = serialize_metadata_values(data)
self._redis_cache.log(data)

def default_log(self, data: TraceLog) -> None:
Expand Down
2 changes: 1 addition & 1 deletion parea/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class TraceLog(Log):
output_for_eval_metrics: Optional[str] = None
evaluation_metric_names: Optional[list[str]] = field(factory=list)
scores: Optional[list[NamedEvaluationScore]] = field(factory=list)
apply_eval_frac: Optional[float] = 1.0
apply_eval_frac: float = 1.0
feedback_score: Optional[float] = None

# info filled from decorator
Expand Down
3 changes: 1 addition & 2 deletions parea/utils/trace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import time
from collections import ChainMap
from collections.abc import AsyncGenerator, AsyncIterator, Generator, Iterator
from copy import deepcopy
from functools import wraps
from random import random

Expand Down Expand Up @@ -123,7 +122,7 @@ def trace(
eval_funcs_names: Optional[list[str]] = None,
eval_funcs: Optional[list[Callable]] = None,
access_output_of_func: Optional[Callable] = None,
apply_eval_frac: Optional[float] = 1.0,
apply_eval_frac: float = 1.0,
):
def init_trace(func_name, _parea_target_field, args, kwargs, func) -> tuple[str, float]:
start_time = time.time()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.78"
version = "0.2.79"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit d1fefed

Please sign in to comment.