Skip to content

Commit

Permalink
Vector Fact Learner (#398)
Browse files Browse the repository at this point in the history
This PR:

1) Adds the `VectorSearchTool` base class
2) Refactors the `VectorSearchQATool` atop it
3) Introduces the `VectorSearchLearnerTool` tool, which learns facts
4) Introduces the `FactLearner` example agent which persistently learns
facts and can answer questions about them from a vector store
  • Loading branch information
eob authored Jun 6, 2023
1 parent 891360a commit 1091ffa
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 31 deletions.
75 changes: 75 additions & 0 deletions src/steamship/agents/examples/fact_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import uuid
from typing import List

from steamship import Block
from steamship.agents.llms.openai import OpenAI
from steamship.agents.react import ReACTAgent
from steamship.agents.schema import AgentContext
from steamship.agents.schema.context import Metadata
from steamship.agents.service.agent_service import AgentService
from steamship.agents.tools.question_answering import VectorSearchQATool
from steamship.agents.tools.question_answering.vector_search_learner_tool import (
VectorSearchLearnerTool,
)
from steamship.agents.utils import with_llm
from steamship.invocable import post
from steamship.utils.repl import AgentREPL


class FactLearner(AgentService):
"""FactLearner is an example AgentService contains an Agent which:
1) Learns facts to a vector store
2) Can answer questions based on those facts"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._agent = ReACTAgent(
tools=[
VectorSearchLearnerTool(),
VectorSearchQATool(),
],
llm=OpenAI(self.client),
)

@post("prompt")
def prompt(self, prompt: str) -> str:
"""Run an agent with the provided text as the input."""

# AgentContexts serve to allow the AgentService to run agents
# with appropriate information about the desired tasking.
# Here, we create a new context on each prompt, and append the
# prompt to the message history stored in the context.
context_id = uuid.uuid4()
context = AgentContext.get_or_create(self.client, {"id": f"{context_id}"})
context.chat_history.append_user_message(prompt)

# Add the LLM
context = with_llm(context=context, llm=OpenAI(client=self.client))

# AgentServices provide an emit function hook to access the output of running
# agents and tools. The emit functions fire at after the supplied agent emits
# a "FinishAction".
#
# Here, we show one way of accessing the output in a synchronous fashion. An
# alternative way would be to access the final Action in the `context.completed_steps`
# after the call to `run_agent()`.
output = ""

def sync_emit(blocks: List[Block], meta: Metadata):
nonlocal output
block_text = "\n".join(
[b.text if b.is_text() else f"({b.mime_type}: {b.id})" for b in blocks]
)
output += block_text

context.emit_funcs.append(sync_emit)
self.run_agent(self._agent, context)
return output


if __name__ == "__main__":
# AgentREPL provides a mechanism for local execution of an AgentService method.
# This is used for simplified debugging as agents and tools are developed and
# added.
AgentREPL(FactLearner, "prompt", agent_package_config={}).run()
6 changes: 2 additions & 4 deletions src/steamship/agents/tools/question_answering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .prompt_database_question_answerer import PromptDatabaseQATool
from .vector_search_learner_tool import VectorSearchLearnerTool
from .vector_search_qa_tool import VectorSearchQATool

__all__ = [
"PromptDatabaseQATool",
"VectorSearchQATool",
]
__all__ = ["PromptDatabaseQATool", "VectorSearchQATool", "VectorSearchLearnerTool"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Answers questions with the assistance of a VectorSearch plugin."""
from typing import Any, List, Optional, Union

from steamship import Block, Tag, Task
from steamship.agents.llms import OpenAI
from steamship.agents.schema import AgentContext
from steamship.agents.tools.question_answering.vector_search_tool import VectorSearchTool
from steamship.agents.utils import with_llm
from steamship.utils.repl import ToolREPL


class VectorSearchLearnerTool(VectorSearchTool):
"""Tool to answer questions with the assistance of a vector search plugin."""

name: str = "VectorSearchLearnerTool"
human_description: str = "Learns a new fact and puts it in the Vector Database."
agent_description: str = (
"Used to remember a fact. Only use this tool if someone asks to remember or learn something. ",
"The input is a fact to learn. ",
"The output is a confirmation that the fact has been learned.",
)

def learn_sentence(self, sentence: str, context: AgentContext, metadata: Optional[dict] = None):
"""Learns a sigle sentence-sized piece of text.
GUIDANCE: No more than about a short sentence is a useful unit of embedding search & lookup.
"""
index = self.get_embedding_index(context.client)
tag = Tag(text=sentence, metadata=metadata)
index.insert(tags=[tag])

def run(self, tool_input: List[Block], context: AgentContext) -> Union[List[Block], Task[Any]]:
"""Learns a fact with the assistance of an Embedding Index plugin.
Inputs
------
tool_input: List[Block]
A list of blocks to be rewritten if text-containing.
context: AgentContext
The active AgentContext.
Output
------
output: List[Blocks]
A lit of blocks containing the answers.
"""

output = []
for input_block in tool_input:
if input_block.is_text():
self.learn_sentence(input_block.text, context=context)
output.append(Block(text=f"I'll remember: {input_block.text}"))
return output


if __name__ == "__main__":
tool = VectorSearchLearnerTool()
repl = ToolREPL(tool)

with repl.temporary_workspace() as client:
repl.run_with_client(
client, context=with_llm(context=AgentContext(), llm=OpenAI(client=client))
)
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Answers questions with the assistance of a VectorSearch plugin."""
from typing import Any, List, Optional, Union, cast
from typing import Any, List, Optional, Union

from steamship import Block, Steamship, Tag, Task
from steamship import Block, Tag, Task
from steamship.agents.llms import OpenAI
from steamship.agents.schema import AgentContext, Tool
from steamship.agents.schema import AgentContext
from steamship.agents.tools.question_answering.vector_search_tool import VectorSearchTool
from steamship.agents.utils import get_llm, with_llm
from steamship.data.plugin.index_plugin_instance import EmbeddingIndexPluginInstance
from steamship.utils.repl import ToolREPL

DEFAULT_QUESTION_ANSWERING_PROMPT = (
Expand All @@ -23,7 +23,7 @@
DEFAULT_SOURCE_DOCUMENT_PROMPT = "Source Document: {text}"


class VectorSearchQATool(Tool):
class VectorSearchQATool(VectorSearchTool):
"""Tool to answer questions with the assistance of a vector search plugin."""

name: str = "VectorSearchQATool"
Expand All @@ -33,29 +33,9 @@ class VectorSearchQATool(Tool):
"The input should be a plain text question. ",
"The output is a plain text answer",
)
embedding_index_handle: Optional[str] = "embedding-index"
embedding_index_version: Optional[str] = None
question_answering_prompt: Optional[str] = DEFAULT_QUESTION_ANSWERING_PROMPT
source_document_prompt: Optional[str] = DEFAULT_SOURCE_DOCUMENT_PROMPT
embedding_index_config: Optional[dict] = {
"embedder": {
"plugin_handle": "openai-embedder",
"plugin_instance-handle": "text-embedding-ada-002",
"fetch_if_exists": True,
"config": {"model": "text-embedding-ada-002", "dimensionality": 1536},
}
}
load_docs_count: int = 2
embedding_index_instance_handle: str = "default-embedding-index"

def get_embedding_index(self, client: Steamship) -> EmbeddingIndexPluginInstance:
index = client.use_plugin(
plugin_handle=self.embedding_index_handle,
instance_handle=self.embedding_index_instance_handle,
config=self.embedding_index_config,
fetch_if_exists=True,
)
return cast(EmbeddingIndexPluginInstance, index)

def answer_question(self, question: str, context: AgentContext) -> List[Block]:
index = self.get_embedding_index(context.client)
Expand All @@ -80,9 +60,9 @@ def run(self, tool_input: List[Block], context: AgentContext) -> Union[List[Bloc
Inputs
------
input: List[Block]
tool_input: List[Block]
A list of blocks to be rewritten if text-containing.
memory: AgentContext
context: AgentContext
The active AgentContext.
Output
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Answers questions with the assistance of a VectorSearch plugin."""
from abc import ABC
from typing import Optional, cast

from steamship import Steamship
from steamship.agents.schema import Tool
from steamship.data.plugin.index_plugin_instance import EmbeddingIndexPluginInstance


class VectorSearchTool(Tool, ABC):
"""Abstract Base Class that provides helper data for a tool that uses Vector Search."""

embedding_index_handle: Optional[str] = "embedding-index"
embedding_index_version: Optional[str] = None
embedding_index_config: Optional[dict] = {
"embedder": {
"plugin_handle": "openai-embedder",
"plugin_instance-handle": "text-embedding-ada-002",
"fetch_if_exists": True,
"config": {"model": "text-embedding-ada-002", "dimensionality": 1536},
}
}
embedding_index_instance_handle: str = "default-embedding-index"

def get_embedding_index(self, client: Steamship) -> EmbeddingIndexPluginInstance:
index = client.use_plugin(
plugin_handle=self.embedding_index_handle,
instance_handle=self.embedding_index_instance_handle,
config=self.embedding_index_config,
fetch_if_exists=True,
)
return cast(EmbeddingIndexPluginInstance, index)

0 comments on commit 1091ffa

Please sign in to comment.