From f97e411317be26435c0f8c9522d3252d2234150f Mon Sep 17 00:00:00 2001 From: Douglas Reid <douglas-reid@users.noreply.github.com> Date: Fri, 7 Jul 2023 15:27:42 -0400 Subject: [PATCH] feat(agents): add functions-based agent to SDK (#454) This is meant to provide a basic implementation of an agent that works with tools using OpenAI's functions extension as the mechanism for tools selection. This is an initial implementation that will likely need refinement as use cases are explored. --------- Co-authored-by: Douglas Reid <doug@steamship.com> --- .../examples/my_functions_based_assistant.py | 75 +++++++++++ src/steamship/agents/functional/__init__.py | 7 + .../agents/functional/functions_based.py | 78 +++++++++++ .../agents/functional/output_parser.py | 89 +++++++++++++ src/steamship/agents/llms/openai.py | 34 ++++- src/steamship/agents/react/output_parser.py | 2 +- src/steamship/agents/schema/__init__.py | 6 +- src/steamship/agents/schema/action.py | 24 +++- src/steamship/agents/schema/agent.py | 13 +- src/steamship/agents/schema/chathistory.py | 3 + src/steamship/agents/schema/llm.py | 13 ++ src/steamship/agents/service/agent_service.py | 3 + .../agents/tools/image_generation/dalle.py | 2 +- src/steamship/data/block.py | 4 +- src/steamship/data/tags/tag_constants.py | 3 + .../agents/test_functions_based_agent.py | 125 ++++++++++++++++++ 16 files changed, 472 insertions(+), 9 deletions(-) create mode 100644 src/steamship/agents/examples/my_functions_based_assistant.py create mode 100644 src/steamship/agents/functional/__init__.py create mode 100644 src/steamship/agents/functional/functions_based.py create mode 100644 src/steamship/agents/functional/output_parser.py create mode 100644 tests/steamship_tests/agents/test_functions_based_agent.py diff --git a/src/steamship/agents/examples/my_functions_based_assistant.py b/src/steamship/agents/examples/my_functions_based_assistant.py new file mode 100644 index 000000000..0220c0673 --- /dev/null +++ b/src/steamship/agents/examples/my_functions_based_assistant.py @@ -0,0 +1,75 @@ +import uuid +from typing import List, Optional + +from steamship import Block +from steamship.agents.functional import FunctionsBasedAgent +from steamship.agents.llms.openai import ChatOpenAI +from steamship.agents.schema import AgentContext +from steamship.agents.schema.context import Metadata +from steamship.agents.schema.message_selectors import MessageWindowMessageSelector +from steamship.agents.service.agent_service import AgentService +from steamship.agents.tools.image_generation import DalleTool +from steamship.agents.tools.search import SearchTool +from steamship.invocable import post +from steamship.utils.repl import AgentREPL + + +class MyFunctionsBasedAssistant(AgentService): + """MyFunctionsBasedAssistant is an example AgentService that exposes a single test endpoint + for trying out Agent-based invocations. It is configured with two simple Tools + to provide an overview of the types of tasks it can accomplish (here, search + and image generation).""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._agent = FunctionsBasedAgent( + tools=[ + SearchTool(), + DalleTool(), + ], + llm=ChatOpenAI(self.client, temperature=0), + conversation_memory=MessageWindowMessageSelector(k=2), + ) + + @post("prompt") + def prompt(self, prompt: str, context_id: Optional[uuid.UUID] = None) -> str: + """Run an agent with the provided text as the input.""" + + # AgentContexts serve to allow the AgentService to run agents + # with appropriate information about the desired tasking. + # Here, we use the passed in context (or a new context) for the prompt, + # and append the prompt to the message history stored in the context. + if not context_id: + context_id = uuid.uuid4() + context = AgentContext.get_or_create(self.client, {"id": f"{context_id}"}) + context.chat_history.append_user_message(prompt) + + # AgentServices provide an emit function hook to access the output of running + # agents and tools. The emit functions fire at after the supplied agent emits + # a "FinishAction". + # + # Here, we show one way of accessing the output in a synchronous fashion. An + # alternative way would be to access the final Action in the `context.completed_steps` + # after the call to `run_agent()`. + output = "" + + def sync_emit(blocks: List[Block], meta: Metadata): + nonlocal output + block_text = "\n".join( + [b.text if b.is_text() else f"({b.mime_type}: {b.id})" for b in blocks] + ) + output += block_text + + context.emit_funcs.append(sync_emit) + self.run_agent(self._agent, context) + context.chat_history.append_assistant_message(output) + return output + + +if __name__ == "__main__": + # AgentREPL provides a mechanism for local execution of an AgentService method. + # This is used for simplified debugging as agents and tools are developed and + # added. + AgentREPL(MyFunctionsBasedAssistant, "prompt", agent_package_config={}).run( + context_id=uuid.uuid4() + ) diff --git a/src/steamship/agents/functional/__init__.py b/src/steamship/agents/functional/__init__.py new file mode 100644 index 000000000..736cf60dd --- /dev/null +++ b/src/steamship/agents/functional/__init__.py @@ -0,0 +1,7 @@ +from .functions_based import FunctionsBasedAgent +from .output_parser import FunctionsBasedOutputParser + +__all__ = [ + "FunctionsBasedAgent", + "FunctionsBasedOutputParser", +] diff --git a/src/steamship/agents/functional/functions_based.py b/src/steamship/agents/functional/functions_based.py new file mode 100644 index 000000000..4b9aaafef --- /dev/null +++ b/src/steamship/agents/functional/functions_based.py @@ -0,0 +1,78 @@ +from typing import List + +from steamship import Block +from steamship.agents.functional.output_parser import FunctionsBasedOutputParser +from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, Tool +from steamship.data.tags.tag_constants import RoleTag + + +class FunctionsBasedAgent(ChatAgent): + """Selects actions for AgentService based on OpenAI Function style LLM Prompting.""" + + PROMPT = """You are a helpful AI assistant. + +NOTE: Some functions return images, video, and audio files. These multimedia files will be represented in messages as +UUIDs for Steamship Blocks. When responding directly to a user, you SHOULD print the Steamship Blocks for the images, +video, or audio as follows: `Block(UUID for the block)`. + +Example response for a request that generated an image: +Here is the image you requested: Block(288A2CA1-4753-4298-9716-53C1E42B726B). + +Only use the functions you have been provided with.""" + + def __init__(self, tools: List[Tool], llm: ChatLLM, **kwargs): + super().__init__( + output_parser=FunctionsBasedOutputParser(tools=tools), llm=llm, tools=tools, **kwargs + ) + + def next_action(self, context: AgentContext) -> Action: + messages = [] + + # get system messsage + system_message = Block(text=self.PROMPT) + system_message.set_chat_role(RoleTag.SYSTEM) + messages.append(system_message) + + messages_from_memory = [] + # get prior conversations + if context.chat_history.is_searchable(): + messages_from_memory.extend( + context.chat_history.search(context.chat_history.last_user_message.text, k=3) + .wait() + .to_ranked_blocks() + ) + + # TODO(dougreid): we need a way to threshold message inclusion, especially for small contexts + + # remove the actual prompt from the semantic search (it will be an exact match) + messages_from_memory = [ + msg + for msg in messages_from_memory + if msg.id != context.chat_history.last_user_message.id + ] + + # get most recent context + messages_from_memory.extend(context.chat_history.select_messages(self.message_selector)) + + # de-dupe the messages from memory + ids = [] + for msg in messages_from_memory: + if msg.id not in ids: + messages.append(msg) + ids.append(msg.id) + + # TODO(dougreid): sort by dates? we SHOULD ensure ordering, given semantic search + + # put the user prompt in the appropriate message location + # this should happen BEFORE any agent/assistant messages related to tool selection + messages.append(context.chat_history.last_user_message) + + # get completed steps + actions = context.completed_steps + for action in actions: + messages.extend(action.to_chat_messages()) + + # call chat() + output_blocks = self.llm.chat(messages=messages, tools=self.tools) + + return self.output_parser.parse(output_blocks[0].text, context) diff --git a/src/steamship/agents/functional/output_parser.py b/src/steamship/agents/functional/output_parser.py new file mode 100644 index 000000000..22a6c718b --- /dev/null +++ b/src/steamship/agents/functional/output_parser.py @@ -0,0 +1,89 @@ +import json +import re +from typing import Dict, List, Optional + +from steamship import Block, MimeTypes, Steamship +from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool +from steamship.data.tags.tag_constants import RoleTag + + +# TODO(dougreid): extract shared bits from this and the ReACT output parser into a utility? +class FunctionsBasedOutputParser(OutputParser): + + tools_lookup_dict: Optional[Dict[str, Tool]] = None + + def __init__(self, **kwargs): + tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", [])} + super().__init__(tools_lookup_dict=tools_lookup_dict, **kwargs) + + def _extract_action_from_function_call(self, text: str, context: AgentContext) -> Action: + wrapper = json.loads(text) + fc = wrapper.get("function_call") + name = fc.get("name", "") + tool = self.tools_lookup_dict.get(name, None) + if tool is None: + raise RuntimeError( + f"Could not find tool from function call: `{name}`. Known tools: {self.tools_lookup_dict.keys()}" + ) + + input_blocks = [] + arguments = fc.get("arguments", "") + args = json.loads(arguments) + # TODO(dougreid): validation and error handling? + + if text := args.get("text"): + input_blocks.append(Block(text=text, mime_type=MimeTypes.TXT)) + else: + uuid = args.get("uuid") + input_blocks.append(Block.get(context.client, id=uuid)) + + return Action(tool=tool, input=input_blocks, context=context) + + @staticmethod + def _blocks_from_text(client: Steamship, text: str) -> List[Block]: + last_response = text.split("AI:")[-1].strip() + + block_id_regex = r"(?:(?:\[|\()?Block)?\(?([A-F0-9]{8}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{12})\)?(?:(\]|\)))?" + remaining_text = last_response + result_blocks: List[Block] = [] + while remaining_text is not None and len(remaining_text) > 0: + match = re.search(block_id_regex, remaining_text) + if match: + pre_block_text = FunctionsBasedOutputParser._remove_block_prefix( + candidate=remaining_text[0 : match.start()] + ) + if len(pre_block_text) > 0: + result_blocks.append(Block(text=pre_block_text)) + result_blocks.append(Block.get(client, _id=match.group(1))) + remaining_text = FunctionsBasedOutputParser._remove_block_suffix( + remaining_text[match.end() :] + ) + else: + result_blocks.append(Block(text=remaining_text)) + remaining_text = "" + + return result_blocks + + @staticmethod + def _remove_block_prefix(candidate: str) -> str: + removed = candidate + if removed.endswith("(Block") or removed.endswith("[Block"): + removed = removed[len("Block") + 1 :] + elif removed.endswith("Block"): + removed = removed[len("Block") :] + return removed + + @staticmethod + def _remove_block_suffix(candidate: str) -> str: + removed = candidate + if removed.startswith(")") or removed.endswith("]"): + removed = removed[1:] + return removed + + def parse(self, text: str, context: AgentContext) -> Action: + if "function_call" in text: + return self._extract_action_from_function_call(text, context) + + finish_block = Block(text=text) + finish_block.set_chat_role(RoleTag.ASSISTANT) + return FinishAction(output=[finish_block], context=context) diff --git a/src/steamship/agents/llms/openai.py b/src/steamship/agents/llms/openai.py index 57d944b0c..fb1b15dfc 100644 --- a/src/steamship/agents/llms/openai.py +++ b/src/steamship/agents/llms/openai.py @@ -1,7 +1,7 @@ from typing import List, Optional -from steamship import Block, PluginInstance, Steamship -from steamship.agents.schema import LLM +from steamship import Block, File, PluginInstance, Steamship +from steamship.agents.schema import LLM, ChatLLM, Tool PLUGIN_HANDLE = "gpt-4" @@ -38,3 +38,33 @@ def complete(self, prompt: str, stop: Optional[str] = None) -> List[Block]: action_task = self.generator.generate(text=prompt, options=options) action_task.wait() return action_task.output.blocks + + +class ChatOpenAI(ChatLLM, OpenAI): + """ChatLLM that uses Steamship's OpenAI plugin to generate chat completions.""" + + def __init__( + self, client, model_name: str = "gpt-4-0613", temperature: float = 0.4, *args, **kwargs + ): + """Create a new instance. + + Valid model names are: + - gpt-4 + - gpt-4-0613 + """ + super().__init__(client=client, model_name=model_name, *args, **kwargs) + + def chat(self, messages: List[Block], tools: Optional[List[Tool]]) -> List[Block]: + # TODO(dougreid): this feels icky. find a better way? + temp_file = File.create(client=self.client, blocks=messages) + + options = {} + if len(tools) > 0: + functions = [] + for tool in tools: + functions.append(tool.as_openai_function()) + options["functions"] = functions + + tool_selection_task = self.generator.generate(input_file_id=temp_file.id, options=options) + tool_selection_task.wait() + return tool_selection_task.output.blocks diff --git a/src/steamship/agents/react/output_parser.py b/src/steamship/agents/react/output_parser.py index 9d1c71f41..6c7e4e017 100644 --- a/src/steamship/agents/react/output_parser.py +++ b/src/steamship/agents/react/output_parser.py @@ -12,7 +12,7 @@ class ReACTOutputParser(OutputParser): tools_lookup_dict: Optional[Dict[str, Tool]] = None def __init__(self, **kwargs): - tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", None)} + tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", [])} super().__init__(tools_lookup_dict=tools_lookup_dict, **kwargs) def parse(self, text: str, context: AgentContext) -> Action: diff --git a/src/steamship/agents/schema/__init__.py b/src/steamship/agents/schema/__init__.py index a7e7509d2..3f70165bf 100644 --- a/src/steamship/agents/schema/__init__.py +++ b/src/steamship/agents/schema/__init__.py @@ -1,8 +1,8 @@ from .action import Action, FinishAction -from .agent import Agent, LLMAgent +from .agent import Agent, ChatAgent, LLMAgent from .chathistory import ChatHistory from .context import AgentContext, EmitFunc, Metadata -from .llm import LLM +from .llm import LLM, ChatLLM from .output_parser import OutputParser from .tool import Tool @@ -10,6 +10,8 @@ "Action", "Agent", "AgentContext", + "ChatLLM", + "ChatAgent", "EmitFunc", "FinishAction", "Metadata", diff --git a/src/steamship/agents/schema/action.py b/src/steamship/agents/schema/action.py index 13054733a..73ce8762b 100644 --- a/src/steamship/agents/schema/action.py +++ b/src/steamship/agents/schema/action.py @@ -2,8 +2,10 @@ from pydantic import BaseModel -from steamship import Block, Task +from steamship import Block, Tag, Task from steamship.agents.schema.tool import AgentContext, Tool +from steamship.data import TagKind +from steamship.data.tags.tag_constants import RoleTag class Action(BaseModel): @@ -21,6 +23,26 @@ class Action(BaseModel): output: Optional[List[Block]] = [] """Any direct output produced by the Tool.""" + def to_chat_messages(self) -> List[Block]: + tags = [ + Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION), + Tag(kind="name", name=self.tool.name), + ] + blocks = [] + for block in self.output: + # TODO(dougreid): should we revisit as_llm_input? we might need only the UUID... + blocks.append( + Block( + text=block.as_llm_input(exclude_block_wrapper=True), + tags=tags, + mime_type=block.mime_type, + ) + ) + + # TODO(dougreid): revisit when have multiple output functions. + # Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation. + return blocks + class AgentTool(Tool): def run(self, tool_input: List[Block], context: AgentContext) -> Union[List[Block], Task[Any]]: diff --git a/src/steamship/agents/schema/agent.py b/src/steamship/agents/schema/agent.py index c413a1b92..5c5ce658d 100644 --- a/src/steamship/agents/schema/agent.py +++ b/src/steamship/agents/schema/agent.py @@ -6,7 +6,7 @@ from steamship import Block from steamship.agents.schema.action import Action from steamship.agents.schema.context import AgentContext -from steamship.agents.schema.llm import LLM +from steamship.agents.schema.llm import LLM, ChatLLM from steamship.agents.schema.message_selectors import MessageSelector, NoMessages from steamship.agents.schema.output_parser import OutputParser from steamship.agents.schema.tool import Tool @@ -57,4 +57,15 @@ def messages_to_prompt_history(messages: List[Block]) -> str: as_strings.append(f"System: {block.text}") elif role == RoleTag.AGENT: as_strings.append(f"Agent: {block.text}") + elif role == RoleTag.FUNCTION: + as_strings.append(f"Function: {block.text}") return "\n".join(as_strings) + + +class ChatAgent(Agent, ABC): + """ChatAgents choose next actions for an AgentService based on chat-based interactions with an LLM.""" + + llm: ChatLLM + + output_parser: OutputParser + """Utility responsible for converting LLM output into Actions""" diff --git a/src/steamship/agents/schema/chathistory.py b/src/steamship/agents/schema/chathistory.py index ae6214d6c..c3277e9f0 100644 --- a/src/steamship/agents/schema/chathistory.py +++ b/src/steamship/agents/schema/chathistory.py @@ -227,3 +227,6 @@ def search(self, text: str, k=None) -> Task[SearchResults]: if self.embedding_index is None: raise SteamshipError("This ChatHistory has no embedding index and is not searchable.") return self.embedding_index.search(text, k) + + def is_searchable(self) -> bool: + return self.embedding_index is not None diff --git a/src/steamship/agents/schema/llm.py b/src/steamship/agents/schema/llm.py index d96ed6698..6a1ebf367 100644 --- a/src/steamship/agents/schema/llm.py +++ b/src/steamship/agents/schema/llm.py @@ -4,6 +4,7 @@ from pydantic.main import BaseModel from steamship import Block +from steamship.agents.schema.tool import Tool class LLM(BaseModel, ABC): @@ -15,3 +16,15 @@ class LLM(BaseModel, ABC): def complete(self, prompt: str, stop: Optional[str] = None) -> List[Block]: """Completes the provided prompt, stopping when the stop sequeunce is found.""" pass + + +# TODO(dougreid): should LLM and ConversationalLLM share a common parent? +class ChatLLM(BaseModel, ABC): + """ChatLLM wraps large language model-based backends that use a chat completion style interation. + + They may be used with Agents in Action selection, or for direct prompt completion.""" + + @abstractmethod + def chat(self, messages: List[Block], tools: Optional[List[Tool]]) -> List[Block]: + """Sends the set of chat messages to the LLM, returning the next part of the conversation""" + pass diff --git a/src/steamship/agents/service/agent_service.py b/src/steamship/agents/service/agent_service.py index bf348dd55..30c100144 100644 --- a/src/steamship/agents/service/agent_service.py +++ b/src/steamship/agents/service/agent_service.py @@ -43,6 +43,9 @@ def run_action(self, action: Action, context: AgentContext): context.completed_steps.append(action) def run_agent(self, agent: Agent, context: AgentContext): + # first, clear any prior agent steps from set of completed steps + # this will allow the agent to select tools/dispatch actions based on a new context + context.completed_steps = [] action = agent.next_action(context=context) while not isinstance(action, FinishAction): # TODO: Arrive at a solid design for the details of this structured log object diff --git a/src/steamship/agents/tools/image_generation/dalle.py b/src/steamship/agents/tools/image_generation/dalle.py index dddbf240d..6d3cc7802 100644 --- a/src/steamship/agents/tools/image_generation/dalle.py +++ b/src/steamship/agents/tools/image_generation/dalle.py @@ -12,7 +12,7 @@ class DalleTool(ImageGeneratorTool): name: str = "DalleTool" human_description: str = "Generates an image from text." agent_description = ( - "Used to generate images from text prompts. Only use if the user has asked directly for an " + "Used to generate still images from text prompts. Only use if the user has asked directly for an " "image. When using this tool, the input should be a plain text string that describes, " "in detail, the desired image." ) diff --git a/src/steamship/data/block.py b/src/steamship/data/block.py index d7acac8ac..ade840b81 100644 --- a/src/steamship/data/block.py +++ b/src/steamship/data/block.py @@ -293,10 +293,12 @@ def _one_time_set_tag(self, tag_kind: str, tag_name: str, string_value: str): self.tags.append(tag) - def as_llm_input(self) -> str: + def as_llm_input(self, exclude_block_wrapper: Optional[bool] = False) -> str: if self.is_text(): return self.text else: + if exclude_block_wrapper: + return f"{self.id}" return f"Block({self.id})" diff --git a/src/steamship/data/tags/tag_constants.py b/src/steamship/data/tags/tag_constants.py index 4342702a8..caf3f904e 100644 --- a/src/steamship/data/tags/tag_constants.py +++ b/src/steamship/data/tags/tag_constants.py @@ -253,6 +253,9 @@ class RoleTag(str, Enum): # This block's content was created by a non-human agent participating in the chat AGENT = "agent" + # This block was created by a Tool, as selected by an OpenAI Function call + FUNCTION = "function" + class ChatTag(str, Enum): """A set of `name` constants for Tags with a `kind` of `TagKind.CHAT`.""" diff --git a/tests/steamship_tests/agents/test_functions_based_agent.py b/tests/steamship_tests/agents/test_functions_based_agent.py new file mode 100644 index 000000000..8c108c74b --- /dev/null +++ b/tests/steamship_tests/agents/test_functions_based_agent.py @@ -0,0 +1,125 @@ +import pytest + +from steamship import Block, Steamship +from steamship.agents.functional import FunctionsBasedAgent +from steamship.agents.llms.openai import ChatOpenAI +from steamship.agents.schema import AgentContext, FinishAction +from steamship.agents.schema.message_selectors import MessageWindowMessageSelector +from steamship.agents.tools.image_generation import DalleTool +from steamship.agents.tools.search import SearchTool + + +@pytest.mark.usefixtures("client") +def test_functions_based_agent_with_no_tools_and_no_memory(client: Steamship): + agent = FunctionsBasedAgent(tools=[], llm=ChatOpenAI(client, temperature=0)) + + ctx_keys = {"id": "testing-foo"} + ctx = AgentContext.get_or_create(client=client, context_keys=ctx_keys, searchable=False) + ctx.chat_history.append_user_message("what should I eat for dinner?") + + action = agent.next_action(context=ctx) + assert isinstance(action, FinishAction) + + +@pytest.mark.usefixtures("client") +def test_functions_based_agent_single_tool_selection_with_no_memory(client: Steamship): + agent = FunctionsBasedAgent( + tools=[SearchTool(), DalleTool()], llm=ChatOpenAI(client, temperature=0) + ) + + ctx_keys = {"id": "testing-foo"} + ctx = AgentContext.get_or_create(client=client, context_keys=ctx_keys, searchable=False) + ctx.chat_history.append_user_message("who won the men's french open in 2023?") + + action = agent.next_action(context=ctx) + assert not isinstance(action, FinishAction) + assert action.tool.name == "SearchTool" + + +@pytest.mark.usefixtures("client") +def test_functions_based_agent_multimodal_tool_selection_with_no_memory(client: Steamship): + agent = FunctionsBasedAgent( + tools=[SearchTool(), DalleTool()], llm=ChatOpenAI(client, temperature=0) + ) + + ctx_keys = {"id": "testing-foo"} + ctx = AgentContext.get_or_create(client=client, context_keys=ctx_keys, searchable=False) + ctx.chat_history.append_user_message("paint a picture of a cat in a silly hat") + + action = agent.next_action(context=ctx) + assert not isinstance(action, FinishAction) + assert action.tool.name == "DalleTool" + + +@pytest.mark.usefixtures("client") +def test_functions_based_agent_no_appropriate_tool_with_no_memory(client: Steamship): + agent = FunctionsBasedAgent( + tools=[SearchTool(), DalleTool()], llm=ChatOpenAI(client, temperature=0) + ) + + ctx_keys = {"id": "testing-foo"} + ctx = AgentContext.get_or_create(client=client, context_keys=ctx_keys, searchable=False) + ctx.chat_history.append_user_message( + "make a 10 second movie about a hamburger that learns to talk" + ) + + action = agent.next_action(context=ctx) + assert isinstance(action, FinishAction) + + +@pytest.mark.usefixtures("client") +def test_functions_based_agent_tool_chaining_without_memory(client: Steamship): + agent = FunctionsBasedAgent( + tools=[SearchTool(), DalleTool()], + llm=ChatOpenAI(client, temperature=0), + ) + + ctx_keys = {"id": "testing-foo"} + ctx = AgentContext.get_or_create(client=client, context_keys=ctx_keys, searchable=False) + ctx.chat_history.append_user_message( + "search to find the first president of the United States and then paint a picture of them" + ) + + action = agent.next_action(context=ctx) + assert not isinstance(action, FinishAction) + assert action.tool.name == "SearchTool" + + action.output.append(Block(text="George Washington")) + ctx.completed_steps.append(action) + + second_action = agent.next_action(context=ctx) + assert not isinstance(second_action, FinishAction) + assert second_action.tool.name == "DalleTool" + + +@pytest.mark.usefixtures("client") +def test_functions_based_agent_tools_with_memory(client: Steamship): + agent = FunctionsBasedAgent( + tools=[SearchTool(), DalleTool()], + llm=ChatOpenAI(client, temperature=0), + message_selector=MessageWindowMessageSelector(k=2), + ) + + ctx_keys = {"id": "testing-foo"} + ctx = AgentContext.get_or_create(client=client, context_keys=ctx_keys) + ctx.chat_history.append_user_message("who is the president of Taiwan?") + + action = agent.next_action(context=ctx) + assert not isinstance(action, FinishAction) + assert action.tool.name == "SearchTool" + + ctx.chat_history.append_assistant_message("Tsai Ing-wen") + ctx.completed_steps = [] + + ctx.chat_history.append_user_message("draw them standing at a podium") + + second_action = agent.next_action(context=ctx) + assert not isinstance(second_action, FinishAction) + assert second_action.tool.name == "DalleTool" + + found = False + for block in second_action.input: + if "Tsai Ing-wen" in block.text: + found = True + + assert found