Skip to content

Commit

Permalink
Merge pull request #223 from parea-ai/PAI-408-fix-json-serializer-for…
Browse files Browse the repository at this point in the history
…-trace-decorator

feat: add universial json serializer
  • Loading branch information
jalexanderII committed Nov 16, 2023
2 parents c1137c2 + 74edd73 commit dac7a72
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
22 changes: 10 additions & 12 deletions parea/utils/trace_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional

import contextvars
import inspect
Expand All @@ -9,11 +9,10 @@
from collections import ChainMap
from functools import wraps

from attrs import asdict

from parea.helpers import gen_trace_id, to_date_and_time_string
from parea.parea_logger import parea_logger
from parea.schemas.models import CompletionResponse, NamedEvaluationScore, TraceLog
from parea.schemas.models import NamedEvaluationScore, TraceLog
from parea.utils.universal_encoder import json_dumps

logger = logging.getLogger()

Expand Down Expand Up @@ -100,8 +99,7 @@ async def async_wrapper(*args, **kwargs):
output_as_list = check_multiple_return_values(func)
try:
result = await func(*args, **kwargs)
output = make_output(result, output_as_list)
trace_data.get()[trace_id].output = output if isinstance(output, str) else json.dumps(output)
trace_data.get()[trace_id].output = make_output(result, output_as_list)
trace_data.get()[trace_id].status = "success"
trace_data.get()[trace_id].evaluation_metric_names = eval_funcs_names
except Exception as e:
Expand All @@ -119,8 +117,7 @@ def wrapper(*args, **kwargs):
output_as_list = check_multiple_return_values(func)
try:
result = func(*args, **kwargs)
output = make_output(result, output_as_list)
trace_data.get()[trace_id].output = output if isinstance(output, str) else json.dumps(output)
trace_data.get()[trace_id].output = make_output(result, output_as_list)
trace_data.get()[trace_id].status = "success"
trace_data.get()[trace_id].evaluation_metric_names = eval_funcs_names
except Exception as e:
Expand Down Expand Up @@ -155,11 +152,12 @@ def check_multiple_return_values(func) -> bool:
return False


def make_output(result, islist) -> Union[list[Any], Any]:
def make_output(result, islist) -> str:
if islist:
return [asdict(r) if isinstance(r, CompletionResponse) else r for r in result]
json_list = [json_dumps(r) for r in result]
return json_dumps(json_list)
else:
return asdict(result) if isinstance(result, CompletionResponse) else result
return json_dumps(result)


def logger_record_log(trace_id: str):
Expand Down Expand Up @@ -187,7 +185,7 @@ def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, a
if access_output_of_func:
output = json.loads(data.output)
output = access_output_of_func(output)
output_for_eval_metrics = json.dumps(output)
output_for_eval_metrics = json_dumps(output)
else:
output_for_eval_metrics = data.output
data.output_for_eval_metrics = output_for_eval_metrics
Expand Down
35 changes: 35 additions & 0 deletions parea/utils/universal_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any

import dataclasses
import datetime
import json

import attrs


def is_dataclass_instance(obj):
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)


def is_attrs_instance(obj):
return attrs.has(obj)


class UniversalEncoder(json.JSONEncoder):
def default(self, obj: Any):
if isinstance(obj, str):
return obj
elif is_dataclass_instance(obj):
return dataclasses.asdict(obj)
elif is_attrs_instance(obj):
return attrs.asdict(obj)
elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
elif isinstance(obj, datetime.timedelta):
return obj.total_seconds()
else:
return super().default(obj)


def json_dumps(obj, **kwargs):
return json.dumps(obj, cls=UniversalEncoder, **kwargs)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.14"
version = "0.2.15"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit dac7a72

Please sign in to comment.