Skip to content

Commit

Permalink
add stop/interrupt capability
Browse files Browse the repository at this point in the history
  • Loading branch information
x5a committed Nov 21, 2024
1 parent eabba4b commit 90e2386
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions computer-use-demo/computer_use_demo/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import subprocess
import traceback
from contextlib import contextmanager
from datetime import datetime, timedelta
from enum import StrEnum
from functools import partial
Expand All @@ -19,6 +20,7 @@
from anthropic.types.beta import (
BetaContentBlockParam,
BetaTextBlockParam,
BetaToolResultBlockParam,
)
from streamlit.delta_generator import DeltaGenerator

Expand All @@ -33,10 +35,14 @@
API_KEY_FILE = CONFIG_DIR / "api_key"
STREAMLIT_STYLE = """
<style>
/* Hide chat input while agent loop is running */
.stApp[data-teststate=running] .stChatInput textarea,
.stApp[data-test-script-state=running] .stChatInput textarea {
display: none;
/* Highlight the stop button in red */
button[kind=header] {
background-color: rgb(255, 75, 75);
border: 1px solid rgb(255, 75, 75);
color: rgb(255, 255, 255);
}
button[kind=header]:hover {
background-color: rgb(255, 51, 51);
}
/* Hide the streamlit deploy button */
.stAppDeployButton {
Expand All @@ -46,6 +52,8 @@
"""

WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior"
INTERRUPT_TEXT = "(user stopped or interrupted and wrote the following)"
INTERRUPT_TOOL_ERROR = "human stopped or interrupted tool execution"


class Sender(StrEnum):
Expand Down Expand Up @@ -82,6 +90,8 @@ def setup_state():
st.session_state.custom_system_prompt = load_from_storage("system_prompt") or ""
if "hide_images" not in st.session_state:
st.session_state.hide_images = False
if "in_sampling_loop" not in st.session_state:
st.session_state.in_sampling_loop = False


def _reset_model():
Expand Down Expand Up @@ -195,7 +205,10 @@ def _reset_api_provider():
st.session_state.messages.append(
{
"role": Sender.USER,
"content": [BetaTextBlockParam(type="text", text=new_message)],
"content": [
*maybe_add_interruption_blocks(),
BetaTextBlockParam(type="text", text=new_message),
],
}
)
_render_message(Sender.USER, new_message)
Expand All @@ -209,7 +222,7 @@ def _reset_api_provider():
# we don't have a user message to respond to, exit early
return

with st.spinner("Running Agent..."):
with track_sampling_loop():
# run the agent sampling loop with the newest message
st.session_state.messages = await sampling_loop(
system_prompt_suffix=st.session_state.custom_system_prompt,
Expand All @@ -230,6 +243,36 @@ def _reset_api_provider():
)


def maybe_add_interruption_blocks():
if not st.session_state.in_sampling_loop:
return []
# If this function is called while we're in the sampling loop, we can assume that the previous sampling loop was interrupted
# and we should annotate the conversation with additional context for the model and heal any incomplete tool use calls
result = []
last_message = st.session_state.messages[-1]
previous_tool_use_ids = [
block["id"] for block in last_message["content"] if block["type"] == "tool_use"
]
for tool_use_id in previous_tool_use_ids:
tool_result = BetaToolResultBlockParam(
tool_use_id=tool_use_id,
type="tool_result",
content=INTERRUPT_TOOL_ERROR,
is_error=True,
)
st.session_state.tools[tool_use_id] = tool_result
result.append(tool_result)
result.append(BetaTextBlockParam(type="text", text=INTERRUPT_TEXT))
return result


@contextmanager
def track_sampling_loop():
st.session_state.in_sampling_loop = True
yield
st.session_state.in_sampling_loop = False


def validate_auth(provider: APIProvider, api_key: str | None):
if provider == APIProvider.ANTHROPIC:
if not api_key:
Expand Down

0 comments on commit 90e2386

Please sign in to comment.