Skip to content

Commit

Permalink
support caching of anthropic system prompt (#18008)
Browse files Browse the repository at this point in the history
  • Loading branch information
nalbion authored Mar 5, 2025
1 parent ea1f987 commit 5731f91
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -467,14 +467,20 @@ def _prepare_chat_with_tools(
chat_history.append(user_msg)

tool_dicts = []
for tool in tools:
tool_dicts.append(
{
"name": tool.metadata.name,
"description": tool.metadata.description,
"input_schema": tool.metadata.get_parameters_dict(),
}
)
if tools:
for tool in tools:
tool_dicts.append(
{
"name": tool.metadata.name,
"description": tool.metadata.description,
"input_schema": tool.metadata.get_parameters_dict(),
}
)
if "prompt-caching" in kwargs.get("extra_headers", {}).get(
"anthropic-beta", ""
):
tool_dicts[-1]["cache_control"] = {"type": "ephemeral"}

return {"messages": chat_history, "tools": tool_dicts, **kwargs}

def _validate_chat_with_tools_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Utility functions for the Anthropic SDK LLM integration.
"""

from typing import Dict, Sequence, Tuple
from typing import Any, Dict, Sequence, Tuple

from llama_index.core.base.llms.types import (
ChatMessage,
Expand Down Expand Up @@ -139,13 +139,16 @@ def messages_to_anthropic_messages(
- System prompt
"""
anthropic_messages = []
system_prompt = ""
system_prompt = []
for message in messages:
if message.role == MessageRole.SYSTEM:
# For system messages, concatenate all text blocks
for block in message.blocks:
if isinstance(block, TextBlock):
system_prompt += block.text + "\n"
if isinstance(block, TextBlock) and block.text:
system_prompt.append(
_text_block_to_anthropic_message(
block, message.additional_kwargs
)
)
elif message.role == MessageRole.FUNCTION or message.role == MessageRole.TOOL:
content = ToolResultBlockParam(
tool_use_id=message.additional_kwargs["tool_call_id"],
Expand All @@ -161,19 +164,12 @@ def messages_to_anthropic_messages(
content: list[TextBlockParam | ImageBlockParam] = []
for block in message.blocks:
if isinstance(block, TextBlock):
content_ = (
TextBlockParam(
text=block.text,
type="text",
cache_control=CacheControlEphemeralParam(type="ephemeral"),
if block.text:
content.append(
_text_block_to_anthropic_message(
block, message.additional_kwargs
)
)
if "cache_control" in message.additional_kwargs
else TextBlockParam(text=block.text, type="text")
)

# avoid empty text blocks
if content_["text"]:
content.append(content_)
elif isinstance(block, ImageBlock):
# FUTURE: Claude does not support URLs, so we need to always convert to base64
img_bytes = block.resolve_image(as_base64=True).read()
Expand Down Expand Up @@ -214,7 +210,19 @@ def messages_to_anthropic_messages(
content=content,
)
anthropic_messages.append(anth_message)
return __merge_common_role_msgs(anthropic_messages), system_prompt.strip()
return __merge_common_role_msgs(anthropic_messages), system_prompt


def _text_block_to_anthropic_message(
block: TextBlock, kwargs: dict[str, Any]
) -> TextBlockParam:
if "cache_control" in kwargs:
return TextBlockParam(
text=block.text,
type="text",
cache_control=CacheControlEphemeralParam(type="ephemeral"),
)
return TextBlockParam(text=block.text, type="text")


# Function used in bedrock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-anthropic"
readme = "README.md"
version = "0.6.7"
version = "0.6.8"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down

0 comments on commit 5731f91

Please sign in to comment.