Skip to content

Commit

Permalink
Merge pull request #703 from parea-ai/PAI-966-anthropic-beta-tool-use…
Browse files Browse the repository at this point in the history
…-support

Pai 966 anthropic beta tool use support
  • Loading branch information
joschkabraun authored Apr 5, 2024
2 parents 12a2ab4 + aea7ee9 commit 2f48a00
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 4 deletions.
124 changes: 124 additions & 0 deletions parea/cookbook/data/anthropic_tool_use_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
single_tool_use = {
"model": "claude-3-opus-20240229",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "What's the weather like in San Francisco?"}],
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
}
},
"required": ["location"],
},
}
]
}


multiple_tool_use = {
"model": "claude-3-opus-20240229",
"max_tokens": 1024,
"messages": [{"role": "user", "content": "What is the weather like right now in New York? Also what time is it there?"}],
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature, either 'celsius' or 'fahrenheit'"
}
},
"required": ["location"]
}
},
{
"name": "get_time",
"description": "Get the current time in a given time zone",
"input_schema": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The IANA time zone name, e.g. America/Los_Angeles"
}
},
"required": ["timezone"]
}
}
]
}


missing_information = {
"model": "claude-3-opus-20240229",
"max_tokens": 1024,
"tools": [
{
"name": "get_weather",
"description": "Get the current weather in a given location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature, either \"celsius\" or \"fahrenheit\""
}
},
"required": ["location"]
}
}
],
"messages": [
{
"role": "user",
"content": "What is the weather like in San Francisco?"
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "<thinking>I need to use get_weather, and the user wants SF, which is likely San Francisco, CA.</thinking>"
},
{
"type": "tool_use",
"id": "toolu_01A09q90qw90lq917835lq9",
"name": "get_weather",
"input": {
"location": "San Francisco, CA",
"unit": "celsius"
}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_01A09q90qw90lq917835lq9",
"content": "65 degrees"
}
]
}
]
}
36 changes: 36 additions & 0 deletions parea/cookbook/tracing_anthropic_tool_use.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio
import os

import anthropic
from dotenv import load_dotenv

from parea import Parea
from parea.cookbook.data.anthropic_tool_use_examples import missing_information, multiple_tool_use, single_tool_use

load_dotenv()

client = anthropic.Anthropic()
aclient = anthropic.AsyncAnthropic()

p = Parea(api_key=os.getenv("PAREA_API_KEY"), project_name="testing")
p.wrap_anthropic_client(client)
p.wrap_anthropic_client(aclient)


def anthropic_sync(create_kwargs):
message = client.beta.tools.messages.create(**create_kwargs)
print(message.content)


async def async_anthropic(create_kwargs):
message = await aclient.beta.tools.messages.create(**create_kwargs)
print(message.content)


if __name__ == "__main__":
anthropic_sync(single_tool_use)
anthropic_sync(multiple_tool_use)
anthropic_sync(missing_information)
# asyncio.run(async_anthropic(single_tool_use))
# asyncio.run(async_anthropic(multiple_tool_use))
# asyncio.run(async_anthropic(missing_information))
27 changes: 23 additions & 4 deletions parea/wrapper/anthropic/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from datetime import datetime

from anthropic import AsyncMessageStreamManager, AsyncStream, Client, MessageStreamManager, Stream
from anthropic.types import ContentBlockDeltaEvent, Message, MessageDeltaEvent, MessageStartEvent
from anthropic.types import ContentBlockDeltaEvent, Message, MessageDeltaEvent, MessageStartEvent, TextBlock

from parea.cache.cache import Cache
from parea.helpers import timezone_aware_now
from parea.schemas import CacheRequest, LLMInputs, ModelParams, TraceLog
from parea.utils.trace_utils import trace_data
from parea.utils.trace_utils import make_output, trace_data
from parea.wrapper import Wrapper
from parea.wrapper.anthropic.stream_wrapper import AnthropicAsyncStreamWrapper, AnthropicStreamWrapper, MessageAsyncStreamManagerWrapper, MessageStreamManagerWrapper
from parea.wrapper.utils import _compute_cost
Expand All @@ -18,14 +18,17 @@
class AnthropicWrapper:

def init(self, log: Callable, cache: Cache, client: Client):
func_names = ["messages.create", "messages.stream"]
if hasattr(client, "beta") and hasattr(client.beta, "tools") and hasattr(client.beta.tools, "messages") and hasattr(client.beta.tools.messages, "create"):
func_names.append("beta.tools.messages.create")
Wrapper(
resolver=self.resolver,
gen_resolver=self.gen_resolver,
agen_resolver=self.agen_resolver,
should_use_gen_resolver=self.should_use_gen_resolver,
log=log,
module=client,
func_names=["messages.create", "messages.stream"],
func_names=func_names,
cache=cache,
convert_kwargs_to_cache_request=self.convert_kwargs_to_cache_request,
convert_cache_to_response=self.convert_cache_to_response,
Expand All @@ -35,7 +38,18 @@ def init(self, log: Callable, cache: Cache, client: Client):
@staticmethod
def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Message]) -> Optional[Any]:
if response:
output = response.content[0].text
if len(response.content) > 1:
from anthropic.types.beta.tools import ToolUseBlock

output_list = []
for content in response.content:
if isinstance(content, TextBlock):
output_list.append(content.text)
elif isinstance(content, ToolUseBlock):
output_list.append(content.model_dump())
output = make_output(output_list, islist=True)
else:
output = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
total_tokens = input_tokens + output_tokens
Expand Down Expand Up @@ -109,6 +123,10 @@ def _get_default_dict_streaming():

@staticmethod
def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs:
functions = [d for d in kwargs.get("tools", [])]
for func in functions:
if 'input_schema' in func:
func["parameters"] = func.pop("input_schema")
return LLMInputs(
model=model or kwargs.get("model", None),
provider="anthropic",
Expand All @@ -118,6 +136,7 @@ def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs:
max_length=kwargs.get("max_tokens", None),
top_p=kwargs.get("top_p", 1.0),
),
functions=functions,
)

def convert_kwargs_to_cache_request(self, _args: Sequence[Any], kwargs: Dict[str, Any]) -> CacheRequest:
Expand Down

0 comments on commit 2f48a00

Please sign in to comment.