-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #147 from parea-ai/PAI-310-local-debug-experience
Pai 310 local debug experience
- Loading branch information
Showing
19 changed files
with
689 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import argparse | ||
import asyncio | ||
import concurrent | ||
import csv | ||
import importlib | ||
import inspect | ||
import os | ||
import sys | ||
import time | ||
from importlib import util | ||
|
||
from attr import asdict, fields_dict | ||
from tqdm import tqdm | ||
|
||
from parea.cache.redis import RedisCache | ||
from parea.schemas.models import TraceLog | ||
|
||
|
||
def load_from_path(module_path, attr_name): | ||
# Ensure the directory of user-provided script is in the system path | ||
dir_name = os.path.dirname(module_path) | ||
if dir_name not in sys.path: | ||
sys.path.insert(0, dir_name) | ||
|
||
module_name = os.path.basename(module_path) | ||
# Add .py extension back in to allow import correctly | ||
module_path_with_ext = f"{module_path}.py" | ||
|
||
spec = importlib.util.spec_from_file_location(module_name, module_path_with_ext) | ||
module = importlib.util.module_from_spec(spec) | ||
spec.loader.exec_module(module) | ||
|
||
if spec.name not in sys.modules: | ||
sys.modules[spec.name] = module | ||
|
||
fn = getattr(module, attr_name) | ||
return fn | ||
|
||
|
||
def read_input_file(file_path) -> list[dict]: | ||
with open(file_path) as file: | ||
reader = csv.DictReader(file) | ||
inputs = list(reader) | ||
return inputs | ||
|
||
|
||
def async_wrapper(fn, **kwargs): | ||
return asyncio.run(fn(**kwargs)) | ||
|
||
|
||
def run_benchmark(args): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--user_func", help="User function to test e.g., path/to/user_code.py:argument_chain", type=str) | ||
parser.add_argument("--inputs", help="Path to the input CSV file", type=str) | ||
parser.add_argument("--redis_host", help="Redis host", type=str, default=os.getenv("REDIS_HOST", "localhost")) | ||
parser.add_argument("--redis_port", help="Redis port", type=int, default=int(os.getenv("REDIS_PORT", 6379))) | ||
parser.add_argument("--redis_password", help="Redis password", type=str, default=None) | ||
parsed_args = parser.parse_args(args) | ||
|
||
fn = load_from_path(*parsed_args.user_func.rsplit(":", 1)) | ||
|
||
data_inputs = read_input_file(parsed_args.inputs) | ||
|
||
redis_logs_key = f"parea-trace-logs-{int(time.time())}" | ||
os.putenv("_parea_redis_logs_key", redis_logs_key) | ||
|
||
with concurrent.futures.ProcessPoolExecutor() as executor: | ||
if inspect.iscoroutinefunction(fn): | ||
futures = [executor.submit(async_wrapper, fn, **data_input) for data_input in data_inputs] | ||
else: | ||
futures = [executor.submit(fn, **data_input) for data_input in data_inputs] | ||
for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): | ||
pass | ||
print(f"Done with {len(futures)} inputs") | ||
|
||
redis_cache = RedisCache(key_logs=redis_logs_key) | ||
|
||
trace_logs: list[TraceLog] = redis_cache.read_logs() | ||
|
||
# write to csv | ||
path_csv = f"trace_logs-{int(time.time())}.csv" | ||
with open(path_csv, "w", newline="") as file: | ||
# write header | ||
columns = fields_dict(TraceLog).keys() | ||
writer = csv.DictWriter(file, fieldnames=columns) | ||
writer.writeheader() | ||
# write rows | ||
for trace_log in trace_logs: | ||
writer.writerow(asdict(trace_log)) | ||
|
||
print(f"Wrote CSV of results to: {path_csv}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .redis import RedisCache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Optional | ||
|
||
from abc import ABC | ||
|
||
from parea.schemas.models import CacheRequest, TraceLog | ||
|
||
|
||
class Cache(ABC): | ||
def get(self, key: CacheRequest) -> Optional[TraceLog]: | ||
""" | ||
Get a normal response from the cache. | ||
Args: | ||
key (CacheRequest): The cache key. | ||
Returns: | ||
Optional[TraceLog]: The cached response, or None if the key was not found. | ||
# noqa: DAR202 | ||
# noqa: DAR401 | ||
""" | ||
raise NotImplementedError | ||
|
||
async def aget(self, key: CacheRequest) -> Optional[TraceLog]: | ||
""" | ||
Get a normal response from the cache. | ||
Args: | ||
key (CacheRequest): The cache key. | ||
Returns: | ||
Optional[TraceLog]: The cached response, or None if the key was not found. | ||
# noqa: DAR202 | ||
# noqa: DAR401 | ||
""" | ||
raise NotImplementedError | ||
|
||
def set(self, key: CacheRequest, value: TraceLog): | ||
""" | ||
Set a normal response in the cache. | ||
Args: | ||
key (CacheRequest): The cache key. | ||
value (TraceLog): The response to cache. | ||
# noqa: DAR401 | ||
""" | ||
raise NotImplementedError | ||
|
||
async def aset(self, key: CacheRequest, value: TraceLog): | ||
""" | ||
Set a normal response in the cache. | ||
Args: | ||
key (CacheRequest): The cache key. | ||
value (TraceLog): The response to cache. | ||
# noqa: DAR401 | ||
""" | ||
raise NotImplementedError | ||
|
||
def invalidate(self, key: CacheRequest): | ||
""" | ||
Invalidate a key in the cache. | ||
Args: | ||
key (CacheRequest): The cache key. | ||
# noqa: DAR401 | ||
""" | ||
raise NotImplementedError | ||
|
||
async def ainvalidate(self, key: CacheRequest): | ||
""" | ||
Invalidate a key in the cache. | ||
Args: | ||
key (CacheRequest): The cache key. | ||
# noqa: DAR401 | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from typing import List, Optional | ||
|
||
import json | ||
import logging | ||
import os | ||
import time | ||
import uuid | ||
|
||
import redis | ||
from attr import asdict | ||
|
||
from parea.cache.cache import Cache | ||
from parea.schemas.models import CacheRequest, TraceLog | ||
|
||
logger = logging.getLogger() | ||
|
||
|
||
def is_uuid(value: str) -> bool: | ||
try: | ||
uuid.UUID(value) | ||
except ValueError: | ||
return False | ||
return True | ||
|
||
|
||
class RedisCache(Cache): | ||
"""A Redis-based cache for caching LLM responses.""" | ||
|
||
def __init__( | ||
self, | ||
key_logs: str = os.getenv("_parea_redis_logs_key", f"trace-logs-{time.time()}"), | ||
host: str = os.getenv("REDIS_HOST", "localhost"), | ||
port: int = int(os.getenv("REDIS_PORT", 6379)), | ||
password: str = os.getenv("REDIS_PASSWORT", None), | ||
ttl=3600 * 6, | ||
): | ||
""" | ||
Initialize the cache. | ||
Args: | ||
key_logs (str): The Redis key for the logs. | ||
host (str): The Redis host. | ||
port (int): The Redis port. | ||
password (str): The Redis password. | ||
ttl (int): The default TTL for cache entries, in seconds. | ||
""" | ||
self.r = redis.Redis( | ||
host=host, | ||
port=port, | ||
password=password, | ||
) | ||
self.ttl = ttl | ||
self.key_logs = key_logs | ||
|
||
def get(self, key: CacheRequest) -> Optional[TraceLog]: | ||
try: | ||
result = self.r.get(json.dumps(asdict(key))) | ||
if result is not None: | ||
return TraceLog(**json.loads(result)) | ||
except redis.RedisError as e: | ||
logger.error(f"Error getting key {key} from cache: {e}") | ||
return None | ||
|
||
async def aget(self, key: CacheRequest) -> Optional[TraceLog]: | ||
return self.get(key) | ||
|
||
def set(self, key: CacheRequest, value: TraceLog): | ||
try: | ||
self.r.set(json.dumps(asdict(key)), json.dumps(asdict(value)), ex=self.ttl) | ||
except redis.RedisError as e: | ||
logger.error(f"Error setting key {key} in cache: {e}") | ||
|
||
async def aset(self, key: CacheRequest, value: TraceLog): | ||
self.set(key, value) | ||
|
||
def invalidate(self, key: CacheRequest): | ||
""" | ||
Invalidate a key in the cache. | ||
Args: | ||
key (CacheRequest): The cache key. | ||
""" | ||
try: | ||
self.r.delete(json.dumps(asdict(key))) | ||
except redis.RedisError as e: | ||
logger.error(f"Error invalidating key {key} from cache: {e}") | ||
|
||
async def ainvalidate(self, key: CacheRequest): | ||
self.invalidate(key) | ||
|
||
def log(self, value: TraceLog): | ||
try: | ||
prev_logs = self.r.hget(self.key_logs, value.trace_id) | ||
log_dict = asdict(value) | ||
if prev_logs: | ||
log_dict = {**json.loads(prev_logs), **log_dict} | ||
self.r.hset(self.key_logs, value.trace_id, json.dumps(log_dict)) | ||
except redis.RedisError as e: | ||
logger.error(f"Error adding to list in cache: {e}") | ||
|
||
def read_logs(self) -> List[TraceLog]: | ||
try: | ||
trace_logs_raw = self.r.hgetall(self.key_logs) | ||
trace_logs = [] | ||
for trace_log_raw in trace_logs_raw.values(): | ||
trace_logs.append(TraceLog(**json.loads(trace_log_raw))) | ||
return trace_logs | ||
except redis.RedisError as e: | ||
logger.error(f"Error reading list from cache: {e}") | ||
return [] |
Oops, something went wrong.