Skip to content

Commit

Permalink
feat(agents): add functions-based agent to SDK (#454)
Browse files Browse the repository at this point in the history
This is meant to provide a basic implementation of an agent that works
with tools using OpenAI's functions extension as the mechanism for tools
selection. This is an initial implementation that will likely need
refinement as use cases are explored.

---------

Co-authored-by: Douglas Reid <[email protected]>
  • Loading branch information
douglas-reid and Douglas Reid authored Jul 7, 2023
1 parent 9457ed8 commit f97e411
Show file tree
Hide file tree
Showing 16 changed files with 472 additions and 9 deletions.
75 changes: 75 additions & 0 deletions src/steamship/agents/examples/my_functions_based_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import uuid
from typing import List, Optional

from steamship import Block
from steamship.agents.functional import FunctionsBasedAgent
from steamship.agents.llms.openai import ChatOpenAI
from steamship.agents.schema import AgentContext
from steamship.agents.schema.context import Metadata
from steamship.agents.schema.message_selectors import MessageWindowMessageSelector
from steamship.agents.service.agent_service import AgentService
from steamship.agents.tools.image_generation import DalleTool
from steamship.agents.tools.search import SearchTool
from steamship.invocable import post
from steamship.utils.repl import AgentREPL


class MyFunctionsBasedAssistant(AgentService):
"""MyFunctionsBasedAssistant is an example AgentService that exposes a single test endpoint
for trying out Agent-based invocations. It is configured with two simple Tools
to provide an overview of the types of tasks it can accomplish (here, search
and image generation)."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._agent = FunctionsBasedAgent(
tools=[
SearchTool(),
DalleTool(),
],
llm=ChatOpenAI(self.client, temperature=0),
conversation_memory=MessageWindowMessageSelector(k=2),
)

@post("prompt")
def prompt(self, prompt: str, context_id: Optional[uuid.UUID] = None) -> str:
"""Run an agent with the provided text as the input."""

# AgentContexts serve to allow the AgentService to run agents
# with appropriate information about the desired tasking.
# Here, we use the passed in context (or a new context) for the prompt,
# and append the prompt to the message history stored in the context.
if not context_id:
context_id = uuid.uuid4()
context = AgentContext.get_or_create(self.client, {"id": f"{context_id}"})
context.chat_history.append_user_message(prompt)

# AgentServices provide an emit function hook to access the output of running
# agents and tools. The emit functions fire at after the supplied agent emits
# a "FinishAction".
#
# Here, we show one way of accessing the output in a synchronous fashion. An
# alternative way would be to access the final Action in the `context.completed_steps`
# after the call to `run_agent()`.
output = ""

def sync_emit(blocks: List[Block], meta: Metadata):
nonlocal output
block_text = "\n".join(
[b.text if b.is_text() else f"({b.mime_type}: {b.id})" for b in blocks]
)
output += block_text

context.emit_funcs.append(sync_emit)
self.run_agent(self._agent, context)
context.chat_history.append_assistant_message(output)
return output


if __name__ == "__main__":
# AgentREPL provides a mechanism for local execution of an AgentService method.
# This is used for simplified debugging as agents and tools are developed and
# added.
AgentREPL(MyFunctionsBasedAssistant, "prompt", agent_package_config={}).run(
context_id=uuid.uuid4()
)
7 changes: 7 additions & 0 deletions src/steamship/agents/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .functions_based import FunctionsBasedAgent
from .output_parser import FunctionsBasedOutputParser

__all__ = [
"FunctionsBasedAgent",
"FunctionsBasedOutputParser",
]
78 changes: 78 additions & 0 deletions src/steamship/agents/functional/functions_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List

from steamship import Block
from steamship.agents.functional.output_parser import FunctionsBasedOutputParser
from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, Tool
from steamship.data.tags.tag_constants import RoleTag


class FunctionsBasedAgent(ChatAgent):
"""Selects actions for AgentService based on OpenAI Function style LLM Prompting."""

PROMPT = """You are a helpful AI assistant.
NOTE: Some functions return images, video, and audio files. These multimedia files will be represented in messages as
UUIDs for Steamship Blocks. When responding directly to a user, you SHOULD print the Steamship Blocks for the images,
video, or audio as follows: `Block(UUID for the block)`.
Example response for a request that generated an image:
Here is the image you requested: Block(288A2CA1-4753-4298-9716-53C1E42B726B).
Only use the functions you have been provided with."""

def __init__(self, tools: List[Tool], llm: ChatLLM, **kwargs):
super().__init__(
output_parser=FunctionsBasedOutputParser(tools=tools), llm=llm, tools=tools, **kwargs
)

def next_action(self, context: AgentContext) -> Action:
messages = []

# get system messsage
system_message = Block(text=self.PROMPT)
system_message.set_chat_role(RoleTag.SYSTEM)
messages.append(system_message)

messages_from_memory = []
# get prior conversations
if context.chat_history.is_searchable():
messages_from_memory.extend(
context.chat_history.search(context.chat_history.last_user_message.text, k=3)
.wait()
.to_ranked_blocks()
)

# TODO(dougreid): we need a way to threshold message inclusion, especially for small contexts

# remove the actual prompt from the semantic search (it will be an exact match)
messages_from_memory = [
msg
for msg in messages_from_memory
if msg.id != context.chat_history.last_user_message.id
]

# get most recent context
messages_from_memory.extend(context.chat_history.select_messages(self.message_selector))

# de-dupe the messages from memory
ids = []
for msg in messages_from_memory:
if msg.id not in ids:
messages.append(msg)
ids.append(msg.id)

# TODO(dougreid): sort by dates? we SHOULD ensure ordering, given semantic search

# put the user prompt in the appropriate message location
# this should happen BEFORE any agent/assistant messages related to tool selection
messages.append(context.chat_history.last_user_message)

# get completed steps
actions = context.completed_steps
for action in actions:
messages.extend(action.to_chat_messages())

# call chat()
output_blocks = self.llm.chat(messages=messages, tools=self.tools)

return self.output_parser.parse(output_blocks[0].text, context)
89 changes: 89 additions & 0 deletions src/steamship/agents/functional/output_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json
import re
from typing import Dict, List, Optional

from steamship import Block, MimeTypes, Steamship
from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool
from steamship.data.tags.tag_constants import RoleTag


# TODO(dougreid): extract shared bits from this and the ReACT output parser into a utility?
class FunctionsBasedOutputParser(OutputParser):

tools_lookup_dict: Optional[Dict[str, Tool]] = None

def __init__(self, **kwargs):
tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", [])}
super().__init__(tools_lookup_dict=tools_lookup_dict, **kwargs)

def _extract_action_from_function_call(self, text: str, context: AgentContext) -> Action:
wrapper = json.loads(text)
fc = wrapper.get("function_call")
name = fc.get("name", "")
tool = self.tools_lookup_dict.get(name, None)
if tool is None:
raise RuntimeError(
f"Could not find tool from function call: `{name}`. Known tools: {self.tools_lookup_dict.keys()}"
)

input_blocks = []
arguments = fc.get("arguments", "")
args = json.loads(arguments)
# TODO(dougreid): validation and error handling?

if text := args.get("text"):
input_blocks.append(Block(text=text, mime_type=MimeTypes.TXT))
else:
uuid = args.get("uuid")
input_blocks.append(Block.get(context.client, id=uuid))

return Action(tool=tool, input=input_blocks, context=context)

@staticmethod
def _blocks_from_text(client: Steamship, text: str) -> List[Block]:
last_response = text.split("AI:")[-1].strip()

block_id_regex = r"(?:(?:\[|\()?Block)?\(?([A-F0-9]{8}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{12})\)?(?:(\]|\)))?"
remaining_text = last_response
result_blocks: List[Block] = []
while remaining_text is not None and len(remaining_text) > 0:
match = re.search(block_id_regex, remaining_text)
if match:
pre_block_text = FunctionsBasedOutputParser._remove_block_prefix(
candidate=remaining_text[0 : match.start()]
)
if len(pre_block_text) > 0:
result_blocks.append(Block(text=pre_block_text))
result_blocks.append(Block.get(client, _id=match.group(1)))
remaining_text = FunctionsBasedOutputParser._remove_block_suffix(
remaining_text[match.end() :]
)
else:
result_blocks.append(Block(text=remaining_text))
remaining_text = ""

return result_blocks

@staticmethod
def _remove_block_prefix(candidate: str) -> str:
removed = candidate
if removed.endswith("(Block") or removed.endswith("[Block"):
removed = removed[len("Block") + 1 :]
elif removed.endswith("Block"):
removed = removed[len("Block") :]
return removed

@staticmethod
def _remove_block_suffix(candidate: str) -> str:
removed = candidate
if removed.startswith(")") or removed.endswith("]"):
removed = removed[1:]
return removed

def parse(self, text: str, context: AgentContext) -> Action:
if "function_call" in text:
return self._extract_action_from_function_call(text, context)

finish_block = Block(text=text)
finish_block.set_chat_role(RoleTag.ASSISTANT)
return FinishAction(output=[finish_block], context=context)
34 changes: 32 additions & 2 deletions src/steamship/agents/llms/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional

from steamship import Block, PluginInstance, Steamship
from steamship.agents.schema import LLM
from steamship import Block, File, PluginInstance, Steamship
from steamship.agents.schema import LLM, ChatLLM, Tool

PLUGIN_HANDLE = "gpt-4"

Expand Down Expand Up @@ -38,3 +38,33 @@ def complete(self, prompt: str, stop: Optional[str] = None) -> List[Block]:
action_task = self.generator.generate(text=prompt, options=options)
action_task.wait()
return action_task.output.blocks


class ChatOpenAI(ChatLLM, OpenAI):
"""ChatLLM that uses Steamship's OpenAI plugin to generate chat completions."""

def __init__(
self, client, model_name: str = "gpt-4-0613", temperature: float = 0.4, *args, **kwargs
):
"""Create a new instance.
Valid model names are:
- gpt-4
- gpt-4-0613
"""
super().__init__(client=client, model_name=model_name, *args, **kwargs)

def chat(self, messages: List[Block], tools: Optional[List[Tool]]) -> List[Block]:
# TODO(dougreid): this feels icky. find a better way?
temp_file = File.create(client=self.client, blocks=messages)

options = {}
if len(tools) > 0:
functions = []
for tool in tools:
functions.append(tool.as_openai_function())
options["functions"] = functions

tool_selection_task = self.generator.generate(input_file_id=temp_file.id, options=options)
tool_selection_task.wait()
return tool_selection_task.output.blocks
2 changes: 1 addition & 1 deletion src/steamship/agents/react/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ReACTOutputParser(OutputParser):
tools_lookup_dict: Optional[Dict[str, Tool]] = None

def __init__(self, **kwargs):
tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", None)}
tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", [])}
super().__init__(tools_lookup_dict=tools_lookup_dict, **kwargs)

def parse(self, text: str, context: AgentContext) -> Action:
Expand Down
6 changes: 4 additions & 2 deletions src/steamship/agents/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from .action import Action, FinishAction
from .agent import Agent, LLMAgent
from .agent import Agent, ChatAgent, LLMAgent
from .chathistory import ChatHistory
from .context import AgentContext, EmitFunc, Metadata
from .llm import LLM
from .llm import LLM, ChatLLM
from .output_parser import OutputParser
from .tool import Tool

__all__ = [
"Action",
"Agent",
"AgentContext",
"ChatLLM",
"ChatAgent",
"EmitFunc",
"FinishAction",
"Metadata",
Expand Down
24 changes: 23 additions & 1 deletion src/steamship/agents/schema/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from pydantic import BaseModel

from steamship import Block, Task
from steamship import Block, Tag, Task
from steamship.agents.schema.tool import AgentContext, Tool
from steamship.data import TagKind
from steamship.data.tags.tag_constants import RoleTag


class Action(BaseModel):
Expand All @@ -21,6 +23,26 @@ class Action(BaseModel):
output: Optional[List[Block]] = []
"""Any direct output produced by the Tool."""

def to_chat_messages(self) -> List[Block]:
tags = [
Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION),
Tag(kind="name", name=self.tool.name),
]
blocks = []
for block in self.output:
# TODO(dougreid): should we revisit as_llm_input? we might need only the UUID...
blocks.append(
Block(
text=block.as_llm_input(exclude_block_wrapper=True),
tags=tags,
mime_type=block.mime_type,
)
)

# TODO(dougreid): revisit when have multiple output functions.
# Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation.
return blocks


class AgentTool(Tool):
def run(self, tool_input: List[Block], context: AgentContext) -> Union[List[Block], Task[Any]]:
Expand Down
13 changes: 12 additions & 1 deletion src/steamship/agents/schema/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from steamship import Block
from steamship.agents.schema.action import Action
from steamship.agents.schema.context import AgentContext
from steamship.agents.schema.llm import LLM
from steamship.agents.schema.llm import LLM, ChatLLM
from steamship.agents.schema.message_selectors import MessageSelector, NoMessages
from steamship.agents.schema.output_parser import OutputParser
from steamship.agents.schema.tool import Tool
Expand Down Expand Up @@ -57,4 +57,15 @@ def messages_to_prompt_history(messages: List[Block]) -> str:
as_strings.append(f"System: {block.text}")
elif role == RoleTag.AGENT:
as_strings.append(f"Agent: {block.text}")
elif role == RoleTag.FUNCTION:
as_strings.append(f"Function: {block.text}")
return "\n".join(as_strings)


class ChatAgent(Agent, ABC):
"""ChatAgents choose next actions for an AgentService based on chat-based interactions with an LLM."""

llm: ChatLLM

output_parser: OutputParser
"""Utility responsible for converting LLM output into Actions"""
3 changes: 3 additions & 0 deletions src/steamship/agents/schema/chathistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,6 @@ def search(self, text: str, k=None) -> Task[SearchResults]:
if self.embedding_index is None:
raise SteamshipError("This ChatHistory has no embedding index and is not searchable.")
return self.embedding_index.search(text, k)

def is_searchable(self) -> bool:
return self.embedding_index is not None
Loading

0 comments on commit f97e411

Please sign in to comment.