From ddae91442b86407a94f4e9d5215d77e847df0ae8 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Fri, 6 Oct 2023 17:35:46 +0200 Subject: [PATCH 01/37] feat: add redis cache --- parea/__init__.py | 3 +- parea/cache/__init__.py | 0 parea/cache/cache.py | 68 ++++++++++++ parea/cache/redis.py | 72 +++++++++++++ parea/client.py | 23 ++++ .../tracing_with_open_ai_endpoint_directly.py | 16 +-- parea/schemas/models.py | 5 + parea/wrapper/__init__.py | 17 +-- parea/wrapper/openai.py | 100 ++++++++++++++---- parea/wrapper/wrapper.py | 44 ++++++-- 10 files changed, 293 insertions(+), 55 deletions(-) create mode 100644 parea/cache/__init__.py create mode 100644 parea/cache/cache.py create mode 100644 parea/cache/redis.py diff --git a/parea/__init__.py b/parea/__init__.py index 32997f23..9facc5c8 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -11,8 +11,7 @@ from importlib import metadata as importlib_metadata -import parea.wrapper # noqa: F401 -from parea.client import Parea +from parea.client import Parea, init def get_version() -> str: diff --git a/parea/cache/__init__.py b/parea/cache/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/parea/cache/cache.py b/parea/cache/cache.py new file mode 100644 index 00000000..388d2867 --- /dev/null +++ b/parea/cache/cache.py @@ -0,0 +1,68 @@ +from abc import ABC +from typing import Optional + +from parea.schemas.models import CacheRequest, TraceLog + + +class Cache(ABC): + def get(self, key: CacheRequest) -> Optional[TraceLog]: + """ + Get a normal response from the cache. + + Args: + key (CacheRequest): The cache key. + + Returns: + The cached response, or None if the key was not found. + """ + raise NotImplementedError + + async def aget(self, key: CacheRequest) -> Optional[TraceLog]: + """ + Get a normal response from the cache. + + Args: + key (CacheRequest): The cache key. + + Returns: + The cached response, or None if the key was not found. + """ + raise NotImplementedError + + def set(self, key: CacheRequest, value: TraceLog): + """ + Set a normal response in the cache. + + Args: + key (CacheRequest): The cache key. + value (TraceLog): The response to cache. + """ + raise NotImplementedError + + async def aset(self, key: CacheRequest, value: TraceLog): + """ + Set a normal response in the cache. + + Args: + key (CacheRequest): The cache key. + value (TraceLog): The response to cache. + """ + raise NotImplementedError + + def invalidate(self, key: CacheRequest): + """ + Invalidate a key in the cache. + + Args: + key (CacheRequest): The cache key. + """ + raise NotImplementedError + + async def ainvalidate(self, key: CacheRequest): + """ + Invalidate a key in the cache. + + Args: + key (CacheRequest): The cache key. + """ + raise NotImplementedError diff --git a/parea/cache/redis.py b/parea/cache/redis.py new file mode 100644 index 00000000..da307786 --- /dev/null +++ b/parea/cache/redis.py @@ -0,0 +1,72 @@ +import json +import logging +import os +from typing import Optional + +import redis +from attr import asdict + +from parea.cache.cache import Cache +from parea.schemas.models import TraceLog, CacheRequest + +logger = logging.getLogger() + + +class RedisLRUCache(Cache): + """A Redis-based LRU cache for caching both normal and streaming responses.""" + + def __init__( + self, + host: str = os.getenv('REDIS_HOST', 'localhost'), + port: int = int(os.getenv('REDIS_PORT', 6379)), + password: str = os.getenv('REDIS_PASSWORT', None), + ttl=3600 * 6 + ): + """ + Initialize the cache. + + Args: + ttl (int): The default TTL for cache entries, in seconds. + """ + self.r = redis.Redis( + host=host, + port=port, + password=password, + ) + self.ttl = ttl + + def get(self, key: CacheRequest) -> Optional[TraceLog]: + try: + result = self.r.get(json.dumps(asdict(key))) + if result is not None: + return TraceLog(**json.loads(result)) + except redis.RedisError as e: + logger.error(f"Error getting key {key} from cache: {e}") + return None + + async def aget(self, key: CacheRequest) -> Optional[TraceLog]: + return self.get(key) + + def set(self, key: CacheRequest, value: TraceLog): + try: + self.r.set(json.dumps(asdict(key)), json.dumps(asdict(value)), ex=self.ttl) + except redis.RedisError as e: + logger.error(f"Error setting key {key} in cache: {e}") + + async def aset(self, key: CacheRequest, value: TraceLog): + self.set(key, value) + + def invalidate(self, key: CacheRequest): + """ + Invalidate a key in the cache. + + Args: + key (str): The cache key. + """ + try: + self.r.delete(json.dumps(asdict(key))) + except redis.RedisError as e: + logger.error(f"Error invalidating key {key} from cache: {e}") + + async def ainvalidate(self, key: CacheRequest): + self.invalidate(key) diff --git a/parea/client.py b/parea/client.py index 2772ff3d..209fbaf4 100644 --- a/parea/client.py +++ b/parea/client.py @@ -1,13 +1,18 @@ import asyncio +import os import time +from typing import Callable from attrs import asdict, define, field from parea.api_client import HTTPClient +from parea.cache.cache import Cache +from parea.cache.redis import RedisLRUCache from parea.helpers import gen_trace_id from parea.parea_logger import parea_logger from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, UseDeployedPromptResponse from parea.utils.trace_utils import default_logger, get_current_trace_id, trace_data +from parea.wrapper import OpenAIWrapper COMPLETION_ENDPOINT = "/completion" DEPLOYED_PROMPT_ENDPOINT = "/deployed-prompt" @@ -18,10 +23,13 @@ class Parea: api_key: str = field(init=True, default="") _client: HTTPClient = field(init=False, default=HTTPClient()) + cache: Cache = field(init=True, default=RedisLRUCache()) def __attrs_post_init__(self): self._client.set_api_key(self.api_key) parea_logger.set_client(self._client) + log = default_logger if self.api_key else lambda *args, **kwargs: None + _init_parea_wrapper(log, self.cache) def completion(self, data: Completion) -> CompletionResponse: inference_id = gen_trace_id() @@ -80,3 +88,18 @@ async def arecord_feedback(self, data: FeedbackRequest) -> None: RECORD_FEEDBACK_ENDPOINT, data=asdict(data), ) + + +_initialized_parea_wrapper = False + + +def init(api_key: str = os.getenv("PAREA_API_KEY"), cache: Cache = RedisLRUCache()) -> None: + Parea(api_key=api_key, cache=cache) + + +def _init_parea_wrapper(log: Callable = None, cache: Cache = None): + global _initialized_parea_wrapper + if _initialized_parea_wrapper: + return + OpenAIWrapper().init(log=log, cache=cache) + _initialized_parea_wrapper = True diff --git a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py index 4fc914e4..4674955b 100644 --- a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py +++ b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py @@ -26,7 +26,7 @@ def argumentor(query: str, additional_description: str = "") -> str: { "role": "system", "content": f"""You are a debater making an argument on a topic. {additional_description}. - The current time is {datetime.now()}""", + The current time is {datetime.now().strftime("%Y-%m-%d")}""", }, {"role": "user", "content": f"The discussion topic is {query}"}, ] @@ -55,7 +55,7 @@ def refiner(query: str, additional_description: str, argument: str, criticism: s { "role": "system", "content": f"""You are a debater making an argument on a topic. {additional_description}. - The current time is {datetime.now()}""", + The current time is {datetime.now().strftime("%Y-%m-%d")}""", }, {"role": "user", "content": f"""The discussion topic is {query}"""}, {"role": "assistant", "content": argument}, @@ -83,9 +83,9 @@ def argument_chain(query: str, additional_description: str = "") -> tuple[str, s additional_description="Provide a concise, few sentence argument on why sparkling wine is good for you.", ) print(result) - p.record_feedback( - FeedbackRequest( - trace_id=trace_id, - score=0.7, # 0.0 (bad) to 1.0 (good) - ) - ) + # # p.record_feedback( + # # FeedbackRequest( + # # trace_id=trace_id, + # # score=0.7, # 0.0 (bad) to 1.0 (good) + # # ) + # # ) diff --git a/parea/schemas/models.py b/parea/schemas/models.py index b9b9d7c1..089ace63 100644 --- a/parea/schemas/models.py +++ b/parea/schemas/models.py @@ -146,3 +146,8 @@ class TraceLog: @define class TraceLogTree(TraceLog): children: Optional[list[TraceLog]] = None + + +@define +class CacheRequest: + configuration: LLMInputs = LLMInputs() diff --git a/parea/wrapper/__init__.py b/parea/wrapper/__init__.py index 87996299..764c505a 100644 --- a/parea/wrapper/__init__.py +++ b/parea/wrapper/__init__.py @@ -1,17 +1,2 @@ -from parea.utils.trace_utils import default_logger from parea.wrapper.openai import OpenAIWrapper - -_initialized_parea_wrapper = False - - -def init(): - global _initialized_parea_wrapper - if _initialized_parea_wrapper: - return - - OpenAIWrapper().init(default_logger) - - _initialized_parea_wrapper = True - - -init() +from parea.wrapper.wrapper import Wrapper diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index 969527b5..f52264c9 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -4,8 +4,10 @@ import openai from openai.openai_object import OpenAIObject +from openai.util import convert_to_openai_object -from ..schemas.models import LLMInputs, ModelParams +from ..cache.cache import Cache +from ..schemas.models import CacheRequest, LLMInputs, ModelParams, TraceLog from ..utils.trace_utils import trace_data from .wrapper import Wrapper @@ -45,7 +47,21 @@ class OpenAIWrapper: - original_methods = {"ChatCompletion.create": openai.ChatCompletion.create, "ChatCompletion.acreate": openai.ChatCompletion.acreate} + original_methods = { + "ChatCompletion.create": openai.ChatCompletion.create, + "ChatCompletion.acreate": openai.ChatCompletion.acreate + } + + def init(self, log: Callable, cache: Cache = None): + Wrapper( + resolver=self.resolver, + log=log, + module=openai, + func_names=list(self.original_methods.keys()), + cache=cache, + convert_kwargs_to_cache_request=self.convert_kwargs_to_cache_request, + convert_cache_to_response=self.convert_cache_to_response, + ) def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Any]): if response: @@ -55,22 +71,8 @@ def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], output = None usage = {} - model = kwargs.get("model", None) - - llm_inputs = LLMInputs( - model=model, - provider="openai", - messages=kwargs.get("messages", None), - functions=kwargs.get("functions", None), - function_call=kwargs.get("function_call", None), - model_params=ModelParams( - temp=kwargs.get("temperature", 1.0), - max_length=kwargs.get("max_tokens", None), - top_p=kwargs.get("top_p", 1.0), - frequency_penalty=kwargs.get("frequency_penalty", 0.0), - presence_penalty=kwargs.get("presence_penalty", 0.0), - ), - ) + llm_configuration = self._kwargs_to_llm_configuration(kwargs) + model = llm_configuration.model model_rate = self.get_model_cost(model) model_completion_rate = self.get_model_cost(model, is_completion=True) @@ -78,15 +80,29 @@ def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], prompt_cost = model_rate * (usage.get("prompt_tokens", 0) / 1000) total_cost = sum([prompt_cost, completion_cost]) - trace_data.get()[trace_id].configuration = llm_inputs + trace_data.get()[trace_id].configuration = llm_configuration trace_data.get()[trace_id].input_tokens = usage.get("prompt_tokens", 0) trace_data.get()[trace_id].output_tokens = usage.get("completion_tokens", 0) trace_data.get()[trace_id].total_tokens = usage.get("total_tokens", 0) trace_data.get()[trace_id].cost = total_cost trace_data.get()[trace_id].output = output - def init(self, log: Callable): - Wrapper(resolver=self.resolver, log=log, module=openai, func_names=list(OpenAIWrapper.original_methods.keys())) + @staticmethod + def _kwargs_to_llm_configuration(kwargs): + return LLMInputs( + model=kwargs.get("model", None), + provider="openai", + messages=kwargs.get("messages", None), + functions=kwargs.get("functions", None), + function_call=kwargs.get("function_call", None), + model_params=ModelParams( + temp=kwargs.get("temperature", 1.0), + max_length=kwargs.get("max_tokens", None), + top_p=kwargs.get("top_p", 1.0), + frequency_penalty=kwargs.get("frequency_penalty", 0.0), + presence_penalty=kwargs.get("presence_penalty", 0.0), + ), + ) @staticmethod def _get_output(result) -> str: @@ -104,7 +120,7 @@ def _format_function_call(response_message) -> str: function_args = dict(response_message["function_call"]["arguments"]) else: function_args = json.loads(response_message["function_call"]["arguments"]) - return f'```{json.dumps({"name": function_name, "arguments": function_args}, indent=4)}```' + return json.dumps({"name": function_name, "arguments": function_args}, indent=4) @staticmethod def get_model_cost(model_name: str, is_completion: bool = False) -> float: @@ -119,3 +135,43 @@ def get_model_cost(model_name: str, is_completion: bool = False) -> float: raise ValueError(msg) return cost + + @staticmethod + def convert_kwargs_to_cache_request(_args: Sequence[Any], kwargs: Dict[str, Any]) -> CacheRequest: + return CacheRequest( + configuration=OpenAIWrapper._kwargs_to_llm_configuration(kwargs), + ) + + @staticmethod + def convert_cache_to_response(cache_response: TraceLog) -> OpenAIObject: + content = cache_response.output + try: + function_call = json.loads(content) + message = { + "role": "assistant", + "content": None, + "function_call": function_call, + } + except json.JSONDecodeError: + message = { + "role": "assistant", + "content": content, + } + + return convert_to_openai_object( + { + "object": "chat.completion", + "model": cache_response.configuration["model"], + "choices": [ + { + "index": 0, + "message": message, + } + ], + "usage": { + "prompt_tokens": cache_response.input_tokens, + "completion_tokens": cache_response.output_tokens, + "total_tokens": cache_response.total_tokens, + }, + } + ) diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index f1e52400..780eebd2 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -5,15 +5,28 @@ import time from uuid import uuid4 +from parea.cache.cache import Cache from parea.schemas.models import TraceLog -from parea.utils.trace_utils import default_logger, to_date_and_time_string, trace_context, trace_data +from parea.utils.trace_utils import to_date_and_time_string, trace_context, trace_data class Wrapper: - def __init__(self, module: Any, func_names: List[str], resolver: Callable, log: Callable = default_logger) -> None: + def __init__( + self, + module: Any, + func_names: List[str], + resolver: Callable, + cache: Cache, + convert_kwargs_to_cache_request: Callable, + convert_cache_to_response: Callable, + log: Callable, + ) -> None: self.resolver = resolver self.log = log self.wrap_functions(module, func_names) + self.cache = cache + self.convert_kwargs_to_cache_request = convert_kwargs_to_cache_request + self.convert_cache_to_response = convert_cache_to_response def wrap_functions(self, module: Any, func_names: List[str]): for func_name in func_names: @@ -73,14 +86,21 @@ async def wrapper(*args, **kwargs): trace_id, start_time = self._init_trace() response = None error = None + cache_hit = False try: - response = await orig_func(*args, **kwargs) + if self.cache: + cache_result = await self.cache.aget(self.convert_kwargs_to_cache_request(args, kwargs)) + if cache_result is not None: + response = self.convert_cache_to_response(cache_result) + cache_hit = True + if response is None: + response = await orig_func(*args, **kwargs) return response except Exception as e: error = str(e) raise finally: - self._cleanup_trace(trace_id, start_time, error, response, args, kwargs) + self._cleanup_trace(trace_id, start_time, error, response, cache_hit, args, kwargs) return wrapper @@ -89,21 +109,29 @@ def wrapper(*args, **kwargs): trace_id, start_time = self._init_trace() response = None error = None + cache_hit = False try: - response = orig_func(*args, **kwargs) + if self.cache: + cache_result = self.cache.get(self.convert_kwargs_to_cache_request(args, kwargs)) + if cache_result is not None: + response = self.convert_cache_to_response(cache_result) + cache_hit = True + if response is None: + response = orig_func(*args, **kwargs) return response except Exception as e: error = str(e) raise e finally: - self._cleanup_trace(trace_id, start_time, error, response, args, kwargs) + self._cleanup_trace(trace_id, start_time, error, response, cache_hit, args, kwargs) return wrapper - def _cleanup_trace(self, trace_id: str, start_time: float, error: str, response: Any, args, kwargs): + def _cleanup_trace(self, trace_id: str, start_time: float, error: str, response: Any, cache_hit, args, kwargs): end_time = time.time() trace_data.get()[trace_id].end_timestamp = to_date_and_time_string(end_time) trace_data.get()[trace_id].latency = end_time - start_time + trace_data.get()[trace_id].cache_hit = cache_hit if error: trace_data.get()[trace_id].error = error @@ -113,5 +141,7 @@ def _cleanup_trace(self, trace_id: str, start_time: float, error: str, response: self.resolver(trace_id, args, kwargs, response) + if not error and self.cache: + self.cache.set(self.convert_kwargs_to_cache_request(args, kwargs), trace_data.get()[trace_id]) self.log(trace_id) trace_context.get().pop() From 99ce13ae1732346b276bd32cc08ccf6d493c33a6 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Fri, 6 Oct 2023 17:36:25 +0200 Subject: [PATCH 02/37] style --- parea/cache/cache.py | 3 ++- parea/cache/redis.py | 11 ++++------- parea/client.py | 3 ++- parea/wrapper/openai.py | 5 +---- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/parea/cache/cache.py b/parea/cache/cache.py index 388d2867..55f60ea4 100644 --- a/parea/cache/cache.py +++ b/parea/cache/cache.py @@ -1,6 +1,7 @@ -from abc import ABC from typing import Optional +from abc import ABC + from parea.schemas.models import CacheRequest, TraceLog diff --git a/parea/cache/redis.py b/parea/cache/redis.py index da307786..d803908f 100644 --- a/parea/cache/redis.py +++ b/parea/cache/redis.py @@ -1,13 +1,14 @@ +from typing import Optional + import json import logging import os -from typing import Optional import redis from attr import asdict from parea.cache.cache import Cache -from parea.schemas.models import TraceLog, CacheRequest +from parea.schemas.models import CacheRequest, TraceLog logger = logging.getLogger() @@ -16,11 +17,7 @@ class RedisLRUCache(Cache): """A Redis-based LRU cache for caching both normal and streaming responses.""" def __init__( - self, - host: str = os.getenv('REDIS_HOST', 'localhost'), - port: int = int(os.getenv('REDIS_PORT', 6379)), - password: str = os.getenv('REDIS_PASSWORT', None), - ttl=3600 * 6 + self, host: str = os.getenv("REDIS_HOST", "localhost"), port: int = int(os.getenv("REDIS_PORT", 6379)), password: str = os.getenv("REDIS_PASSWORT", None), ttl=3600 * 6 ): """ Initialize the cache. diff --git a/parea/client.py b/parea/client.py index 209fbaf4..79c6b598 100644 --- a/parea/client.py +++ b/parea/client.py @@ -1,7 +1,8 @@ +from typing import Callable + import asyncio import os import time -from typing import Callable from attrs import asdict, define, field diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index f52264c9..49e7d2ab 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -47,10 +47,7 @@ class OpenAIWrapper: - original_methods = { - "ChatCompletion.create": openai.ChatCompletion.create, - "ChatCompletion.acreate": openai.ChatCompletion.acreate - } + original_methods = {"ChatCompletion.create": openai.ChatCompletion.create, "ChatCompletion.acreate": openai.ChatCompletion.acreate} def init(self, log: Callable, cache: Cache = None): Wrapper( From e37aacdd55fe03707b3c9c1e5ba631a69e568ee1 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Fri, 6 Oct 2023 18:53:08 +0200 Subject: [PATCH 03/37] feat: execute user function via cli --- .../tracing_with_open_ai_endpoint_directly.py | 2 +- parea/tester.py | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 parea/tester.py diff --git a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py index 4674955b..8161ed7b 100644 --- a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py +++ b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py @@ -12,7 +12,7 @@ openai.api_key = os.getenv("OPENAI_API_KEY") -p = Parea(api_key=os.getenv("DEV_API_KEY")) +p = Parea(api_key=os.getenv("PAREA_API_KEY")) def call_llm(data: list[dict], model: str = "gpt-3.5-turbo", temperature: float = 0.0) -> str: diff --git a/parea/tester.py b/parea/tester.py new file mode 100644 index 00000000..8ed84b70 --- /dev/null +++ b/parea/tester.py @@ -0,0 +1,43 @@ +import concurrent +import os +import csv +import argparse +import sys +from importlib import machinery, util + + +def load_from_path(path_to_module, attr_name): + module_name = os.path.basename(path_to_module) + loader = machinery.SourceFileLoader(module_name, path_to_module) + spec = util.spec_from_file_location(module_name, path_to_module, loader=loader) + module = util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + fn = getattr(module, attr_name) + return fn + + +def read_input_file(file_path): + with open(file_path, 'r') as file: + reader = csv.reader(file) + inputs = list(reader) + return inputs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--user_func", help="User function to test e.g., path/to/user_code.py:argument_chain", type=str) + parser.add_argument("--inputs", help="Path to the input CSV file", type=str) + args = parser.parse_args() + + fn = load_from_path(*args.user_func.rsplit(":", 1)) + + data_inputs = read_input_file(args.inputs) + + with concurrent.futures.ThreadPoolExecutor() as executor: + results = list(executor.map(fn, data_inputs)) + + for i, result in enumerate(results): + print(f'input: {data_inputs[i]}') + print(f'result: {result}') + print() From b2b37b701c149d0b4b8d238de0b34466231b475c Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Fri, 6 Oct 2023 22:21:08 +0200 Subject: [PATCH 04/37] feat: execute user function via cli in parallel with traces working --- parea/tester.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/parea/tester.py b/parea/tester.py index 8ed84b70..70a78433 100644 --- a/parea/tester.py +++ b/parea/tester.py @@ -1,4 +1,5 @@ import concurrent +import importlib import os import csv import argparse @@ -6,17 +7,28 @@ from importlib import machinery, util -def load_from_path(path_to_module, attr_name): - module_name = os.path.basename(path_to_module) - loader = machinery.SourceFileLoader(module_name, path_to_module) - spec = util.spec_from_file_location(module_name, path_to_module, loader=loader) - module = util.module_from_spec(spec) - sys.modules[spec.name] = module +def load_from_path(module_path, attr_name): + # Ensure the directory of user-provided script is in the system path + dir_name = os.path.dirname(module_path) + if dir_name not in sys.path: + sys.path.insert(0, dir_name) + + module_name = os.path.basename(module_path) + # Add .py extension back in to allow import correctly + module_path_with_ext = f"{module_path}.py" + + spec = importlib.util.spec_from_file_location(module_name, module_path_with_ext) + module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) + + if spec.name not in sys.modules: + sys.modules[spec.name] = module + fn = getattr(module, attr_name) return fn + def read_input_file(file_path): with open(file_path, 'r') as file: reader = csv.reader(file) @@ -34,7 +46,7 @@ def read_input_file(file_path): data_inputs = read_input_file(args.inputs) - with concurrent.futures.ThreadPoolExecutor() as executor: + with concurrent.futures.ProcessPoolExecutor() as executor: results = list(executor.map(fn, data_inputs)) for i, result in enumerate(results): From c819533ee762c67d69da50a6196e99b5c4322a42 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Sun, 8 Oct 2023 08:50:30 +0200 Subject: [PATCH 05/37] feat: write logs to csv --- parea/api_client.py | 2 +- parea/cache/redis.py | 36 +++++++++++++++++-- parea/client.py | 19 +++++----- .../tracing_with_open_ai_endpoint_directly.py | 1 - parea/parea_logger.py | 14 ++++++++ parea/tester.py | 30 ++++++++++++---- parea/utils/trace_utils.py | 30 ++++++++++------ parea/wrapper/wrapper.py | 1 + 8 files changed, 104 insertions(+), 29 deletions(-) diff --git a/parea/api_client.py b/parea/api_client.py index 1546a5e2..da41d8d4 100644 --- a/parea/api_client.py +++ b/parea/api_client.py @@ -5,7 +5,7 @@ class HTTPClient: _instance = None - base_url = "https://optimus-prompt-backend.vercel.app/api/parea/v1" + base_url = "https://parea-ai-backend-e2adf7624bcb3980.onporter.run/api/parea/v1" api_key = None def __new__(cls, *args, **kwargs): diff --git a/parea/cache/redis.py b/parea/cache/redis.py index d803908f..75f7947a 100644 --- a/parea/cache/redis.py +++ b/parea/cache/redis.py @@ -1,4 +1,6 @@ -from typing import Optional +import time +import uuid +from typing import Optional, List import json import logging @@ -13,11 +15,19 @@ logger = logging.getLogger() +def is_uuid(value: str) -> bool: + try: + uuid.UUID(value) + except ValueError: + return False + return True + + class RedisLRUCache(Cache): """A Redis-based LRU cache for caching both normal and streaming responses.""" def __init__( - self, host: str = os.getenv("REDIS_HOST", "localhost"), port: int = int(os.getenv("REDIS_PORT", 6379)), password: str = os.getenv("REDIS_PASSWORT", None), ttl=3600 * 6 + self, key_logs: str = os.getenv('_redis_logs_key', f'trace-logs-{time.time()}'), host: str = os.getenv("REDIS_HOST", "localhost"), port: int = int(os.getenv("REDIS_PORT", 6379)), password: str = os.getenv("REDIS_PASSWORT", None), ttl=3600 * 6 ): """ Initialize the cache. @@ -31,6 +41,7 @@ def __init__( password=password, ) self.ttl = ttl + self.key_logs = key_logs def get(self, key: CacheRequest) -> Optional[TraceLog]: try: @@ -67,3 +78,24 @@ def invalidate(self, key: CacheRequest): async def ainvalidate(self, key: CacheRequest): self.invalidate(key) + + def log(self, value: TraceLog): + try: + prev_logs = self.r.hget(self.key_logs, value.trace_id) + log_dict = asdict(value) + if prev_logs: + log_dict = {**json.loads(prev_logs), **log_dict} + self.r.hset(self.key_logs, value.trace_id, json.dumps(log_dict)) + except redis.RedisError as e: + logger.error(f"Error adding to list in cache: {e}") + + def read_logs(self) -> List[TraceLog]: + try: + trace_logs_raw = self.r.hgetall(self.key_logs) + trace_logs = [] + for trace_log_raw in trace_logs_raw.values(): + trace_logs.append(TraceLog(**json.loads(trace_log_raw))) + return trace_logs + except redis.RedisError as e: + logger.error(f"Error reading list from cache: {e}") + return [] diff --git a/parea/client.py b/parea/client.py index 79c6b598..a4a9f961 100644 --- a/parea/client.py +++ b/parea/client.py @@ -10,10 +10,11 @@ from parea.cache.cache import Cache from parea.cache.redis import RedisLRUCache from parea.helpers import gen_trace_id -from parea.parea_logger import parea_logger -from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, UseDeployedPromptResponse -from parea.utils.trace_utils import default_logger, get_current_trace_id, trace_data +from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, \ + UseDeployedPromptResponse +from parea.utils.trace_utils import get_current_trace_id, trace_data, logger_record_log, logger_all_possible from parea.wrapper import OpenAIWrapper +from parea.parea_logger import parea_logger COMPLETION_ENDPOINT = "/completion" DEPLOYED_PROMPT_ENDPOINT = "/deployed-prompt" @@ -28,9 +29,11 @@ class Parea: def __attrs_post_init__(self): self._client.set_api_key(self.api_key) - parea_logger.set_client(self._client) - log = default_logger if self.api_key else lambda *args, **kwargs: None - _init_parea_wrapper(log, self.cache) + if self.api_key: + parea_logger.set_client(self._client) + if isinstance(self.cache, RedisLRUCache): + parea_logger.set_redis_lru_cache(self.cache) + _init_parea_wrapper(logger_all_possible, self.cache) def completion(self, data: Completion) -> CompletionResponse: inference_id = gen_trace_id() @@ -42,7 +45,7 @@ def completion(self, data: Completion) -> CompletionResponse: ) if parent_trace_id := get_current_trace_id(): trace_data.get()[parent_trace_id].children.append(inference_id) - default_logger(parent_trace_id) + logger_record_log(parent_trace_id) return CompletionResponse(**r.json()) async def acompletion(self, data: Completion) -> CompletionResponse: @@ -55,7 +58,7 @@ async def acompletion(self, data: Completion) -> CompletionResponse: ) if parent_trace_id := get_current_trace_id(): trace_data.get()[parent_trace_id].children.append(inference_id) - default_logger(parent_trace_id) + logger_record_log(parent_trace_id) return CompletionResponse(**r.json()) def get_prompt(self, data: UseDeployedPrompt) -> UseDeployedPromptResponse: diff --git a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py index 8161ed7b..eb34b2a4 100644 --- a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py +++ b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py @@ -5,7 +5,6 @@ from dotenv import load_dotenv from parea import Parea -from parea.schemas.models import FeedbackRequest from parea.utils.trace_utils import get_current_trace_id, trace load_dotenv() diff --git a/parea/parea_logger.py b/parea/parea_logger.py index e6f70c34..b8e34b35 100644 --- a/parea/parea_logger.py +++ b/parea/parea_logger.py @@ -1,6 +1,7 @@ from attrs import asdict, define, field from parea.api_client import HTTPClient +from parea.cache.redis import RedisLRUCache from parea.schemas.models import TraceLog LOG_ENDPOINT = "/trace_log" @@ -9,10 +10,14 @@ @define class PareaLogger: _client: HTTPClient = field(init=False) + _redis_lru_cache: RedisLRUCache = field(init=False) def set_client(self, client: HTTPClient) -> None: self._client = client + def set_redis_lru_cache(self, cache: RedisLRUCache) -> None: + self._redis_lru_cache = cache + def record_log(self, data: TraceLog) -> None: self._client.request( "POST", @@ -27,5 +32,14 @@ async def arecord_log(self, data: TraceLog) -> None: data=asdict(data), ) + def write_log(self, data: TraceLog) -> None: + self._redis_lru_cache.log(data) + + def default_log(self, data: TraceLog) -> None: + if self._redis_lru_cache: + self.write_log(data) + if self._client: + self.record_log(data) + parea_logger = PareaLogger() diff --git a/parea/tester.py b/parea/tester.py index 70a78433..3637c2ed 100644 --- a/parea/tester.py +++ b/parea/tester.py @@ -4,7 +4,14 @@ import csv import argparse import sys -from importlib import machinery, util +import time +from importlib import util +from typing import List + +from attr import fields_dict, asdict + +from parea.cache.redis import RedisLRUCache +from parea.schemas.models import TraceLog def load_from_path(module_path, attr_name): @@ -28,7 +35,6 @@ def load_from_path(module_path, attr_name): return fn - def read_input_file(file_path): with open(file_path, 'r') as file: reader = csv.reader(file) @@ -46,10 +52,22 @@ def read_input_file(file_path): data_inputs = read_input_file(args.inputs) + redis_logs_key = f'parea-trace-logs-{int(time.time())}' + os.putenv('_redis_logs_key', redis_logs_key) + with concurrent.futures.ProcessPoolExecutor() as executor: results = list(executor.map(fn, data_inputs)) - for i, result in enumerate(results): - print(f'input: {data_inputs[i]}') - print(f'result: {result}') - print() + redis_cache = RedisLRUCache(key_logs=redis_logs_key) + + trace_logs: List[TraceLog] = redis_cache.read_logs() + + # write to csv + with open(f'trace_logs-{int(time.time())}.csv', 'w', newline='') as file: + # write header + columns = fields_dict(TraceLog).keys() + writer = csv.DictWriter(file, fieldnames=columns) + writer.writeheader() + # write rows + for trace_log in trace_logs: + writer.writerow(asdict(trace_log)) diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index 840c19c9..246c7aec 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -1,10 +1,10 @@ -from typing import Any, List, Optional, Tuple, Union +import threading +from typing import Any, Optional, Union import contextvars import inspect import json import logging -import threading import time from collections import ChainMap from functools import wraps @@ -87,7 +87,7 @@ def cleanup_trace(trace_id, start_time): end_time = time.time() trace_data.get()[trace_id].end_timestamp = to_date_and_time_string(end_time) trace_data.get()[trace_id].latency = end_time - start_time - default_logger(trace_id) + logger_all_possible(trace_id) trace_context.get().pop() def decorator(func): @@ -138,14 +138,6 @@ def wrapper(*args, **kwargs): return decorator -def default_logger(trace_id: str): - logging_thread = threading.Thread( - target=parea_logger.record_log, - kwargs={"data": trace_data.get()[trace_id]}, - ) - logging_thread.start() - - def check_multiple_return_values(func) -> bool: specs = inspect.getfullargspec(func) try: @@ -161,3 +153,19 @@ def make_output(result, islist) -> Union[list[Any], Any]: return [asdict(r) if isinstance(r, CompletionResponse) else r for r in result] else: return asdict(result) if isinstance(result, CompletionResponse) else result + + +def logger_record_log(trace_id: str): + logging_thread = threading.Thread( + target=parea_logger.record_log, + kwargs={"data": trace_data.get()[trace_id]}, + ) + logging_thread.start() + + +def logger_all_possible(trace_id: str): + logging_thread = threading.Thread( + target=parea_logger.default_log, + kwargs={"data": trace_data.get()[trace_id]}, + ) + logging_thread.start() diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index 780eebd2..5159ff9b 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -143,5 +143,6 @@ def _cleanup_trace(self, trace_id: str, start_time: float, error: str, response: if not error and self.cache: self.cache.set(self.convert_kwargs_to_cache_request(args, kwargs), trace_data.get()[trace_id]) + self.log(trace_id) trace_context.get().pop() From 7b1a8958ca7eb920e160f649b010fb1d3a3e657d Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Mon, 9 Oct 2023 10:28:27 +0200 Subject: [PATCH 06/37] feat: add progress bar --- parea/tester.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/parea/tester.py b/parea/tester.py index 3637c2ed..ed266978 100644 --- a/parea/tester.py +++ b/parea/tester.py @@ -9,6 +9,7 @@ from typing import List from attr import fields_dict, asdict +from tqdm import tqdm from parea.cache.redis import RedisLRUCache from parea.schemas.models import TraceLog @@ -56,18 +57,24 @@ def read_input_file(file_path): os.putenv('_redis_logs_key', redis_logs_key) with concurrent.futures.ProcessPoolExecutor() as executor: - results = list(executor.map(fn, data_inputs)) - - redis_cache = RedisLRUCache(key_logs=redis_logs_key) - - trace_logs: List[TraceLog] = redis_cache.read_logs() - - # write to csv - with open(f'trace_logs-{int(time.time())}.csv', 'w', newline='') as file: - # write header - columns = fields_dict(TraceLog).keys() - writer = csv.DictWriter(file, fieldnames=columns) - writer.writeheader() - # write rows - for trace_log in trace_logs: - writer.writerow(asdict(trace_log)) + futures = [executor.submit(fn, data_input) for data_input in data_inputs] + for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): + pass + print(f"Done with {len(futures)} inputs") + + redis_cache = RedisLRUCache(key_logs=redis_logs_key) + + trace_logs: List[TraceLog] = redis_cache.read_logs() + + # write to csv + path_csv = f'trace_logs-{int(time.time())}.csv' + with open(path_csv, 'w', newline='') as file: + # write header + columns = fields_dict(TraceLog).keys() + writer = csv.DictWriter(file, fieldnames=columns) + writer.writeheader() + # write rows + for trace_log in trace_logs: + writer.writerow(asdict(trace_log)) + + print(f'Wrote CSV of results to: {path_csv}') From 04521db6614e4ded71b33b8a2dff7b1236b63d82 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Mon, 9 Oct 2023 14:52:23 +0200 Subject: [PATCH 07/37] feat: enable streaming --- parea/wrapper/openai.py | 60 +++++++++++++++++++++++++++++++++++----- parea/wrapper/wrapper.py | 52 ++++++++++++++++++++++++---------- 2 files changed, 91 insertions(+), 21 deletions(-) diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index 49e7d2ab..4dda6555 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Optional, Sequence +from collections import defaultdict +from typing import Any, Callable, Dict, Optional, Sequence, Iterator import json @@ -52,6 +53,8 @@ class OpenAIWrapper: def init(self, log: Callable, cache: Cache = None): Wrapper( resolver=self.resolver, + gen_resolver=self.gen_resolver, + agen_resolver=self.agen_resolver, log=log, module=openai, func_names=list(self.original_methods.keys()), @@ -60,19 +63,20 @@ def init(self, log: Callable, cache: Cache = None): convert_cache_to_response=self.convert_cache_to_response, ) - def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Any]): + @staticmethod + def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Any]) -> Optional[Any]: if response: usage = response["usage"] - output = self._get_output(response) + output = OpenAIWrapper._get_output(response) else: output = None usage = {} - llm_configuration = self._kwargs_to_llm_configuration(kwargs) + llm_configuration = OpenAIWrapper._kwargs_to_llm_configuration(kwargs) model = llm_configuration.model - model_rate = self.get_model_cost(model) - model_completion_rate = self.get_model_cost(model, is_completion=True) + model_rate = OpenAIWrapper.get_model_cost(model) + model_completion_rate = OpenAIWrapper.get_model_cost(model, is_completion=True) completion_cost = model_completion_rate * (usage.get("completion_tokens", 0) / 1000) prompt_cost = model_rate * (usage.get("prompt_tokens", 0) / 1000) total_cost = sum([prompt_cost, completion_cost]) @@ -83,6 +87,39 @@ def resolver(self, trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], trace_data.get()[trace_id].total_tokens = usage.get("total_tokens", 0) trace_data.get()[trace_id].cost = total_cost trace_data.get()[trace_id].output = output + return response + + @staticmethod + def gen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Iterator[Any], final_log) -> Iterator[Any]: + llm_configuration = OpenAIWrapper._kwargs_to_llm_configuration(kwargs) + trace_data.get()[trace_id].configuration = llm_configuration + + message = defaultdict(str) + for chunk in response: + update_dict = chunk.choices[0].delta._previous + for key, val in update_dict.items(): + message[key] += val + yield chunk + + trace_data.get()[trace_id].output = OpenAIWrapper._get_output(message) + + final_log() + + @staticmethod + async def agen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Iterator[Any], final_log) -> Iterator[Any]: + llm_configuration = OpenAIWrapper._kwargs_to_llm_configuration(kwargs) + trace_data.get()[trace_id].configuration = llm_configuration + + message = defaultdict(str) + async for chunk in response: + update_dict = chunk.choices[0].delta._previous + for key, val in update_dict.items(): + message[key] += val + yield chunk + + trace_data.get()[trace_id].output = OpenAIWrapper._get_output(message) + + final_log() @staticmethod def _kwargs_to_llm_configuration(kwargs): @@ -102,7 +139,16 @@ def _kwargs_to_llm_configuration(kwargs): ) @staticmethod - def _get_output(result) -> str: + def _get_output(result: Any) -> str: + if not isinstance(result, OpenAIObject): + result = convert_to_openai_object({ + "choices": [ + { + "index": 0, + "message": result, + } + ] + }) response_message = result.choices[0].message if response_message.get("function_call", None): completion = OpenAIWrapper._format_function_call(response_message) diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index 5159ff9b..c1dc1350 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Tuple, Iterator import functools import inspect @@ -16,12 +16,16 @@ def __init__( module: Any, func_names: List[str], resolver: Callable, + gen_resolver: Callable, + agen_resolver: Callable, cache: Cache, convert_kwargs_to_cache_request: Callable, convert_cache_to_response: Callable, log: Callable, ) -> None: self.resolver = resolver + self.gen_resolver = gen_resolver + self.agen_resolver = agen_resolver self.log = log self.wrap_functions(module, func_names) self.cache = cache @@ -95,12 +99,11 @@ async def wrapper(*args, **kwargs): cache_hit = True if response is None: response = await orig_func(*args, **kwargs) - return response except Exception as e: error = str(e) raise finally: - self._cleanup_trace(trace_id, start_time, error, response, cache_hit, args, kwargs) + return await self._acleanup_trace(trace_id, start_time, error, cache_hit, args, kwargs, response) return wrapper @@ -118,19 +121,15 @@ def wrapper(*args, **kwargs): cache_hit = True if response is None: response = orig_func(*args, **kwargs) - return response except Exception as e: error = str(e) raise e finally: - self._cleanup_trace(trace_id, start_time, error, response, cache_hit, args, kwargs) + return self._cleanup_trace(trace_id, start_time, error, cache_hit, args, kwargs, response) return wrapper - def _cleanup_trace(self, trace_id: str, start_time: float, error: str, response: Any, cache_hit, args, kwargs): - end_time = time.time() - trace_data.get()[trace_id].end_timestamp = to_date_and_time_string(end_time) - trace_data.get()[trace_id].latency = end_time - start_time + def _cleanup_trace_core(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response): trace_data.get()[trace_id].cache_hit = cache_hit if error: @@ -139,10 +138,35 @@ def _cleanup_trace(self, trace_id: str, start_time: float, error: str, response: else: trace_data.get()[trace_id].status = "success" - self.resolver(trace_id, args, kwargs, response) + def final_log(): + end_time = time.time() + trace_data.get()[trace_id].end_timestamp = to_date_and_time_string(end_time) + trace_data.get()[trace_id].latency = end_time - start_time - if not error and self.cache: - self.cache.set(self.convert_kwargs_to_cache_request(args, kwargs), trace_data.get()[trace_id]) + if not error and self.cache: + self.cache.set(self.convert_kwargs_to_cache_request(args, kwargs), trace_data.get()[trace_id]) - self.log(trace_id) - trace_context.get().pop() + self.log(trace_id) + trace_context.get().pop() + + return final_log + + def _cleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response): + final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs, response) + + if isinstance(response, Iterator): + return self.gen_resolver(trace_id, args, kwargs, response, final_log) + else: + self.resolver(trace_id, args, kwargs, response) + final_log() + return response + + async def _acleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response): + final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs, response) + + if isinstance(response, Iterator): + return await self.agen_resolver(trace_id, args, kwargs, response, final_log) + else: + self.resolver(trace_id, args, kwargs, response) + final_log() + return response From 706fdd0f50ded697027b409f3e5234ee7775ece3 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Mon, 9 Oct 2023 20:10:31 +0200 Subject: [PATCH 08/37] feat: enable async --- parea/wrapper/wrapper.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index c1dc1350..62348cf8 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Tuple, Iterator +from typing import Any, Callable, List, Tuple, Iterator, AsyncIterator import functools import inspect @@ -91,9 +91,10 @@ async def wrapper(*args, **kwargs): response = None error = None cache_hit = False + cache_key = self.convert_kwargs_to_cache_request(args, kwargs) try: if self.cache: - cache_result = await self.cache.aget(self.convert_kwargs_to_cache_request(args, kwargs)) + cache_result = await self.cache.aget(cache_key) if cache_result is not None: response = self.convert_cache_to_response(cache_result) cache_hit = True @@ -101,6 +102,7 @@ async def wrapper(*args, **kwargs): response = await orig_func(*args, **kwargs) except Exception as e: error = str(e) + await self.cache.ainvalidate(cache_key) raise finally: return await self._acleanup_trace(trace_id, start_time, error, cache_hit, args, kwargs, response) @@ -113,9 +115,10 @@ def wrapper(*args, **kwargs): response = None error = None cache_hit = False + cache_key = self.convert_kwargs_to_cache_request(args, kwargs) try: if self.cache: - cache_result = self.cache.get(self.convert_kwargs_to_cache_request(args, kwargs)) + cache_result = self.cache.get(cache_key) if cache_result is not None: response = self.convert_cache_to_response(cache_result) cache_hit = True @@ -123,6 +126,8 @@ def wrapper(*args, **kwargs): response = orig_func(*args, **kwargs) except Exception as e: error = str(e) + if self.cache: + self.cache.invalidate(cache_key) raise e finally: return self._cleanup_trace(trace_id, start_time, error, cache_hit, args, kwargs, response) @@ -164,8 +169,8 @@ def _cleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit async def _acleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response): final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs, response) - if isinstance(response, Iterator): - return await self.agen_resolver(trace_id, args, kwargs, response, final_log) + if isinstance(response, AsyncIterator): + return self.agen_resolver(trace_id, args, kwargs, response, final_log) else: self.resolver(trace_id, args, kwargs, response) final_log() From 4bc90a8c3eeab8c1424fa1f5dbb4d601081491e9 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Mon, 9 Oct 2023 20:12:58 +0200 Subject: [PATCH 09/37] style --- parea/cache/redis.py | 13 +++++++++---- parea/client.py | 7 +++---- parea/tester.py | 23 ++++++++++++----------- parea/utils/trace_utils.py | 2 +- parea/wrapper/openai.py | 22 ++++++++++++---------- parea/wrapper/wrapper.py | 2 +- 6 files changed, 38 insertions(+), 31 deletions(-) diff --git a/parea/cache/redis.py b/parea/cache/redis.py index 75f7947a..97963df9 100644 --- a/parea/cache/redis.py +++ b/parea/cache/redis.py @@ -1,10 +1,10 @@ -import time -import uuid -from typing import Optional, List +from typing import List, Optional import json import logging import os +import time +import uuid import redis from attr import asdict @@ -27,7 +27,12 @@ class RedisLRUCache(Cache): """A Redis-based LRU cache for caching both normal and streaming responses.""" def __init__( - self, key_logs: str = os.getenv('_redis_logs_key', f'trace-logs-{time.time()}'), host: str = os.getenv("REDIS_HOST", "localhost"), port: int = int(os.getenv("REDIS_PORT", 6379)), password: str = os.getenv("REDIS_PASSWORT", None), ttl=3600 * 6 + self, + key_logs: str = os.getenv("_redis_logs_key", f"trace-logs-{time.time()}"), + host: str = os.getenv("REDIS_HOST", "localhost"), + port: int = int(os.getenv("REDIS_PORT", 6379)), + password: str = os.getenv("REDIS_PASSWORT", None), + ttl=3600 * 6, ): """ Initialize the cache. diff --git a/parea/client.py b/parea/client.py index a4a9f961..2df2bdd4 100644 --- a/parea/client.py +++ b/parea/client.py @@ -10,11 +10,10 @@ from parea.cache.cache import Cache from parea.cache.redis import RedisLRUCache from parea.helpers import gen_trace_id -from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, \ - UseDeployedPromptResponse -from parea.utils.trace_utils import get_current_trace_id, trace_data, logger_record_log, logger_all_possible -from parea.wrapper import OpenAIWrapper from parea.parea_logger import parea_logger +from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, UseDeployedPromptResponse +from parea.utils.trace_utils import get_current_trace_id, logger_all_possible, logger_record_log, trace_data +from parea.wrapper import OpenAIWrapper COMPLETION_ENDPOINT = "/completion" DEPLOYED_PROMPT_ENDPOINT = "/deployed-prompt" diff --git a/parea/tester.py b/parea/tester.py index ed266978..5dc0e969 100644 --- a/parea/tester.py +++ b/parea/tester.py @@ -1,14 +1,15 @@ +from typing import List + +import argparse import concurrent +import csv import importlib import os -import csv -import argparse import sys import time from importlib import util -from typing import List -from attr import fields_dict, asdict +from attr import asdict, fields_dict from tqdm import tqdm from parea.cache.redis import RedisLRUCache @@ -37,7 +38,7 @@ def load_from_path(module_path, attr_name): def read_input_file(file_path): - with open(file_path, 'r') as file: + with open(file_path) as file: reader = csv.reader(file) inputs = list(reader) return inputs @@ -53,8 +54,8 @@ def read_input_file(file_path): data_inputs = read_input_file(args.inputs) - redis_logs_key = f'parea-trace-logs-{int(time.time())}' - os.putenv('_redis_logs_key', redis_logs_key) + redis_logs_key = f"parea-trace-logs-{int(time.time())}" + os.putenv("_redis_logs_key", redis_logs_key) with concurrent.futures.ProcessPoolExecutor() as executor: futures = [executor.submit(fn, data_input) for data_input in data_inputs] @@ -64,11 +65,11 @@ def read_input_file(file_path): redis_cache = RedisLRUCache(key_logs=redis_logs_key) - trace_logs: List[TraceLog] = redis_cache.read_logs() + trace_logs: list[TraceLog] = redis_cache.read_logs() # write to csv - path_csv = f'trace_logs-{int(time.time())}.csv' - with open(path_csv, 'w', newline='') as file: + path_csv = f"trace_logs-{int(time.time())}.csv" + with open(path_csv, "w", newline="") as file: # write header columns = fields_dict(TraceLog).keys() writer = csv.DictWriter(file, fieldnames=columns) @@ -77,4 +78,4 @@ def read_input_file(file_path): for trace_log in trace_logs: writer.writerow(asdict(trace_log)) - print(f'Wrote CSV of results to: {path_csv}') + print(f"Wrote CSV of results to: {path_csv}") diff --git a/parea/utils/trace_utils.py b/parea/utils/trace_utils.py index 246c7aec..d44f89b4 100644 --- a/parea/utils/trace_utils.py +++ b/parea/utils/trace_utils.py @@ -1,10 +1,10 @@ -import threading from typing import Any, Optional, Union import contextvars import inspect import json import logging +import threading import time from collections import ChainMap from functools import wraps diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index 4dda6555..a365a871 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -1,7 +1,7 @@ -from collections import defaultdict -from typing import Any, Callable, Dict, Optional, Sequence, Iterator +from typing import Any, Callable, Dict, Iterator, Optional, Sequence import json +from collections import defaultdict import openai from openai.openai_object import OpenAIObject @@ -141,14 +141,16 @@ def _kwargs_to_llm_configuration(kwargs): @staticmethod def _get_output(result: Any) -> str: if not isinstance(result, OpenAIObject): - result = convert_to_openai_object({ - "choices": [ - { - "index": 0, - "message": result, - } - ] - }) + result = convert_to_openai_object( + { + "choices": [ + { + "index": 0, + "message": result, + } + ] + } + ) response_message = result.choices[0].message if response_message.get("function_call", None): completion = OpenAIWrapper._format_function_call(response_message) diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index 62348cf8..715f1612 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Tuple, Iterator, AsyncIterator +from typing import Any, AsyncIterator, Callable, Iterator, List, Tuple import functools import inspect From 3b65d9f872301f6bce110fc3dec8f100f1c9668e Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Mon, 9 Oct 2023 21:05:14 +0200 Subject: [PATCH 10/37] fix: recreate iterable when streaming --- parea/wrapper/openai.py | 32 ++++++++++++++++++++++++++++---- parea/wrapper/wrapper.py | 9 ++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index a365a871..6785e8f4 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterator, Optional, Sequence +from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Union, AsyncIterator import json from collections import defaultdict @@ -61,6 +61,7 @@ def init(self, log: Callable, cache: Cache = None): cache=cache, convert_kwargs_to_cache_request=self.convert_kwargs_to_cache_request, convert_cache_to_response=self.convert_cache_to_response, + aconvert_cache_to_response=self.aconvert_cache_to_response, ) @staticmethod @@ -140,7 +141,7 @@ def _kwargs_to_llm_configuration(kwargs): @staticmethod def _get_output(result: Any) -> str: - if not isinstance(result, OpenAIObject): + if not isinstance(result, OpenAIObject) and isinstance(result, dict): result = convert_to_openai_object( { "choices": [ @@ -188,7 +189,7 @@ def convert_kwargs_to_cache_request(_args: Sequence[Any], kwargs: Dict[str, Any] ) @staticmethod - def convert_cache_to_response(cache_response: TraceLog) -> OpenAIObject: + def _convert_cache_to_response(_args: Sequence[Any], kwargs: Dict[str, Any], cache_response: TraceLog) -> OpenAIObject: content = cache_response.output try: function_call = json.loads(content) @@ -203,6 +204,8 @@ def convert_cache_to_response(cache_response: TraceLog) -> OpenAIObject: "content": content, } + message_field = 'delta' if kwargs.get('stream', False) else 'message' + return convert_to_openai_object( { "object": "chat.completion", @@ -210,7 +213,7 @@ def convert_cache_to_response(cache_response: TraceLog) -> OpenAIObject: "choices": [ { "index": 0, - "message": message, + message_field: message, } ], "usage": { @@ -220,3 +223,24 @@ def convert_cache_to_response(cache_response: TraceLog) -> OpenAIObject: }, } ) + + @staticmethod + def convert_cache_to_response(_args: Sequence[Any], kwargs: Dict[str, Any], cache_response: TraceLog) -> Union[OpenAIObject, Iterator[OpenAIObject]]: + response = OpenAIWrapper._convert_cache_to_response(_args, kwargs, cache_response) + if kwargs.get("stream", False): + return iter([response]) + else: + return response + + @staticmethod + def aconvert_cache_to_response(_args: Sequence[Any], kwargs: Dict[str, Any], cache_response: TraceLog) -> Union[OpenAIObject, AsyncIterator[OpenAIObject]]: + response = OpenAIWrapper._convert_cache_to_response(_args, kwargs, cache_response) + if kwargs.get("stream", False): + def aiterator(iterable): + async def gen(): + for item in iterable: + yield item + return gen() + return aiterator([response]) + else: + return response diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index 715f1612..0dee680d 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -21,6 +21,7 @@ def __init__( cache: Cache, convert_kwargs_to_cache_request: Callable, convert_cache_to_response: Callable, + aconvert_cache_to_response: Callable, log: Callable, ) -> None: self.resolver = resolver @@ -31,6 +32,7 @@ def __init__( self.cache = cache self.convert_kwargs_to_cache_request = convert_kwargs_to_cache_request self.convert_cache_to_response = convert_cache_to_response + self.aconvert_cache_to_response = aconvert_cache_to_response def wrap_functions(self, module: Any, func_names: List[str]): for func_name in func_names: @@ -96,13 +98,14 @@ async def wrapper(*args, **kwargs): if self.cache: cache_result = await self.cache.aget(cache_key) if cache_result is not None: - response = self.convert_cache_to_response(cache_result) + response = self.aconvert_cache_to_response(args, kwargs, cache_result) cache_hit = True if response is None: response = await orig_func(*args, **kwargs) except Exception as e: error = str(e) - await self.cache.ainvalidate(cache_key) + if self.cache: + await self.cache.ainvalidate(cache_key) raise finally: return await self._acleanup_trace(trace_id, start_time, error, cache_hit, args, kwargs, response) @@ -120,7 +123,7 @@ def wrapper(*args, **kwargs): if self.cache: cache_result = self.cache.get(cache_key) if cache_result is not None: - response = self.convert_cache_to_response(cache_result) + response = self.convert_cache_to_response(args, kwargs, cache_result) cache_hit = True if response is None: response = orig_func(*args, **kwargs) From c495280b5b462fab59e0b7d650b9c6e78be39365 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Mon, 9 Oct 2023 21:36:27 +0200 Subject: [PATCH 11/37] fix: avoid nonetype issue --- parea/wrapper/openai.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index 6785e8f4..3153591a 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -99,7 +99,8 @@ def gen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], re for chunk in response: update_dict = chunk.choices[0].delta._previous for key, val in update_dict.items(): - message[key] += val + if isinstance(val, str): + message[key] += val yield chunk trace_data.get()[trace_id].output = OpenAIWrapper._get_output(message) @@ -115,7 +116,8 @@ async def agen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, A async for chunk in response: update_dict = chunk.choices[0].delta._previous for key, val in update_dict.items(): - message[key] += val + if isinstance(val, str): + message[key] += val yield chunk trace_data.get()[trace_id].output = OpenAIWrapper._get_output(message) From a51adb2e292c10f22bf05589fbdf64a8dea10973 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 08:42:19 +0200 Subject: [PATCH 12/37] fix: non string response types from oai --- parea/wrapper/openai.py | 10 +++++----- parea/wrapper/wrapper.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index 3153591a..8c101ab1 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -99,8 +99,8 @@ def gen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], re for chunk in response: update_dict = chunk.choices[0].delta._previous for key, val in update_dict.items(): - if isinstance(val, str): - message[key] += val + if val is not None: + message[key] += str(val) yield chunk trace_data.get()[trace_id].output = OpenAIWrapper._get_output(message) @@ -108,7 +108,7 @@ def gen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], re final_log() @staticmethod - async def agen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Iterator[Any], final_log) -> Iterator[Any]: + async def agen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: AsyncIterator[Any], final_log) -> AsyncIterator[Any]: llm_configuration = OpenAIWrapper._kwargs_to_llm_configuration(kwargs) trace_data.get()[trace_id].configuration = llm_configuration @@ -116,8 +116,8 @@ async def agen_resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, A async for chunk in response: update_dict = chunk.choices[0].delta._previous for key, val in update_dict.items(): - if isinstance(val, str): - message[key] += val + if val is not None: + message[key] += str(val) yield chunk trace_data.get()[trace_id].output = OpenAIWrapper._get_output(message) diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index 0dee680d..a32d124c 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -137,7 +137,7 @@ def wrapper(*args, **kwargs): return wrapper - def _cleanup_trace_core(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response): + def _cleanup_trace_core(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs): trace_data.get()[trace_id].cache_hit = cache_hit if error: @@ -160,7 +160,7 @@ def final_log(): return final_log def _cleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response): - final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs, response) + final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs) if isinstance(response, Iterator): return self.gen_resolver(trace_id, args, kwargs, response, final_log) @@ -170,7 +170,7 @@ def _cleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit return response async def _acleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response): - final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs, response) + final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs) if isinstance(response, AsyncIterator): return self.agen_resolver(trace_id, args, kwargs, response, final_log) From 6a89c20bf64158b98c629d5f5e2ca9c71172b515 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 08:58:54 +0200 Subject: [PATCH 13/37] fix: reconstruction of cached response --- parea/wrapper/openai.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index 8c101ab1..7220c684 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -193,18 +193,16 @@ def convert_kwargs_to_cache_request(_args: Sequence[Any], kwargs: Dict[str, Any] @staticmethod def _convert_cache_to_response(_args: Sequence[Any], kwargs: Dict[str, Any], cache_response: TraceLog) -> OpenAIObject: content = cache_response.output + message = {"role": "assistant"} try: function_call = json.loads(content) - message = { - "role": "assistant", - "content": None, - "function_call": function_call, - } + if isinstance(function_call, dict) and "name" in function_call and "arguments" in function_call and len(function_call) == 2: + message["function_call"] = function_call + message["content"] = None + else: + message["content"] = content except json.JSONDecodeError: - message = { - "role": "assistant", - "content": content, - } + message["content"] = content message_field = 'delta' if kwargs.get('stream', False) else 'message' From 8bc16c417b3f20cdf9f22119db53a862ceb8e1dd Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 09:08:20 +0200 Subject: [PATCH 14/37] fix: make executable without setting api key --- parea/parea_logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parea/parea_logger.py b/parea/parea_logger.py index b8e34b35..0fdd8940 100644 --- a/parea/parea_logger.py +++ b/parea/parea_logger.py @@ -9,8 +9,8 @@ @define class PareaLogger: - _client: HTTPClient = field(init=False) - _redis_lru_cache: RedisLRUCache = field(init=False) + _client: HTTPClient = field(init=False, default=None) + _redis_lru_cache: RedisLRUCache = field(init=False, default=None) def set_client(self, client: HTTPClient) -> None: self._client = client From 6277814cf1899d0ebab6aa9c93a6265263c2575c Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 09:30:01 +0200 Subject: [PATCH 15/37] refactor --- parea/cache/redis.py | 6 +++--- parea/client.py | 8 ++++---- parea/parea_logger.py | 6 +++--- parea/tester.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/parea/cache/redis.py b/parea/cache/redis.py index 97963df9..4e81d657 100644 --- a/parea/cache/redis.py +++ b/parea/cache/redis.py @@ -23,12 +23,12 @@ def is_uuid(value: str) -> bool: return True -class RedisLRUCache(Cache): - """A Redis-based LRU cache for caching both normal and streaming responses.""" +class RedisCache(Cache): + """A Redis-based cache for caching LLM responses.""" def __init__( self, - key_logs: str = os.getenv("_redis_logs_key", f"trace-logs-{time.time()}"), + key_logs: str = os.getenv("_parea_redis_logs_key", f"trace-logs-{time.time()}"), host: str = os.getenv("REDIS_HOST", "localhost"), port: int = int(os.getenv("REDIS_PORT", 6379)), password: str = os.getenv("REDIS_PASSWORT", None), diff --git a/parea/client.py b/parea/client.py index 2df2bdd4..cec37da6 100644 --- a/parea/client.py +++ b/parea/client.py @@ -8,7 +8,7 @@ from parea.api_client import HTTPClient from parea.cache.cache import Cache -from parea.cache.redis import RedisLRUCache +from parea.cache.redis import RedisCache from parea.helpers import gen_trace_id from parea.parea_logger import parea_logger from parea.schemas.models import Completion, CompletionResponse, FeedbackRequest, UseDeployedPrompt, UseDeployedPromptResponse @@ -24,13 +24,13 @@ class Parea: api_key: str = field(init=True, default="") _client: HTTPClient = field(init=False, default=HTTPClient()) - cache: Cache = field(init=True, default=RedisLRUCache()) + cache: Cache = field(init=True, default=RedisCache()) def __attrs_post_init__(self): self._client.set_api_key(self.api_key) if self.api_key: parea_logger.set_client(self._client) - if isinstance(self.cache, RedisLRUCache): + if isinstance(self.cache, RedisCache): parea_logger.set_redis_lru_cache(self.cache) _init_parea_wrapper(logger_all_possible, self.cache) @@ -96,7 +96,7 @@ async def arecord_feedback(self, data: FeedbackRequest) -> None: _initialized_parea_wrapper = False -def init(api_key: str = os.getenv("PAREA_API_KEY"), cache: Cache = RedisLRUCache()) -> None: +def init(api_key: str = os.getenv("PAREA_API_KEY"), cache: Cache = RedisCache()) -> None: Parea(api_key=api_key, cache=cache) diff --git a/parea/parea_logger.py b/parea/parea_logger.py index 0fdd8940..796f4999 100644 --- a/parea/parea_logger.py +++ b/parea/parea_logger.py @@ -1,7 +1,7 @@ from attrs import asdict, define, field from parea.api_client import HTTPClient -from parea.cache.redis import RedisLRUCache +from parea.cache.redis import RedisCache from parea.schemas.models import TraceLog LOG_ENDPOINT = "/trace_log" @@ -10,12 +10,12 @@ @define class PareaLogger: _client: HTTPClient = field(init=False, default=None) - _redis_lru_cache: RedisLRUCache = field(init=False, default=None) + _redis_lru_cache: RedisCache = field(init=False, default=None) def set_client(self, client: HTTPClient) -> None: self._client = client - def set_redis_lru_cache(self, cache: RedisLRUCache) -> None: + def set_redis_lru_cache(self, cache: RedisCache) -> None: self._redis_lru_cache = cache def record_log(self, data: TraceLog) -> None: diff --git a/parea/tester.py b/parea/tester.py index 5dc0e969..2f9acbc9 100644 --- a/parea/tester.py +++ b/parea/tester.py @@ -12,7 +12,7 @@ from attr import asdict, fields_dict from tqdm import tqdm -from parea.cache.redis import RedisLRUCache +from parea.cache.redis import RedisCache from parea.schemas.models import TraceLog @@ -55,7 +55,7 @@ def read_input_file(file_path): data_inputs = read_input_file(args.inputs) redis_logs_key = f"parea-trace-logs-{int(time.time())}" - os.putenv("_redis_logs_key", redis_logs_key) + os.putenv("_parea_redis_logs_key", redis_logs_key) with concurrent.futures.ProcessPoolExecutor() as executor: futures = [executor.submit(fn, data_input) for data_input in data_inputs] @@ -63,7 +63,7 @@ def read_input_file(file_path): pass print(f"Done with {len(futures)} inputs") - redis_cache = RedisLRUCache(key_logs=redis_logs_key) + redis_cache = RedisCache(key_logs=redis_logs_key) trace_logs: list[TraceLog] = redis_cache.read_logs() From 8b0a41ba6c3f4f9aa5916bb76042f851bc4f4f7d Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 09:58:08 +0200 Subject: [PATCH 16/37] docs: add local cache --- README.md | 42 +++++++++++++++++++++++++++++++++++++++-- parea/__init__.py | 1 + parea/cache/__init__.py | 1 + requirements.txt | 1 + 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b4926652..12fdec3d 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,45 @@ or install with `Poetry` poetry add parea-ai ``` -## Getting Started +## Debugging Chains & Agents + +You can iterate on your chains & agents much faster by using a local cache. This will allow you to make changes to your +code & prompts without waiting for all previous, valid completions. Simply add these two lines to the beginning your code and start +[a local redis cache](https://redis.io/docs/getting-started/install-stack/): + +```python +from parea import init + +init() +``` + +Above will use the default redis cache at `localhost:6379` with no password. You can also specify your redis database by: + +```python +from parea import init, RedisCache + +cache = RedisCache( + host=os.getenv("REDIS_HOST", "localhost"), # default value + port=int(os.getenv("REDIS_PORT", 6379)), # default value + password=os.getenv("REDIS_PASSWORT", None) # default value +) +init(cache=cache) # default value +``` + +### Automatically log all your LLM call traces + +You can automatically log all your LLM traces to the Parea dashboard by setting the `PAREA_API_KEY` environment variable or specifying it in the `init` function. +This will help you debug issues your customers are facing by stepping through the LLM call traces and recreating the issue +in your local setup & code. + +```python +from parea import init + +init(api_key=os.getenv("PAREA_API_KEY")) # default value +``` + + +## Use a deployed prompt ```python import os @@ -78,7 +116,7 @@ async def main_async(): print(deployed_prompt) ``` -### Logging results from LLM providers +### Logging results from LLM providers [Example] ```python import os diff --git a/parea/__init__.py b/parea/__init__.py index 9facc5c8..9d830f85 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -12,6 +12,7 @@ from importlib import metadata as importlib_metadata from parea.client import Parea, init +from parea.cache import RedisCache def get_version() -> str: diff --git a/parea/cache/__init__.py b/parea/cache/__init__.py index e69de29b..86b8c5ec 100644 --- a/parea/cache/__init__.py +++ b/parea/cache/__init__.py @@ -0,0 +1 @@ +from .redis import RedisCache diff --git a/requirements.txt b/requirements.txt index 991f7349..314d3f4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ httpx~=0.24.1 python-dotenv~=1.0.0 contextvars~=2.4.0 openai +redis~=5.0.1 From f98a621c0a97177e2508c39fa7a5c3a00a5300a1 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 09:59:09 +0200 Subject: [PATCH 17/37] fix: revert to original example --- .../tracing_with_open_ai_endpoint_directly.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py index eb34b2a4..90002ad4 100644 --- a/parea/cookbook/tracing_with_open_ai_endpoint_directly.py +++ b/parea/cookbook/tracing_with_open_ai_endpoint_directly.py @@ -5,6 +5,7 @@ from dotenv import load_dotenv from parea import Parea +from parea.schemas.models import FeedbackRequest from parea.utils.trace_utils import get_current_trace_id, trace load_dotenv() @@ -82,9 +83,9 @@ def argument_chain(query: str, additional_description: str = "") -> tuple[str, s additional_description="Provide a concise, few sentence argument on why sparkling wine is good for you.", ) print(result) - # # p.record_feedback( - # # FeedbackRequest( - # # trace_id=trace_id, - # # score=0.7, # 0.0 (bad) to 1.0 (good) - # # ) - # # ) + p.record_feedback( + FeedbackRequest( + trace_id=trace_id, + score=0.7, # 0.0 (bad) to 1.0 (good) + ) + ) From bbcfa10792a13981d85a5ca1abc5987816b15c74 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 09:59:41 +0200 Subject: [PATCH 18/37] style --- parea/__init__.py | 2 +- parea/wrapper/openai.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/parea/__init__.py b/parea/__init__.py index 9d830f85..c7c2b4a8 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -11,8 +11,8 @@ from importlib import metadata as importlib_metadata -from parea.client import Parea, init from parea.cache import RedisCache +from parea.client import Parea, init def get_version() -> str: diff --git a/parea/wrapper/openai.py b/parea/wrapper/openai.py index 7220c684..71d9f36b 100644 --- a/parea/wrapper/openai.py +++ b/parea/wrapper/openai.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Union, AsyncIterator +from typing import Any, AsyncIterator, Callable, Dict, Iterator, Optional, Sequence, Union import json from collections import defaultdict @@ -204,7 +204,7 @@ def _convert_cache_to_response(_args: Sequence[Any], kwargs: Dict[str, Any], cac except json.JSONDecodeError: message["content"] = content - message_field = 'delta' if kwargs.get('stream', False) else 'message' + message_field = "delta" if kwargs.get("stream", False) else "message" return convert_to_openai_object( { @@ -236,11 +236,14 @@ def convert_cache_to_response(_args: Sequence[Any], kwargs: Dict[str, Any], cach def aconvert_cache_to_response(_args: Sequence[Any], kwargs: Dict[str, Any], cache_response: TraceLog) -> Union[OpenAIObject, AsyncIterator[OpenAIObject]]: response = OpenAIWrapper._convert_cache_to_response(_args, kwargs, cache_response) if kwargs.get("stream", False): + def aiterator(iterable): async def gen(): for item in iterable: yield item + return gen() + return aiterator([response]) else: return response From 88817e26460342a20794d36cbf8ad86ca28ccac4 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 10:48:40 +0200 Subject: [PATCH 19/37] docs: update --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 12fdec3d..9300bca4 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ poetry add parea-ai ## Debugging Chains & Agents You can iterate on your chains & agents much faster by using a local cache. This will allow you to make changes to your -code & prompts without waiting for all previous, valid completions. Simply add these two lines to the beginning your code and start +code & prompts without waiting for all previous, valid LLM responses. Simply add these two lines to the beginning your code and start [a local redis cache](https://redis.io/docs/getting-started/install-stack/): ```python From 576b308d50e792b640eaad90f813c84ba1d96ea7 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 10:59:13 +0200 Subject: [PATCH 20/37] fix: fix lint --- parea/cache/cache.py | 12 ++++++------ parea/cache/redis.py | 6 +++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/parea/cache/cache.py b/parea/cache/cache.py index 55f60ea4..fa4b6ad1 100644 --- a/parea/cache/cache.py +++ b/parea/cache/cache.py @@ -6,7 +6,7 @@ class Cache(ABC): - def get(self, key: CacheRequest) -> Optional[TraceLog]: + def get(self, key: CacheRequest) -> Optional[TraceLog]: # noqa: DAR401, DAR202 """ Get a normal response from the cache. @@ -18,7 +18,7 @@ def get(self, key: CacheRequest) -> Optional[TraceLog]: """ raise NotImplementedError - async def aget(self, key: CacheRequest) -> Optional[TraceLog]: + async def aget(self, key: CacheRequest) -> Optional[TraceLog]: # noqa: DAR401, DAR202 """ Get a normal response from the cache. @@ -30,7 +30,7 @@ async def aget(self, key: CacheRequest) -> Optional[TraceLog]: """ raise NotImplementedError - def set(self, key: CacheRequest, value: TraceLog): + def set(self, key: CacheRequest, value: TraceLog): # noqa: DAR401 """ Set a normal response in the cache. @@ -40,7 +40,7 @@ def set(self, key: CacheRequest, value: TraceLog): """ raise NotImplementedError - async def aset(self, key: CacheRequest, value: TraceLog): + async def aset(self, key: CacheRequest, value: TraceLog): # noqa: DAR401 """ Set a normal response in the cache. @@ -50,7 +50,7 @@ async def aset(self, key: CacheRequest, value: TraceLog): """ raise NotImplementedError - def invalidate(self, key: CacheRequest): + def invalidate(self, key: CacheRequest): # noqa: DAR401 """ Invalidate a key in the cache. @@ -59,7 +59,7 @@ def invalidate(self, key: CacheRequest): """ raise NotImplementedError - async def ainvalidate(self, key: CacheRequest): + async def ainvalidate(self, key: CacheRequest): # noqa: DAR401 """ Invalidate a key in the cache. diff --git a/parea/cache/redis.py b/parea/cache/redis.py index 4e81d657..b274c264 100644 --- a/parea/cache/redis.py +++ b/parea/cache/redis.py @@ -38,6 +38,10 @@ def __init__( Initialize the cache. Args: + key_logs (str): The Redis key for the logs. + host (str): The Redis host. + port (int): The Redis port. + password (str): The Redis password. ttl (int): The default TTL for cache entries, in seconds. """ self.r = redis.Redis( @@ -74,7 +78,7 @@ def invalidate(self, key: CacheRequest): Invalidate a key in the cache. Args: - key (str): The cache key. + key (CacheRequest): The cache key. """ try: self.r.delete(json.dumps(asdict(key))) From 0453ab5b642fc09d137344be907728817f422570 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:04:05 +0200 Subject: [PATCH 21/37] fix: fix linter --- parea/cache/cache.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/parea/cache/cache.py b/parea/cache/cache.py index fa4b6ad1..e7a30177 100644 --- a/parea/cache/cache.py +++ b/parea/cache/cache.py @@ -6,7 +6,7 @@ class Cache(ABC): - def get(self, key: CacheRequest) -> Optional[TraceLog]: # noqa: DAR401, DAR202 + def get(self, key: CacheRequest) -> Optional[TraceLog]: """ Get a normal response from the cache. @@ -14,11 +14,14 @@ def get(self, key: CacheRequest) -> Optional[TraceLog]: # noqa: DAR401, DAR202 key (CacheRequest): The cache key. Returns: - The cached response, or None if the key was not found. + Optional[TraceLog]: The cached response, or None if the key was not found. + + Raises: + NotImplementedError: This method must be overridden in a subclass. """ raise NotImplementedError - async def aget(self, key: CacheRequest) -> Optional[TraceLog]: # noqa: DAR401, DAR202 + async def aget(self, key: CacheRequest) -> Optional[TraceLog]: """ Get a normal response from the cache. @@ -26,44 +29,59 @@ async def aget(self, key: CacheRequest) -> Optional[TraceLog]: # noqa: DAR401, key (CacheRequest): The cache key. Returns: - The cached response, or None if the key was not found. + Optional[TraceLog]: The cached response, or None if the key was not found. + + Raises: + NotImplementedError: This method must be overridden in a subclass. """ raise NotImplementedError - def set(self, key: CacheRequest, value: TraceLog): # noqa: DAR401 + def set(self, key: CacheRequest, value: TraceLog): """ Set a normal response in the cache. Args: key (CacheRequest): The cache key. value (TraceLog): The response to cache. + + Raises: + NotImplementedError: This method must be overridden in a subclass. """ raise NotImplementedError - async def aset(self, key: CacheRequest, value: TraceLog): # noqa: DAR401 + async def aset(self, key: CacheRequest, value: TraceLog): """ Set a normal response in the cache. Args: key (CacheRequest): The cache key. value (TraceLog): The response to cache. + + Raises: + NotImplementedError: This method must be overridden in a subclass. """ raise NotImplementedError - def invalidate(self, key: CacheRequest): # noqa: DAR401 + def invalidate(self, key: CacheRequest): """ Invalidate a key in the cache. Args: key (CacheRequest): The cache key. + + Raises: + NotImplementedError: This method must be overridden in a subclass. """ raise NotImplementedError - async def ainvalidate(self, key: CacheRequest): # noqa: DAR401 + async def ainvalidate(self, key: CacheRequest): """ Invalidate a key in the cache. Args: key (CacheRequest): The cache key. + + Raises: + NotImplementedError: This method must be overridden in a subclass. """ raise NotImplementedError From b8063b026c2ced596a72e8f1b92f77460c66c78f Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:08:55 +0200 Subject: [PATCH 22/37] fix: fix linter --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 3c46a08c..6f7c3ce9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,3 +2,4 @@ # https://github.com/terrencepreilly/darglint strictness = long docstring_style = google +ignore=DAR202 From 5bcaa3b99e843070001c1295725f09bcc6d3dd10 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:09:34 +0200 Subject: [PATCH 23/37] fix: fix linter --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 4c0c8c27..bdb5d79d 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ formatting: codestyle .PHONY: test test: PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=parea tests/ - poetry run coverage-badge -o assets/images/coverage.svg -f + poetry run coverage-badge -o assets/images/coverage.svg -f --ignore DAR202 .PHONY: check-codestyle check-codestyle: From 303dd0121aa9c53fcaf501b91e07093e6dc284f5 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:10:13 +0200 Subject: [PATCH 24/37] fix: fix linter --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 6f7c3ce9..3c46a08c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,4 +2,3 @@ # https://github.com/terrencepreilly/darglint strictness = long docstring_style = google -ignore=DAR202 From 08a5a3e10dd42d22e0d5ae5d4c1c79f99aeb5938 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:12:18 +0200 Subject: [PATCH 25/37] feat: add redis --- poetry.lock | 32 ++++++++++++++++++++++++++++++-- pyproject.toml | 1 + 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 7c0fc773..81c5b4a6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -3132,6 +3132,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3139,8 +3140,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3157,6 +3165,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3164,6 +3173,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3421,6 +3431,24 @@ files = [ [package.extras] full = ["numpy"] +[[package]] +name = "redis" +version = "5.0.1" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"}, + {file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""} + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "requests" version = "2.31.0" @@ -4271,4 +4299,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "d62f81f0d7b74d569d4ca077bef8acacc77e0dc349582d76d65bf97fd8054ce8" +content-hash = "43b1aaf47321c70e0bc0d47b975835fcef5bed8647428ee64fe7bbf0aac78368" diff --git a/pyproject.toml b/pyproject.toml index 6f500586..699986c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ pyupgrade = "^3.9.0" jupyter = "^1.0.0" contextvars = "^2.4" openai = "^0.27.9" +redis = "^5.0.1" [tool.poetry.dev-dependencies] bandit = "^1.7.1" From de8d0a2535f855701de19e72533e7840b20aa7cd Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:15:55 +0200 Subject: [PATCH 26/37] fix: ignore DAR202 --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index bdb5d79d..40eb833e 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ test: check-codestyle: poetry run isort --diff --check-only --settings-path pyproject.toml ./ poetry run black --diff --check --config pyproject.toml ./ - poetry run darglint --verbosity 2 parea tests + poetry run darglint --verbosity 2 --ignore DAR202 parea tests .PHONY: mypy mypy: From fad9ed9687d6017adb38ecfacc7543b5f3c22c20 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:19:00 +0200 Subject: [PATCH 27/37] fix: ignore DAR202 --- parea/cache/cache.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/parea/cache/cache.py b/parea/cache/cache.py index e7a30177..0e49bb0e 100644 --- a/parea/cache/cache.py +++ b/parea/cache/cache.py @@ -16,6 +16,8 @@ def get(self, key: CacheRequest) -> Optional[TraceLog]: Returns: Optional[TraceLog]: The cached response, or None if the key was not found. + # noqa: DAR202 + Raises: NotImplementedError: This method must be overridden in a subclass. """ @@ -31,6 +33,8 @@ async def aget(self, key: CacheRequest) -> Optional[TraceLog]: Returns: Optional[TraceLog]: The cached response, or None if the key was not found. + # noqa: DAR202 + Raises: NotImplementedError: This method must be overridden in a subclass. """ From 434cd714d4c1834175f1062ea83c1a725815ffea Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:19:22 +0200 Subject: [PATCH 28/37] fix: ignore DAR202 --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 40eb833e..4c0c8c27 100644 --- a/Makefile +++ b/Makefile @@ -41,13 +41,13 @@ formatting: codestyle .PHONY: test test: PYTHONPATH=$(PYTHONPATH) poetry run pytest -c pyproject.toml --cov-report=html --cov=parea tests/ - poetry run coverage-badge -o assets/images/coverage.svg -f --ignore DAR202 + poetry run coverage-badge -o assets/images/coverage.svg -f .PHONY: check-codestyle check-codestyle: poetry run isort --diff --check-only --settings-path pyproject.toml ./ poetry run black --diff --check --config pyproject.toml ./ - poetry run darglint --verbosity 2 --ignore DAR202 parea tests + poetry run darglint --verbosity 2 parea tests .PHONY: mypy mypy: From 685c2f4f689a7906a96e2b4bdfcde50674befa1a Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:21:24 +0200 Subject: [PATCH 29/37] fix: ignore DAR401 --- parea/cache/cache.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/parea/cache/cache.py b/parea/cache/cache.py index 0e49bb0e..4f7f98d3 100644 --- a/parea/cache/cache.py +++ b/parea/cache/cache.py @@ -16,10 +16,7 @@ def get(self, key: CacheRequest) -> Optional[TraceLog]: Returns: Optional[TraceLog]: The cached response, or None if the key was not found. - # noqa: DAR202 - - Raises: - NotImplementedError: This method must be overridden in a subclass. + # noqa: DAR202, DAR401 """ raise NotImplementedError @@ -33,10 +30,7 @@ async def aget(self, key: CacheRequest) -> Optional[TraceLog]: Returns: Optional[TraceLog]: The cached response, or None if the key was not found. - # noqa: DAR202 - - Raises: - NotImplementedError: This method must be overridden in a subclass. + # noqa: DAR202, DAR401 """ raise NotImplementedError @@ -48,8 +42,7 @@ def set(self, key: CacheRequest, value: TraceLog): key (CacheRequest): The cache key. value (TraceLog): The response to cache. - Raises: - NotImplementedError: This method must be overridden in a subclass. + # noqa: DAR401 """ raise NotImplementedError @@ -61,8 +54,7 @@ async def aset(self, key: CacheRequest, value: TraceLog): key (CacheRequest): The cache key. value (TraceLog): The response to cache. - Raises: - NotImplementedError: This method must be overridden in a subclass. + # noqa: DAR401 """ raise NotImplementedError @@ -73,8 +65,7 @@ def invalidate(self, key: CacheRequest): Args: key (CacheRequest): The cache key. - Raises: - NotImplementedError: This method must be overridden in a subclass. + # noqa: DAR401 """ raise NotImplementedError @@ -85,7 +76,6 @@ async def ainvalidate(self, key: CacheRequest): Args: key (CacheRequest): The cache key. - Raises: - NotImplementedError: This method must be overridden in a subclass. + # noqa: DAR401 """ raise NotImplementedError From 529233dcfcd43f7147587ba333d733b1494cc6d7 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 11:22:46 +0200 Subject: [PATCH 30/37] fix: ignore DAR401 --- parea/cache/cache.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/parea/cache/cache.py b/parea/cache/cache.py index 4f7f98d3..35dff655 100644 --- a/parea/cache/cache.py +++ b/parea/cache/cache.py @@ -16,7 +16,8 @@ def get(self, key: CacheRequest) -> Optional[TraceLog]: Returns: Optional[TraceLog]: The cached response, or None if the key was not found. - # noqa: DAR202, DAR401 + # noqa: DAR202 + # noqa: DAR401 """ raise NotImplementedError @@ -30,7 +31,8 @@ async def aget(self, key: CacheRequest) -> Optional[TraceLog]: Returns: Optional[TraceLog]: The cached response, or None if the key was not found. - # noqa: DAR202, DAR401 + # noqa: DAR202 + # noqa: DAR401 """ raise NotImplementedError From 75dfe9dc956d426c5a3c76aab10a281b91eddc24 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 15:20:59 +0200 Subject: [PATCH 31/37] feat: read in kwargs --- parea/tester.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/parea/tester.py b/parea/tester.py index 2f9acbc9..569ccc0e 100644 --- a/parea/tester.py +++ b/parea/tester.py @@ -37,9 +37,9 @@ def load_from_path(module_path, attr_name): return fn -def read_input_file(file_path): +def read_input_file(file_path) -> List[dict]: with open(file_path) as file: - reader = csv.reader(file) + reader = csv.DictReader(file) inputs = list(reader) return inputs @@ -58,7 +58,7 @@ def read_input_file(file_path): os.putenv("_parea_redis_logs_key", redis_logs_key) with concurrent.futures.ProcessPoolExecutor() as executor: - futures = [executor.submit(fn, data_input) for data_input in data_inputs] + futures = [executor.submit(fn, **data_input) for data_input in data_inputs] for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): pass print(f"Done with {len(futures)} inputs") From e501e257a81e75f6ee939cc98126e774bd9e1a0f Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 16:10:08 +0200 Subject: [PATCH 32/37] feat: add support for async in benchmark test --- parea/tester.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/parea/tester.py b/parea/tester.py index 569ccc0e..a3f9ae38 100644 --- a/parea/tester.py +++ b/parea/tester.py @@ -1,3 +1,5 @@ +import asyncio +import inspect from typing import List import argparse @@ -44,6 +46,10 @@ def read_input_file(file_path) -> List[dict]: return inputs +def async_wrapper(fn, **kwargs): + return asyncio.run(fn(**kwargs)) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--user_func", help="User function to test e.g., path/to/user_code.py:argument_chain", type=str) @@ -58,7 +64,10 @@ def read_input_file(file_path) -> List[dict]: os.putenv("_parea_redis_logs_key", redis_logs_key) with concurrent.futures.ProcessPoolExecutor() as executor: - futures = [executor.submit(fn, **data_input) for data_input in data_inputs] + if inspect.iscoroutinefunction(fn): + futures = [executor.submit(async_wrapper, fn, **data_input) for data_input in data_inputs] + else: + futures = [executor.submit(fn, **data_input) for data_input in data_inputs] for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): pass print(f"Done with {len(futures)} inputs") From 0d64b6472263f4714481649123a04f539cceb7b6 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 16:11:36 +0200 Subject: [PATCH 33/37] style --- parea/tester.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/parea/tester.py b/parea/tester.py index a3f9ae38..c5705c40 100644 --- a/parea/tester.py +++ b/parea/tester.py @@ -1,11 +1,9 @@ -import asyncio -import inspect -from typing import List - import argparse +import asyncio import concurrent import csv import importlib +import inspect import os import sys import time @@ -39,7 +37,7 @@ def load_from_path(module_path, attr_name): return fn -def read_input_file(file_path) -> List[dict]: +def read_input_file(file_path) -> list[dict]: with open(file_path) as file: reader = csv.DictReader(file) inputs = list(reader) From ae8281662496b7ff41fce7371b0b07bd728f97aa Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 21:26:48 +0200 Subject: [PATCH 34/37] feat: add benchmark command --- parea/__init__.py | 11 ++++++++++- parea/{tester.py => benchmark.py} | 11 +++++++---- pyproject.toml | 3 +++ 3 files changed, 20 insertions(+), 5 deletions(-) rename parea/{tester.py => benchmark.py} (84%) diff --git a/parea/__init__.py b/parea/__init__.py index c7c2b4a8..8ce05d05 100644 --- a/parea/__init__.py +++ b/parea/__init__.py @@ -8,9 +8,10 @@ To install the official [Python SDK](https://pypi.org/project/parea/), run the following command: ```bash pip install parea ```. """ - +import sys from importlib import metadata as importlib_metadata +from parea.benchmark import run_benchmark from parea.cache import RedisCache from parea.client import Parea, init @@ -23,3 +24,11 @@ def get_version() -> str: version: str = get_version() + + +def main(): + args = sys.argv[1:] + if args[0] == "benchmark": + run_benchmark(args[1:]) + else: + print(f"Unknown command: '{args[0]}'") diff --git a/parea/tester.py b/parea/benchmark.py similarity index 84% rename from parea/tester.py rename to parea/benchmark.py index c5705c40..c173dd4c 100644 --- a/parea/tester.py +++ b/parea/benchmark.py @@ -48,15 +48,18 @@ def async_wrapper(fn, **kwargs): return asyncio.run(fn(**kwargs)) -if __name__ == "__main__": +def run_benchmark(args): parser = argparse.ArgumentParser() parser.add_argument("--user_func", help="User function to test e.g., path/to/user_code.py:argument_chain", type=str) parser.add_argument("--inputs", help="Path to the input CSV file", type=str) - args = parser.parse_args() + parser.add_argument('--redis_host', help='Redis host', type=str, default=os.getenv("REDIS_HOST", "localhost")) + parser.add_argument('--redis_port', help='Redis port', type=int, default=int(os.getenv("REDIS_PORT", 6379))) + parser.add_argument('--redis_password', help='Redis password', type=str, default=None) + parsed_args = parser.parse_args(args) - fn = load_from_path(*args.user_func.rsplit(":", 1)) + fn = load_from_path(*parsed_args.user_func.rsplit(":", 1)) - data_inputs = read_input_file(args.inputs) + data_inputs = read_input_file(parsed_args.inputs) redis_logs_key = f"parea-trace-logs-{int(time.time())}" os.putenv("_parea_redis_logs_key", redis_logs_key) diff --git a/pyproject.toml b/pyproject.toml index 699986c8..a4ca767c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,9 @@ coverage-badge = "^1.1.0" pytest-html = "^3.1.1" pytest-cov = "^4.1.0" +[tool.poetry.scripts] +parea = 'parea:main' + [tool.black] # https://github.com/psf/black target-version = ["py39"] From 0638bd6b26f5d62edf2e7395fbdf71dde3708f2a Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 21:28:40 +0200 Subject: [PATCH 35/37] style --- parea/benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/parea/benchmark.py b/parea/benchmark.py index c173dd4c..badf7e36 100644 --- a/parea/benchmark.py +++ b/parea/benchmark.py @@ -52,9 +52,9 @@ def run_benchmark(args): parser = argparse.ArgumentParser() parser.add_argument("--user_func", help="User function to test e.g., path/to/user_code.py:argument_chain", type=str) parser.add_argument("--inputs", help="Path to the input CSV file", type=str) - parser.add_argument('--redis_host', help='Redis host', type=str, default=os.getenv("REDIS_HOST", "localhost")) - parser.add_argument('--redis_port', help='Redis port', type=int, default=int(os.getenv("REDIS_PORT", 6379))) - parser.add_argument('--redis_password', help='Redis password', type=str, default=None) + parser.add_argument("--redis_host", help="Redis host", type=str, default=os.getenv("REDIS_HOST", "localhost")) + parser.add_argument("--redis_port", help="Redis port", type=int, default=int(os.getenv("REDIS_PORT", 6379))) + parser.add_argument("--redis_password", help="Redis password", type=str, default=None) parsed_args = parser.parse_args(args) fn = load_from_path(*parsed_args.user_func.rsplit(":", 1)) From 4e1ef3a62ae5a97bb73821a33851bca00f01ac2c Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 21:38:49 +0200 Subject: [PATCH 36/37] fix: parent trace latency --- parea/helpers.py | 5 +++++ parea/wrapper/wrapper.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/parea/helpers.py b/parea/helpers.py index 4329834d..09636859 100644 --- a/parea/helpers.py +++ b/parea/helpers.py @@ -9,3 +9,8 @@ def gen_trace_id() -> str: def to_date_and_time_string(timestamp: float) -> str: return time.strftime("%Y-%m-%d %H:%M:%S %Z", time.localtime(timestamp)) + + +def date_and_time_string_to_timestamp(date_and_time_string: str) -> float: + return time.mktime(time.strptime(date_and_time_string, "%Y-%m-%d %H:%M:%S %Z")) + diff --git a/parea/wrapper/wrapper.py b/parea/wrapper/wrapper.py index a32d124c..35f60fb0 100644 --- a/parea/wrapper/wrapper.py +++ b/parea/wrapper/wrapper.py @@ -6,6 +6,7 @@ from uuid import uuid4 from parea.cache.cache import Cache +from parea.helpers import date_and_time_string_to_timestamp from parea.schemas.models import TraceLog from parea.utils.trace_utils import to_date_and_time_string, trace_context, trace_data @@ -151,10 +152,16 @@ def final_log(): trace_data.get()[trace_id].end_timestamp = to_date_and_time_string(end_time) trace_data.get()[trace_id].latency = end_time - start_time + parent_id = trace_context.get()[-2] + trace_data.get()[parent_id].end_timestamp = to_date_and_time_string(end_time) + start_time_parent = date_and_time_string_to_timestamp(trace_data.get()[parent_id].start_timestamp) + trace_data.get()[parent_id].latency = end_time - start_time_parent + if not error and self.cache: self.cache.set(self.convert_kwargs_to_cache_request(args, kwargs), trace_data.get()[trace_id]) self.log(trace_id) + self.log(parent_id) trace_context.get().pop() return final_log From a5390f9190406c15c14f4f899fec62e945e724e3 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Tue, 10 Oct 2023 21:40:22 +0200 Subject: [PATCH 37/37] chore: bump version --- parea/helpers.py | 1 - pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/parea/helpers.py b/parea/helpers.py index 09636859..67eac45b 100644 --- a/parea/helpers.py +++ b/parea/helpers.py @@ -13,4 +13,3 @@ def to_date_and_time_string(timestamp: float) -> str: def date_and_time_string_to_timestamp(date_and_time_string: str) -> float: return time.mktime(time.strptime(date_and_time_string, "%Y-%m-%d %H:%M:%S %Z")) - diff --git a/pyproject.toml b/pyproject.toml index a4ca767c..fde9ac0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "parea-ai" packages = [{ include = "parea" }] -version = "0.2.7" +version = "0.2.8" description = "Parea python sdk" readme = "README.md" authors = ["joel-parea-ai "]