Skip to content
This repository has been archived by the owner on Mar 16, 2024. It is now read-only.

Commit

Permalink
'Refactored by Sourcery' (#229)
Browse files Browse the repository at this point in the history
Co-authored-by: Sourcery AI <>
  • Loading branch information
sourcery-ai[bot] authored Jul 17, 2023
1 parent 88b84d0 commit dee1360
Show file tree
Hide file tree
Showing 14 changed files with 36 additions and 55 deletions.
3 changes: 1 addition & 2 deletions automata/code_parsers/py/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def get_docstring_from_node(node: Optional[AST]) -> str:
return AST_NO_RESULT_FOUND

elif isinstance(node, (AsyncFunctionDef, ClassDef, FunctionDef, Module)):
doc_string = get_docstring(node)
if doc_string:
if doc_string := get_docstring(node):
return doc_string.replace('"""', "").replace("'''", "")
else:
return AST_NO_RESULT_FOUND
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,10 @@ def _process_classes_and_methods(
class_header = "\n".join(decorators + [class_header])
interface += self.process_entry(f"{class_header}")

# Handle class attributes
attributes = [
if attributes := [
f"{unparse(a.target)}: {unparse(a.annotation)}"
for a in get_all_attributes(cls)
]
if attributes:
]:
interface += "\n".join(attributes)

with self.increased_indentation():
Expand Down
4 changes: 1 addition & 3 deletions automata/code_parsers/py/doc_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def generate_index_files(self, docs_dir: str) -> None:
+ existing_content[end_idx:]
)

# Add new auto-generated content
auto_content = auto_start_marker
auto_content += " .. toctree::\n"
auto_content = auto_start_marker + " .. toctree::\n"
auto_content += (
" :maxdepth: 2\n\n"
if not root_dir_node or root_dir_node.is_root_dir() # type: ignore
Expand Down
10 changes: 4 additions & 6 deletions automata/core/base/patterns/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@ class Singleton(abc.ABCMeta, type):

_instances: Dict[str, Any] = {}

def __call__(cls, *args, **kwargs):
def __call__(self, *args, **kwargs):
"""Call method for the singleton metaclass."""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]
if self not in self._instances:
self._instances[self] = super(Singleton, self).__call__(*args, **kwargs)
return self._instances[self]
6 changes: 2 additions & 4 deletions automata/experimental/search/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def _prepare_query_to_symbol_similarity(
"""
if query_to_symbol_similarity is None:
return {k: 1.0 / node_count for k in stochastic_graph}
missing = set(self.graph) - set(query_to_symbol_similarity)
if missing:
if missing := set(self.graph) - set(query_to_symbol_similarity):
raise NetworkXError(
f"query_to_symbol_similarity dictionary must have a value for every node. Missing {len(missing)} nodes."
)
Expand All @@ -197,8 +196,7 @@ def _prepare_dangling_weights(
) -> Dict[Symbol, float]:
if dangling is None:
return query_to_symbol_similarity
missing = set(self.graph) - set(dangling)
if missing:
if missing := set(self.graph) - set(dangling):
raise NetworkXError(
f"Dangling node dictionary must have a value for every node. Missing nodes {missing}"
)
Expand Down
5 changes: 2 additions & 3 deletions automata/experimental/search/symbol_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,11 @@ def _find_pattern_in_modules(self, pattern: str) -> Dict[str, List[int]]:
for module_path, module in py_module_loader.items():
if module:
lines = py_ast_unparse(module).splitlines()
line_numbers = [
if line_numbers := [
i + 1
for i, line in enumerate(lines)
if pattern in line.strip()
]
if line_numbers:
]:
matches[module_path] = line_numbers
return matches

Expand Down
9 changes: 4 additions & 5 deletions automata/singletons/dependency_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,12 @@ def get(self, dependency: str) -> Any:
return self._instances[dependency]

method_name = f"create_{dependency}"
if hasattr(self, method_name):
creation_method = getattr(self, method_name)
logger.info(f"Creating dependency {dependency}")
instance = creation_method()
else:
if not hasattr(self, method_name):
raise AgentGeneralError(f"Dependency {dependency} not found.")

creation_method = getattr(self, method_name)
logger.info(f"Creating dependency {dependency}")
instance = creation_method()
self._instances[dependency] = instance

# Perform synchronization
Expand Down
2 changes: 1 addition & 1 deletion automata/symbol/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def accept_character(self, r: str, what: str):

@staticmethod
def is_identifier_character(c: str) -> bool:
return c.isalpha() or c.isdigit() or c in ["-", "+", "$", "_"]
return c.isalpha() or c.isdigit() or c in {"-", "+", "$", "_"}


def parse_symbol(symbol_uri: str) -> Symbol:
Expand Down
21 changes: 8 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def temp_output_vector_dir():
# The TemporaryDirectory context manager should already clean up the directory,
# but just in case it doesn't (e.g. due to an error), we'll try removing it manually as well.
try:
shutil.rmtree(filename + "/")
shutil.rmtree(f"{filename}/")
except OSError:
pass

Expand All @@ -57,7 +57,7 @@ def temp_output_filename():
# The TemporaryDirectory context manager should already clean up the directory,
# but just in case it doesn't (e.g. due to an error), we'll try removing it manually as well.
try:
shutil.rmtree(filename + "/")
shutil.rmtree(f"{filename}/")
except OSError:
pass

Expand All @@ -71,7 +71,7 @@ def symbols():
These symbols at one point reflected existing code
but they are not guaranteed to be up to date.
"""
symbols = [
return [
# Symbol with a simple attribute
parse_symbol(
"scip-python python automata v0.0.0 `config.automata_agent_config`/AutomataAgentConfig#description."
Expand Down Expand Up @@ -106,8 +106,6 @@ def symbols():
),
]

return symbols


EXAMPLE_SYMBOL_PREFIX = (
"scip-python python automata v0.0.0 `config.automata_agent_config`/"
Expand Down Expand Up @@ -190,17 +188,17 @@ def automata_agent(mocker, automata_agent_config_builder):
"""Creates a mock AutomataAgent object for testing"""

llm_toolkits_list = ["context-oracle"]
kwargs = {}

dependencies: Set[Any] = set()
for tool in llm_toolkits_list:
for dependency_name, _ in AgentToolFactory.TOOLKIT_TYPE_TO_ARGS[
AgentToolkitNames(tool)
]:
dependencies.add(dependency_name)

for dependency in dependencies:
kwargs[dependency] = dependency_factory.get(dependency)
kwargs = {
dependency: dependency_factory.get(dependency)
for dependency in dependencies
}
tools = AgentToolFactory.build_tools(["context-oracle"], **kwargs)

instructions = "Test instruction."
Expand Down Expand Up @@ -267,10 +265,7 @@ def environment():
@pytest.fixture
def registry(task):
def mock_get_tasks_by_query(query, params):
if params[0] == task.task_id:
return [task]
else:
return []
return [task] if params[0] == task.task_id else []

db = MagicMock()
db.get_tasks_by_query.side_effect = (
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sample_modules/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def sample_function(name):
"""This is a sample function."""
return f"Hello, {name}! Sqrt(2) = " + str(math.sqrt(2))
return f"Hello, {name}! Sqrt(2) = {str(math.sqrt(2))}"


class Person:
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/singletons/test_singletons_dependency_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ def test_reset(dependency_factory):

dependency_factory.reset()

assert dependency_factory._class_cache == {}
assert dependency_factory._instances == {}
assert dependency_factory.overrides == {}
assert not dependency_factory._class_cache
assert not dependency_factory._instances
assert not dependency_factory.overrides
6 changes: 3 additions & 3 deletions tests/unit/singletons/tests_singletons_dependency_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,6 @@ def test_reset(dependency_factory):

dependency_factory.reset()

assert dependency_factory._class_cache == {}
assert dependency_factory._instances == {}
assert dependency_factory.overrides == {}
assert not dependency_factory._class_cache
assert not dependency_factory._instances
assert not dependency_factory.overrides
9 changes: 3 additions & 6 deletions tests/unit/symbol/test_symbol_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,19 @@

@pytest.fixture
def cem(mock_provider, mock_db):
cem = SymbolCodeEmbeddingHandler(
return SymbolCodeEmbeddingHandler(
embedding_builder=mock_provider, embedding_db=mock_db
)
return cem


@pytest.fixture
def mock_db():
mock_db = MagicMock(ChromaSymbolEmbeddingVectorDatabase)
return mock_db
return MagicMock(ChromaSymbolEmbeddingVectorDatabase)


@pytest.fixture
def mock_provider():
mock_provider = Mock(EmbeddingBuilder)
return mock_provider
return Mock(EmbeddingBuilder)


def test_update_embeddings(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/symbol/test_symbol_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_is_global_symbol(symbols):

def test_is_local_symbol(symbols):
for symbol in symbols:
assert is_local_symbol("local " + symbol.uri)
assert is_local_symbol(f"local {symbol.uri}")


def _unparse(symbol: Symbol):
Expand Down

0 comments on commit dee1360

Please sign in to comment.