-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR: 1) Adds the `VectorSearchTool` base class 2) Refactors the `VectorSearchQATool` atop it 3) Introduces the `VectorSearchLearnerTool` tool, which learns facts 4) Introduces the `FactLearner` example agent which persistently learns facts and can answer questions about them from a vector store
- Loading branch information
Showing
5 changed files
with
179 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import uuid | ||
from typing import List | ||
|
||
from steamship import Block | ||
from steamship.agents.llms.openai import OpenAI | ||
from steamship.agents.react import ReACTAgent | ||
from steamship.agents.schema import AgentContext | ||
from steamship.agents.schema.context import Metadata | ||
from steamship.agents.service.agent_service import AgentService | ||
from steamship.agents.tools.question_answering import VectorSearchQATool | ||
from steamship.agents.tools.question_answering.vector_search_learner_tool import ( | ||
VectorSearchLearnerTool, | ||
) | ||
from steamship.agents.utils import with_llm | ||
from steamship.invocable import post | ||
from steamship.utils.repl import AgentREPL | ||
|
||
|
||
class FactLearner(AgentService): | ||
"""FactLearner is an example AgentService contains an Agent which: | ||
1) Learns facts to a vector store | ||
2) Can answer questions based on those facts""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self._agent = ReACTAgent( | ||
tools=[ | ||
VectorSearchLearnerTool(), | ||
VectorSearchQATool(), | ||
], | ||
llm=OpenAI(self.client), | ||
) | ||
|
||
@post("prompt") | ||
def prompt(self, prompt: str) -> str: | ||
"""Run an agent with the provided text as the input.""" | ||
|
||
# AgentContexts serve to allow the AgentService to run agents | ||
# with appropriate information about the desired tasking. | ||
# Here, we create a new context on each prompt, and append the | ||
# prompt to the message history stored in the context. | ||
context_id = uuid.uuid4() | ||
context = AgentContext.get_or_create(self.client, {"id": f"{context_id}"}) | ||
context.chat_history.append_user_message(prompt) | ||
|
||
# Add the LLM | ||
context = with_llm(context=context, llm=OpenAI(client=self.client)) | ||
|
||
# AgentServices provide an emit function hook to access the output of running | ||
# agents and tools. The emit functions fire at after the supplied agent emits | ||
# a "FinishAction". | ||
# | ||
# Here, we show one way of accessing the output in a synchronous fashion. An | ||
# alternative way would be to access the final Action in the `context.completed_steps` | ||
# after the call to `run_agent()`. | ||
output = "" | ||
|
||
def sync_emit(blocks: List[Block], meta: Metadata): | ||
nonlocal output | ||
block_text = "\n".join( | ||
[b.text if b.is_text() else f"({b.mime_type}: {b.id})" for b in blocks] | ||
) | ||
output += block_text | ||
|
||
context.emit_funcs.append(sync_emit) | ||
self.run_agent(self._agent, context) | ||
return output | ||
|
||
|
||
if __name__ == "__main__": | ||
# AgentREPL provides a mechanism for local execution of an AgentService method. | ||
# This is used for simplified debugging as agents and tools are developed and | ||
# added. | ||
AgentREPL(FactLearner, "prompt", agent_package_config={}).run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
from .prompt_database_question_answerer import PromptDatabaseQATool | ||
from .vector_search_learner_tool import VectorSearchLearnerTool | ||
from .vector_search_qa_tool import VectorSearchQATool | ||
|
||
__all__ = [ | ||
"PromptDatabaseQATool", | ||
"VectorSearchQATool", | ||
] | ||
__all__ = ["PromptDatabaseQATool", "VectorSearchQATool", "VectorSearchLearnerTool"] |
63 changes: 63 additions & 0 deletions
63
src/steamship/agents/tools/question_answering/vector_search_learner_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
"""Answers questions with the assistance of a VectorSearch plugin.""" | ||
from typing import Any, List, Optional, Union | ||
|
||
from steamship import Block, Tag, Task | ||
from steamship.agents.llms import OpenAI | ||
from steamship.agents.schema import AgentContext | ||
from steamship.agents.tools.question_answering.vector_search_tool import VectorSearchTool | ||
from steamship.agents.utils import with_llm | ||
from steamship.utils.repl import ToolREPL | ||
|
||
|
||
class VectorSearchLearnerTool(VectorSearchTool): | ||
"""Tool to answer questions with the assistance of a vector search plugin.""" | ||
|
||
name: str = "VectorSearchLearnerTool" | ||
human_description: str = "Learns a new fact and puts it in the Vector Database." | ||
agent_description: str = ( | ||
"Used to remember a fact. Only use this tool if someone asks to remember or learn something. ", | ||
"The input is a fact to learn. ", | ||
"The output is a confirmation that the fact has been learned.", | ||
) | ||
|
||
def learn_sentence(self, sentence: str, context: AgentContext, metadata: Optional[dict] = None): | ||
"""Learns a sigle sentence-sized piece of text. | ||
GUIDANCE: No more than about a short sentence is a useful unit of embedding search & lookup. | ||
""" | ||
index = self.get_embedding_index(context.client) | ||
tag = Tag(text=sentence, metadata=metadata) | ||
index.insert(tags=[tag]) | ||
|
||
def run(self, tool_input: List[Block], context: AgentContext) -> Union[List[Block], Task[Any]]: | ||
"""Learns a fact with the assistance of an Embedding Index plugin. | ||
Inputs | ||
------ | ||
tool_input: List[Block] | ||
A list of blocks to be rewritten if text-containing. | ||
context: AgentContext | ||
The active AgentContext. | ||
Output | ||
------ | ||
output: List[Blocks] | ||
A lit of blocks containing the answers. | ||
""" | ||
|
||
output = [] | ||
for input_block in tool_input: | ||
if input_block.is_text(): | ||
self.learn_sentence(input_block.text, context=context) | ||
output.append(Block(text=f"I'll remember: {input_block.text}")) | ||
return output | ||
|
||
|
||
if __name__ == "__main__": | ||
tool = VectorSearchLearnerTool() | ||
repl = ToolREPL(tool) | ||
|
||
with repl.temporary_workspace() as client: | ||
repl.run_with_client( | ||
client, context=with_llm(context=AgentContext(), llm=OpenAI(client=client)) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
src/steamship/agents/tools/question_answering/vector_search_tool.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Answers questions with the assistance of a VectorSearch plugin.""" | ||
from abc import ABC | ||
from typing import Optional, cast | ||
|
||
from steamship import Steamship | ||
from steamship.agents.schema import Tool | ||
from steamship.data.plugin.index_plugin_instance import EmbeddingIndexPluginInstance | ||
|
||
|
||
class VectorSearchTool(Tool, ABC): | ||
"""Abstract Base Class that provides helper data for a tool that uses Vector Search.""" | ||
|
||
embedding_index_handle: Optional[str] = "embedding-index" | ||
embedding_index_version: Optional[str] = None | ||
embedding_index_config: Optional[dict] = { | ||
"embedder": { | ||
"plugin_handle": "openai-embedder", | ||
"plugin_instance-handle": "text-embedding-ada-002", | ||
"fetch_if_exists": True, | ||
"config": {"model": "text-embedding-ada-002", "dimensionality": 1536}, | ||
} | ||
} | ||
embedding_index_instance_handle: str = "default-embedding-index" | ||
|
||
def get_embedding_index(self, client: Steamship) -> EmbeddingIndexPluginInstance: | ||
index = client.use_plugin( | ||
plugin_handle=self.embedding_index_handle, | ||
instance_handle=self.embedding_index_instance_handle, | ||
config=self.embedding_index_config, | ||
fetch_if_exists=True, | ||
) | ||
return cast(EmbeddingIndexPluginInstance, index) |