Skip to content

Commit

Permalink
refactor(claude)!: begin refactor to messages api
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Apr 8, 2024
1 parent ae77bb2 commit 680f3e6
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 39 deletions.
160 changes: 122 additions & 38 deletions kani/engines/anthropic/engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
import json
import os
import warnings

from kani.ai_function import AIFunction
from kani.exceptions import MissingModelDependencies, PromptError
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage, ChatRole
from kani.prompts.pipeline import PromptPipeline
from ..base import BaseEngine, Completion

try:
Expand All @@ -18,11 +20,44 @@
) from None

CONTEXT_SIZES_BY_PREFIX = [
("claude-3", 200000),
("claude-2.1", 200000),
("", 100000),
]


# ==== pipe ====
def content_transform(msg: ChatMessage):
# FUNCTION messages should look like:
# {
# "role": "user",
# "content": [
# {
# "type": "tool_result",
# "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
# "content": "65 degrees"
# }
# ]
# }
if msg.role != ChatRole.FUNCTION:
return msg.text
# todo is_error
result = {"type": "tool_result", "tool_use_id": msg.tool_call_id, "content": msg.text}
return [result]


# assumes system messages are plucked before calling
CLAUDE_PIPELINE = (
PromptPipeline()
.translate_role(role=ChatRole.SYSTEM, to=ChatRole.USER)
.merge_consecutive(role=ChatRole.USER, sep="\n")
.merge_consecutive(role=ChatRole.ASSISTANT, sep=" ")
.ensure_bound_function_calls()
.ensure_start(role=ChatRole.USER)
.conversation_dict(function_role="user", content_transform=content_transform)
)


class AnthropicEngine(BaseEngine):
"""Engine for using the Anthropic API.
Expand All @@ -32,13 +67,14 @@ class AnthropicEngine(BaseEngine):
See https://docs.anthropic.com/claude/reference/selecting-a-model for a list of available models.
"""

token_reserve = 4 # each prompt ends with \n\nAssistant:
# because we have to estimate tokens wildly and the ctx is so long we'll just reserve a bunch
token_reserve = 500

def __init__(
self,
api_key: str = None,
model: str = "claude-2.1",
max_tokens_to_sample: int = 512,
model: str = "claude-3-haiku",
max_tokens: int = 512,
max_context_size: int = None,
*,
retry: int = 2,
Expand All @@ -51,7 +87,7 @@ def __init__(
:param api_key: Your Anthropic API key. By default, the API key will be read from the `ANTHROPIC_API_KEY`
environment variable.
:param model: The id of the model to use (e.g. "claude-2.1", "claude-instant-1.2").
:param max_tokens_to_sample: The maximum number of tokens to sample at each generation (defaults to 450).
:param max_tokens: The maximum number of tokens to sample at each generation (defaults to 512).
Generally, you should set this to the same number as your Kani's ``desired_response_tokens``.
:param max_context_size: The maximum amount of tokens allowed in the chat prompt. If None, uses the given
model's full context size.
Expand Down Expand Up @@ -79,12 +115,54 @@ def __init__(
api_key=api_key, max_retries=retry, base_url=api_base, default_headers=headers
)
self.model = model
self.max_tokens_to_sample = max_tokens_to_sample
self.max_tokens = max_tokens
self.max_context_size = max_context_size
self.hyperparams = hyperparams
self.tokenizer = sync_get_tokenizer()

# token counting - claude 3+ does not release tokenizer so we have to do heuristics and cache
self.token_cache = {}
if model.startswith("claude-2"):
self.tokenizer = sync_get_tokenizer()
else:
# claude 3 tokenizer just... doesn't exist
# https://github.com/anthropics/anthropic-sdk-python/issues/375 pain
self.tokenizer = None

# ==== token counting ====
@staticmethod
def message_cache_key(message: ChatMessage):
# (role, content, tool calls)

# we'll use msgpart identity for the hash here since we'll always have a ref as long as it's in a message
# history
hashable_content = tuple(part if isinstance(part, str) else id(part) for part in message.parts)

# use (name, args) for tool calls
if message.tool_calls:
hashable_tool_calls = tuple((tc.function.name, tc.function.arguments) for tc in message.tool_calls)
else:
hashable_tool_calls = message.tool_calls

return hash((message.role, hashable_content, hashable_tool_calls))

def message_len(self, message: ChatMessage) -> int:
# use cache
cache_key = self.message_cache_key(message)
if cache_key in self.token_cache:
return self.token_cache[cache_key]

# use tokenizer
if self.tokenizer is not None:
return self._message_len_tokenizer(message)

# panik - I guess we'll pretend that 4 chars = 1 token...?
n = len(message.role.value) + len(message.text)
if message.tool_calls:
for tc in message.tool_calls:
n += len(tc.function.name) + len(tc.function.arguments)
return n // 4

def _message_len_tokenizer(self, message):
# human messages are prefixed with `\n\nHuman: ` and assistant with `\n\nAssistant:`
if message.role == ChatRole.USER:
mlen = 5
Expand All @@ -97,45 +175,51 @@ def message_len(self, message: ChatMessage) -> int:
mlen += len(self.tokenizer.encode(message.text).ids)
return mlen

@staticmethod
def build_prompt(messages: list[ChatMessage]):
# Claude prompts must start with a human message
first_human_idx = next((i for i, m in enumerate(messages) if m.role == ChatRole.USER), None)
if first_human_idx is None:
raise PromptError("Prompts to Anthropic models must contain at least one USER message.")
def function_token_reserve(self, functions: list[AIFunction]) -> int:
if not functions:
return 0
# wrap an inner impl to use lru_cache with frozensets
return self._function_token_reserve_impl(frozenset(functions))

# and make sure the system messages are included
last_system_idx = next((i for i, m in enumerate(messages) if m.role != ChatRole.SYSTEM), None)
if last_system_idx:
out = ["\n\n".join(m.text for m in messages[:last_system_idx])]
else:
out = []

for idx, message in enumerate(messages[first_human_idx:]):
if message.role == ChatRole.USER:
out.append(f"{HUMAN_PROMPT} {message.text}")
elif message.role == ChatRole.ASSISTANT:
out.append(f"{AI_PROMPT} {message.text}")
else:
warnings.warn(
f"Encountered a {message.role} message in the middle of the prompt - Anthropic models expect an"
" optional SYSTEM message followed by alternating USER and ASSISTANT messages. Appending the"
" content to the prompt..."
)
out.append(f"\n\n{message.text}")
return "".join(out) + AI_PROMPT
@functools.lru_cache(maxsize=256)
def _function_token_reserve_impl(self, functions):
# panik, also assume len/4?
n = sum(len(f.name) + len(f.desc) + len(json.dumps(f.json_schema)) for f in functions)
return n // 4

# ==== requests ====
async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> Completion:
prompt = self.build_prompt(messages)
completion = await self.client.completions.create(
kwargs = {}

# --- messages ---
# pluck system messages
last_system_idx = next((i for i, m in enumerate(messages) if m.role != ChatRole.SYSTEM), None)
if last_system_idx:
kwargs["system"] = "\n\n".join(m.text for m in messages[:last_system_idx])
messages = messages[last_system_idx:]

# enforce ordering and function call bindings
# and translate to dict spec
messages = CLAUDE_PIPELINE(messages)

# --- tools ---
if functions:
kwargs["tools"] = [
{"name": f.name, "description": f.desc, "input_schema": f.json_schema} for f in functions
]

completion = await self.client.messages.create(
model=self.model,
max_tokens_to_sample=self.max_tokens_to_sample,
prompt=prompt,
max_tokens=self.max_tokens,
messages=messages,
**kwargs,
**self.hyperparams,
**hyperparams,
)

# todo translate to kani
return Completion(message=ChatMessage.assistant(completion.completion.strip()))

async def close(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ openai = [
]

anthropic = [
"anthropic>=0.7.3,<1.0.0",
"anthropic>=0.23.0,<1.0.0",
]

[project.urls]
Expand Down

0 comments on commit 680f3e6

Please sign in to comment.