From 74edd7322e6dde5846c9061ab59fa69d3dd59343 Mon Sep 17 00:00:00 2001 From: Joel Alexander Date: Thu, 16 Nov 2023 17:08:41 -0500 Subject: [PATCH] feat: add universial json serializer --- parea/utils/trace_utils.py | 22 +++++++++----------- parea/utils/universal_encoder.py | 35 ++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 46 insertions(+), 13 deletions(-) create mode 100644 parea/utils/universal_encoder.py diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index e3ad0212..05a477bb 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional import contextvars import inspect @@ -9,11 +9,10 @@ from collections import ChainMap from functools import wraps -from attrs import asdict - from parea.helpers import gen_trace_id, to_date_and_time_string from parea.parea_logger import parea_logger -from parea.schemas.models import CompletionResponse, NamedEvaluationScore, TraceLog +from parea.schemas.models import NamedEvaluationScore, TraceLog +from parea.utils.universal_encoder import json_dumps logger = logging.getLogger() @@ -100,8 +99,7 @@ async def async_wrapper(*args, **kwargs): output_as_list = check_multiple_return_values(func) try: result = await func(*args, **kwargs) - output = make_output(result, output_as_list) - trace_data.get()[trace_id].output = output if isinstance(output, str) else json.dumps(output) + trace_data.get()[trace_id].output = make_output(result, output_as_list) trace_data.get()[trace_id].status = "success" trace_data.get()[trace_id].evaluation_metric_names = eval_funcs_names except Exception as e: @@ -119,8 +117,7 @@ def wrapper(*args, **kwargs): output_as_list = check_multiple_return_values(func) try: result = func(*args, **kwargs) - output = make_output(result, output_as_list) - trace_data.get()[trace_id].output = output if isinstance(output, str) else json.dumps(output) + trace_data.get()[trace_id].output = make_output(result, output_as_list) trace_data.get()[trace_id].status = "success" trace_data.get()[trace_id].evaluation_metric_names = eval_funcs_names except Exception as e: @@ -155,11 +152,12 @@ def check_multiple_return_values(func) -> bool: return False -def make_output(result, islist) -> Union[list[Any], Any]: +def make_output(result, islist) -> str: if islist: - return [asdict(r) if isinstance(r, CompletionResponse) else r for r in result] + json_list = [json_dumps(r) for r in result] + return json_dumps(json_list) else: - return asdict(result) if isinstance(result, CompletionResponse) else result + return json_dumps(result) def logger_record_log(trace_id: str): @@ -187,7 +185,7 @@ def call_eval_funcs_then_log(trace_id: str, eval_funcs: list[Callable] = None, a if access_output_of_func: output = json.loads(data.output) output = access_output_of_func(output) - output_for_eval_metrics = json.dumps(output) + output_for_eval_metrics = json_dumps(output) else: output_for_eval_metrics = data.output data.output_for_eval_metrics = output_for_eval_metrics diff --git a/parea/utils/universal_encoder.py b/parea/utils/universal_encoder.py new file mode 100644 index 00000000..dd54e9fa --- /dev/null +++ b/parea/utils/universal_encoder.py @@ -0,0 +1,35 @@ +from typing import Any + +import dataclasses +import datetime +import json + +import attrs + + +def is_dataclass_instance(obj): + return dataclasses.is_dataclass(obj) and not isinstance(obj, type) + + +def is_attrs_instance(obj): + return attrs.has(obj) + + +class UniversalEncoder(json.JSONEncoder): + def default(self, obj: Any): + if isinstance(obj, str): + return obj + elif is_dataclass_instance(obj): + return dataclasses.asdict(obj) + elif is_attrs_instance(obj): + return attrs.asdict(obj) + elif isinstance(obj, (datetime.datetime, datetime.date, datetime.time)): + return obj.isoformat() + elif isinstance(obj, datetime.timedelta): + return obj.total_seconds() + else: + return super().default(obj) + + +def json_dumps(obj, **kwargs): + return json.dumps(obj, cls=UniversalEncoder, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 37a758f0..680db576 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.14" +version = "0.2.15" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]