Skip to content

Commit

Permalink
Merge pull request #147 from parea-ai/PAI-310-local-debug-experience
Browse files Browse the repository at this point in the history
Pai 310 local debug experience
  • Loading branch information
joschkabraun authored Oct 10, 2023
2 parents e8a2a24 + a5390f9 commit d694293
Show file tree
Hide file tree
Showing 19 changed files with 689 additions and 84 deletions.
42 changes: 40 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,45 @@ or install with `Poetry`
poetry add parea-ai
```

## Getting Started
## Debugging Chains & Agents

You can iterate on your chains & agents much faster by using a local cache. This will allow you to make changes to your
code & prompts without waiting for all previous, valid LLM responses. Simply add these two lines to the beginning your code and start
[a local redis cache](https://redis.io/docs/getting-started/install-stack/):

```python
from parea import init

init()
```

Above will use the default redis cache at `localhost:6379` with no password. You can also specify your redis database by:

```python
from parea import init, RedisCache

cache = RedisCache(
host=os.getenv("REDIS_HOST", "localhost"), # default value
port=int(os.getenv("REDIS_PORT", 6379)), # default value
password=os.getenv("REDIS_PASSWORT", None) # default value
)
init(cache=cache) # default value
```

### Automatically log all your LLM call traces

You can automatically log all your LLM traces to the Parea dashboard by setting the `PAREA_API_KEY` environment variable or specifying it in the `init` function.
This will help you debug issues your customers are facing by stepping through the LLM call traces and recreating the issue
in your local setup & code.

```python
from parea import init

init(api_key=os.getenv("PAREA_API_KEY")) # default value
```


## Use a deployed prompt

```python
import os
Expand Down Expand Up @@ -78,7 +116,7 @@ async def main_async():
print(deployed_prompt)
```

### Logging results from LLM providers
### Logging results from LLM providers [Example]

```python
import os
Expand Down
15 changes: 12 additions & 3 deletions parea/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
To install the official [Python SDK](https://pypi.org/project/parea/),
run the following command: ```bash pip install parea ```.
"""

import sys
from importlib import metadata as importlib_metadata

import parea.wrapper # noqa: F401
from parea.client import Parea
from parea.benchmark import run_benchmark
from parea.cache import RedisCache
from parea.client import Parea, init


def get_version() -> str:
Expand All @@ -23,3 +24,11 @@ def get_version() -> str:


version: str = get_version()


def main():
args = sys.argv[1:]
if args[0] == "benchmark":
run_benchmark(args[1:])
else:
print(f"Unknown command: '{args[0]}'")
2 changes: 1 addition & 1 deletion parea/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class HTTPClient:
_instance = None
base_url = "https://optimus-prompt-backend.vercel.app/api/parea/v1"
base_url = "https://parea-ai-backend-e2adf7624bcb3980.onporter.run/api/parea/v1"
api_key = None

def __new__(cls, *args, **kwargs):
Expand Down
91 changes: 91 additions & 0 deletions parea/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse
import asyncio
import concurrent
import csv
import importlib
import inspect
import os
import sys
import time
from importlib import util

from attr import asdict, fields_dict
from tqdm import tqdm

from parea.cache.redis import RedisCache
from parea.schemas.models import TraceLog


def load_from_path(module_path, attr_name):
# Ensure the directory of user-provided script is in the system path
dir_name = os.path.dirname(module_path)
if dir_name not in sys.path:
sys.path.insert(0, dir_name)

module_name = os.path.basename(module_path)
# Add .py extension back in to allow import correctly
module_path_with_ext = f"{module_path}.py"

spec = importlib.util.spec_from_file_location(module_name, module_path_with_ext)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

if spec.name not in sys.modules:
sys.modules[spec.name] = module

fn = getattr(module, attr_name)
return fn


def read_input_file(file_path) -> list[dict]:
with open(file_path) as file:
reader = csv.DictReader(file)
inputs = list(reader)
return inputs


def async_wrapper(fn, **kwargs):
return asyncio.run(fn(**kwargs))


def run_benchmark(args):
parser = argparse.ArgumentParser()
parser.add_argument("--user_func", help="User function to test e.g., path/to/user_code.py:argument_chain", type=str)
parser.add_argument("--inputs", help="Path to the input CSV file", type=str)
parser.add_argument("--redis_host", help="Redis host", type=str, default=os.getenv("REDIS_HOST", "localhost"))
parser.add_argument("--redis_port", help="Redis port", type=int, default=int(os.getenv("REDIS_PORT", 6379)))
parser.add_argument("--redis_password", help="Redis password", type=str, default=None)
parsed_args = parser.parse_args(args)

fn = load_from_path(*parsed_args.user_func.rsplit(":", 1))

data_inputs = read_input_file(parsed_args.inputs)

redis_logs_key = f"parea-trace-logs-{int(time.time())}"
os.putenv("_parea_redis_logs_key", redis_logs_key)

with concurrent.futures.ProcessPoolExecutor() as executor:
if inspect.iscoroutinefunction(fn):
futures = [executor.submit(async_wrapper, fn, **data_input) for data_input in data_inputs]
else:
futures = [executor.submit(fn, **data_input) for data_input in data_inputs]
for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
pass
print(f"Done with {len(futures)} inputs")

redis_cache = RedisCache(key_logs=redis_logs_key)

trace_logs: list[TraceLog] = redis_cache.read_logs()

# write to csv
path_csv = f"trace_logs-{int(time.time())}.csv"
with open(path_csv, "w", newline="") as file:
# write header
columns = fields_dict(TraceLog).keys()
writer = csv.DictWriter(file, fieldnames=columns)
writer.writeheader()
# write rows
for trace_log in trace_logs:
writer.writerow(asdict(trace_log))

print(f"Wrote CSV of results to: {path_csv}")
1 change: 1 addition & 0 deletions parea/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .redis import RedisCache
83 changes: 83 additions & 0 deletions parea/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Optional

from abc import ABC

from parea.schemas.models import CacheRequest, TraceLog


class Cache(ABC):
def get(self, key: CacheRequest) -> Optional[TraceLog]:
"""
Get a normal response from the cache.
Args:
key (CacheRequest): The cache key.
Returns:
Optional[TraceLog]: The cached response, or None if the key was not found.
# noqa: DAR202
# noqa: DAR401
"""
raise NotImplementedError

async def aget(self, key: CacheRequest) -> Optional[TraceLog]:
"""
Get a normal response from the cache.
Args:
key (CacheRequest): The cache key.
Returns:
Optional[TraceLog]: The cached response, or None if the key was not found.
# noqa: DAR202
# noqa: DAR401
"""
raise NotImplementedError

def set(self, key: CacheRequest, value: TraceLog):
"""
Set a normal response in the cache.
Args:
key (CacheRequest): The cache key.
value (TraceLog): The response to cache.
# noqa: DAR401
"""
raise NotImplementedError

async def aset(self, key: CacheRequest, value: TraceLog):
"""
Set a normal response in the cache.
Args:
key (CacheRequest): The cache key.
value (TraceLog): The response to cache.
# noqa: DAR401
"""
raise NotImplementedError

def invalidate(self, key: CacheRequest):
"""
Invalidate a key in the cache.
Args:
key (CacheRequest): The cache key.
# noqa: DAR401
"""
raise NotImplementedError

async def ainvalidate(self, key: CacheRequest):
"""
Invalidate a key in the cache.
Args:
key (CacheRequest): The cache key.
# noqa: DAR401
"""
raise NotImplementedError
110 changes: 110 additions & 0 deletions parea/cache/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import List, Optional

import json
import logging
import os
import time
import uuid

import redis
from attr import asdict

from parea.cache.cache import Cache
from parea.schemas.models import CacheRequest, TraceLog

logger = logging.getLogger()


def is_uuid(value: str) -> bool:
try:
uuid.UUID(value)
except ValueError:
return False
return True


class RedisCache(Cache):
"""A Redis-based cache for caching LLM responses."""

def __init__(
self,
key_logs: str = os.getenv("_parea_redis_logs_key", f"trace-logs-{time.time()}"),
host: str = os.getenv("REDIS_HOST", "localhost"),
port: int = int(os.getenv("REDIS_PORT", 6379)),
password: str = os.getenv("REDIS_PASSWORT", None),
ttl=3600 * 6,
):
"""
Initialize the cache.
Args:
key_logs (str): The Redis key for the logs.
host (str): The Redis host.
port (int): The Redis port.
password (str): The Redis password.
ttl (int): The default TTL for cache entries, in seconds.
"""
self.r = redis.Redis(
host=host,
port=port,
password=password,
)
self.ttl = ttl
self.key_logs = key_logs

def get(self, key: CacheRequest) -> Optional[TraceLog]:
try:
result = self.r.get(json.dumps(asdict(key)))
if result is not None:
return TraceLog(**json.loads(result))
except redis.RedisError as e:
logger.error(f"Error getting key {key} from cache: {e}")
return None

async def aget(self, key: CacheRequest) -> Optional[TraceLog]:
return self.get(key)

def set(self, key: CacheRequest, value: TraceLog):
try:
self.r.set(json.dumps(asdict(key)), json.dumps(asdict(value)), ex=self.ttl)
except redis.RedisError as e:
logger.error(f"Error setting key {key} in cache: {e}")

async def aset(self, key: CacheRequest, value: TraceLog):
self.set(key, value)

def invalidate(self, key: CacheRequest):
"""
Invalidate a key in the cache.
Args:
key (CacheRequest): The cache key.
"""
try:
self.r.delete(json.dumps(asdict(key)))
except redis.RedisError as e:
logger.error(f"Error invalidating key {key} from cache: {e}")

async def ainvalidate(self, key: CacheRequest):
self.invalidate(key)

def log(self, value: TraceLog):
try:
prev_logs = self.r.hget(self.key_logs, value.trace_id)
log_dict = asdict(value)
if prev_logs:
log_dict = {**json.loads(prev_logs), **log_dict}
self.r.hset(self.key_logs, value.trace_id, json.dumps(log_dict))
except redis.RedisError as e:
logger.error(f"Error adding to list in cache: {e}")

def read_logs(self) -> List[TraceLog]:
try:
trace_logs_raw = self.r.hgetall(self.key_logs)
trace_logs = []
for trace_log_raw in trace_logs_raw.values():
trace_logs.append(TraceLog(**json.loads(trace_log_raw)))
return trace_logs
except redis.RedisError as e:
logger.error(f"Error reading list from cache: {e}")
return []
Loading

0 comments on commit d694293

Please sign in to comment.