diff --git a/parea/cookbook/data/anthropic_tool_use_examples.py b/parea/cookbook/data/anthropic_tool_use_examples.py new file mode 100644 index 00000000..b540b2e3 --- /dev/null +++ b/parea/cookbook/data/anthropic_tool_use_examples.py @@ -0,0 +1,124 @@ +single_tool_use = { + "model": "claude-3-opus-20240229", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "What's the weather like in San Francisco?"}], + "tools": [ + { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + } + ] +} + + +multiple_tool_use = { + "model": "claude-3-opus-20240229", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "What is the weather like right now in New York? Also what time is it there?"}], + "tools": [ + { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature, either 'celsius' or 'fahrenheit'" + } + }, + "required": ["location"] + } + }, + { + "name": "get_time", + "description": "Get the current time in a given time zone", + "input_schema": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The IANA time zone name, e.g. America/Los_Angeles" + } + }, + "required": ["timezone"] + } + } + ] +} + + +missing_information = { + "model": "claude-3-opus-20240229", + "max_tokens": 1024, + "tools": [ + { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature, either \"celsius\" or \"fahrenheit\"" + } + }, + "required": ["location"] + } + } + ], + "messages": [ + { + "role": "user", + "content": "What is the weather like in San Francisco?" + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "I need to use get_weather, and the user wants SF, which is likely San Francisco, CA." + }, + { + "type": "tool_use", + "id": "toolu_01A09q90qw90lq917835lq9", + "name": "get_weather", + "input": { + "location": "San Francisco, CA", + "unit": "celsius" + } + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_01A09q90qw90lq917835lq9", + "content": "65 degrees" + } + ] + } + ] +} diff --git a/parea/cookbook/tracing_anthropic_tool_use.py b/parea/cookbook/tracing_anthropic_tool_use.py new file mode 100644 index 00000000..95066c30 --- /dev/null +++ b/parea/cookbook/tracing_anthropic_tool_use.py @@ -0,0 +1,36 @@ +import asyncio +import os + +import anthropic +from dotenv import load_dotenv + +from parea import Parea +from parea.cookbook.data.anthropic_tool_use_examples import missing_information, multiple_tool_use, single_tool_use + +load_dotenv() + +client = anthropic.Anthropic() +aclient = anthropic.AsyncAnthropic() + +p = Parea(api_key=os.getenv("PAREA_API_KEY"), project_name="testing") +p.wrap_anthropic_client(client) +p.wrap_anthropic_client(aclient) + + +def anthropic_sync(create_kwargs): + message = client.beta.tools.messages.create(**create_kwargs) + print(message.content) + + +async def async_anthropic(create_kwargs): + message = await aclient.beta.tools.messages.create(**create_kwargs) + print(message.content) + + +if __name__ == "__main__": + anthropic_sync(single_tool_use) + anthropic_sync(multiple_tool_use) + anthropic_sync(missing_information) + # asyncio.run(async_anthropic(single_tool_use)) + # asyncio.run(async_anthropic(multiple_tool_use)) + # asyncio.run(async_anthropic(missing_information)) diff --git a/parea/wrapper/anthropic/anthropic.py b/parea/wrapper/anthropic/anthropic.py index c88357c5..1fc3f9e2 100644 --- a/parea/wrapper/anthropic/anthropic.py +++ b/parea/wrapper/anthropic/anthropic.py @@ -4,12 +4,12 @@ from datetime import datetime from anthropic import AsyncMessageStreamManager, AsyncStream, Client, MessageStreamManager, Stream -from anthropic.types import ContentBlockDeltaEvent, Message, MessageDeltaEvent, MessageStartEvent +from anthropic.types import ContentBlockDeltaEvent, Message, MessageDeltaEvent, MessageStartEvent, TextBlock from parea.cache.cache import Cache from parea.helpers import timezone_aware_now from parea.schemas import CacheRequest, LLMInputs, ModelParams, TraceLog -from parea.utils.trace_utils import trace_data +from parea.utils.trace_utils import make_output, trace_data from parea.wrapper import Wrapper from parea.wrapper.anthropic.stream_wrapper import AnthropicAsyncStreamWrapper, AnthropicStreamWrapper, MessageAsyncStreamManagerWrapper, MessageStreamManagerWrapper from parea.wrapper.utils import _compute_cost @@ -18,6 +18,9 @@ class AnthropicWrapper: def init(self, log: Callable, cache: Cache, client: Client): + func_names = ["messages.create", "messages.stream"] + if hasattr(client, "beta") and hasattr(client.beta, "tools") and hasattr(client.beta.tools, "messages") and hasattr(client.beta.tools.messages, "create"): + func_names.append("beta.tools.messages.create") Wrapper( resolver=self.resolver, gen_resolver=self.gen_resolver, @@ -25,7 +28,7 @@ def init(self, log: Callable, cache: Cache, client: Client): should_use_gen_resolver=self.should_use_gen_resolver, log=log, module=client, - func_names=["messages.create", "messages.stream"], + func_names=func_names, cache=cache, convert_kwargs_to_cache_request=self.convert_kwargs_to_cache_request, convert_cache_to_response=self.convert_cache_to_response, @@ -35,7 +38,18 @@ def init(self, log: Callable, cache: Cache, client: Client): @staticmethod def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Message]) -> Optional[Any]: if response: - output = response.content[0].text + if len(response.content) > 1: + from anthropic.types.beta.tools import ToolUseBlock + + output_list = [] + for content in response.content: + if isinstance(content, TextBlock): + output_list.append(content.text) + elif isinstance(content, ToolUseBlock): + output_list.append(content.model_dump()) + output = make_output(output_list, islist=True) + else: + output = response.content[0].text input_tokens = response.usage.input_tokens output_tokens = response.usage.output_tokens total_tokens = input_tokens + output_tokens @@ -109,6 +123,10 @@ def _get_default_dict_streaming(): @staticmethod def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs: + functions = [d for d in kwargs.get("tools", [])] + for func in functions: + if 'input_schema' in func: + func["parameters"] = func.pop("input_schema") return LLMInputs( model=model or kwargs.get("model", None), provider="anthropic", @@ -118,6 +136,7 @@ def _kwargs_to_llm_configuration(kwargs, model=None) -> LLMInputs: max_length=kwargs.get("max_tokens", None), top_p=kwargs.get("top_p", 1.0), ), + functions=functions, ) def convert_kwargs_to_cache_request(self, _args: Sequence[Any], kwargs: Dict[str, Any]) -> CacheRequest: