From 5731f91571d631e541a2f5219282b674f4c86132 Mon Sep 17 00:00:00 2001 From: Nicholas Albion Date: Thu, 6 Mar 2025 10:25:51 +1100 Subject: [PATCH] support caching of anthropic system prompt (#18008) --- .../llama_index/llms/anthropic/base.py | 22 ++++++---- .../llama_index/llms/anthropic/utils.py | 44 +++++++++++-------- .../llama-index-llms-anthropic/pyproject.toml | 2 +- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py index 81a8c6ac1620b..278388d355624 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py @@ -467,14 +467,20 @@ def _prepare_chat_with_tools( chat_history.append(user_msg) tool_dicts = [] - for tool in tools: - tool_dicts.append( - { - "name": tool.metadata.name, - "description": tool.metadata.description, - "input_schema": tool.metadata.get_parameters_dict(), - } - ) + if tools: + for tool in tools: + tool_dicts.append( + { + "name": tool.metadata.name, + "description": tool.metadata.description, + "input_schema": tool.metadata.get_parameters_dict(), + } + ) + if "prompt-caching" in kwargs.get("extra_headers", {}).get( + "anthropic-beta", "" + ): + tool_dicts[-1]["cache_control"] = {"type": "ephemeral"} + return {"messages": chat_history, "tools": tool_dicts, **kwargs} def _validate_chat_with_tools_response( diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py index 06e27c5fffb66..42731637dd8f3 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py @@ -2,7 +2,7 @@ Utility functions for the Anthropic SDK LLM integration. """ -from typing import Dict, Sequence, Tuple +from typing import Any, Dict, Sequence, Tuple from llama_index.core.base.llms.types import ( ChatMessage, @@ -139,13 +139,16 @@ def messages_to_anthropic_messages( - System prompt """ anthropic_messages = [] - system_prompt = "" + system_prompt = [] for message in messages: if message.role == MessageRole.SYSTEM: - # For system messages, concatenate all text blocks for block in message.blocks: - if isinstance(block, TextBlock): - system_prompt += block.text + "\n" + if isinstance(block, TextBlock) and block.text: + system_prompt.append( + _text_block_to_anthropic_message( + block, message.additional_kwargs + ) + ) elif message.role == MessageRole.FUNCTION or message.role == MessageRole.TOOL: content = ToolResultBlockParam( tool_use_id=message.additional_kwargs["tool_call_id"], @@ -161,19 +164,12 @@ def messages_to_anthropic_messages( content: list[TextBlockParam | ImageBlockParam] = [] for block in message.blocks: if isinstance(block, TextBlock): - content_ = ( - TextBlockParam( - text=block.text, - type="text", - cache_control=CacheControlEphemeralParam(type="ephemeral"), + if block.text: + content.append( + _text_block_to_anthropic_message( + block, message.additional_kwargs + ) ) - if "cache_control" in message.additional_kwargs - else TextBlockParam(text=block.text, type="text") - ) - - # avoid empty text blocks - if content_["text"]: - content.append(content_) elif isinstance(block, ImageBlock): # FUTURE: Claude does not support URLs, so we need to always convert to base64 img_bytes = block.resolve_image(as_base64=True).read() @@ -214,7 +210,19 @@ def messages_to_anthropic_messages( content=content, ) anthropic_messages.append(anth_message) - return __merge_common_role_msgs(anthropic_messages), system_prompt.strip() + return __merge_common_role_msgs(anthropic_messages), system_prompt + + +def _text_block_to_anthropic_message( + block: TextBlock, kwargs: dict[str, Any] +) -> TextBlockParam: + if "cache_control" in kwargs: + return TextBlockParam( + text=block.text, + type="text", + cache_control=CacheControlEphemeralParam(type="ephemeral"), + ) + return TextBlockParam(text=block.text, type="text") # Function used in bedrock diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml index 51bc9a6d93648..b7ae443cbc794 100644 --- a/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-anthropic" readme = "README.md" -version = "0.6.7" +version = "0.6.8" [tool.poetry.dependencies] python = ">=3.9,<4.0"