diff --git a/parea/client.py b/parea/client.py index a7dce144..0a731be7 100644 --- a/parea/client.py +++ b/parea/client.py @@ -6,6 +6,7 @@ from attrs import asdict, define, field from cattrs import structure +from openai import OpenAI from parea.api_client import HTTPClient from parea.cache import InMemoryCache, RedisCache @@ -46,6 +47,9 @@ def __attrs_post_init__(self): parea_logger.set_redis_cache(self.cache) _init_parea_wrapper(logger_all_possible, self.cache) + def wrap_openai_client(self, client: OpenAI) -> None: + OpenAIWrapper().init(log=logger_all_possible, cache=self.cache, module_client=client) + def completion(self, data: Completion) -> CompletionResponse: parent_trace_id = get_current_trace_id() inference_id = gen_trace_id() diff --git a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py index bf8edda0..4632072a 100644 --- a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py +++ b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py @@ -1,8 +1,8 @@ import os from datetime import datetime -import openai from dotenv import load_dotenv +from openai import OpenAI from parea import Parea from parea.schemas.models import FeedbackRequest @@ -10,13 +10,14 @@ load_dotenv() -openai.api_key = os.getenv("OPENAI_API_KEY") +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) p = Parea(api_key=os.getenv("PAREA_API_KEY")) +p.wrap_openai_client(client) def call_llm(data: list[dict], model: str = "gpt-3.5-turbo", temperature: float = 0.0) -> str: - return openai.chat.completions.create(model=model, temperature=temperature, messages=data).choices[0].message.content + return client.chat.completions.create(model=model, temperature=temperature, messages=data).choices[0].message.content @trace diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index 47e8d13a..f6cacf67 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -122,13 +122,13 @@ class OpenAIWrapper: except openai.OpenAIError: original_methods = {} - def init(self, log: Callable, cache: Cache = None): + def init(self, log: Callable, cache: Cache = None, module_client=openai): Wrapper( resolver=self.resolver, gen_resolver=self.gen_resolver, agen_resolver=self.agen_resolver, log=log, - module=openai, + module=module_client, func_names=list(self.original_methods.keys()), cache=cache, convert_kwargs_to_cache_request=self.convert_kwargs_to_cache_request, diff --git a/pyproject.toml b/pyproject.toml index fe2e17d4..f1fb552d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.28" +version = "0.2.29" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]