-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(agents): add functions-based agent to SDK (#454)
This is meant to provide a basic implementation of an agent that works with tools using OpenAI's functions extension as the mechanism for tools selection. This is an initial implementation that will likely need refinement as use cases are explored. --------- Co-authored-by: Douglas Reid <[email protected]>
- Loading branch information
1 parent
9457ed8
commit f97e411
Showing
16 changed files
with
472 additions
and
9 deletions.
There are no files selected for viewing
75 changes: 75 additions & 0 deletions
75
src/steamship/agents/examples/my_functions_based_assistant.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import uuid | ||
from typing import List, Optional | ||
|
||
from steamship import Block | ||
from steamship.agents.functional import FunctionsBasedAgent | ||
from steamship.agents.llms.openai import ChatOpenAI | ||
from steamship.agents.schema import AgentContext | ||
from steamship.agents.schema.context import Metadata | ||
from steamship.agents.schema.message_selectors import MessageWindowMessageSelector | ||
from steamship.agents.service.agent_service import AgentService | ||
from steamship.agents.tools.image_generation import DalleTool | ||
from steamship.agents.tools.search import SearchTool | ||
from steamship.invocable import post | ||
from steamship.utils.repl import AgentREPL | ||
|
||
|
||
class MyFunctionsBasedAssistant(AgentService): | ||
"""MyFunctionsBasedAssistant is an example AgentService that exposes a single test endpoint | ||
for trying out Agent-based invocations. It is configured with two simple Tools | ||
to provide an overview of the types of tasks it can accomplish (here, search | ||
and image generation).""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self._agent = FunctionsBasedAgent( | ||
tools=[ | ||
SearchTool(), | ||
DalleTool(), | ||
], | ||
llm=ChatOpenAI(self.client, temperature=0), | ||
conversation_memory=MessageWindowMessageSelector(k=2), | ||
) | ||
|
||
@post("prompt") | ||
def prompt(self, prompt: str, context_id: Optional[uuid.UUID] = None) -> str: | ||
"""Run an agent with the provided text as the input.""" | ||
|
||
# AgentContexts serve to allow the AgentService to run agents | ||
# with appropriate information about the desired tasking. | ||
# Here, we use the passed in context (or a new context) for the prompt, | ||
# and append the prompt to the message history stored in the context. | ||
if not context_id: | ||
context_id = uuid.uuid4() | ||
context = AgentContext.get_or_create(self.client, {"id": f"{context_id}"}) | ||
context.chat_history.append_user_message(prompt) | ||
|
||
# AgentServices provide an emit function hook to access the output of running | ||
# agents and tools. The emit functions fire at after the supplied agent emits | ||
# a "FinishAction". | ||
# | ||
# Here, we show one way of accessing the output in a synchronous fashion. An | ||
# alternative way would be to access the final Action in the `context.completed_steps` | ||
# after the call to `run_agent()`. | ||
output = "" | ||
|
||
def sync_emit(blocks: List[Block], meta: Metadata): | ||
nonlocal output | ||
block_text = "\n".join( | ||
[b.text if b.is_text() else f"({b.mime_type}: {b.id})" for b in blocks] | ||
) | ||
output += block_text | ||
|
||
context.emit_funcs.append(sync_emit) | ||
self.run_agent(self._agent, context) | ||
context.chat_history.append_assistant_message(output) | ||
return output | ||
|
||
|
||
if __name__ == "__main__": | ||
# AgentREPL provides a mechanism for local execution of an AgentService method. | ||
# This is used for simplified debugging as agents and tools are developed and | ||
# added. | ||
AgentREPL(MyFunctionsBasedAssistant, "prompt", agent_package_config={}).run( | ||
context_id=uuid.uuid4() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .functions_based import FunctionsBasedAgent | ||
from .output_parser import FunctionsBasedOutputParser | ||
|
||
__all__ = [ | ||
"FunctionsBasedAgent", | ||
"FunctionsBasedOutputParser", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from typing import List | ||
|
||
from steamship import Block | ||
from steamship.agents.functional.output_parser import FunctionsBasedOutputParser | ||
from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, Tool | ||
from steamship.data.tags.tag_constants import RoleTag | ||
|
||
|
||
class FunctionsBasedAgent(ChatAgent): | ||
"""Selects actions for AgentService based on OpenAI Function style LLM Prompting.""" | ||
|
||
PROMPT = """You are a helpful AI assistant. | ||
NOTE: Some functions return images, video, and audio files. These multimedia files will be represented in messages as | ||
UUIDs for Steamship Blocks. When responding directly to a user, you SHOULD print the Steamship Blocks for the images, | ||
video, or audio as follows: `Block(UUID for the block)`. | ||
Example response for a request that generated an image: | ||
Here is the image you requested: Block(288A2CA1-4753-4298-9716-53C1E42B726B). | ||
Only use the functions you have been provided with.""" | ||
|
||
def __init__(self, tools: List[Tool], llm: ChatLLM, **kwargs): | ||
super().__init__( | ||
output_parser=FunctionsBasedOutputParser(tools=tools), llm=llm, tools=tools, **kwargs | ||
) | ||
|
||
def next_action(self, context: AgentContext) -> Action: | ||
messages = [] | ||
|
||
# get system messsage | ||
system_message = Block(text=self.PROMPT) | ||
system_message.set_chat_role(RoleTag.SYSTEM) | ||
messages.append(system_message) | ||
|
||
messages_from_memory = [] | ||
# get prior conversations | ||
if context.chat_history.is_searchable(): | ||
messages_from_memory.extend( | ||
context.chat_history.search(context.chat_history.last_user_message.text, k=3) | ||
.wait() | ||
.to_ranked_blocks() | ||
) | ||
|
||
# TODO(dougreid): we need a way to threshold message inclusion, especially for small contexts | ||
|
||
# remove the actual prompt from the semantic search (it will be an exact match) | ||
messages_from_memory = [ | ||
msg | ||
for msg in messages_from_memory | ||
if msg.id != context.chat_history.last_user_message.id | ||
] | ||
|
||
# get most recent context | ||
messages_from_memory.extend(context.chat_history.select_messages(self.message_selector)) | ||
|
||
# de-dupe the messages from memory | ||
ids = [] | ||
for msg in messages_from_memory: | ||
if msg.id not in ids: | ||
messages.append(msg) | ||
ids.append(msg.id) | ||
|
||
# TODO(dougreid): sort by dates? we SHOULD ensure ordering, given semantic search | ||
|
||
# put the user prompt in the appropriate message location | ||
# this should happen BEFORE any agent/assistant messages related to tool selection | ||
messages.append(context.chat_history.last_user_message) | ||
|
||
# get completed steps | ||
actions = context.completed_steps | ||
for action in actions: | ||
messages.extend(action.to_chat_messages()) | ||
|
||
# call chat() | ||
output_blocks = self.llm.chat(messages=messages, tools=self.tools) | ||
|
||
return self.output_parser.parse(output_blocks[0].text, context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import json | ||
import re | ||
from typing import Dict, List, Optional | ||
|
||
from steamship import Block, MimeTypes, Steamship | ||
from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool | ||
from steamship.data.tags.tag_constants import RoleTag | ||
|
||
|
||
# TODO(dougreid): extract shared bits from this and the ReACT output parser into a utility? | ||
class FunctionsBasedOutputParser(OutputParser): | ||
|
||
tools_lookup_dict: Optional[Dict[str, Tool]] = None | ||
|
||
def __init__(self, **kwargs): | ||
tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", [])} | ||
super().__init__(tools_lookup_dict=tools_lookup_dict, **kwargs) | ||
|
||
def _extract_action_from_function_call(self, text: str, context: AgentContext) -> Action: | ||
wrapper = json.loads(text) | ||
fc = wrapper.get("function_call") | ||
name = fc.get("name", "") | ||
tool = self.tools_lookup_dict.get(name, None) | ||
if tool is None: | ||
raise RuntimeError( | ||
f"Could not find tool from function call: `{name}`. Known tools: {self.tools_lookup_dict.keys()}" | ||
) | ||
|
||
input_blocks = [] | ||
arguments = fc.get("arguments", "") | ||
args = json.loads(arguments) | ||
# TODO(dougreid): validation and error handling? | ||
|
||
if text := args.get("text"): | ||
input_blocks.append(Block(text=text, mime_type=MimeTypes.TXT)) | ||
else: | ||
uuid = args.get("uuid") | ||
input_blocks.append(Block.get(context.client, id=uuid)) | ||
|
||
return Action(tool=tool, input=input_blocks, context=context) | ||
|
||
@staticmethod | ||
def _blocks_from_text(client: Steamship, text: str) -> List[Block]: | ||
last_response = text.split("AI:")[-1].strip() | ||
|
||
block_id_regex = r"(?:(?:\[|\()?Block)?\(?([A-F0-9]{8}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{12})\)?(?:(\]|\)))?" | ||
remaining_text = last_response | ||
result_blocks: List[Block] = [] | ||
while remaining_text is not None and len(remaining_text) > 0: | ||
match = re.search(block_id_regex, remaining_text) | ||
if match: | ||
pre_block_text = FunctionsBasedOutputParser._remove_block_prefix( | ||
candidate=remaining_text[0 : match.start()] | ||
) | ||
if len(pre_block_text) > 0: | ||
result_blocks.append(Block(text=pre_block_text)) | ||
result_blocks.append(Block.get(client, _id=match.group(1))) | ||
remaining_text = FunctionsBasedOutputParser._remove_block_suffix( | ||
remaining_text[match.end() :] | ||
) | ||
else: | ||
result_blocks.append(Block(text=remaining_text)) | ||
remaining_text = "" | ||
|
||
return result_blocks | ||
|
||
@staticmethod | ||
def _remove_block_prefix(candidate: str) -> str: | ||
removed = candidate | ||
if removed.endswith("(Block") or removed.endswith("[Block"): | ||
removed = removed[len("Block") + 1 :] | ||
elif removed.endswith("Block"): | ||
removed = removed[len("Block") :] | ||
return removed | ||
|
||
@staticmethod | ||
def _remove_block_suffix(candidate: str) -> str: | ||
removed = candidate | ||
if removed.startswith(")") or removed.endswith("]"): | ||
removed = removed[1:] | ||
return removed | ||
|
||
def parse(self, text: str, context: AgentContext) -> Action: | ||
if "function_call" in text: | ||
return self._extract_action_from_function_call(text, context) | ||
|
||
finish_block = Block(text=text) | ||
finish_block.set_chat_role(RoleTag.ASSISTANT) | ||
return FinishAction(output=[finish_block], context=context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.