Skip to content

Commit

Permalink
[computer-use-demo] Add prompt caching
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinji committed Oct 23, 2024
1 parent 4acbe40 commit a98a0e6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
17 changes: 14 additions & 3 deletions computer-use-demo/computer_use_demo/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ToolResultBlockParam,
)
from anthropic.types.beta import (
BetaCacheControlEphemeralParam,
BetaContentBlock,
BetaContentBlockParam,
BetaImageBlockParam,
Expand All @@ -24,8 +25,6 @@

from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult

BETA_FLAG = "computer-use-2024-10-22"


class APIProvider(StrEnum):
ANTHROPIC = "anthropic"
Expand Down Expand Up @@ -74,6 +73,7 @@ async def sampling_loop(
api_key: str,
only_n_most_recent_images: int | None = None,
max_tokens: int = 4096,
prompt_caching: bool = True,
):
"""
Agentic sampling loop for the assistant/tool interaction of computer use.
Expand All @@ -98,6 +98,17 @@ async def sampling_loop(
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()

betas = ["computer-use-2024-10-22"]
if prompt_caching:
betas.append("prompt-caching-2024-07-31")
for message in messages:
if isinstance(message["content"], str):
continue
for content_block_param in message["content"]:
content_block_param["cache_control"] = (
BetaCacheControlEphemeralParam(type="ephemeral")
)

# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
Expand All @@ -108,7 +119,7 @@ async def sampling_loop(
model=model,
system=system,
tools=tool_collection.to_params(),
betas=["computer-use-2024-10-22"],
betas=betas,
)

api_response_callback(cast(APIResponse[BetaMessage], raw_response))
Expand Down
3 changes: 2 additions & 1 deletion computer-use-demo/tests/loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ async def test_loop():
)

assert len(result) == 4
assert result[0] == {"role": "user", "content": "Test message"}
assert result[0]["role"] == "user"
assert result[0]["content"] == "Test message"
assert result[1]["role"] == "assistant"
assert result[2]["role"] == "user"
assert result[3]["role"] == "assistant"
Expand Down

0 comments on commit a98a0e6

Please sign in to comment.