Skip to content

Commit

Permalink
Merge pull request emrgnt-cmplxty#26 from EmergentAGI/feature/refacto…
Browse files Browse the repository at this point in the history
…r-tool-layout

Refactor agent tool layout
  • Loading branch information
emrgnt-cmplxty authored Jun 17, 2023
2 parents 0db821b + 03b242c commit b4bebcc
Show file tree
Hide file tree
Showing 28 changed files with 257 additions and 95 deletions.
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
OPENAI_API_KEY=your_openai_api_key
TASK_DB_PATH=your_task_db_path
CONVERSATION_DB_PATH=your_conversation_db_path
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@ This project is inspired by the theory that code is essentially a form of memory

## Installation and Usage

---

### Initial Setup

Follow these steps to setup the Automata environment

```bash
# Clone the repository
git clone [email protected]:EmergentAGI/AutomataDocs.git
cd AutomataDocs
git clone [email protected]:EmergentAGI/Automata.git
cd Automata

# Create the local environment
python3 -m venv local_env
Expand All @@ -31,7 +29,6 @@ pre-commit install
cp .env.example .env
MY_API_KEY=your_openai_api_key_here
sed -i "s/your_openai_api_key/${MY_API_KEY}/" .env
sed -i "s/your_task_db_path/$PWD/tasks.sqlite3/" .env
sed -i "s/your_openai_api_key/$PWD/conversations.sqlite3/" .env
```

Expand Down
2 changes: 1 addition & 1 deletion automata/config/agent_config_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
AutomataInstructionPayload,
InstructionConfigVersion,
)
from automata.core.agent.tool_management.tool_management_utils import build_llm_toolkits
from automata.core.agent.tools.tool_utils import build_llm_toolkits
from automata.core.base.tool import Toolkit, ToolkitType


Expand Down
4 changes: 1 addition & 3 deletions automata/config/config_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def setup(self):
@classmethod
def load_automata_yaml_config(cls, config_name: AgentConfigName) -> Dict:
"""Loads the automata.yaml config file."""
from automata.core.agent.tool_management.tool_management_utils import (
build_llm_toolkits,
)
from automata.core.agent.tools.tool_utils import build_llm_toolkits

file_dir_path = os.path.dirname(os.path.abspath(__file__))
config_abs_path = os.path.join(
Expand Down
Binary file modified automata/config/symbol/index.scip
Binary file not shown.
2 changes: 1 addition & 1 deletion automata/config/symbol/symbol_code_embedding.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion automata/config/symbol/symbol_doc_embedding_l2.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion automata/config/symbol/symbol_doc_embedding_l3.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions automata/core/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def iter_step(self) -> Optional[Tuple[OpenAIChatMessage, OpenAIChatMessage]]:
observations = self._generate_observations(response_text)

completion_message = retrieve_completion_message(observations)
print("completion_message = ", completion_message)
if completion_message is not None:
self.completed = True
self._save_message(
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any


class BaseToolManager(ABC):
class AgentTool(ABC):
def __init__(self, **kwargs):
pass

@abstractmethod
def build_tools(self) -> Any:
def build(self) -> Any:
pass
87 changes: 87 additions & 0 deletions automata/core/agent/tools/context_oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import textwrap
from typing import List

from automata.core.agent.tools.agent_tool import AgentTool
from automata.core.base.tool import Tool
from automata.core.embedding.symbol_similarity import SymbolSimilarity
from automata.core.symbol.search.symbol_search import SymbolSearch


class ContextOracle(AgentTool):
"""
ContextOracleManager is responsible for managing context oracle tools.
"""

def __init__(
self,
symbol_search: SymbolSearch,
symbol_doc_similarity: SymbolSimilarity,
):
"""
Initializes ContextOracleManager with given SymbolSearch, SymbolSimilarity,
optional ContextTool, and post processing function.
Args:
symbol_search (SymbolSearch): The symbol search object.
symbol_doc_similarity (SymbolSimilarity): The symbol doc similarity object.
"""
self.symbol_search = symbol_search
self.symbol_doc_similarity = symbol_doc_similarity

def build(self) -> List[Tool]:
"""
Builds all the context tools.
Returns:
List[Tool]: The list of built tools.
"""
tools = [
Tool(
name="context-oracle",
func=self._context_oracle_processor,
description=textwrap.dedent(
"""
This tool combines SymbolSearch and SymbolSimilarity to create contexts.
Given a query, it uses SymbolSimilarity calculate the similarity between each symbol's documentation and the query returns the most similar document.
Then, it leverages SymbolSearch to combine Semantic Search with PageRank to find the most relevant symbols to the query.
The overview documentation of these symbols is then concated to the result of the SymbolSimilarity query to create a context.
For instance, if a query reads 'Tell me about SymbolRank', it will find the most similar document to this query from the embeddings,
which in this case would be the documentation for the SymbolRank class.
Then, it will use SymbolSearch to fetch some of the most relevant symbols which would be 'Symbol', 'SymbolSearch', 'SymbolGraph', etc.
This results in a comprehensive context for the query.
"""
),
return_direct=True,
)
]
return tools

def _context_oracle_processor(self, query: str) -> str:
"""
The context oracle tool processor function.
Args:
query (str): The query string.
Returns:
str: The processed result.
"""
doc_output = self.symbol_doc_similarity.get_query_similarity_dict(query)
rank_output = self.symbol_search.symbol_rank_search(query)

result = self.symbol_doc_similarity.embedding_handler.get_embedding(
sorted(doc_output.items(), key=lambda x: -x[1])[0][0]
).embedding_source

for symbol, _ in rank_output[0:10]:
try:
print("Processing symbol = ", symbol)
result += "%s\n" % symbol.dotpath
result += self.symbol_doc_similarity.embedding_handler.get_embedding(
symbol
).summary
except Exception as e:
print("Exception = ", e)
continue
return result
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging
from typing import List, Optional

from automata.core.agent.tool_management.base_tool_manager import BaseToolManager
from automata.core.agent.tools.agent_tool import AgentTool
from automata.core.base.tool import Tool
from automata.core.coding.py_coding.py_utils import NO_RESULT_FOUND_STR
from automata.core.coding.py_coding.retriever import PyCodeRetriever

logger = logging.getLogger(__name__)


class PyCodeRetrieverToolManager(BaseToolManager):
class PyCodeRetrieverTool(AgentTool):
"""
PyCodeRetrieverToolManager
PyCodeRetrieverTool
A class for interacting with the PythonIndexer API, which provides functionality to read
the code state of a of local Python files.
"""
Expand All @@ -32,7 +32,7 @@ def __init__(self, **kwargs):
self.verbose = kwargs.get("verbose") or False
self.stream = kwargs.get("stream") or True

def build_tools(self) -> List[Tool]:
def build(self) -> List[Tool]:
"""Builds a list of Tool objects for interacting with PythonIndexer."""
tools = [
Tool(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from automata.core.base.tool import Tool
from automata.core.coding.py_coding.writer import PyCodeWriter

from .base_tool_manager import BaseToolManager
from .agent_tool import AgentTool

logger = logging.getLogger(__name__)


class PyCodeWriterToolManager(BaseToolManager):
class PyCodeWriterTool(AgentTool):
"""
PyCodeWriterToolManager
PyCodeWriterTool
A class for interacting with the PythonWriter API, which provides functionality to modify
the code state of a given directory of Python files.
"""
Expand All @@ -22,7 +22,7 @@ def __init__(
**kwargs,
):
"""
Initializes a PyCodeWriterToolManager object with the given inputs.
Initializes a PyCodeWriterTool object with the given inputs.
Args:
- writer (PythonWriter): A PythonWriter object representing the code writer to work with.
Expand All @@ -41,7 +41,7 @@ def __init__(
self.temperature = kwargs.get("temperature", 0.7)
self.do_write = kwargs.get("do_write", True)

def build_tools(self) -> List[Tool]:
def build(self) -> List[Tool]:
"""Builds a list of Tool object for interacting with PythonWriter."""
tools = [
Tool(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SearchTool(Enum):
EXACT_SEARCH = "exact-search"


class SymbolSearchToolManager:
class SymbolSearchTool:
def __init__(
self,
symbol_search: SymbolSearch,
Expand Down Expand Up @@ -50,13 +50,13 @@ def build_tool(self, tool_type: SearchTool) -> Tool:
)
raise ValueError(f"Invalid tool type: {tool_type}")

def build_tools(self) -> List[Tool]:
def build(self) -> List[Tool]:
return [self.build_tool(tool_type) for tool_type in self.search_tools]

def process_query(
self, tool_type: SearchTool, query: str
) -> Union[SymbolReferencesResult, SymbolRankResult, SourceCodeResult, ExactSearchResult,]:
tools_dict = {tool.name: tool.func for tool in self.build_tools()}
tools_dict = {tool.name: tool.func for tool in self.build()}
result = tools_dict[tool_type.value](query)
if self.post_processing:
result = self.post_processing(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict, List

from automata.config.config_types import ConfigCategory
from automata.core.agent.tool_management.base_tool_manager import BaseToolManager
from automata.core.agent.tools.agent_tool import AgentTool
from automata.core.base.tool import Tool, Toolkit, ToolkitType
from automata.core.coding.py_coding.retriever import PyCodeRetriever
from automata.core.coding.py_coding.writer import PyCodeWriter
Expand All @@ -28,31 +28,29 @@ class ToolManagerFactory:
_retriever_instance = None # store instance of PyCodeRetriever

@staticmethod
def create_tool_manager(toolkit_type: ToolkitType) -> BaseToolManager:
def create_tool_manager(toolkit_type: ToolkitType) -> AgentTool:
if toolkit_type == ToolkitType.PYTHON_RETRIEVER:
if ToolManagerFactory._retriever_instance is None:
ToolManagerFactory._retriever_instance = PyCodeRetriever()

PyCodeRetrieverToolManager = importlib.import_module(
"automata.core.agent.tool_management.python_code_retriever_tool_manager"
).PyCodeRetrieverToolManager
return PyCodeRetrieverToolManager(
python_retriever=ToolManagerFactory._retriever_instance
)
PyCodeRetrieverTool = importlib.import_module(
"automata.core.agent.tools.py_code_retriever"
).PyCodeRetrieverTool
return PyCodeRetrieverTool(python_retriever=ToolManagerFactory._retriever_instance)
elif toolkit_type == ToolkitType.PYTHON_WRITER:
if ToolManagerFactory._retriever_instance is None:
ToolManagerFactory._retriever_instance = PyCodeRetriever()

PyCodeWriterToolManager = importlib.import_module(
"automata.core.agent.tool_management.python_code_writer_tool_manager"
).PyCodeWriterToolManager
return PyCodeWriterToolManager(
PyCodeWriterTool = importlib.import_module(
"automata.core.agent.tools.py_code_writer"
).PyCodeWriterTool
return PyCodeWriterTool(
python_writer=PyCodeWriter(ToolManagerFactory._retriever_instance)
)
elif toolkit_type == ToolkitType.SYMBOL_SEARCHER:
SymbolSearchToolManager = importlib.import_module(
"automata.core.agent.tool_management.symbol_search_manager"
).SymbolSearchToolManager
SymbolSearchTool = importlib.import_module(
"automata.core.agent.tools.symbol_search_manager"
).SymbolSearchTool

graph = SymbolGraph()
subgraph = graph.get_rankable_symbol_subgraph()
Expand All @@ -72,7 +70,7 @@ def create_tool_manager(toolkit_type: ToolkitType) -> BaseToolManager:
symbol_rank_config=SymbolRankConfig(),
code_subgraph=subgraph,
)
return SymbolSearchToolManager(symbol_search=symbol_search)
return SymbolSearchTool(symbol_search=symbol_search)
else:
raise ValueError("Unknown toolkit type: %s" % toolkit_type)

Expand All @@ -81,21 +79,21 @@ class ToolkitBuilder:
def __init__(self, **kwargs):
"""Initializes a ToolkitBuilder object with the given inputs."""

self._tool_management: Dict[ToolkitType, BaseToolManager] = {}
self._tool_management: Dict[ToolkitType, AgentTool] = {}

def _build_toolkit(self, toolkit_type: ToolkitType) -> Toolkit:
"""Builds a toolkit of the given type."""
tool_manager = ToolManagerFactory.create_tool_manager(toolkit_type)

if not tool_manager:
raise ValueError("Unknown toolkit type: %s" % toolkit_type)
tools = ToolkitBuilder.build_tools(tool_manager)
tools = ToolkitBuilder.build(tool_manager)
return Toolkit(tools)

@staticmethod
def build_tools(tool_manager: BaseToolManager) -> List[Tool]:
def build(tool_manager: AgentTool) -> List[Tool]:
"""Build tools from a tool manager."""
return tool_manager.build_tools()
return tool_manager.build()


def build_llm_toolkits(tool_list: List[str], **kwargs) -> Dict[ToolkitType, Toolkit]:
Expand Down
10 changes: 0 additions & 10 deletions automata/core/embedding/code_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from typing import List

from automata.core.database.vector import VectorDatabaseProvider
from automata.core.symbol.symbol_types import Symbol, SymbolCodeEmbedding
Expand Down Expand Up @@ -98,12 +97,3 @@ def update_existing_embedding(self, source_code: str, symbol: Symbol):
else:
logger.debug("Passing for %s", symbol)
pass

def get_all_supported_symbols(self) -> List[Symbol]:
"""
Get all the symbols in the database.
Returns:
List[Symbol]: List of all the symbols in the database
"""
return self.embedding_db.get_all_symbols()
12 changes: 11 additions & 1 deletion automata/core/embedding/embedding_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import logging
from enum import Enum
from typing import Any, Dict
from typing import Any, Dict, List

import numpy as np
import openai
Expand Down Expand Up @@ -53,6 +53,7 @@ def build_embedding(self, symbol_source: str) -> np.ndarray:
class SymbolEmbeddingHandler(abc.ABC):
"""An abstract class to handle the embedding of symbols"""

@abc.abstractmethod
def __init__(
self,
embedding_db: VectorDatabaseProvider,
Expand All @@ -72,6 +73,15 @@ def update_embedding(self, symbol: Symbol):
"""An abstract method to update the embedding for a symbol"""
pass

def get_all_supported_symbols(self) -> List[Symbol]:
"""
Get all the symbols in the database.
Returns:
List[Symbol]: List of all the symbols in the database
"""
return self.embedding_db.get_all_symbols()


class EmbeddingSimilarity(abc.ABC):
@abc.abstractmethod
Expand Down
Loading

0 comments on commit b4bebcc

Please sign in to comment.