Skip to content

Commit

Permalink
Merge pull request #503 from parea-ai/PAI-724-auto-wrap-openai-client…
Browse files Browse the repository at this point in the history
…-classes

Pai 724 auto wrap openai client classes
  • Loading branch information
joschkabraun committed Feb 22, 2024
2 parents 16cafb1 + cc982a5 commit 20f73d4
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
30 changes: 28 additions & 2 deletions parea/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import AsyncIterable, Iterable

import httpx
import openai
from attrs import asdict, define, field
from cattrs import structure
from dotenv import load_dotenv
Expand Down Expand Up @@ -72,7 +73,7 @@ def __attrs_post_init__(self):
parea_logger.set_project_uuid(self.project_uuid)
if isinstance(self.cache, (RedisCache, InMemoryCache)):
parea_logger.set_redis_cache(self.cache)
_init_parea_wrapper(logger_all_possible, self.cache)
_init_parea_wrapper(logger_all_possible, self.cache, self)

def wrap_openai_client(self, client: "OpenAI") -> None:
"""Only necessary for instance client with OpenAI version >= 1.0.0"""
Expand Down Expand Up @@ -300,9 +301,34 @@ def _update_data_and_trace(self, data: Completion) -> Completion:
_initialized_parea_wrapper = False


def _init_parea_wrapper(log: Callable = None, cache: Cache = None):
def create_subclass_with_new_init(openai_client, parea_client: Parea):
"""Creates a subclass of the given openai_client to always wrap it with Parea at instantiation."""

def new_init(self, *args, **kwargs):
openai_client.__init__(self, *args, **kwargs)
parea_client.wrap_openai_client(self)

subclass = type(openai_client.__name__, (openai_client,), {"__init__": new_init})

return subclass


def _init_parea_wrapper(log: Callable = None, cache: Cache = None, parea_client: Parea = None):
global _initialized_parea_wrapper
if _initialized_parea_wrapper:
return

# always wrap the module-level client
OpenAIWrapper().init(log=log, cache=cache)

# for new versions of openai, we need to wrap the client classes at instantiation
if not openai.__version__.startswith("0."):
from openai import AsyncOpenAI, OpenAI
from openai.lib.azure import AsyncAzureOpenAI, AzureOpenAI

openai.OpenAI = create_subclass_with_new_init(OpenAI, parea_client)
openai.AzureOpenAI = create_subclass_with_new_init(AzureOpenAI, parea_client)
openai.AsyncOpenAI = create_subclass_with_new_init(AsyncOpenAI, parea_client)
openai.AsyncAzureOpenAI = create_subclass_with_new_init(AsyncAzureOpenAI, parea_client)

_initialized_parea_wrapper = True
11 changes: 10 additions & 1 deletion parea/wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,18 @@ def wrap_functions(self, module: Any, func_names: list[str]):

def _wrapped_func(self, original_func: Callable) -> Callable:
unwrapped_func = original_func
is_wrapped_attribute_name = "_is_already_wrapped_by_parea"
while hasattr(unwrapped_func, "__wrapped__"):
if getattr(unwrapped_func, is_wrapped_attribute_name, False):
return original_func
unwrapped_func = unwrapped_func.__wrapped__
return self._get_decorator(unwrapped_func, original_func)
if getattr(unwrapped_func, is_wrapped_attribute_name, False):
return original_func

wrapped_func = self._get_decorator(unwrapped_func, original_func)
setattr(wrapped_func, is_wrapped_attribute_name, True)

return wrapped_func

def _get_decorator(self, unwrapped_func: Callable, original_func: Callable):
if inspect.iscoroutinefunction(unwrapped_func):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.82"
version = "0.2.83"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit 20f73d4

Please sign in to comment.