From e9f29e7bc2d990d6ade24ec98c7be2cb3f936dc3 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Mon, 28 Aug 2023 21:33:23 -0700 Subject: [PATCH] fix: fix typo --- parea/wrapper/__init__.py | 2 +- parea/wrapper/openai.py | 17 ++++++++--------- parea/wrapper/wrapper.py | 10 +++------- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/parea/wrapper/__init__.py b/parea/wrapper/__init__.py index eb6d319b..87996299 100644 --- a/parea/wrapper/__init__.py +++ b/parea/wrapper/__init__.py @@ -9,7 +9,7 @@ def init(): if _initialized_parea_wrapper: return - OpenAIWrapper.init(default_logger) + OpenAIWrapper().init(default_logger) _initialized_parea_wrapper = True diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index d5157923..969527b5 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -3,6 +3,7 @@ import json import openai +from openai.openai_object import OpenAIObject from ..schemas.models import LLMInputs, ModelParams from ..utils.trace_utils import trace_data @@ -46,11 +47,10 @@ class OpenAIWrapper: original_methods = {"ChatCompletion.create": openai.ChatCompletion.create, "ChatCompletion.acreate": openai.ChatCompletion.acreate} - @staticmethod - def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Any]): + def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Any]): if response: usage = response["usage"] - output = OpenAIWrapper._get_output(response) + output = self._get_output(response) else: output = None usage = {} @@ -72,8 +72,8 @@ def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], respon ), ) - model_rate = OpenAIWrapper.get_model_cost(model) - model_completion_rate = OpenAIWrapper.get_model_cost(model, is_completion=True) + model_rate = self.get_model_cost(model) + model_completion_rate = self.get_model_cost(model, is_completion=True) completion_cost = model_completion_rate * (usage.get("completion_tokens", 0) / 1000) prompt_cost = model_rate * (usage.get("prompt_tokens", 0) / 1000) total_cost = sum([prompt_cost, completion_cost]) @@ -85,9 +85,8 @@ def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], respon trace_data.get()[trace_id].cost = total_cost trace_data.get()[trace_id].output = output - @staticmethod - def init(log: Callable): - Wrapper(resolver=OpenAIWrapper.resolver, log=log, module=openai, func_names=list(OpenAIWrapper.original_methods.keys())) + def init(self, log: Callable): + Wrapper(resolver=self.resolver, log=log, module=openai, func_names=list(OpenAIWrapper.original_methods.keys())) @staticmethod def _get_output(result) -> str: @@ -101,7 +100,7 @@ def _get_output(result) -> str: @staticmethod def _format_function_call(response_message) -> str: function_name = response_message["function_call"]["name"] - if isinstance(response_message["function_call"]["arguments"], openai.openai_object.OpenAIObject): + if isinstance(response_message["function_call"]["arguments"], OpenAIObject): function_args = dict(response_message["function_call"]["arguments"]) else: function_args = json.loads(response_message["function_call"]["arguments"]) diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index 1f68358e..25213b19 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -72,14 +72,15 @@ def async_decorator(self, orig_func: Callable) -> Callable: async def wrapper(*args, **kwargs): trace_id, start_time = self._init_trace() response = None + error = None try: response = await orig_func(*args, **kwargs) return response except Exception as e: - self._handle_error(trace_id, e) + error = str(e) raise finally: - self._cleanup_trace(trace_id, start_time, response, args, kwargs) + self._cleanup_trace(trace_id, start_time, error, response, args, kwargs) return wrapper @@ -114,8 +115,3 @@ def _cleanup_trace(self, trace_id: str, start_time: float, error: str, response: self.log(trace_id) trace_context.get().pop() - - @staticmethod - def _handle_error(trace_id: str, e: Exception): - trace_data.get()[trace_id].error = str(e) - trace_data.get()[trace_id].status = "error"