Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/parea-ai/parea-sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
joschkabraun committed Aug 29, 2023
2 parents 351df2f + 832f660 commit a639b23
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 17 deletions.
2 changes: 1 addition & 1 deletion parea/wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def init():
if _initialized_parea_wrapper:
return

OpenAIWrapper.init(default_logger)
OpenAIWrapper().init(default_logger)

_initialized_parea_wrapper = True

Expand Down
17 changes: 8 additions & 9 deletions parea/wrapper/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

import openai
from openai.openai_object import OpenAIObject

from ..schemas.models import LLMInputs, ModelParams
from ..utils.trace_utils import trace_data
Expand Down Expand Up @@ -46,11 +47,10 @@
class OpenAIWrapper:
original_methods = {"ChatCompletion.create": openai.ChatCompletion.create, "ChatCompletion.acreate": openai.ChatCompletion.acreate}

@staticmethod
def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Any]):
def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Any]):
if response:
usage = response["usage"]
output = OpenAIWrapper._get_output(response)
output = self._get_output(response)
else:
output = None
usage = {}
Expand All @@ -72,8 +72,8 @@ def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], respon
),
)

model_rate = OpenAIWrapper.get_model_cost(model)
model_completion_rate = OpenAIWrapper.get_model_cost(model, is_completion=True)
model_rate = self.get_model_cost(model)
model_completion_rate = self.get_model_cost(model, is_completion=True)
completion_cost = model_completion_rate * (usage.get("completion_tokens", 0) / 1000)
prompt_cost = model_rate * (usage.get("prompt_tokens", 0) / 1000)
total_cost = sum([prompt_cost, completion_cost])
Expand All @@ -85,9 +85,8 @@ def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], respon
trace_data.get()[trace_id].cost = total_cost
trace_data.get()[trace_id].output = output

@staticmethod
def init(log: Callable):
Wrapper(resolver=OpenAIWrapper.resolver, log=log, module=openai, func_names=list(OpenAIWrapper.original_methods.keys()))
def init(self, log: Callable):
Wrapper(resolver=self.resolver, log=log, module=openai, func_names=list(OpenAIWrapper.original_methods.keys()))

@staticmethod
def _get_output(result) -> str:
Expand All @@ -101,7 +100,7 @@ def _get_output(result) -> str:
@staticmethod
def _format_function_call(response_message) -> str:
function_name = response_message["function_call"]["name"]
if isinstance(response_message["function_call"]["arguments"], openai.openai_object.OpenAIObject):
if isinstance(response_message["function_call"]["arguments"], OpenAIObject):
function_args = dict(response_message["function_call"]["arguments"])
else:
function_args = json.loads(response_message["function_call"]["arguments"])
Expand Down
10 changes: 3 additions & 7 deletions parea/wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,15 @@ def async_decorator(self, orig_func: Callable) -> Callable:
async def wrapper(*args, **kwargs):
trace_id, start_time = self._init_trace()
response = None
error = None
try:
response = await orig_func(*args, **kwargs)
return response
except Exception as e:
self._handle_error(trace_id, e)
error = str(e)
raise
finally:
self._cleanup_trace(trace_id, start_time, response, args, kwargs)
self._cleanup_trace(trace_id, start_time, error, response, args, kwargs)

return wrapper

Expand Down Expand Up @@ -114,8 +115,3 @@ def _cleanup_trace(self, trace_id: str, start_time: float, error: str, response:

self.log(trace_id)
trace_context.get().pop()

@staticmethod
def _handle_error(trace_id: str, e: Exception):
trace_data.get()[trace_id].error = str(e)
trace_data.get()[trace_id].status = "error"

0 comments on commit a639b23

Please sign in to comment.