From 7cc7a7fe6459377eb3021366e5a3b20d70fb352f Mon Sep 17 00:00:00 2001 From: Douglas Reid Date: Thu, 7 Sep 2023 10:38:15 -0700 Subject: [PATCH 1/4] feat(streaming): proof of concept for streaming agentservice execution This PR presents a series of changes that should support a way to stream response information back to a client via an AgentService. In order to achieve the streaming result, a new method on the AgentService is exposed: `async_prompt`. This new method returns a new `StreamingResponse` that has two fields: `task` and `file`. These fields provide access to (a) the async task that will be streaming results and (b) a file (here the `ChatHistory` file) to which all status messages and assistant interactions will be saved. This PR relies on a full deployment of https://github.com/steamship-plugins/gpt4/pull/10 to the target environment for testing / validation. --- .../agents/functional/functions_based.py | 82 +++++++++++++++---- .../agents/functional/output_parser.py | 43 ++++++++-- src/steamship/agents/llms/openai.py | 41 +++++++--- src/steamship/agents/schema/action.py | 27 +----- src/steamship/agents/schema/agent.py | 19 +++-- src/steamship/agents/schema/chathistory.py | 41 +++++++--- .../agents/schema/message_selectors.py | 30 +++++-- src/steamship/agents/service/agent_service.py | 44 +++++++++- src/steamship/data/tags/tag_constants.py | 3 + src/steamship/utils/repl.py | 19 ++++- .../agents/test_agent_service.py | 37 ++++++++- 11 files changed, 299 insertions(+), 87 deletions(-) diff --git a/src/steamship/agents/functional/functions_based.py b/src/steamship/agents/functional/functions_based.py index 367ded3b0..3bfb5537c 100644 --- a/src/steamship/agents/functional/functions_based.py +++ b/src/steamship/agents/functional/functions_based.py @@ -1,9 +1,11 @@ +import json from typing import List -from steamship import Block +from steamship import Block, MimeTypes, Tag from steamship.agents.functional.output_parser import FunctionsBasedOutputParser -from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, Tool -from steamship.data.tags.tag_constants import RoleTag +from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, FinishAction, Tool +from steamship.data.tags.tag_constants import RoleTag, TagKind, TagValueKey +from steamship.data.tags.tag_utils import get_tag class FunctionsBasedAgent(ChatAgent): @@ -41,21 +43,15 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]: .wait() .to_ranked_blocks() ) - # TODO(dougreid): we need a way to threshold message inclusion, especially for small contexts - # remove the actual prompt from the semantic search (it will be an exact match) - messages_from_memory = [ - msg - for msg in messages_from_memory - if msg.id != context.chat_history.last_user_message.id - ] - # get most recent context messages_from_memory.extend(context.chat_history.select_messages(self.message_selector)) # de-dupe the messages from memory - ids = [] + ids = [ + context.chat_history.last_user_message.id + ] # filter out last user message, it is appended afterwards for msg in messages_from_memory: if msg.id not in ids: messages.append(msg) @@ -67,10 +63,8 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]: # this should happen BEFORE any agent/assistant messages related to tool selection messages.append(context.chat_history.last_user_message) - # get completed steps - actions = context.completed_steps - for action in actions: - messages.extend(action.to_chat_messages()) + # get working history (completed actions) + messages.extend(self._function_calls_since_last_user_message(context)) return messages @@ -81,4 +75,58 @@ def next_action(self, context: AgentContext) -> Action: # Run the default LLM on those messages output_blocks = self.llm.chat(messages=messages, tools=self.tools) - return self.output_parser.parse(output_blocks[0].text, context) + future_action = self.output_parser.parse(output_blocks[0].text, context) + if not isinstance(future_action, FinishAction): + # record the LLM's function response in history + self._record_action_selection(future_action, context) + return future_action + + def _function_calls_since_last_user_message(self, context: AgentContext) -> List[Block]: + function_calls = [] + for block in context.chat_history.messages[::-1]: # is this too inefficient at scale? + if block.chat_role == RoleTag.USER: + return reversed(function_calls) + if get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION): + function_calls.append(block) + elif get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION): + function_calls.append(block) + return reversed(function_calls) + + def _to_openai_function_selection(self, action: Action) -> str: + fc = {"name": action.tool} + args = {} + for block in action.input: + for t in block.tags: + if t.kind == TagKind.FUNCTION_ARG: + args[t.name] = block.as_llm_input(exclude_block_wrapper=True) + + fc["arguments"] = json.dumps(args) # the arguments must be a string value NOT a dict + return json.dumps(fc) + + def _record_action_selection(self, action: Action, context: AgentContext): + tags = [ + Tag(kind=TagKind.ROLE, name=RoleTag.ASSISTANT), + Tag(kind=TagKind.FUNCTION_SELECTION, name=action.tool), + ] + context.chat_history.file.append_block( + text=self._to_openai_function_selection(action), tags=tags, mime_type=MimeTypes.TXT + ) + + def record_action_run(self, action: Action, context: AgentContext): + super().record_action_run(action, context) + + tags = [ + Tag( + kind=TagKind.ROLE, + name=RoleTag.FUNCTION, + value={TagValueKey.STRING_VALUE: action.tool}, + ), + ] + # TODO(dougreid): I'm not convinced this is correct for tools that return multiple values. + # It _feels_ like these should be named and inlined as a single message in history, etc. + for block in action.output: + context.chat_history.file.append_block( + text=block.as_llm_input(exclude_block_wrapper=True), + tags=tags, + mime_type=block.mime_type, + ) diff --git a/src/steamship/agents/functional/output_parser.py b/src/steamship/agents/functional/output_parser.py index ccd5ce420..cb1fae6d2 100644 --- a/src/steamship/agents/functional/output_parser.py +++ b/src/steamship/agents/functional/output_parser.py @@ -3,9 +3,9 @@ import string from typing import Dict, List, Optional -from steamship import Block, MimeTypes, Steamship +from steamship import Block, MimeTypes, Steamship, Tag from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool -from steamship.data.tags.tag_constants import RoleTag +from steamship.data.tags.tag_constants import RoleTag, TagKind from steamship.utils.utils import is_valid_uuid4 @@ -42,16 +42,45 @@ def _extract_action_from_function_call(self, text: str, context: AgentContext) - try: args = json.loads(arguments) if text := args.get("text"): - input_blocks.append(Block(text=text, mime_type=MimeTypes.TXT)) + input_blocks.append( + Block( + text=text, + tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")], + mime_type=MimeTypes.TXT, + ) + ) elif uuid_arg := args.get("uuid"): - input_blocks.append(Block.get(context.client, _id=uuid_arg)) + existing_block = Block.get(context.client, _id=uuid_arg) + tag = Tag.create( + existing_block.client, + file_id=existing_block.file_id, + block_id=existing_block.id, + kind=TagKind.FUNCTION_ARG, + name="uuid", + ) + existing_block.tags.append(tag) + input_blocks.append(existing_block) except json.decoder.JSONDecodeError: if isinstance(arguments, str): if is_valid_uuid4(arguments): - input_blocks.append(Block.get(context.client, _id=uuid_arg)) + existing_block = Block.get(context.client, _id=arguments) + tag = Tag.create( + existing_block.client, + file_id=existing_block.file_id, + block_id=existing_block.id, + kind=TagKind.FUNCTION_ARG, + name="uuid", + ) + existing_block.tags.append(tag) + input_blocks.append(existing_block) else: - input_blocks.append(Block(text=arguments, mime_type=MimeTypes.TXT)) - + input_blocks.append( + Block( + text=arguments, + tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")], + mime_type=MimeTypes.TXT, + ) + ) return Action(tool=tool.name, input=input_blocks, context=context) @staticmethod diff --git a/src/steamship/agents/llms/openai.py b/src/steamship/agents/llms/openai.py index 19398677f..c02039459 100644 --- a/src/steamship/agents/llms/openai.py +++ b/src/steamship/agents/llms/openai.py @@ -58,6 +58,7 @@ def complete(self, prompt: str, stop: Optional[str] = None, **kwargs) -> List[Bl if "max_tokens" in kwargs: options["max_tokens"] = kwargs["max_tokens"] + # TODO(dougreid): do we care about streaming here? should we take a kwarg that is file_id ? action_task = self.generator.generate(text=prompt, options=options) action_task.wait() return action_task.output.blocks @@ -84,12 +85,8 @@ def chat(self, messages: List[Block], tools: Optional[List[Tool]], **kwargs) -> Supported kwargs include: - `max_tokens` (controls the size of LLM responses) """ - - temp_file = File.create( - client=self.client, - blocks=messages, - tags=[Tag(kind=TagKind.GENERATION, name=GenerationTag.PROMPT_COMPLETION)], - ) + if len(messages) <= 0: + return [] options = {} if len(tools) > 0: @@ -119,7 +116,31 @@ def chat(self, messages: List[Block], tools: Optional[List[Tool]], **kwargs) -> logging.info(f"OpenAI ChatComplete ({messages[-1].as_llm_input()})", extra=extra) - tool_selection_task = self.generator.generate(input_file_id=temp_file.id, options=options) - tool_selection_task.wait() - - return tool_selection_task.output.blocks + # for streaming use cases, we want to always use the existing file + # the way to detect this would be if all messages were from the same file + if self._from_same_file(blocks=messages): + file_id = messages[0].id + block_ids = [b.id for b in messages] + generate_task = self.generator.generate( + input_file_id=file_id, + input_file_block_index_list=block_ids, + options=options, + append_output_to_file=True, + ) + else: + tags = [Tag(kind=TagKind.GENERATION, name=GenerationTag.PROMPT_COMPLETION)] + temp_file = File.create(client=self.client, blocks=messages, tags=tags) + generate_task = self.generator.generate(input_file_id=temp_file.id, options=options) + + generate_task.wait() + + return generate_task.output.blocks + + def _from_same_file(self, blocks: List[Block]) -> bool: + if len(blocks) <= 1: + return True + file_id = blocks[0].file_id + for b in blocks[1:]: + if b.file_id != file_id: + return False + return True diff --git a/src/steamship/agents/schema/action.py b/src/steamship/agents/schema/action.py index 929f5afde..e57ec6ebb 100644 --- a/src/steamship/agents/schema/action.py +++ b/src/steamship/agents/schema/action.py @@ -1,10 +1,9 @@ from typing import List, Optional from pydantic import BaseModel +from pydantic.fields import Field -from steamship import Block, Tag -from steamship.data import TagKind -from steamship.data.tags.tag_constants import RoleTag +from steamship import Block class Action(BaseModel): @@ -22,32 +21,12 @@ class Action(BaseModel): output: Optional[List[Block]] """Any direct output produced by the Tool.""" - is_final: bool = False + is_final: bool = Field(default=False) """Whether this Action should be the final action performed in a reasoning loop. Setting this to True means that the executing Agent should halt any reasoning. """ - def to_chat_messages(self) -> List[Block]: - tags = [ - Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION), - Tag(kind="name", name=self.tool), - ] - blocks = [] - for block in self.output: - # TODO(dougreid): should we revisit as_llm_input? we might need only the UUID... - blocks.append( - Block( - text=block.as_llm_input(exclude_block_wrapper=True), - tags=tags, - mime_type=block.mime_type, - ) - ) - - # TODO(dougreid): revisit when have multiple output functions. - # Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation. - return blocks - class FinishAction(Action): """Represents a final selected action in an Agent Execution.""" diff --git a/src/steamship/agents/schema/agent.py b/src/steamship/agents/schema/agent.py index 890b29dc4..110e3b8af 100644 --- a/src/steamship/agents/schema/agent.py +++ b/src/steamship/agents/schema/agent.py @@ -11,7 +11,8 @@ from steamship.agents.schema.message_selectors import MessageSelector, NoMessages from steamship.agents.schema.output_parser import OutputParser from steamship.agents.schema.tool import Tool -from steamship.data.tags.tag_constants import RoleTag +from steamship.data.tags.tag_constants import RoleTag, TagKind +from steamship.data.tags.tag_utils import get_tag class Agent(BaseModel, ABC): @@ -31,6 +32,11 @@ class Agent(BaseModel, ABC): def next_action(self, context: AgentContext) -> Action: pass + @abstractmethod + def record_action_run(self, action: Action, context: AgentContext): + # TODO(dougreid): should this method (or just bit) actually be on AgentContext? + context.completed_steps.append(action) + class LLMAgent(Agent): """LLMAgents choose next actions for an AgentService based on interactions with an LLM.""" @@ -53,17 +59,16 @@ def messages_to_prompt_history(messages: List[Block]) -> str: # Internal Status Messages are not considered part of **prompt** history. # Their inclusion could lead to problematic LLM behavior, etc. # As such are explicitly skipped here: - # - DON'T RETURN AGENT MESSAGES - # - DON'T RETURN TOOL MESSAGES - # - DON'T RETURN LLM MESSAGES + # - DON'T RETURN STATUS MESSAGES + # - DON'T RETURN FUNCTION or FUNCTION_SELECTION MESSAGES if role == RoleTag.USER: as_strings.append(f"User: {block.text}") - elif role == RoleTag.ASSISTANT: + elif role == RoleTag.ASSISTANT and ( + get_tag(block.tags, TagKind.FUNCTION_SELECTION) is None + ): as_strings.append(f"Assistant: {block.text}") elif role == RoleTag.SYSTEM: as_strings.append(f"System: {block.text}") - elif role == RoleTag.FUNCTION: - as_strings.append(f"Function: {block.text}") return "\n".join(as_strings) diff --git a/src/steamship/agents/schema/chathistory.py b/src/steamship/agents/schema/chathistory.py index a3051f87b..609eebc81 100644 --- a/src/steamship/agents/schema/chathistory.py +++ b/src/steamship/agents/schema/chathistory.py @@ -279,6 +279,28 @@ def clear(self): self.refresh() + def append_status_message_with_role( + self, + text: str = None, + role: RoleTag = RoleTag.USER, + tags: List[Tag] = None, + content: Union[str, bytes] = None, + url: Optional[str] = None, + mime_type: Optional[MimeTypes] = None, + ) -> Block: + """Append a new block to this with content provided by the end-user.""" + tags = tags or [] + tags.append( + Tag( + kind=TagKind.STATUS_MESSAGE, + name=ChatTag.ROLE, + value={TagValueKey.STRING_VALUE: role}, + ) + ) + return self.file.append_block( + text=text, tags=tags, content=content, url=url, mime_type=mime_type + ) + def append_agent_message( self, text: str = None, @@ -288,7 +310,9 @@ def append_agent_message( mime_type: Optional[MimeTypes] = None, ) -> Block: """Append a new block to this with status update messages from the Agent.""" - return self.append_message_with_role(text, RoleTag.AGENT, tags, content, url, mime_type) + return self.append_status_message_with_role( + text, RoleTag.AGENT, tags, content, url, mime_type + ) def append_tool_message( self, @@ -299,7 +323,9 @@ def append_tool_message( mime_type: Optional[MimeTypes] = None, ) -> Block: """Append a new block to this with status update messages from the Agent.""" - return self.append_message_with_role(text, RoleTag.TOOL, tags, content, url, mime_type) + return self.append_status_message_with_role( + text, RoleTag.TOOL, tags, content, url, mime_type + ) def append_llm_message( self, @@ -310,7 +336,9 @@ def append_llm_message( mime_type: Optional[MimeTypes] = None, ) -> Block: """Append a new block to this with status update messages from the Agent.""" - return self.append_message_with_role(text, RoleTag.LLM, tags, content, url, mime_type) + return self.append_status_message_with_role( + text, RoleTag.LLM, tags, content, url, mime_type + ) class ChatHistoryLoggingHandler(StreamHandler): @@ -365,13 +393,6 @@ def _append_message(self, message_dict: dict, author_kind: str): if author_kind == AgentLogging.AGENT: return self.chat_history.append_agent_message( text=message, - tags=[ - Tag( - kind=TagKind.AGENT_STATUS_MESSAGE, - name=message_type, - value={TagValueKey.STRING_VALUE: message}, - ), - ], mime_type=MimeTypes.TXT, ) elif author_kind == AgentLogging.TOOL: diff --git a/src/steamship/agents/schema/message_selectors.py b/src/steamship/agents/schema/message_selectors.py index 6c5b0f260..dfd31b4b9 100644 --- a/src/steamship/agents/schema/message_selectors.py +++ b/src/steamship/agents/schema/message_selectors.py @@ -5,7 +5,8 @@ from pydantic.main import BaseModel from steamship import Block -from steamship.data.tags.tag_constants import RoleTag +from steamship.data.tags.tag_constants import RoleTag, TagKind +from steamship.data.tags.tag_utils import get_tag class MessageSelector(BaseModel, ABC): @@ -29,20 +30,34 @@ def is_assistant_message(block: Block) -> bool: return role == RoleTag.ASSISTANT +def is_assistant_function_message(block: Block) -> bool: + is_function_selection = get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION) + return is_assistant_message(block) and is_function_selection + + +def is_user_history_message(block: Block) -> bool: + return is_user_message(block) or ( + is_assistant_message(block) and not is_assistant_function_message(block) + ) + + class MessageWindowMessageSelector(MessageSelector): k: int def get_messages(self, messages: List[Block]) -> List[Block]: msgs = messages[:] msgs.pop() # don't add the current prompt to the memory - if len(msgs) <= (self.k * 2): - return msgs + history_msgs = [ + msg for msg in msgs if is_user_history_message(msg) + ] # filter to only user history messages + if len(history_msgs) <= (self.k * 2): + return history_msgs selected_msgs = [] limit = self.k * 2 - scope = msgs[len(messages) - limit :] + scope = history_msgs[len(history_msgs) - limit :] for block in scope: - if is_user_message(block) or is_assistant_message(block): + if is_user_history_message(block): selected_msgs.append(block) return selected_msgs @@ -63,7 +78,10 @@ def get_messages(self, messages: List[Block]) -> List[Block]: msgs = messages[:] msgs.pop() # don't add the current prompt to the memory - for block in reversed(msgs): + history_msgs = [ + msg for msg in msgs if is_user_history_message(msg) + ] # filter to only user history messages + for block in reversed(history_msgs): if block.chat_role != RoleTag.SYSTEM and current_tokens < self.max_tokens: block_tokens = tokens(block) if block_tokens + current_tokens < self.max_tokens: diff --git a/src/steamship/agents/service/agent_service.py b/src/steamship/agents/service/agent_service.py index 5476f00ad..9bc245619 100644 --- a/src/steamship/agents/service/agent_service.py +++ b/src/steamship/agents/service/agent_service.py @@ -2,15 +2,24 @@ import uuid from typing import List, Optional -from steamship import Block, SteamshipError, Task +from pydantic.main import BaseModel + +from steamship import Block, File, SteamshipError, Task from steamship.agents.llms.openai import OpenAI from steamship.agents.logging import AgentLogging, StreamingOpts from steamship.agents.schema import Action, Agent, FinishAction from steamship.agents.schema.context import AgentContext, Metadata from steamship.agents.utils import with_llm +from steamship.data import TagKind +from steamship.data.tags.tag_constants import ChatTag from steamship.invocable import PackageService, post +class StreamingResponse(BaseModel): + task: Task + file: File + + class AgentService(PackageService): """AgentService is a Steamship Package that can use an Agent, Tools, and a provided AgentContext to respond to user input.""" @@ -84,6 +93,7 @@ def next_action(self, agent: Agent, input_blocks: List[Block], context: AgentCon }, ) + # save action selection to history... return action def run_action(self, agent: Agent, action: Action, context: AgentContext): @@ -109,7 +119,7 @@ def run_action(self, agent: Agent, action: Action, context: AgentContext): }, ) action.output = output_blocks - context.completed_steps.append(action) + agent.record_action_run(action, context) return tool = next((tool for tool in agent.tools if tool.name == action.tool), None) @@ -150,7 +160,8 @@ def run_action(self, agent: Agent, action: Action, context: AgentContext): action.is_final = ( tool.is_final ) # Permit the tool to decide if this action should halt the reasoning loop. - context.completed_steps.append(action) + + agent.record_action_run(action, context) if context.action_cache and tool.cacheable: context.action_cache.update(key=action, value=action.output) @@ -295,6 +306,33 @@ def build_default_context(self, context_id: Optional[str] = None, **kwargs) -> A context = with_llm(context=context, llm=llm) return context + @post("async_prompt") + def async_prompt( + self, prompt: Optional[str] = None, context_id: Optional[str] = None, **kwargs + ) -> StreamingResponse: + with self.build_default_context(context_id, **kwargs) as context: + ctx_id = context_id + + # if no context ID is provided, we need to make sure that the streaming context ID + # is the same one as the non-streaming. + if not ctx_id: + ctx_file = context.chat_history.file + for tag in ctx_file.tags: + if tag.kind == TagKind.CHAT and tag.name == ChatTag.CONTEXT_KEYS: + if value := tag.value: + ctx_id = value.get("id", None) + + # if you can't find a consistent context_id, then there is no way to provide an accurate + # streaming endpoint. + if not ctx_id: + # TODO(dougreid): this points to a slight flaw in the context_keys vs. context_id + raise SteamshipError("Error setting up context: no id found for context.") + + task = self.invoke_later( + "/prompt", arguments={"prompt": prompt, "context_id": ctx_id, **kwargs} + ) + return StreamingResponse(task=task, file=context.chat_history.file) + @post("prompt") def prompt( self, prompt: Optional[str] = None, context_id: Optional[str] = None, **kwargs diff --git a/src/steamship/data/tags/tag_constants.py b/src/steamship/data/tags/tag_constants.py index d4847bf97..381db0f15 100644 --- a/src/steamship/data/tags/tag_constants.py +++ b/src/steamship/data/tags/tag_constants.py @@ -30,9 +30,12 @@ class TagKind(str, Enum): CHAT = "chat" CHAT_HISTORY_CONTEXT = "chat-history-context" MESSAGE_ID = "message-id" + STATUS_MESSAGE = "status-message" AGENT_STATUS_MESSAGE = "agent-status-message" TOOL_STATUS_MESSAGE = "tool-status-message" LLM_STATUS_MESSAGE = "llm-status-message" + FUNCTION_ARG = "function-arg" + FUNCTION_SELECTION = "function-selection" class DocTag(str, Enum): diff --git a/src/steamship/utils/repl.py b/src/steamship/utils/repl.py index bc1cf0311..a4565286d 100644 --- a/src/steamship/utils/repl.py +++ b/src/steamship/utils/repl.py @@ -14,6 +14,8 @@ from steamship.agents.logging import AgentLogging from steamship.agents.schema import AgentContext, Tool from steamship.agents.service.agent_service import AgentService +from steamship.data import TagKind, TagValueKey +from steamship.data.tags.tag_utils import get_tag from steamship.data.workspace import Workspace from steamship.invocable.dev_logging_handler import DevelopmentLoggingHandler @@ -202,10 +204,23 @@ def print_history(self, client: Steamship, *args, **kwargs): history = agent_ctx.chat_history history.refresh() for block in history.messages: + chat_role = block.chat_role + status_msg = get_tag(block.tags, kind=TagKind.STATUS_MESSAGE) + if not chat_role and not status_msg: + continue + + if chat_role: + prefix = f"[{chat_role}]" + else: + if value := status_msg.value: + prefix = f"[{value.get(TagValueKey.STRING_VALUE)} status]" + else: + prefix = "[status]" + if block.is_text(): - print(f"[{block.chat_role}] {block.text}") + print(f"{prefix} {block.text}") else: - print(f"[{block.chat_role}] {block.id} ({block.mime_type})") + print(f"{prefix} {block.id} ({block.mime_type})") print("\n------------------------------\n") exit(0) diff --git a/tests/steamship_tests/agents/test_agent_service.py b/tests/steamship_tests/agents/test_agent_service.py index ecb9ffecd..ecaf18bee 100644 --- a/tests/steamship_tests/agents/test_agent_service.py +++ b/tests/steamship_tests/agents/test_agent_service.py @@ -5,7 +5,7 @@ from steamship_tests import SRC_PATH from steamship_tests.utils.deployables import deploy_package -from steamship import Block, Steamship, SteamshipError, Task +from steamship import Block, File, Steamship, SteamshipError, Task, TaskState from steamship.agents.functional import FunctionsBasedAgent from steamship.agents.llms.openai import ChatOpenAI from steamship.agents.schema import Action, AgentContext, Tool @@ -248,3 +248,38 @@ def test_context_logging_to_chat_history_everything(client: Steamship): assert not has_status_message(chat_history.messages, RoleTag.AGENT) assert not has_status_message(chat_history.messages, RoleTag.LLM) assert has_status_message(chat_history.messages, RoleTag.TOOL) + + +@pytest.mark.usefixtures("client") +def test_async_prompt(client: Steamship): + example_agent_service_path = ( + SRC_PATH / "steamship" / "agents" / "examples" / "example_assistant.py" + ) + with deploy_package(client, example_agent_service_path, wait_for_init=True) as ( + _, + _, + agent_service, + ): + context_id = "some_async_fun" + streaming_resp = agent_service.invoke( + "async_prompt", + prompt="who is the current president of the United States?", + context_id=context_id, + ) + + assert streaming_resp is not None + assert streaming_resp["file"] is not None + assert streaming_resp["task"] is not None + + file = File(client=client, **(streaming_resp["file"])) + task = Task(client=client, **(streaming_resp["task"])) + + original_len = len(file.blocks) + + task.wait() + assert task.state in [TaskState.succeeded, TaskState.failed] + + file.refresh() + assert ( + len(file.blocks) > original_len + ), "File should have increased in size during AgentService execution" From 91bfebeae7fb1dffd42032a2d524f5a5ba145f99 Mon Sep 17 00:00:00 2001 From: Douglas Reid Date: Sun, 24 Sep 2023 22:06:25 -0700 Subject: [PATCH 2/4] modifications --- requirements.dev.txt | 1 + .../agents/functional/functions_based.py | 22 ++++- .../agents/functional/output_parser.py | 1 + src/steamship/agents/llms/openai.py | 6 +- src/steamship/agents/schema/chathistory.py | 16 +++- src/steamship/agents/schema/context.py | 17 +++- src/steamship/agents/service/agent_service.py | 1 + src/steamship/data/block.py | 7 ++ .../agents/test_agent_service.py | 80 ++++++++++++++++--- 9 files changed, 130 insertions(+), 21 deletions(-) diff --git a/requirements.dev.txt b/requirements.dev.txt index 549c86261..5322aeb89 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -51,6 +51,7 @@ sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 +sseclient-py==1.8.0 toml==0.10.2 tomli==2.0.0 typing-extensions==4.2.0 diff --git a/src/steamship/agents/functional/functions_based.py b/src/steamship/agents/functional/functions_based.py index 3bfb5537c..a2253b6c8 100644 --- a/src/steamship/agents/functional/functions_based.py +++ b/src/steamship/agents/functional/functions_based.py @@ -27,13 +27,17 @@ def __init__(self, tools: List[Tool], llm: ChatLLM, **kwargs): output_parser=FunctionsBasedOutputParser(tools=tools), llm=llm, tools=tools, **kwargs ) + def _get_or_create_system_message(self, context: AgentContext) -> Block: + sys_msg = context.chat_history.last_system_message + if sys_msg: + return sys_msg + + return context.chat_history.append_system_message(text=self.PROMPT, mime_type=MimeTypes.TXT) + def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]: - messages: List[Block] = [] + messages: List[Block] = [self._get_or_create_system_message(context)] # get system message - system_message = Block(text=self.PROMPT) - system_message.set_chat_role(RoleTag.SYSTEM) - messages.append(system_message) messages_from_memory = [] # get prior conversations @@ -107,6 +111,11 @@ def _record_action_selection(self, action: Action, context: AgentContext): tags = [ Tag(kind=TagKind.ROLE, name=RoleTag.ASSISTANT), Tag(kind=TagKind.FUNCTION_SELECTION, name=action.tool), + Tag( + kind="request-id", + name=context.request_id, + value={TagValueKey.STRING_VALUE: context.request_id}, + ), ] context.chat_history.file.append_block( text=self._to_openai_function_selection(action), tags=tags, mime_type=MimeTypes.TXT @@ -121,6 +130,11 @@ def record_action_run(self, action: Action, context: AgentContext): name=RoleTag.FUNCTION, value={TagValueKey.STRING_VALUE: action.tool}, ), + Tag( + kind="request-id", + name=context.request_id, + value={TagValueKey.STRING_VALUE: context.request_id}, + ), ] # TODO(dougreid): I'm not convinced this is correct for tools that return multiple values. # It _feels_ like these should be named and inlined as a single message in history, etc. diff --git a/src/steamship/agents/functional/output_parser.py b/src/steamship/agents/functional/output_parser.py index cb1fae6d2..b1ae50b71 100644 --- a/src/steamship/agents/functional/output_parser.py +++ b/src/steamship/agents/functional/output_parser.py @@ -136,4 +136,5 @@ def parse(self, text: str, context: AgentContext) -> Action: finish_blocks = FunctionsBasedOutputParser._blocks_from_text(context.client, text) for finish_block in finish_blocks: finish_block.set_chat_role(RoleTag.ASSISTANT) + finish_block.set_request_id(context.request_id) return FinishAction(output=finish_blocks, context=context) diff --git a/src/steamship/agents/llms/openai.py b/src/steamship/agents/llms/openai.py index c02039459..d05c629ea 100644 --- a/src/steamship/agents/llms/openai.py +++ b/src/steamship/agents/llms/openai.py @@ -119,11 +119,11 @@ def chat(self, messages: List[Block], tools: Optional[List[Tool]], **kwargs) -> # for streaming use cases, we want to always use the existing file # the way to detect this would be if all messages were from the same file if self._from_same_file(blocks=messages): - file_id = messages[0].id - block_ids = [b.id for b in messages] + file_id = messages[0].file_id + block_indices = [b.index_in_file for b in messages] generate_task = self.generator.generate( input_file_id=file_id, - input_file_block_index_list=block_ids, + input_file_block_index_list=block_indices, options=options, append_output_to_file=True, ) diff --git a/src/steamship/agents/schema/chathistory.py b/src/steamship/agents/schema/chathistory.py index 609eebc81..be75e95ad 100644 --- a/src/steamship/agents/schema/chathistory.py +++ b/src/steamship/agents/schema/chathistory.py @@ -350,10 +350,12 @@ class ChatHistoryLoggingHandler(StreamHandler): chat_history: ChatHistory log_level: any streaming_opts: StreamingOpts + request_id: str def __init__( self, chat_history: ChatHistory, + request_id: str, log_level: any = logging.INFO, streaming_opts: Optional[StreamingOpts] = None, ): @@ -366,6 +368,7 @@ def __init__( self.streaming_opts = streaming_opts else: self.streaming_opts = StreamingOpts() + self.request_id = request_id def emit(self, record): if record.levelno < self.log_level: @@ -390,9 +393,16 @@ def _append_message(self, message_dict: dict, author_kind: str): message = message_dict.get("message", None) message_type = message_dict.get(AgentLogging.MESSAGE_TYPE, AgentLogging.MESSAGE) + req_id_tag = Tag( + kind="request-id", + name=self.request_id, + value={TagValueKey.STRING_VALUE: self.request_id}, + ) + if author_kind == AgentLogging.AGENT: return self.chat_history.append_agent_message( text=message, + tags=[req_id_tag], mime_type=MimeTypes.TXT, ) elif author_kind == AgentLogging.TOOL: @@ -404,7 +414,8 @@ def _append_message(self, message_dict: dict, author_kind: str): kind=TagKind.TOOL_STATUS_MESSAGE, name=message_type, value={TagValueKey.STRING_VALUE: message, "tool": tool_name}, - ) + ), + req_id_tag, ], mime_type=MimeTypes.TXT, ) @@ -417,7 +428,8 @@ def _append_message(self, message_dict: dict, author_kind: str): kind=TagKind.LLM_STATUS_MESSAGE, name=message_type, value={TagValueKey.STRING_VALUE: message, "llm": llm_name}, - ) + ), + req_id_tag, ], mime_type=MimeTypes.TXT, ) diff --git a/src/steamship/agents/schema/context.py b/src/steamship/agents/schema/context.py index dad86531e..854325d32 100644 --- a/src/steamship/agents/schema/context.py +++ b/src/steamship/agents/schema/context.py @@ -1,4 +1,5 @@ import logging +import uuid from typing import Any, Callable, Dict, List, Optional from steamship import Block, Steamship, Tag @@ -49,10 +50,14 @@ def id(self) -> str: """Caches all interations with LLMs within a Context. This provides a way to avoid duplicated calls to LLMs when within the same context.""" - def __init__(self, streaming_opts: Optional[StreamingOpts] = None): + request_id: str + """Identifier for the current request being handled by this context.""" + + def __init__(self, request_id: str, streaming_opts: Optional[StreamingOpts] = None): self.metadata = {} self.completed_steps = [] self.emit_funcs = [] + self.request_id = request_id # TODO: protect this? if streaming_opts is not None: self._streaming_opts = streaming_opts else: @@ -67,14 +72,18 @@ def get_or_create( use_llm_cache: Optional[bool] = False, use_action_cache: Optional[bool] = False, streaming_opts: Optional[StreamingOpts] = None, + request_id: Optional[str] = None, ): from steamship.agents.schema.chathistory import ChatHistory if streaming_opts is None: streaming_opts = StreamingOpts() + if request_id is None: + request_id = str(uuid.uuid4()) + history = ChatHistory.get_or_create(client, context_keys, tags, searchable=searchable) - context = AgentContext(streaming_opts=streaming_opts) + context = AgentContext(streaming_opts=streaming_opts, request_id=request_id) context.chat_history = history context.client = client @@ -97,7 +106,9 @@ def __enter__(self): if self._streaming_opts.stream_intermediate_events: self._chat_history_logger = ChatHistoryLoggingHandler( - chat_history=self.chat_history, streaming_opts=self._streaming_opts + chat_history=self.chat_history, + streaming_opts=self._streaming_opts, + request_id=self.request_id, ) logger = logging.getLogger() logger.addHandler(self._chat_history_logger) diff --git a/src/steamship/agents/service/agent_service.py b/src/steamship/agents/service/agent_service.py index 9bc245619..4b6524a94 100644 --- a/src/steamship/agents/service/agent_service.py +++ b/src/steamship/agents/service/agent_service.py @@ -285,6 +285,7 @@ def build_default_context(self, context_id: Optional[str] = None, **kwargs) -> A context = AgentContext.get_or_create( client=self.client, + request_id=self.client.config.request_id, context_keys={"id": f"{context_id}"}, use_llm_cache=use_llm_cache, use_action_cache=use_action_cache, diff --git a/src/steamship/data/block.py b/src/steamship/data/block.py index f7c4c7891..9501ff69a 100644 --- a/src/steamship/data/block.py +++ b/src/steamship/data/block.py @@ -274,6 +274,13 @@ def set_chat_id(self, chat_id: str): tag_kind=DocTag.CHAT, tag_name=ChatTag.CHAT_ID, string_value=chat_id ) + def set_request_id(self, request_id: Optional[str]): + if not request_id or len(request_id.strip()) == 0: + return + return self._one_time_set_tag( + tag_kind="request-id", tag_name=request_id, string_value=request_id + ) + @property def thread_id(self) -> Optional[str]: return get_tag_value_key( diff --git a/tests/steamship_tests/agents/test_agent_service.py b/tests/steamship_tests/agents/test_agent_service.py index ecaf18bee..4b7bf9bff 100644 --- a/tests/steamship_tests/agents/test_agent_service.py +++ b/tests/steamship_tests/agents/test_agent_service.py @@ -261,25 +261,87 @@ def test_async_prompt(client: Steamship): agent_service, ): context_id = "some_async_fun" - streaming_resp = agent_service.invoke( - "async_prompt", - prompt="who is the current president of the United States?", - context_id=context_id, - ) + try: + streaming_resp = agent_service.invoke( + "async_prompt", + prompt="who is the current president of the United States?", + context_id=context_id, + ) + except SteamshipError as error: + pytest.fail(f"failed request: {error}") assert streaming_resp is not None assert streaming_resp["file"] is not None assert streaming_resp["task"] is not None file = File(client=client, **(streaming_resp["file"])) - task = Task(client=client, **(streaming_resp["task"])) + streaming_task = Task(client=client, **(streaming_resp["task"])) original_len = len(file.blocks) - - task.wait() - assert task.state in [TaskState.succeeded, TaskState.failed] + req_id = streaming_task.request_id + + import sseclient + + # import requests + sse_source = f"{client.config.api_base}file/{file.id}/stream?tagKindFilter=request-id&tagNameFilter={req_id}" + print(f"\nSSE SOURCE: {sse_source}") + headers = { + "Accept": "text/event-stream", + "X-Workspace-Id": client.get_workspace().id, + "Authorization": f"Bearer {client.config.api_key.get_secret_value()}", + } + print(headers) + # sse_response = requests.get(sse_source, stream=True, headers=headers) + + # TODO: it is concerning that putting this block after a Task completes is the only way it seems to work + # import urllib3 + # timeout = urllib3.Timeout(connect=2.0, read=10.0) + # http = urllib3.PoolManager(timeout=timeout) + # sse_response = http.request('GET', sse_source, preload_content=False, headers=headers) + # + # sse_client = sseclient.SSEClient(sse_response) + # num_events = 0 + # try: + # for event in sse_client.events(): + # num_events += 1 + # + # # TODO: should we assert these are blockCreated events? + # print(event) + # except urllib3.exceptions.ReadTimeoutError: + # # TODO: this is required to get a termination on the streaming wait. + # # TODO: do we have a way to signal events have stopped? + # pass + + try: + streaming_task.wait_until_completed() + except SteamshipError as error: + pytest.fail(f"Task failed to complete: {error}") + + assert streaming_task.state in [TaskState.succeeded] + + # TODO: it is concerning that putting this block after a Task completes is the only way it seems to work + import urllib3 + + timeout = urllib3.Timeout(connect=2.0, read=10.0) + http = urllib3.PoolManager(timeout=timeout) + sse_response = http.request("GET", sse_source, preload_content=False, headers=headers) + + sse_client = sseclient.SSEClient(sse_response) + num_events = 0 + try: + for event in sse_client.events(): + num_events += 1 + + # TODO: should we assert these are blockCreated events? + print(event) + except urllib3.exceptions.ReadTimeoutError: + # TODO: this is required to get a termination on the streaming wait. + # TODO: do we have a way to signal events have stopped? + pass file.refresh() assert ( len(file.blocks) > original_len ), "File should have increased in size during AgentService execution" + + assert num_events == 13 From e8a9f8d999f332f28260f788a231619bbf14b00c Mon Sep 17 00:00:00 2001 From: Douglas Reid Date: Mon, 25 Sep 2023 10:40:29 -0700 Subject: [PATCH 3/4] clean up test --- .../agents/test_agent_service.py | 97 ++++++++++++------- 1 file changed, 60 insertions(+), 37 deletions(-) diff --git a/tests/steamship_tests/agents/test_agent_service.py b/tests/steamship_tests/agents/test_agent_service.py index 4b7bf9bff..58074ef99 100644 --- a/tests/steamship_tests/agents/test_agent_service.py +++ b/tests/steamship_tests/agents/test_agent_service.py @@ -1,3 +1,4 @@ +import json from typing import Any, List, Union import pytest @@ -8,6 +9,7 @@ from steamship import Block, File, Steamship, SteamshipError, Task, TaskState from steamship.agents.functional import FunctionsBasedAgent from steamship.agents.llms.openai import ChatOpenAI +from steamship.agents.logging import AgentLogging from steamship.agents.schema import Action, AgentContext, Tool from steamship.agents.service.agent_service import AgentService from steamship.data.tags.tag_constants import ChatTag, RoleTag, TagKind, TagValueKey @@ -251,7 +253,7 @@ def test_context_logging_to_chat_history_everything(client: Steamship): @pytest.mark.usefixtures("client") -def test_async_prompt(client: Steamship): +def test_async_prompt(client: Steamship): # noqa: C901 example_agent_service_path = ( SRC_PATH / "steamship" / "agents" / "examples" / "example_assistant.py" ) @@ -280,37 +282,17 @@ def test_async_prompt(client: Steamship): original_len = len(file.blocks) req_id = streaming_task.request_id + import requests import sseclient - # import requests - sse_source = f"{client.config.api_base}file/{file.id}/stream?tagKindFilter=request-id&tagNameFilter={req_id}" - print(f"\nSSE SOURCE: {sse_source}") + sse_source = f"{client.config.api_base}file/{file.id}/stream?tagKindFilter=request-id&tagNameFilter={req_id}&timeoutSeconds=30" + # print(f"\nSSE SOURCE: {sse_source}") headers = { "Accept": "text/event-stream", "X-Workspace-Id": client.get_workspace().id, "Authorization": f"Bearer {client.config.api_key.get_secret_value()}", } - print(headers) - # sse_response = requests.get(sse_source, stream=True, headers=headers) - - # TODO: it is concerning that putting this block after a Task completes is the only way it seems to work - # import urllib3 - # timeout = urllib3.Timeout(connect=2.0, read=10.0) - # http = urllib3.PoolManager(timeout=timeout) - # sse_response = http.request('GET', sse_source, preload_content=False, headers=headers) - # - # sse_client = sseclient.SSEClient(sse_response) - # num_events = 0 - # try: - # for event in sse_client.events(): - # num_events += 1 - # - # # TODO: should we assert these are blockCreated events? - # print(event) - # except urllib3.exceptions.ReadTimeoutError: - # # TODO: this is required to get a termination on the streaming wait. - # # TODO: do we have a way to signal events have stopped? - # pass + # print(headers) try: streaming_task.wait_until_completed() @@ -319,29 +301,70 @@ def test_async_prompt(client: Steamship): assert streaming_task.state in [TaskState.succeeded] - # TODO: it is concerning that putting this block after a Task completes is the only way it seems to work - import urllib3 - - timeout = urllib3.Timeout(connect=2.0, read=10.0) - http = urllib3.PoolManager(timeout=timeout) - sse_response = http.request("GET", sse_source, preload_content=False, headers=headers) + llm_prompt_event_count = 0 + function_selection_event = False + tool_execution_event = False + function_complete_event = False + assistant_chat_response_event = False + # TODO: it is concerning that putting this block after a Task completes is the only way it seems to work + sse_response = requests.get(sse_source, stream=True, headers=headers, timeout=45) sse_client = sseclient.SSEClient(sse_response) num_events = 0 try: + # sse event format: {'blockCreated': {'blockId': '', 'createdAt': '2023-09-25T16:12:54Z'}} for event in sse_client.events(): num_events += 1 - - # TODO: should we assert these are blockCreated events? - print(event) - except urllib3.exceptions.ReadTimeoutError: + block_creation_event = json.loads(event.data) + block_created = block_creation_event["blockCreated"] + block_id = block_created["blockId"] + block = Block.get(client=client, _id=block_id) + for t in block.tags: + match t.kind: + case TagKind.LLM_STATUS_MESSAGE: + if t.name == AgentLogging.PROMPT: + llm_prompt_event_count += 1 + case TagKind.FUNCTION_SELECTION: + if t.name == "SearchTool": + function_selection_event = True + case TagKind.TOOL_STATUS_MESSAGE: + tool_execution_event = True + case TagKind.ROLE: + if t.name == RoleTag.FUNCTION: + function_complete_event = True + case TagKind.CHAT: + if ( + t.name == ChatTag.ROLE + and t.value.get(TagValueKey.STRING_VALUE, "") == RoleTag.ASSISTANT + ): + assistant_chat_response_event = True + except requests.exceptions.ConnectionError as err: + sse_response.close() # TODO: this is required to get a termination on the streaming wait. # TODO: do we have a way to signal events have stopped? - pass + if "Read timed out." in str(err): + pass + else: + raise err + except Exception as err: + sse_response.close() + raise err + else: + sse_response.close() file.refresh() assert ( len(file.blocks) > original_len ), "File should have increased in size during AgentService execution" - assert num_events == 13 + assert num_events > 0, "Events should have been streamed during execution" + assert llm_prompt_event_count == 2, ( + "At least 3 llm prompts should have happened (first for tool selection, " + "second for generating final answer)" + ) + assert function_selection_event is True, "SearchTool should have been selected" + assert tool_execution_event is True, "SearchTool should log a status message" + assert function_complete_event is True, "SearchTool should return a response" + assert ( + assistant_chat_response_event is True + ), "Agent should have sent the assistant chat response" From 832eb1fe72cd288d1bbf64f42c0c26f357c10cd5 Mon Sep 17 00:00:00 2001 From: Douglas Reid Date: Mon, 25 Sep 2023 21:36:47 -0700 Subject: [PATCH 4/4] full test with in-flight streaming --- .../agents/test_agent_service.py | 178 +++++++++++------- 1 file changed, 113 insertions(+), 65 deletions(-) diff --git a/tests/steamship_tests/agents/test_agent_service.py b/tests/steamship_tests/agents/test_agent_service.py index 58074ef99..927cb0d85 100644 --- a/tests/steamship_tests/agents/test_agent_service.py +++ b/tests/steamship_tests/agents/test_agent_service.py @@ -1,7 +1,9 @@ import json +import time from typing import Any, List, Union import pytest +import requests from pydantic.fields import PrivateAttr from steamship_tests import SRC_PATH from steamship_tests.utils.deployables import deploy_package @@ -27,7 +29,6 @@ def _blocks_from_invoke(client: Steamship, potential_blocks) -> List[Block]: @pytest.mark.usefixtures("client") def test_example_with_caching_service(client: Steamship): - # TODO(dougreid): replace the example agent with fake/free/fast tools to minimize test time / costs? example_caching_agent_path = ( @@ -87,7 +88,6 @@ def test_example_with_caching_service(client: Steamship): class FakeUncachableTool(Tool): - name = "FakeUncacheableTool" human_description = "Fake tool" agent_description = "Ignored" @@ -252,8 +252,16 @@ def test_context_logging_to_chat_history_everything(client: Steamship): assert has_status_message(chat_history.messages, RoleTag.TOOL) -@pytest.mark.usefixtures("client") -def test_async_prompt(client: Steamship): # noqa: C901 +@pytest.fixture() +def close_clients(): + to_close = [] + yield to_close + for item in to_close: + item.close() + + +@pytest.mark.usefixtures("client", "close_clients") +def test_async_prompt(client: Steamship, close_clients): # noqa: C901 example_agent_service_path = ( SRC_PATH / "steamship" / "agents" / "examples" / "example_assistant.py" ) @@ -280,27 +288,16 @@ def test_async_prompt(client: Steamship): # noqa: C901 streaming_task = Task(client=client, **(streaming_resp["task"])) original_len = len(file.blocks) - req_id = streaming_task.request_id - import requests - import sseclient + while streaming_task.state in [TaskState.waiting]: + # tight loop to check on waiting status of Task + # we only want to try to stream once a Task **starts** + time.sleep(0.1) + streaming_task.refresh() - sse_source = f"{client.config.api_base}file/{file.id}/stream?tagKindFilter=request-id&tagNameFilter={req_id}&timeoutSeconds=30" - # print(f"\nSSE SOURCE: {sse_source}") - headers = { - "Accept": "text/event-stream", - "X-Workspace-Id": client.get_workspace().id, - "Authorization": f"Bearer {client.config.api_key.get_secret_value()}", - } - # print(headers) - - try: - streaming_task.wait_until_completed() - except SteamshipError as error: - pytest.fail(f"Task failed to complete: {error}") - - assert streaming_task.state in [TaskState.succeeded] + assert streaming_task.state in [TaskState.running] + block_ids_seen = [] llm_prompt_event_count = 0 function_selection_event = False tool_execution_event = False @@ -308,58 +305,47 @@ def test_async_prompt(client: Steamship): # noqa: C901 assistant_chat_response_event = False # TODO: it is concerning that putting this block after a Task completes is the only way it seems to work - sse_response = requests.get(sse_source, stream=True, headers=headers, timeout=45) - sse_client = sseclient.SSEClient(sse_response) num_events = 0 - try: - # sse event format: {'blockCreated': {'blockId': '', 'createdAt': '2023-09-25T16:12:54Z'}} - for event in sse_client.events(): - num_events += 1 - block_creation_event = json.loads(event.data) - block_created = block_creation_event["blockCreated"] - block_id = block_created["blockId"] - block = Block.get(client=client, _id=block_id) - for t in block.tags: - match t.kind: - case TagKind.LLM_STATUS_MESSAGE: - if t.name == AgentLogging.PROMPT: - llm_prompt_event_count += 1 - case TagKind.FUNCTION_SELECTION: - if t.name == "SearchTool": - function_selection_event = True - case TagKind.TOOL_STATUS_MESSAGE: - tool_execution_event = True - case TagKind.ROLE: - if t.name == RoleTag.FUNCTION: - function_complete_event = True - case TagKind.CHAT: - if ( - t.name == ChatTag.ROLE - and t.value.get(TagValueKey.STRING_VALUE, "") == RoleTag.ASSISTANT - ): - assistant_chat_response_event = True - except requests.exceptions.ConnectionError as err: - sse_response.close() - # TODO: this is required to get a termination on the streaming wait. - # TODO: do we have a way to signal events have stopped? - if "Read timed out." in str(err): - pass - else: - raise err - except Exception as err: - sse_response.close() - raise err - else: - sse_response.close() + # sse event format: {'blockCreated': {'blockId': '', 'createdAt': '2023-09-25T16:12:54Z'}} + for event in events_while_running(client, streaming_task, file): + # TODO: it seems like event ids aren't consistent + block_creation_event = json.loads(event.data) + block_created = block_creation_event["blockCreated"] + block_id = block_created["blockId"] + if block_id in block_ids_seen: + continue + block_ids_seen.append(block_id) + num_events += 1 + block = Block.get(client=client, _id=block_id) + for t in block.tags: + match t.kind: + case TagKind.LLM_STATUS_MESSAGE: + if t.name == AgentLogging.PROMPT: + llm_prompt_event_count += 1 + case TagKind.FUNCTION_SELECTION: + if t.name == "SearchTool": + function_selection_event = True + case TagKind.TOOL_STATUS_MESSAGE: + tool_execution_event = True + case TagKind.ROLE: + if t.name == RoleTag.FUNCTION: + function_complete_event = True + case TagKind.CHAT: + if ( + t.name == ChatTag.ROLE + and t.value.get(TagValueKey.STRING_VALUE, "") == RoleTag.ASSISTANT + ): + assistant_chat_response_event = True file.refresh() assert ( len(file.blocks) > original_len ), "File should have increased in size during AgentService execution" + print(f"num events: {num_events}") assert num_events > 0, "Events should have been streamed during execution" assert llm_prompt_event_count == 2, ( - "At least 3 llm prompts should have happened (first for tool selection, " + "At least 2 llm prompts should have happened (first for tool selection, " "second for generating final answer)" ) assert function_selection_event is True, "SearchTool should have been selected" @@ -368,3 +354,65 @@ def test_async_prompt(client: Steamship): # noqa: C901 assert ( assistant_chat_response_event is True ), "Agent should have sent the assistant chat response" + # assert False + + +def events_while_running(client: Steamship, task: Task, file: File): + req_id = task.request_id + while task.state in [TaskState.running]: + yields = 0 + # NOTE: I'm convinced that this generator of generators approach is not correct, but... + event_gen = events_for_file(client, file.id, req_id) + try: + for event in event_gen: + yields += 1 + yield event + except StopIteration: + # not ready to stream, or done streaming. + print("stop iteration") + pass + print(f"total yields: {yields}") + # This is not ideal, but it at least should make sure we get **all** events + task.refresh() + print(f"task is complete: {task.state}") + + +def events_for_file(client: Steamship, file_id: str, req_id: str): + print("getting events for file.") + import sseclient + + sse_source = f"{client.config.api_base}file/{file_id}/stream?tagKindFilter=request-id&tagNameFilter={req_id}&timeoutSeconds=30" + headers = { + "Accept": "text/event-stream", + "X-Workspace-Id": client.get_workspace().id, + "Authorization": f"Bearer {client.config.api_key.get_secret_value()}", + } + + sse_response = requests.get(sse_source, stream=True, headers=headers, timeout=45) + sse_client = sseclient.SSEClient(sse_response) + yields = 0 + try: + for event in sse_client.events(): + yields += 1 + print(f"--> yield: {yields}") + yield event + except requests.exceptions.ConnectionError as err: + if "Read timed out." in str(err): + print("-- timeout") + pass + else: + sse_client.close() + sse_response.close() + raise err + except StopIteration: + print("-- stop iteration") + sse_client.close() + sse_response.close() + except Exception as err: + sse_client.close() + sse_response.close() + raise err + else: + print("-- successful close of stream.") + sse_client.close() + sse_response.close()