Skip to content

Commit

Permalink
Merge pull request emrgnt-cmplxty#10 from EmergentAGI/feature/udpate-…
Browse files Browse the repository at this point in the history
…run-embedding

update run_embedding
  • Loading branch information
emrgnt-cmplxty authored Jun 13, 2023
2 parents f07abbb + 8311eeb commit 96dbd3c
Show file tree
Hide file tree
Showing 24 changed files with 95 additions and 102 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
**/sample_modules/**
.env
local
notebooks/
# Local Cruft
automata_docs.egg-info/
**/**/.DS_Store
Expand Down
57 changes: 27 additions & 30 deletions automata_docs/cli/scripts/run_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from tqdm import tqdm

from automata_docs.configs.config_enums import ConfigCategory
from automata_docs.core.embedding.symbol_embedding import SymbolEmbeddingMap
from automata_docs.core.database.vector import JSONVectorDatabase
from automata_docs.core.embedding.symbol_embedding import SymbolCodeEmbeddingHandler
from automata_docs.core.symbol.symbol_graph import SymbolGraph
from automata_docs.core.symbol.symbol_utils import get_rankable_symbols
from automata_docs.core.utils import config_path
Expand All @@ -17,37 +18,27 @@ def main(*args, **kwargs):
"""
Update the distance embedding based on the symbols present in the system.
"""
scip_path = os.path.join(config_path(), ConfigCategory.SYMBOLS.value, "index.scip")
scip_path = os.path.join(
config_path(), ConfigCategory.SYMBOLS.value, kwargs.get("index_file", "index.scip")
)
embedding_path = os.path.join(
config_path(), ConfigCategory.SYMBOLS.value, "symbol_embedding.json"
config_path(),
ConfigCategory.SYMBOLS.value,
kwargs.get("embedding_file", "symbol_embedding.json"),
)

symbol_graph = SymbolGraph(scip_path)

if kwargs.get("update_embedding_map") or kwargs.get("build_new_embedding_map"):
all_defined_symbols = symbol_graph.get_all_available_symbols()
filtered_symbols = get_rankable_symbols(all_defined_symbols)
chunks = [
filtered_symbols[i : i + CHUNK_SIZE]
for i in range(0, len(filtered_symbols), CHUNK_SIZE)
]

for chunk in tqdm(chunks):
if kwargs.get("build_new_embedding_map") and chunk == chunks[0]:
symbol_embedding = SymbolEmbeddingMap(
all_defined_symbols=chunk,
build_new_embedding_map=True,
embedding_path=embedding_path,
)
else:
symbol_embedding = SymbolEmbeddingMap(
load_embedding_map=True,
embedding_path=embedding_path,
)
symbol_embedding.update_embeddings(chunk)

symbol_embedding.save(embedding_path, overwrite=True)
return "Success"
all_defined_symbols = symbol_graph.get_all_available_symbols()
filtered_symbols = sorted(get_rankable_symbols(all_defined_symbols), key=lambda x: x.path)

embedding_db = JSONVectorDatabase(embedding_path)
embedding_handler = SymbolCodeEmbeddingHandler(embedding_db)

for symbol in tqdm(filtered_symbols):
embedding_handler.update_embedding(symbol)
embedding_db.save()
return "Success"


if __name__ == "__main__":
Expand All @@ -61,9 +52,15 @@ def main(*args, **kwargs):

# Add the arguments
parser.add_argument(
"--update_embedding_map",
action="store_true",
help="Flag to update the embedding map.",
"--index_file",
default="index.scip",
help="Which index file to use for the embedding modifications.",
)
# Add the arguments
parser.add_argument(
"--embedding_file",
default="symbol_embedding.json",
help="Which embedding file to save to.",
)

parser.add_argument(
Expand Down
Binary file modified automata_docs/configs/symbols/index.scip
Binary file not shown.
1 change: 0 additions & 1 deletion automata_docs/configs/symbols/symbol_documentation.json

This file was deleted.

2 changes: 1 addition & 1 deletion automata_docs/configs/symbols/symbol_embedding.json

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions automata_docs/core/embedding/symbol_embedding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import abc
import logging
from typing import Dict, List, Optional

import numpy as np
import openai
from typing import List, Optional

from automata_docs.core.database.vector import VectorDatabaseProvider
from automata_docs.core.symbol.symbol_types import Symbol, SymbolCodeEmbedding, SymbolEmbedding
Expand Down Expand Up @@ -66,12 +63,17 @@ def update_embedding(self, symbol: Symbol) -> None:
convert_to_fst_object,
)

desc_path_to_symbol = {
".".join([desc.name for desc in symbol.descriptors]): symbol
for symbol in self.embedding_db.get_all_symbols()
}
try:
symbol_source = str(convert_to_fst_object(symbol))
symbol_desc_identifier = ".".join([desc.name for desc in symbol.descriptors])

if self.embedding_db.contains(symbol):
existing_embedding = self.embedding_db.get(symbol)
symbol_source = str(convert_to_fst_object(symbol))
if symbol_desc_identifier in desc_path_to_symbol:
existing_embedding = self.embedding_db.get(
desc_path_to_symbol[symbol_desc_identifier]
)

if isinstance(existing_embedding, SymbolCodeEmbedding):
# If the symbol is already in the embedding map, check if the source code is the same
Expand All @@ -81,7 +83,7 @@ def update_embedding(self, symbol: Symbol) -> None:
new_embedding = self.embedding_provider.build_embedding(symbol_source)
existing_embedding.vector = new_embedding
existing_embedding.source_code = symbol_source

print("calling update...")
# Update the embedding in the database
self.embedding_db.update(existing_embedding)
else:
Expand Down
42 changes: 21 additions & 21 deletions automata_docs/core/indexing/python_indexing/module_tree_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DOT_SEP,
convert_fpath_to_module_dotpath,
)
from automata_docs.core.utils import root_path, root_py_path
from automata_docs.core.utils import root_path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,25 +82,6 @@ def put_module(self, module_dotpath: str, module: RedBaron):
self._loaded_modules[module_dotpath] = module
self._dotpath_map.put_module(module_dotpath)

@staticmethod
def _load_module_from_fpath(path) -> Optional[RedBaron]:
"""
Loads and returns an FST object for the given file path.
Args:
path (str): The file path of the Python source code.
Returns:
Module: RedBaron FST object.
"""

try:
module = RedBaron(open(path).read())
return module
except Exception as e:
logger.error(f"Failed to load module '{path}' due to: {e}")
return None

def get_existing_module_dotpath(self, module_obj: RedBaron) -> Optional[str]:
"""
Returns the module dotpath for the specified module object.
Expand Down Expand Up @@ -150,4 +131,23 @@ def __contains__(self, item):
@classmethod
@lru_cache(maxsize=1)
def cached_default(cls) -> "LazyModuleTreeMap":
return cls(root_py_path())
return cls(root_path())

@staticmethod
def _load_module_from_fpath(path) -> Optional[RedBaron]:
"""
Loads and returns an FST object for the given file path.
Args:
path (str): The file path of the Python source code.
Returns:
Module: RedBaron FST object.
"""

try:
module = RedBaron(open(path).read())
return module
except Exception as e:
logger.error(f"Failed to load module '{path}' due to: {e}")
return None
20 changes: 10 additions & 10 deletions automata_docs/core/symbol/search/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,41 @@ def symbols():
symbols = [
# Symbol with a simple attribute
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#description."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#description."
),
# Symbol with a method with foreign argument
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#load().(config_name)"
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#load().(config_name)"
),
# Symbol with a class method, self as argument
# parse_symbol(
# "scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `tools.python_tools.python_ast_indexer`/PythonASTIndexer#get_module_path().(self)"
# "scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `tools.python_tools.python_ast_indexer`/PythonASTIndexer#get_module_path().(self)"
# ),
# Symbol with a locally defined object
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.tasks.automata_task_executor`/logger."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.tasks.automata_task_executor`/logger."
),
# Symbol with a class object and class variable
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#verbose."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#verbose."
),
# Symbol with a function in a module
# parse_symbol("scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.coordinator.tests.test_automata_coordinator`/test().(coordinator)"),
# parse_symbol("scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.coordinator.tests.test_automata_coordinator`/test().(coordinator)"),
# Symbol with a class method
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `evals.eval_helpers`/EvalAction#__init__().(action)"
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `evals.eval_helpers`/EvalAction#__init__().(action)"
),
# Symbol with an object
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.agent.automata_agent_enums`/ActionIndicator#CODE."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.agent.automata_agent_enums`/ActionIndicator#CODE."
),
# Class Name
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.agent.automata_agent_enums`/ActionIndicator#"
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.agent.automata_agent_enums`/ActionIndicator#"
),
# Init
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.base.tool`/ToolNotFoundError#__init__()."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.base.tool`/ToolNotFoundError#__init__()."
),
]

Expand Down
4 changes: 2 additions & 2 deletions automata_docs/core/symbol/symbol_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ class Symbol:
from automata_docs.core.symbol.search.symbol_parser import parse_symbol
symbol_class = parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `automata_docs.core.agent.automata_agent_enums`/ActionIndicator#"
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `automata_docs.core.agent.automata_agent_enums`/ActionIndicator#"
)
symbol_method = parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `automata_docs.core.base.tool`/ToolNotFoundError#__init__()."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `automata_docs.core.base.tool`/ToolNotFoundError#__init__()."
)
"""

Expand Down
7 changes: 5 additions & 2 deletions automata_docs/core/symbol/symbol_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import List, Optional

from redbaron import RedBaron

Expand Down Expand Up @@ -26,7 +26,10 @@ def convert_to_fst_object(
descriptors = list(symbol.descriptors)
obj = None
module_map = module_map or LazyModuleTreeMap.cached_default()

# print('descriptors = ', descriptors)
# print('module_map._dotpath_map = ', module_map._dotpath_map)
# print('module_map._dotpath_map._module_dotpath_to_fpath_map = ', module_map._dotpath_map._module_dotpath_to_fpath_map)
# print('module_map._loaded_modules = ', module_map._loaded_modules)
while descriptors:
top_descriptor = descriptors.pop(0)
if (
Expand Down
8 changes: 1 addition & 7 deletions automata_docs/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,7 @@ def get_logging_config(
log_level: int = logging.INFO, log_file: Optional[str] = None
) -> dict[str, Any]:
"""Returns logging configuration."""
color_scheme = {
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
}

logging_config: LoggingConfig = {
"version": 1,
"disable_existing_loggers": False,
Expand Down
File renamed without changes.
22 changes: 11 additions & 11 deletions tests/unit/conftest.py → automata_docs/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def temp_output_filename():
os.remove(filename)


prefix = "scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/"
prefix = "scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/"


@pytest.fixture
Expand Down Expand Up @@ -59,41 +59,41 @@ def symbols():
symbols = [
# Symbol with a simple attribute
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#description."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#description."
),
# Symbol with a method with foreign argument
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#load().(config_name)"
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#load().(config_name)"
),
# Symbol with a class method, self as argument
# parse_symbol(
# "scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `tools.python_tools.python_ast_indexer`/PythonASTIndexer#get_module_path().(self)"
# "scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `tools.python_tools.python_ast_indexer`/PythonASTIndexer#get_module_path().(self)"
# ),
# Symbol with a locally defined object
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.tasks.automata_task_executor`/logger."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.tasks.automata_task_executor`/logger."
),
# Symbol with a class object and class variable
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#verbose."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `configs.automata_agent_configs`/AutomataAgentConfig#verbose."
),
# Symbol with a function in a module
# parse_symbol("scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.coordinator.tests.test_automata_coordinator`/test().(coordinator)"),
# parse_symbol("scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.coordinator.tests.test_automata_coordinator`/test().(coordinator)"),
# Symbol with a class method
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `evals.eval_helpers`/EvalAction#__init__().(action)"
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `evals.eval_helpers`/EvalAction#__init__().(action)"
),
# Symbol with an object
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.agent.automata_agent_enums`/ActionIndicator#CODE."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.agent.automata_agent_enums`/ActionIndicator#CODE."
),
# Class Name
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.agent.automata_agent_enums`/ActionIndicator#"
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.agent.automata_agent_enums`/ActionIndicator#"
),
# Init
parse_symbol(
"scip-python python automata 75482692a6fe30c72db516201a6f47d9fb4af065 `core.base.tool`/ToolNotFoundError#__init__()."
"scip-python python automata_docs 75482692a6fe30c72db516201a6f47d9fb4af065 `core.base.tool`/ToolNotFoundError#__init__()."
),
]

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def test_update_embedding(monkeypatch, mock_simple_method_symbols):
)
cem.embedding_provider.build_embedding.return_value = [1, 2, 3, 4]
cem.embedding_db.contains = lambda x: True
cem.embedding_db.get_all_symbols = lambda: [mock_simple_method_symbols[0]]

cem.update_embedding(mock_simple_method_symbols[0])
embedding = cem.embedding_db.data[0].vector
print("cem.embedding_db = ", cem.embedding_db.data)
assert len(cem.embedding_db.data) == 1 # Expect empty embedding map because of exception
assert embedding == [1, 2, 3, 4]

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def test_parse_symbol(symbols):
for symbol in symbols:
assert symbol.scheme == "scip-python"
assert symbol.package.manager == "python"
assert symbol.package.name == "automata"
assert symbol.package.name == "automata_docs"
assert symbol.package.version == "75482692a6fe30c72db516201a6f47d9fb4af065"
assert len(symbol.descriptors) > 0

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock

import numpy as np

Expand All @@ -7,11 +7,9 @@
EmbeddingsProvider,
SymbolCodeEmbeddingHandler,
)
from automata_docs.core.embedding.symbol_similarity import NormType, SymbolSimilarity
from automata_docs.core.embedding.symbol_similarity import SymbolSimilarity
from automata_docs.core.symbol.symbol_types import SymbolCodeEmbedding

from .conftest import get_sem, patch_get_embedding


def test_get_nearest_symbols_for_query(
monkeypatch, mock_embedding, mock_simple_method_symbols, temp_output_filename
Expand Down
Loading

0 comments on commit 96dbd3c

Please sign in to comment.