diff --git a/agentstack/__init__.py b/agentstack/__init__.py index e69de29b..e645be5c 100644 --- a/agentstack/__init__.py +++ b/agentstack/__init__.py @@ -0,0 +1,8 @@ + + +class ValidationError(Exception): + """ + Raised when a validation error occurs ie. a file does not meet the required + format or a syntax error is found. + """ + pass \ No newline at end of file diff --git a/agentstack/agents.py b/agentstack/agents.py new file mode 100644 index 00000000..8040f26f --- /dev/null +++ b/agentstack/agents.py @@ -0,0 +1,98 @@ +from typing import Optional +import os +from pathlib import Path +import pydantic +from ruamel.yaml import YAML, YAMLError +from ruamel.yaml.scalarstring import FoldedScalarString +from agentstack import ValidationError + + +AGENTS_FILENAME: Path = Path("src/config/agents.yaml") + +yaml = YAML() +yaml.preserve_quotes = True # Preserve quotes in existing data + + +class AgentConfig(pydantic.BaseModel): + """ + Interface for interacting with an agent configuration. + + Multiple agents are stored in a single YAML file, so we always look up the + requested agent by `name`. + + Use it as a context manager to make and save edits: + ```python + with AgentConfig('agent_name') as config: + config.llm = "openai/gpt-4o" + + Config Schema + ------------- + name: str + The name of the agent; used for lookup. + role: Optional[str] + The role of the agent. + goal: Optional[str] + The goal of the agent. + backstory: Optional[str] + The backstory of the agent. + llm: Optional[str] + The model this agent should use. + Adheres to the format set by the framework. + """ + + name: str + role: Optional[str] = "" + goal: Optional[str] = "" + backstory: Optional[str] = "" + llm: Optional[str] = "" + + def __init__(self, name: str, path: Optional[Path] = None): + if not path: + path = Path() + + filename = path / AGENTS_FILENAME + if not os.path.exists(filename): + os.makedirs(filename.parent, exist_ok=True) + filename.touch() + + try: + with open(filename, 'r') as f: + data = yaml.load(f) or {} + data = data.get(name, {}) or {} + super().__init__(**{**{'name': name}, **data}) + except YAMLError as e: + # TODO format MarkedYAMLError lines/messages + raise ValidationError(f"Error parsing agents file: {filename}\n{e}") + except pydantic.ValidationError as e: + error_str = "Error validating agent config:\n" + for error in e.errors(): + error_str += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n" + raise ValidationError(f"Error loading agent {name} from {filename}.\n{error_str}") + + # store the path *after* loading data + self._path = path + + def model_dump(self, *args, **kwargs) -> dict: + dump = super().model_dump(*args, **kwargs) + dump.pop('name') # name is the key, so keep it out of the data + # format these as FoldedScalarStrings + for key in ('role', 'goal', 'backstory'): + dump[key] = FoldedScalarString(dump.get(key) or "") + return {self.name: dump} + + def write(self): + filename = self._path / AGENTS_FILENAME + + with open(filename, 'r') as f: + data = yaml.load(f) or {} + + data.update(self.model_dump()) + + with open(filename, 'w') as f: + yaml.dump(data, f) + + def __enter__(self) -> 'AgentConfig': + return self + + def __exit__(self, *args): + self.write() diff --git a/agentstack/cli/__init__.py b/agentstack/cli/__init__.py index afd42af5..1a35e913 100644 --- a/agentstack/cli/__init__.py +++ b/agentstack/cli/__init__.py @@ -1 +1 @@ -from .cli import init_project_builder, list_tools, configure_default_model +from .cli import init_project_builder, list_tools, configure_default_model, run_project diff --git a/agentstack/cli/cli.py b/agentstack/cli/cli.py index 3fdea1b5..15b61676 100644 --- a/agentstack/cli/cli.py +++ b/agentstack/cli/cli.py @@ -4,6 +4,7 @@ import time from datetime import datetime from typing import Optional +from pathlib import Path import requests import itertools @@ -21,12 +22,15 @@ ) from agentstack.logger import log from agentstack.utils import get_package_path +from agentstack.tools import get_all_tools from agentstack.generation.files import ConfigFile -from agentstack.generation.tool_generation import get_all_tools -from agentstack import packaging, generation +from agentstack import frameworks +from agentstack import packaging +from agentstack import generation from agentstack.utils import open_json_file, term_color, is_snake_case from agentstack.update import AGENTSTACK_PACKAGE + PREFERRED_MODELS = [ 'openai/gpt-4o', 'anthropic/claude-3-5-sonnet', @@ -158,13 +162,32 @@ def configure_default_model(path: Optional[str] = None): ) if model == other_msg: # If the user selects "Other", prompt for a model name - print(f'A list of available models is available at: "https://docs.litellm.ai/docs/providers"') + print('A list of available models is available at: "https://docs.litellm.ai/docs/providers"') model = inquirer.text(message="Enter the model name") with ConfigFile(path) as agentstack_config: agentstack_config.default_model = model +def run_project(framework: str, path: str = ''): + """Validate that the project is ready to run and then run it.""" + if framework not in frameworks.SUPPORTED_FRAMEWORKS: + print(term_color(f"Framework {framework} is not supported by agentstack.", 'red')) + sys.exit(1) + + _path = Path(path) + + try: + frameworks.validate_project(framework, _path) + except frameworks.ValidationError as e: + print(term_color("Project validation failed:", 'red')) + print(e) + sys.exit(1) + + entrypoint = _path / frameworks.get_entrypoint_path(framework) + os.system(f'python {entrypoint}') + + def ask_framework() -> str: framework = "CrewAI" # framework = inquirer.list_input( diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py new file mode 100644 index 00000000..fef72bc6 --- /dev/null +++ b/agentstack/frameworks/__init__.py @@ -0,0 +1,116 @@ +from typing import Optional, Protocol +from types import ModuleType +from importlib import import_module +from pathlib import Path +from agentstack import ValidationError +from agentstack.tools import ToolConfig +from agentstack.agents import AgentConfig +from agentstack.tasks import TaskConfig + + +CREWAI = 'crewai' +SUPPORTED_FRAMEWORKS = [CREWAI, ] + +class FrameworkModule(Protocol): + """ + Protocol spec for a framework implementation module. + """ + ENTRYPOINT: Path + """ + Relative path to the entrypoint file for the framework in the user's project. + ie. `src/crewai.py` + """ + + def validate_project(self, path: Optional[Path] = None) -> None: + """ + Validate that a user's project is ready to run. + Raises a `ValidationError` if the project is not valid. + """ + ... + + def add_tool(self, tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None: + """ + Add a tool to an agent in the user's project. + """ + ... + + def remove_tool(self, tool: ToolConfig, agent_name: str, path: Optional[Path] = None) -> None: + """ + Remove a tool from an agent in user's project. + """ + ... + + def get_agent_names(self, path: Optional[Path] = None) -> list[str]: + """ + Get a list of agent names in the user's project. + """ + ... + + def add_agent(self, agent: AgentConfig, path: Optional[Path] = None) -> None: + """ + Add an agent to the user's project. + """ + ... + + def add_task(self, task: TaskConfig, path: Optional[Path] = None) -> None: + """ + Add a task to the user's project. + """ + ... + + +def get_framework_module(framework: str) -> FrameworkModule: + """ + Get the module for a framework. + """ + try: + return import_module(f".{framework}", package=__package__) + except ImportError: + raise Exception(f"Framework {framework} could not be imported.") + +def get_entrypoint_path(framework: str, path: Optional[Path] = None) -> Path: + """ + Get the path to the entrypoint file for a framework. + """ + if path is None: + path = Path() + return path / get_framework_module(framework).ENTRYPOINT + +def validate_project(framework: str, path: Optional[Path] = None): + """ + Validate that the user's project is ready to run. + """ + return get_framework_module(framework).validate_project(path) + +def add_tool(framework: str, tool: ToolConfig, agent_name: str, path: Optional[Path] = None): + """ + Add a tool to the user's project. + The tool will have aready been installed in the user's application and have + all dependencies installed. We're just handling code generation here. + """ + return get_framework_module(framework).add_tool(tool, agent_name, path) + +def remove_tool(framework: str, tool: ToolConfig, agent_name: str, path: Optional[Path] = None): + """ + Remove a tool from the user's project. + """ + return get_framework_module(framework).remove_tool(tool, agent_name, path) + +def get_agent_names(framework: str, path: Optional[Path] = None) -> list[str]: + """ + Get a list of agent names in the user's project. + """ + return get_framework_module(framework).get_agent_names(path) + +def add_agent(framework: str, agent: AgentConfig, path: Optional[Path] = None): + """ + Add an agent to the user's project. + """ + return get_framework_module(framework).add_agent(agent, path) + +def add_task(framework: str, task: TaskConfig, path: Optional[Path] = None): + """ + Add a task to the user's project. + """ + return get_framework_module(framework).add_task(task, path) + diff --git a/agentstack/frameworks/crewai.py b/agentstack/frameworks/crewai.py new file mode 100644 index 00000000..7c7877fc --- /dev/null +++ b/agentstack/frameworks/crewai.py @@ -0,0 +1,281 @@ +from typing import Optional +from pathlib import Path +import ast +from agentstack import ValidationError +from agentstack.tools import ToolConfig +from agentstack.tasks import TaskConfig +from agentstack.agents import AgentConfig +from agentstack.generation import asttools + + +ENTRYPOINT: Path = Path('src/crew.py') + + +class CrewFile(asttools.File): + """ + Parses and manipulates the CrewAI entrypoint file. + All AST interactions should happen within the methods of this class. + """ + + _base_class: Optional[ast.ClassDef] = None + + def get_base_class(self) -> ast.ClassDef: + """A base class is a class decorated with `@CrewBase`.""" + if self._base_class is None: # Gets cached to save repeat iteration + try: + self._base_class = asttools.find_class_with_decorator(self.tree, 'CrewBase')[0] + except IndexError: + raise ValidationError(f"`@CrewBase` decorated class not found in {ENTRYPOINT}") + return self._base_class + + def get_crew_method(self) -> ast.FunctionDef: + """A `crew` method is a method decorated with `@crew`.""" + try: + base_class = self.get_base_class() + return asttools.find_decorated_method_in_class(base_class, 'crew')[0] + except IndexError: + raise ValidationError( + f"`@crew` decorated method not found in `{base_class.name}` class in {ENTRYPOINT}" + ) + + def get_task_methods(self) -> list[ast.FunctionDef]: + """A `task` method is a method decorated with `@task`.""" + return asttools.find_decorated_method_in_class(self.get_base_class(), 'task') + + def add_task_method(self, task: TaskConfig): + """Add a new task method to the CrewAI entrypoint.""" + task_methods = self.get_task_methods() + if task.name in [method.name for method in task_methods]: + # TODO this should check all methods in the class for duplicates + raise ValidationError(f"Task `{task.name}` already exists in {ENTRYPOINT}") + if task_methods: + # Add after the existing task methods + _, pos = self.get_node_range(task_methods[-1]) + else: + # Add before the `crew` method + crew_method = self.get_crew_method() + pos, _ = self.get_node_range(crew_method) + + code = f""" @task + def {task.name}(self) -> Task: + return Task( + config=self.tasks_config['{task.name}'], + )""" + if not self.source[:pos].endswith('\n'): + code = '\n\n' + code + if not self.source[pos:].startswith('\n'): + code += '\n\n' + self.edit_node_range(pos, pos, code) + + def get_agent_methods(self) -> list[ast.FunctionDef]: + """An `agent` method is a method decorated with `@agent`.""" + return asttools.find_decorated_method_in_class(self.get_base_class(), 'agent') + + def add_agent_method(self, agent: AgentConfig): + """Add a new agent method to the CrewAI entrypoint.""" + # TODO do we want to pre-populate any tools? + agent_methods = self.get_agent_methods() + if agent.name in [method.name for method in agent_methods]: + # TODO this should check all methods in the class for duplicates + raise ValidationError(f"Agent `{agent.name}` already exists in {ENTRYPOINT}") + if agent_methods: + # Add after the existing agent methods + _, pos = self.get_node_range(agent_methods[-1]) + else: + # Add before the `crew` method + crew_method = self.get_crew_method() + pos, _ = self.get_node_range(crew_method) + + code = f""" @agent + def {agent.name}(self) -> Agent: + return Agent( + config=self.agents_config['{agent.name}'], + tools=[], # add tools here or use `agentstack tools add + verbose=True, + )""" + if not self.source[:pos].endswith('\n'): + code = '\n\n' + code + if not self.source[pos:].startswith('\n'): + code += '\n\n' + self.edit_node_range(pos, pos, code) + + def get_agent_tools(self, agent_name: str) -> ast.List: + """ + Get the tools used by an agent as AST nodes. + + Tool definitons are inside of the methods marked with an `@agent` decorator. + The method returns a new class instance with the tools as a list of callables + under the kwarg `tools`. + """ + method = asttools.find_method(self.get_agent_methods(), agent_name) + if method is None: + raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") + + agent_class = asttools.find_class_instantiation(method, 'Agent') + if agent_class is None: + raise ValidationError( + f"`@agent` method `{agent_name}` does not have an `Agent` class instantiation in {ENTRYPOINT}" + ) + + tools_kwarg = asttools.find_kwarg_in_method_call(agent_class, 'tools') + if not tools_kwarg: + raise ValidationError( + f"`@agent` method `{agent_name}` does not have a keyword argument `tools` in {ENTRYPOINT}" + ) + + if not isinstance(tools_kwarg.value, ast.List): + raise ValidationError( + f"`@agent` method `{agent_name}` has a non-list value for the `tools` kwarg in {ENTRYPOINT}" + ) + + return tools_kwarg.value + + def add_agent_tools(self, agent_name: str, tool: ToolConfig): + """ + Add new tools to be used by an agent. + + Tool definitons are inside of the methods marked with an `@agent` decorator. + The method returns a new class instance with the tools as a list of callables + under the kwarg `tools`. + """ + method = asttools.find_method(self.get_agent_methods(), agent_name) + if method is None: + raise ValidationError(f"`@agent` method `{agent_name}` does not exist in {ENTRYPOINT}") + + new_tool_nodes: set[ast.expr] = set() + for tool_name in tool.tools: + # This prefixes the tool name with the 'tools' module + node: ast.expr = asttools.create_attribute('tools', tool_name) + if tool.tools_bundled: # Splat the variable if it's bundled + node = ast.Starred(value=node, ctx=ast.Load()) + new_tool_nodes.add(node) + + existing_node: ast.List = self.get_agent_tools(agent_name) + elts: set[ast.expr] = set(existing_node.elts) | new_tool_nodes + new_node = ast.List(elts=list(elts), ctx=ast.Load()) + start, end = self.get_node_range(existing_node) + self.edit_node_range(start, end, new_node) + + def remove_agent_tools(self, agent_name: str, tool: ToolConfig): + """ + Remove tools from an agent belonging to `tool`. + """ + existing_node: ast.List = self.get_agent_tools(agent_name) + start, end = self.get_node_range(existing_node) + + # modify the existing node to remove any matching tools + for tool_name in tool.tools: + for node in existing_node.elts: + if isinstance(node, ast.Starred): + if isinstance(node.value, ast.Attribute): + attr_name = node.value.attr + else: + continue # not an attribute node + elif isinstance(node, ast.Attribute): + attr_name = node.attr + else: + continue # not an attribute node + if attr_name == tool_name: + existing_node.elts.remove(node) + + self.edit_node_range(start, end, existing_node) + + +def validate_project(path: Optional[Path] = None) -> None: + """ + Validate that a CrewAI project is ready to run. + Raises an `agentstack.VaidationError` if the project is not valid. + """ + if path is None: + path = Path() + try: + crew_file = CrewFile(path / ENTRYPOINT) + except ValidationError as e: + raise e + + # A valid project must have a class in the crew.py file decorated with `@CrewBase` + try: + class_node = crew_file.get_base_class() + except ValidationError as e: + raise e + + # The Crew class must have one method decorated with `@crew` + try: + crew_file.get_crew_method() + except ValidationError as e: + raise e + + # The Crew class must have one or more methods decorated with `@task` + if len(crew_file.get_task_methods()) < 1: + raise ValidationError( + f"`@task` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" + "Create a new task using `agentstack generate task `." + ) + + # The Crew class must have one or more methods decorated with `@agent` + if len(crew_file.get_agent_methods()) < 1: + raise ValidationError( + f"`@agent` decorated method not found in `{class_node.name}` class in {ENTRYPOINT}.\n" + "Create a new agent using `agentstack generate agent `." + ) + + +def get_task_names(path: Optional[Path] = None) -> list[str]: + """ + Get a list of task names (methods with an @task decorator). + """ + if path is None: + path = Path() + crew_file = CrewFile(path / ENTRYPOINT) + return [method.name for method in crew_file.get_task_methods()] + + +def add_task(task: TaskConfig, path: Optional[Path] = None) -> None: + """ + Add a task method to the CrewAI entrypoint. + """ + if path is None: + path = Path() + with CrewFile(path / ENTRYPOINT) as crew_file: + crew_file.add_task_method(task) + + +def get_agent_names(path: Optional[Path] = None) -> list[str]: + """ + Get a list of agent names (methods with an @agent decorator). + """ + if path is None: + path = Path() + crew_file = CrewFile(path / ENTRYPOINT) + return [method.name for method in crew_file.get_agent_methods()] + + +def add_agent(agent: AgentConfig, path: Optional[Path] = None) -> None: + """ + Add an agent method to the CrewAI entrypoint. + """ + if path is None: + path = Path() + with CrewFile(path / ENTRYPOINT) as crew_file: + crew_file.add_agent_method(agent) + + +def add_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): + """ + Add a tool to the CrewAI entrypoint for the specified agent. + The agent should already exist in the crew class and have a keyword argument `tools`. + """ + if path is None: + path = Path() + with CrewFile(path / ENTRYPOINT) as crew_file: + crew_file.add_agent_tools(agent_name, tool) + + +def remove_tool(tool: ToolConfig, agent_name: str, path: Optional[Path] = None): + """ + Remove a tool from the CrewAI framework for the specified agent. + """ + if path is None: + path = Path() + with CrewFile(path / ENTRYPOINT) as crew_file: + crew_file.remove_agent_tools(agent_name, tool) diff --git a/agentstack/generation/__init__.py b/agentstack/generation/__init__.py index 49d62c82..82e2eb56 100644 --- a/agentstack/generation/__init__.py +++ b/agentstack/generation/__init__.py @@ -1,4 +1,4 @@ -from .agent_generation import generate_agent, get_agent_names -from .task_generation import generate_task, get_task_names +from .agent_generation import add_agent +from .task_generation import add_task from .tool_generation import add_tool, remove_tool from .files import ConfigFile, EnvFile, CONFIG_FILENAME \ No newline at end of file diff --git a/agentstack/generation/agent_generation.py b/agentstack/generation/agent_generation.py index 8eb865f4..2bc03157 100644 --- a/agentstack/generation/agent_generation.py +++ b/agentstack/generation/agent_generation.py @@ -1,105 +1,39 @@ -from typing import Optional, List - -from .gen_utils import insert_code_after_tag, get_crew_components, CrewComponent -from agentstack.utils import verify_agentstack_project, get_framework +import sys +from typing import Optional +from pathlib import Path +from agentstack import ValidationError +from agentstack import frameworks +from agentstack.utils import verify_agentstack_project +from agentstack.agents import AgentConfig, AGENTS_FILENAME from agentstack.generation.files import ConfigFile -import os -from ruamel.yaml import YAML -from ruamel.yaml.scalarstring import FoldedScalarString -def generate_agent( - name, - role: Optional[str], - goal: Optional[str], - backstory: Optional[str], - llm: Optional[str], +def add_agent( + agent_name: str, + role: Optional[str] = None, + goal: Optional[str] = None, + backstory: Optional[str] = None, + llm: Optional[str] = None, + path: Optional[Path] = None, ): - agentstack_config = ConfigFile() # TODO path - if not role: - role = 'Add your role here' - if not goal: - goal = 'Add your goal here' - if not backstory: - backstory = 'Add your backstory here' - if not llm: - llm = agentstack_config.default_model - - verify_agentstack_project() - - framework = get_framework() - - if framework == 'crewai': - generate_crew_agent(name, role, goal, backstory, llm) - print(" > Added to src/config/agents.yaml") - else: - print(f"This function is not yet implemented for {framework}") - return - - print(f"Added agent \"{name}\" to your AgentStack project successfully!") - - -def generate_crew_agent( - name, - role: Optional[str] = 'Add your role here', - goal: Optional[str] = 'Add your goal here', - backstory: Optional[str] = 'Add your backstory here', - llm: Optional[str] = 'openai/gpt-4o', -): - config_path = os.path.join('src', 'config', 'agents.yaml') - - # Ensure the directory exists - os.makedirs(os.path.dirname(config_path), exist_ok=True) - - yaml = YAML() - yaml.preserve_quotes = True # Preserve quotes in existing data - - # Read existing data - if os.path.exists(config_path): - with open(config_path, 'r') as file: - try: - data = yaml.load(file) or {} - except Exception as exc: - print(f"Error parsing YAML file: {exc}") - data = {} - else: - data = {} - - # Handle None values - role_str = FoldedScalarString(role) if role else FoldedScalarString('') - goals_str = FoldedScalarString(goal) if goal else FoldedScalarString('') - backstory_str = FoldedScalarString(backstory) if backstory else FoldedScalarString('') - model_str = llm if llm else '' - - # Add new agent details - data[name] = { - 'role': role_str, - 'goal': goals_str, - 'backstory': backstory_str, - 'llm': model_str, - } - - # Write back to the file without altering existing content - with open(config_path, 'w') as file: - yaml.dump(data, file) - - # Now lets add the agent to crew.py - file_path = 'src/crew.py' - tag = '# Agent definitions' - code_to_insert = [ - "@agent", - f"def {name}(self) -> Agent:", - " return Agent(", - f" config=self.agents_config['{name}'],", - " tools=[], # add tools here or use `agentstack tools add ", # TODO: Add any tools in agentstack.json - " verbose=True", - " )", - "", - ] - - insert_code_after_tag(file_path, tag, code_to_insert) - - -def get_agent_names(framework: str = 'crewai', path: str = '') -> List[str]: - """Get only agent names from the crew file""" - return get_crew_components(framework, CrewComponent.AGENT, path)['agents'] + if path is None: + path = Path() + verify_agentstack_project(path) + agentstack_config = ConfigFile(path) + framework = agentstack_config.framework + + agent = AgentConfig(agent_name, path) + with agent as config: + config.role = role or "Add your role here" + config.goal = goal or "Add your goal here" + config.backstory = backstory or "Add your backstory here" + config.llm = llm or agentstack_config.default_model + + try: + frameworks.add_agent(framework, agent, path) + print(f" > Added to {AGENTS_FILENAME}") + except ValidationError as e: + print(f"Error adding agent to project:\n{e}") + sys.exit(1) + + print(f"Added agent \"{agent_name}\" to your AgentStack project successfully!") diff --git a/agentstack/generation/asttools.py b/agentstack/generation/asttools.py new file mode 100644 index 00000000..575e0403 --- /dev/null +++ b/agentstack/generation/asttools.py @@ -0,0 +1,177 @@ +""" +Tools for working with ASTs. + +We include convenience functions here based on real needs inside the codebase, +such as finding a method definition in a class, or finding a method by its decorator. + +It's not optimal to have a fully-featured set of functions as this would be +unwieldy, but since our use-cases are well-defined, we can provide a set of +functions that are useful for the specific tasks we need to accomplish. +""" + +from typing import TypeVar, Optional, Union, Iterable +from pathlib import Path +import ast +import astor +import asttokens +from agentstack import ValidationError + + +FileT = TypeVar('FileT', bound='File') +ASTT = TypeVar('ASTT', bound=ast.AST) + + +class File: + """ + Parses and manipulates a Python source file with an AST. + + Use it as a context manager to make and save edits: + ```python + with File(filename) as f: + f.edit_node_range(start, end, new_node) + ``` + + Lookups are done using the built-in `ast` module, which we only use to find + and read nodes in the tree. + + Edits are done using string indexing on the source code, which preserves a + majority of the original formatting and prevents comments from being lost. + + In cases where we are constructing new AST nodes, we use `astor` to render + the node as source code. + """ + + filename: Path + source: str + atok: asttokens.ASTTokens + tree: ast.Module + + def __init__(self, filename: Path): + self.filename = filename + self.read() + + def read(self): + try: + with open(self.filename, 'r') as f: + self.source = f.read() + self.atok = asttokens.ASTTokens(self.source, parse=True) + self.tree = self.atok.tree + except (FileNotFoundError, SyntaxError) as e: + raise ValidationError(f"Failed to parse {self.filename}\n{e}") + + def write(self): + with open(self.filename, 'w', encoding='utf-8') as f: + f.write(self.source) + + def get_node_range(self, node: ast.AST) -> tuple[int, int]: + """Get the string start and end indexes for a node in the source code.""" + return self.atok.get_text_range(node) + + def edit_node_range(self, start: int, end: int, node: Union[str, ast.AST]): + """Splice a new node or string into the source code at the given range.""" + if isinstance(node, ast.expr): + module = ast.Module(body=[ast.Expr(value=node)], type_ignores=[]) + _node = astor.to_source(module).strip() + else: + _node = node + + self.source = self.source[:start] + _node + self.source[end:] + # In order to continue accurately modifying the AST, we need to re-parse the source. + self.atok = asttokens.ASTTokens(self.source, parse=True) + + if self.atok.tree: + self.tree = self.atok.tree + else: + raise ValidationError(f"Failed to parse {self.filename} after edit") + + def __enter__(self: FileT) -> FileT: + return self + + def __exit__(self, *args): + self.write() + + +def get_all_imports(tree: ast.Module) -> list[ast.ImportFrom]: + """Find all import statements in an AST.""" + imports = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ImportFrom): # NOTE must be in format `from x import y` + imports.append(node) + return imports + + +def find_method(tree: Union[Iterable[ASTT], ASTT], method_name: str) -> Optional[ast.FunctionDef]: + """Find a method definition in an AST.""" + if isinstance(tree, ast.AST): + _tree = list(ast.iter_child_nodes(tree)) + else: + _tree = list(tree) + + for node in _tree: + if isinstance(node, ast.FunctionDef) and node.name == method_name: + return node + return None + + +def find_kwarg_in_method_call(node: ast.Call, kwarg_name: str) -> Optional[ast.keyword]: + """Find a keyword argument in a method call or class instantiation.""" + for arg in node.keywords: + if isinstance(arg, ast.keyword) and arg.arg == kwarg_name: + return arg + return None + + +def find_class_instantiation(tree: Union[Iterable[ast.AST], ast.AST], class_name: str) -> Optional[ast.Call]: + """ + Find a class instantiation statement in an AST by the class name. + This can either be an assignment to a variable or a return statement. + """ + if isinstance(tree, ast.AST): + _tree = list(ast.iter_child_nodes(tree)) + else: + _tree = list(tree) + + for node in _tree: + if isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and isinstance(node.value, ast.Call) + and target.id == class_name + ): + return node.value + elif ( + isinstance(node, ast.Return) + and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == class_name + ): + return node.value + return None + + +def find_class_with_decorator(tree: ast.Module, decorator_name: str) -> list[ast.ClassDef]: + """Find a class definition that is marked by a decorator in an AST.""" + nodes = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.ClassDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == decorator_name: + nodes.append(node) + return nodes + + +def find_decorated_method_in_class(classdef: ast.ClassDef, decorator_name: str) -> list[ast.FunctionDef]: + """Find all method definitions in a class definition which are decorated with a specific decorator.""" + nodes = [] + for node in ast.iter_child_nodes(classdef): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == decorator_name: + nodes.append(node) + return nodes + + +def create_attribute(base_name: str, attr_name: str) -> ast.Attribute: + """Create an AST node for an attribute""" + return ast.Attribute(value=ast.Name(id=base_name, ctx=ast.Load()), attr=attr_name, ctx=ast.Load()) diff --git a/agentstack/generation/gen_utils.py b/agentstack/generation/gen_utils.py index d4a0fbab..551e12cc 100644 --- a/agentstack/generation/gen_utils.py +++ b/agentstack/generation/gen_utils.py @@ -1,9 +1,4 @@ import ast -import sys -from enum import Enum -from typing import Optional, Union, List - -from agentstack.utils import term_color def insert_code_after_tag(file_path, tag, code_to_insert, next_line=False): @@ -73,71 +68,3 @@ def string_in_file(file_path: str, str_to_match: str) -> bool: with open(file_path, 'r') as file: file_content = file.read() return str_to_match in file_content - - -def _framework_filename(framework: str, path: str = ''): - if framework == 'crewai': - return f'{path}src/crew.py' - - print(term_color(f'Unknown framework: {framework}', 'red')) - sys.exit(1) - - -class CrewComponent(str, Enum): - AGENT = "agent" - TASK = "task" - - -def get_crew_components( - framework: str = 'crewai', - component_type: Optional[Union[CrewComponent, List[CrewComponent]]] = None, - path: str = '', -) -> dict[str, List[str]]: - """ - Get names of components (agents and/or tasks) defined in a crew file. - - Args: - framework: Name of the framework - component_type: Optional filter for specific component types. - Can be CrewComponentType.AGENT, CrewComponentType.TASK, - or a list of types. If None, returns all components. - path: Optional path to the framework file - - Returns: - Dictionary with 'agents' and 'tasks' keys containing lists of names - """ - filename = _framework_filename(framework, path) - - # Convert single component type to list for consistent handling - if isinstance(component_type, CrewComponent): - component_type = [component_type] - - # Read the source file - with open(filename, 'r') as f: - source = f.read() - - # Parse the source into an AST - tree = ast.parse(source) - - components = {'agents': [], 'tasks': []} - - # Find all function definitions with relevant decorators - for node in ast.walk(tree): - if isinstance(node, ast.FunctionDef): - # Check decorators - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name): - if ( - component_type is None or CrewComponent.AGENT in component_type - ) and decorator.id == 'agent': - components['agents'].append(node.name) - elif ( - component_type is None or CrewComponent.TASK in component_type - ) and decorator.id == 'task': - components['tasks'].append(node.name) - - # If specific types were requested, only return those - if component_type: - return {k: v for k, v in components.items() if CrewComponent(k[:-1]) in component_type} - - return components diff --git a/agentstack/generation/task_generation.py b/agentstack/generation/task_generation.py index 99df034e..a6e1d662 100644 --- a/agentstack/generation/task_generation.py +++ b/agentstack/generation/task_generation.py @@ -1,94 +1,36 @@ -from typing import Optional, List - -from .gen_utils import insert_after_tasks, get_crew_components, CrewComponent -from ..utils import verify_agentstack_project, get_framework -import os -from ruamel.yaml import YAML -from ruamel.yaml.scalarstring import FoldedScalarString - - -def generate_task( - name, - description: Optional[str], - expected_output: Optional[str], - agent: Optional[str], -): - if not description: - description = 'Add your description here' - if not expected_output: - expected_output = 'Add your expected_output here' - if not agent: - agent = 'default_agent' - - verify_agentstack_project() - - framework = get_framework() - - if framework == 'crewai': - generate_crew_task(name, description, expected_output, agent) - print(" > Added to src/config/tasks.yaml") - else: - print(f"This function is not yet implemented for {framework}") - return - - print(f"Added task \"{name}\" to your AgentStack project successfully!") - - -def generate_crew_task( - name, - description: Optional[str], - expected_output: Optional[str], - agent: Optional[str], +import sys +from typing import Optional +from pathlib import Path +from agentstack import ValidationError +from agentstack import frameworks +from agentstack.utils import verify_agentstack_project +from agentstack.tasks import TaskConfig, TASKS_FILENAME +from agentstack.generation.files import ConfigFile + + +def add_task( + task_name: str, + description: Optional[str] = None, + expected_output: Optional[str] = None, + agent: Optional[str] = None, + path: Optional[Path] = None, ): - config_path = os.path.join('src', 'config', 'tasks.yaml') - - # Ensure the directory exists - os.makedirs(os.path.dirname(config_path), exist_ok=True) - - yaml = YAML() - yaml.preserve_quotes = True # Preserve quotes in existing data - - # Read existing data - if os.path.exists(config_path): - with open(config_path, 'r') as file: - try: - data = yaml.load(file) or {} - except Exception as exc: - print(f"Error parsing YAML file: {exc}") - data = {} - else: - data = {} - - # Handle None values - description_str = FoldedScalarString(description) if description else FoldedScalarString('') - expected_output_str = FoldedScalarString(expected_output) if expected_output else FoldedScalarString('') - agent_str = FoldedScalarString(agent) if agent else FoldedScalarString('') - - # Add new agent details - data[name] = { - 'description': description_str, - 'expected_output': expected_output_str, - 'agent': agent_str, - } - - # Write back to the file without altering existing content - with open(config_path, 'w') as file: - yaml.dump(data, file) - - # Add task to crew.py - file_path = 'src/crew.py' - code_to_insert = [ - "@task", - f"def {name}(self) -> Task:", - " return Task(", - f" config=self.tasks_config['{name}'],", - " )", - "", - ] - - insert_after_tasks(file_path, code_to_insert) - - -def get_task_names(framework: str, path: str = '') -> List[str]: - """Get only task names from the crew file""" - return get_crew_components(framework, CrewComponent.TASK, path)['tasks'] + if path is None: + path = Path() + verify_agentstack_project(path) + agentstack_config = ConfigFile(path) + framework = agentstack_config.framework + + task = TaskConfig(task_name, path) + with task as config: + config.description = description or "Add your description here" + config.expected_output = expected_output or "Add your expected_output here" + config.agent = agent or "agent_name" + + try: + frameworks.add_task(framework, task, path) + print(f" > Added to {TASKS_FILENAME}") + except ValidationError as e: + print(f"Error adding task to project:\n{e}") + sys.exit(1) + print(f"Added task \"{task_name}\" to your AgentStack project successfully!") diff --git a/agentstack/generation/tool_generation.py b/agentstack/generation/tool_generation.py index 38989181..4ecb2b21 100644 --- a/agentstack/generation/tool_generation.py +++ b/agentstack/generation/tool_generation.py @@ -1,411 +1,172 @@ import os import sys -from typing import Optional, List +from typing import Optional from pathlib import Path -from typing import Union - -from . import get_agent_names -from .gen_utils import insert_code_after_tag, _framework_filename -from ..utils import open_json_file, get_framework, term_color import shutil -import fileinput -import astor import ast -from pydantic import BaseModel, ValidationError +from agentstack import frameworks from agentstack import packaging -from agentstack.utils import get_package_path +from agentstack import ValidationError +from agentstack.utils import term_color +from agentstack.tools import ToolConfig +from agentstack.generation import asttools from agentstack.generation.files import ConfigFile, EnvFile -TOOL_INIT_FILENAME = "src/tools/__init__.py" -FRAMEWORK_FILENAMES: dict[str, str] = { - 'crewai': 'src/crew.py', -} - - -def get_framework_filename(framework: str, path: str = ''): - if path: - path = path.endswith('/') and path or path + '/' - else: - path = './' - try: - return f"{path}{FRAMEWORK_FILENAMES[framework]}" - except KeyError: - print(term_color(f'Unknown framework: {framework}', 'red')) - sys.exit(1) - - -class ToolConfig(BaseModel): - name: str - category: str - tools: list[str] - url: Optional[str] = None - tools_bundled: bool = False - cta: Optional[str] = None - env: Optional[dict] = None - packages: Optional[List[str]] = None - post_install: Optional[str] = None - post_remove: Optional[str] = None - - @classmethod - def from_tool_name(cls, name: str) -> 'ToolConfig': - path = get_package_path() / f'tools/{name}.json' - if not os.path.exists(path): - print(term_color(f'No known agentstack tool: {name}', 'red')) - sys.exit(1) - return cls.from_json(path) - - @classmethod - def from_json(cls, path: Path) -> 'ToolConfig': - data = open_json_file(path) - try: - return cls(**data) - except ValidationError as e: - print(term_color(f"Error validating tool config JSON: \n{path}", 'red')) - for error in e.errors(): - print(f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}") - sys.exit(1) +# This is the filename of the location of tool imports in the user's project. +TOOLS_INIT_FILENAME: Path = Path("src/tools/__init__.py") - def get_import_statement(self) -> str: - return f"from .{self.name}_tool import {', '.join(self.tools)}" - def get_impl_file_path(self, framework: str) -> Path: - return get_package_path() / f'templates/{framework}/tools/{self.name}_tool.py' - - -def get_all_tool_paths() -> list[Path]: - paths = [] - tools_dir = get_package_path() / 'tools' - for file in tools_dir.iterdir(): - if file.is_file() and file.suffix == '.json': - paths.append(file) - return paths - - -def get_all_tool_names() -> list[str]: - return [path.stem for path in get_all_tool_paths()] +class ToolsInitFile(asttools.File): + """ + Modifiable AST representation of the tools init file. + Use it as a context manager to make and save edits: + ```python + with ToolsInitFile(filename) as tools_init: + tools_init.add_import_for_tool(...) + ``` + """ -def get_all_tools() -> list[ToolConfig]: - return [ToolConfig.from_json(path) for path in get_all_tool_paths()] + def get_import_for_tool(self, tool: ToolConfig) -> Optional[ast.ImportFrom]: + """ + Get the import statement for a tool. + raises a ValidationError if the tool is imported multiple times. + """ + all_imports = asttools.get_all_imports(self.tree) + tool_imports = [i for i in all_imports if tool.module_name == i.module] + if len(tool_imports) > 1: + raise ValidationError(f"Multiple imports for tool {tool.name} found in {self.filename}") -def add_tool(tool_name: str, path: Optional[str] = None, agents: Optional[List[str]] = []): - if path: - path = path.endswith('/') and path or path + '/' - else: - path = './' + try: + return tool_imports[0] + except IndexError: + return None + + def add_import_for_tool(self, framework: str, tool: ToolConfig): + """ + Add an import for a tool. + raises a ValidationError if the tool is already imported. + """ + tool_import = self.get_import_for_tool(tool) + if tool_import: + raise ValidationError(f"Tool {tool.name} already imported in {self.filename}") - framework = get_framework(path) + try: + last_import = asttools.get_all_imports(self.tree)[-1] + start, end = self.get_node_range(last_import) + except IndexError: + start, end = 0, 0 # No imports in the file + + import_statement = tool.get_import_statement(framework) + self.edit_node_range(end, end, f"\n{import_statement}") + + def remove_import_for_tool(self, framework: str, tool: ToolConfig): + """ + Remove an import for a tool. + raises a ValidationError if the tool is not imported. + """ + tool_import = self.get_import_for_tool(tool) + if not tool_import: + raise ValidationError(f"Tool {tool.name} not imported in {self.filename}") + + start, end = self.get_node_range(tool_import) + self.edit_node_range(start, end, "") + + +def add_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None): + if path is None: + path = Path() agentstack_config = ConfigFile(path) + framework = agentstack_config.framework if tool_name in agentstack_config.tools: print(term_color(f'Tool {tool_name} is already installed', 'red')) sys.exit(1) - tool_data = ToolConfig.from_tool_name(tool_name) - tool_file_path = tool_data.get_impl_file_path(framework) + tool = ToolConfig.from_tool_name(tool_name) + tool_file_path = tool.get_impl_file_path(framework) + + if tool.packages: + packaging.install(' '.join(tool.packages)) + + # Move tool from package to project + shutil.copy(tool_file_path, path / f'src/tools/{tool.module_name}.py') + + try: # Edit the user's project tool init file to include the tool + with ToolsInitFile(path / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(framework, tool) + except ValidationError as e: + print(term_color(f"Error adding tool:\n{e}", 'red')) - if tool_data.packages: - packaging.install(' '.join(tool_data.packages)) - shutil.copy(tool_file_path, f'{path}src/tools/{tool_name}_tool.py') # Move tool from package to project - add_tool_to_tools_init(tool_data, path) # Export tool from tools dir - add_tool_to_agent_definition( - framework=framework, tool_data=tool_data, path=path, agents=agents - ) # Add tool to agent definition + # Edit the framework entrypoint file to include the tool in the agent definition + if not agents: # If no agents are specified, add the tool to all agents + agents = frameworks.get_agent_names(framework, path) + for agent_name in agents: + frameworks.add_tool(framework, tool, agent_name, path) - if tool_data.env: # add environment variables which don't exist + if tool.env: # add environment variables which don't exist with EnvFile(path) as env: - for var, value in tool_data.env.items(): + for var, value in tool.env.items(): env.append_if_new(var, value) with EnvFile(path, filename=".env.example") as env: - for var, value in tool_data.env.items(): + for var, value in tool.env.items(): env.append_if_new(var, value) - if tool_data.post_install: - os.system(tool_data.post_install) + if tool.post_install: + os.system(tool.post_install) with agentstack_config as config: - config.tools.append(tool_name) + config.tools.append(tool.name) - print(term_color(f'🔨 Tool {tool_name} added to agentstack project successfully', 'green')) - if tool_data.cta: - print(term_color(f'🪩 {tool_data.cta}', 'blue')) + print(term_color(f'🔨 Tool {tool.name} added to agentstack project successfully', 'green')) + if tool.cta: + print(term_color(f'🪩 {tool.cta}', 'blue')) -def remove_tool(tool_name: str, path: Optional[str] = None): - if path: - path = path.endswith('/') and path or path + '/' - else: - path = './' - - framework = get_framework() +def remove_tool(tool_name: str, agents: Optional[list[str]] = [], path: Optional[Path] = None): + if path is None: + path = Path() agentstack_config = ConfigFile(path) + framework = agentstack_config.framework if tool_name not in agentstack_config.tools: print(term_color(f'Tool {tool_name} is not installed', 'red')) sys.exit(1) - tool_data = ToolConfig.from_tool_name(tool_name) - if tool_data.packages: - packaging.remove(' '.join(tool_data.packages)) + tool = ToolConfig.from_tool_name(tool_name) + if tool.packages: + packaging.remove(' '.join(tool.packages)) + + # TODO ensure that other agents in the project are not using the tool. try: - os.remove(f'{path}src/tools/{tool_name}_tool.py') + os.remove(path / f'src/tools/{tool.module_name}.py') except FileNotFoundError: - print(f'"src/tools/{tool_name}_tool.py" not found') - remove_tool_from_tools_init(tool_data, path) - remove_tool_from_agent_definition(framework, tool_data, path) - if tool_data.post_remove: - os.system(tool_data.post_remove) + print(f'"src/tools/{tool.module_name}.py" not found') + + try: # Edit the user's project tool init file to exclude the tool + with ToolsInitFile(path / TOOLS_INIT_FILENAME) as tools_init: + tools_init.remove_import_for_tool(framework, tool) + except ValidationError as e: + print(term_color(f"Error removing tool:\n{e}", 'red')) + + # Edit the framework entrypoint file to exclude the tool in the agent definition + if not agents: # If no agents are specified, remove the tool from all agents + agents = frameworks.get_agent_names(framework, path) + for agent_name in agents: + frameworks.remove_tool(framework, tool, agent_name, path) + + if tool.post_remove: + os.system(tool.post_remove) # We don't remove the .env variables to preserve user data. with agentstack_config as config: - config.tools.remove(tool_name) + config.tools.remove(tool.name) print( term_color(f'🔨 Tool {tool_name}', 'green'), term_color('removed', 'red'), term_color('from agentstack project successfully', 'green'), ) - - -def add_tool_to_tools_init(tool_data: ToolConfig, path: str = ''): - file_path = f'{path}{TOOL_INIT_FILENAME}' - tag = '# tool import' - code_to_insert = [ - tool_data.get_import_statement(), - ] - insert_code_after_tag(file_path, tag, code_to_insert, next_line=True) - - -def remove_tool_from_tools_init(tool_data: ToolConfig, path: str = ''): - """Search for the import statement in the init and remove it.""" - file_path = f'{path}{TOOL_INIT_FILENAME}' - import_statement = tool_data.get_import_statement() - with fileinput.input(files=file_path, inplace=True) as f: - for line in f: - if line.strip() != import_statement: - print(line, end='') - - -def add_tool_to_agent_definition( - framework: str, - tool_data: ToolConfig, - path: str = '', - agents: Optional[list[str]] = [], -): - """ - Add tools to specific agent definitions using AST transformation. - - Args: - framework: Name of the framework - tool_data: ToolConfig - agents: Optional list of agent names to modify. If None, modifies all agents. - path: Optional path to the framework file - """ - modify_agent_tools( - framework=framework, - tool_data=tool_data, - operation='add', - agents=agents, - path=path, - base_name='tools', - ) - - -def remove_tool_from_agent_definition(framework: str, tool_data: ToolConfig, path: str = ''): - modify_agent_tools( - framework=framework, - tool_data=tool_data, - operation='remove', - agents=None, - path=path, - base_name='tools', - ) - - -def _create_tool_attribute(tool_name: str, base_name: str = 'tools') -> ast.Attribute: - """Create an AST node for a tool attribute""" - return ast.Attribute(value=ast.Name(id=base_name, ctx=ast.Load()), attr=tool_name, ctx=ast.Load()) - - -def _create_starred_tool(tool_name: str, base_name: str = 'tools') -> ast.Starred: - """Create an AST node for a starred tool expression""" - return ast.Starred( - value=ast.Attribute(value=ast.Name(id=base_name, ctx=ast.Load()), attr=tool_name, ctx=ast.Load()), - ctx=ast.Load(), - ) - - -def _create_tool_attributes(tool_names: List[str], base_name: str = 'tools') -> List[ast.Attribute]: - """Create AST nodes for multiple tool attributes""" - return [_create_tool_attribute(name, base_name) for name in tool_names] - - -def _create_tool_nodes( - tool_names: List[str], is_bundled: bool = False, base_name: str = 'tools' -) -> List[Union[ast.Attribute, ast.Starred]]: - """Create AST nodes for multiple tool attributes""" - return [ - _create_starred_tool(name, base_name) if is_bundled else _create_tool_attribute(name, base_name) - for name in tool_names - ] - - -def _is_tool_node_match(node: ast.AST, tool_name: str, base_name: str = 'tools') -> bool: - """ - Check if an AST node matches a tool reference, regardless of whether it's starred - - Args: - node: AST node to check (can be Attribute or Starred) - tool_name: Name of the tool to match - base_name: Base module name (default: 'tools') - - Returns: - bool: True if the node matches the tool reference - """ - # If it's a Starred node, check its value - if isinstance(node, ast.Starred): - node = node.value - - # Extract the attribute name and base regardless of node type - if isinstance(node, ast.Attribute): - is_base_match = isinstance(node.value, ast.Name) and node.value.id == base_name - is_name_match = node.attr == tool_name - return is_base_match and is_name_match - - return False - - -def _process_tools_list( - current_tools: List[ast.AST], - tool_data: ToolConfig, - operation: str, - base_name: str = 'tools', -) -> List[ast.AST]: # type: ignore[return-type,arg-type] - """ - Process a tools list according to the specified operation. - - Args: - current_tools: Current list of tool nodes - tool_data: Tool configuration - operation: Operation to perform ('add' or 'remove') - base_name: Base module name for tools - """ - if operation == 'add': - new_tools = current_tools.copy() - # Add new tools with bundling if specified - new_tools.extend(_create_tool_nodes(tool_data.tools, tool_data.tools_bundled, base_name)) - return new_tools - - elif operation == 'remove': - # Filter out tools that match any in the removal list - return [ - tool - for tool in current_tools - if not any(_is_tool_node_match(tool, name, base_name) for name in tool_data.tools) - ] - - raise ValueError(f"Unsupported operation: {operation}") - - -def _modify_agent_tools( - node: ast.FunctionDef, - tool_data: ToolConfig, - operation: str, - agents: Optional[List[str]] = None, - base_name: str = 'tools', -) -> ast.FunctionDef: - """ - Modify the tools list in an agent definition. - - Args: - node: AST node of the function to modify - tool_data: Tool configuration - operation: Operation to perform ('add' or 'remove') - agents: Optional list of agent names to modify - base_name: Base module name for tools - """ - # Skip if not in specified agents list - if agents is not None and agents != []: - if node.name not in agents: - return node - - # Check if this is an agent-decorated function - if not any(isinstance(d, ast.Name) and d.id == 'agent' for d in node.decorator_list): - return node - - # Find the Return statement and modify tools - for item in node.body: - if isinstance(item, ast.Return): - agent_call = item.value - if isinstance(agent_call, ast.Call): - for kw in agent_call.keywords: - if kw.arg == 'tools': - if isinstance(kw.value, ast.List): - # Process the tools list - new_tools = _process_tools_list(kw.value.elts, tool_data, operation, base_name) # type: ignore - # Replace with new list - kw.value = ast.List(elts=new_tools, ctx=ast.Load()) # type: ignore - - return node - - -def modify_agent_tools( - framework: str, - tool_data: ToolConfig, - operation: str, - agents: Optional[List[str]] = None, - path: str = '', - base_name: str = 'tools', -) -> None: - """ - Modify tools in agent definitions using AST transformation. - - Args: - framework: Name of the framework - tool_data: ToolConfig - operation: Operation to perform ('add' or 'remove') - agents: Optional list of agent names to modify - path: Optional path to the framework file - base_name: Base module name for tools (default: 'tools') - """ - if agents is not None: - valid_agents = get_agent_names(path=path) - for agent in agents: - if agent not in valid_agents: - print(term_color(f"Agent '{agent}' not found in the project.", 'red')) - sys.exit(1) - - filename = _framework_filename(framework, path) - - with open(filename, 'r', encoding='utf-8') as f: - source_lines = f.readlines() - - # Create a map of line numbers to comments - comments = {} - for i, line in enumerate(source_lines): - stripped = line.strip() - if stripped.startswith('#'): - comments[i + 1] = line - - tree = ast.parse(''.join(source_lines)) - - class ModifierTransformer(ast.NodeTransformer): - def visit_FunctionDef(self, node): - return _modify_agent_tools(node, tool_data, operation, agents, base_name) - - modified_tree = ModifierTransformer().visit(tree) - modified_source = astor.to_source(modified_tree) - modified_lines = modified_source.splitlines() - - # Reinsert comments - final_lines = [] - for i, line in enumerate(modified_lines, 1): - if i in comments: - final_lines.append(comments[i]) - final_lines.append(line + '\n') - - with open(filename, 'w', encoding='utf-8') as f: - f.write(''.join(final_lines)) diff --git a/agentstack/main.py b/agentstack/main.py index 9cc1333f..f117c1d1 100644 --- a/agentstack/main.py +++ b/agentstack/main.py @@ -2,10 +2,10 @@ import os import sys -from agentstack.cli import init_project_builder, list_tools, configure_default_model +from agentstack.cli import init_project_builder, list_tools, configure_default_model, run_project from agentstack.telemetry import track_cli_command from agentstack.utils import get_version, get_framework -import agentstack.generation as generation +from agentstack import generation from agentstack.update import check_for_updates import webbrowser @@ -107,15 +107,14 @@ def main(): init_project_builder(args.slug_name, args.template, args.wizard) elif args.command in ["run", "r"]: framework = get_framework() - if framework == "crewai": - os.system("python src/main.py") - elif args.command in ["generate", "g"]: - if args.generate_command in ["agent", "a"]: + run_project(framework) + elif args.command in ['generate', 'g']: + if args.generate_command in ['agent', 'a']: if not args.llm: configure_default_model() - generation.generate_agent(args.name, args.role, args.goal, args.backstory, args.llm) - elif args.generate_command in ["task", "t"]: - generation.generate_task(args.name, args.description, args.expected_output, args.agent) + generation.add_agent(args.name, args.role, args.goal, args.backstory, args.llm) + elif args.generate_command in ['task', 't']: + generation.add_task(args.name, args.description, args.expected_output, args.agent) else: generate_parser.print_help() elif args.command in ["tools", "t"]: diff --git a/agentstack/tasks.py b/agentstack/tasks.py new file mode 100644 index 00000000..bad1d52e --- /dev/null +++ b/agentstack/tasks.py @@ -0,0 +1,94 @@ +from typing import Optional +import os +from pathlib import Path +import pydantic +from ruamel.yaml import YAML, YAMLError +from ruamel.yaml.scalarstring import FoldedScalarString +from agentstack import ValidationError + + +TASKS_FILENAME: Path = Path("src/config/tasks.yaml") + +yaml = YAML() +yaml.preserve_quotes = True # Preserve quotes in existing data + + +class TaskConfig(pydantic.BaseModel): + """ + Interface for interacting with a task configuration. + + Multiple tasks are stored in a single YAML file, so we always look up the + requested task by `name`. + + Use it as a context manager to make and save edits: + ```python + with TaskConfig('task_name') as config: + config.description = "foo" + + Config Schema + ------------- + name: str + The name of the agent; used for lookup. + description: Optional[str] + The description of the task. + expected_output: Optional[str] + The expected output of the task. + agent: Optional[str] + The agent to use for the task. + """ + + name: str + description: Optional[str] = "" + expected_output: Optional[str] = "" + agent: Optional[str] = "" + + def __init__(self, name: str, path: Optional[Path] = None): + if not path: + path = Path() + + filename = path / TASKS_FILENAME + if not os.path.exists(filename): + os.makedirs(filename.parent, exist_ok=True) + filename.touch() + + try: + with open(filename, 'r') as f: + data = yaml.load(f) or {} + data = data.get(name, {}) or {} + super().__init__(**{**{'name': name}, **data}) + except YAMLError as e: + # TODO format MarkedYAMLError lines/messages + raise ValidationError(f"Error parsing tasks file: {filename}\n{e}") + except pydantic.ValidationError as e: + error_str = "Error validating tasks config:\n" + for error in e.errors(): + error_str += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n" + raise ValidationError(f"Error loading task {name} from {filename}.\n{error_str}") + + # store the path *after* loading data + self._path = path + + def model_dump(self, *args, **kwargs) -> dict: + dump = super().model_dump(*args, **kwargs) + dump.pop('name') # name is the key, so keep it out of the data + # format these as FoldedScalarStrings + for key in ('description', 'expected_output', 'agent'): + dump[key] = FoldedScalarString(dump.get(key) or "") + return {self.name: dump} + + def write(self): + filename = self._path / TASKS_FILENAME + + with open(filename, 'r') as f: + data = yaml.load(f) or {} + + data.update(self.model_dump()) + + with open(filename, 'w') as f: + yaml.dump(data, f) + + def __enter__(self) -> 'TaskConfig': + return self + + def __exit__(self, *args): + self.write() diff --git a/agentstack/tools.py b/agentstack/tools.py new file mode 100644 index 00000000..1acb8d97 --- /dev/null +++ b/agentstack/tools.py @@ -0,0 +1,71 @@ +from typing import Optional +import os +import sys +from pathlib import Path +import pydantic +from agentstack.utils import get_package_path, open_json_file, term_color + + +class ToolConfig(pydantic.BaseModel): + """ + This represents the configuration data for a tool. + It parses and validates the `config.json` file for a tool. + """ + + name: str + category: str + tools: list[str] + url: Optional[str] = None + tools_bundled: bool = False + cta: Optional[str] = None + env: Optional[dict] = None + packages: Optional[list[str]] = None + post_install: Optional[str] = None + post_remove: Optional[str] = None + + @classmethod + def from_tool_name(cls, name: str) -> 'ToolConfig': + path = get_package_path() / f'tools/{name}.json' + if not os.path.exists(path): # TODO raise exceptions and handle message/exit in cli + print(term_color(f'No known agentstack tool: {name}', 'red')) + sys.exit(1) + return cls.from_json(path) + + @classmethod + def from_json(cls, path: Path) -> 'ToolConfig': + data = open_json_file(path) + try: + return cls(**data) + except pydantic.ValidationError as e: + # TODO raise exceptions and handle message/exit in cli + print(term_color(f"Error validating tool config JSON: \n{path}", 'red')) + for error in e.errors(): + print(f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}") + sys.exit(1) + + @property + def module_name(self) -> str: + return f"{self.name}_tool" + + def get_import_statement(self, framework: str) -> str: + return f"from .{self.module_name} import {', '.join(self.tools)}" + + def get_impl_file_path(self, framework: str) -> Path: + return get_package_path() / f'templates/{framework}/tools/{self.module_name}.py' + + +def get_all_tool_paths() -> list[Path]: + paths = [] + tools_dir = get_package_path() / 'tools' + for file in tools_dir.iterdir(): + if file.is_file() and file.suffix == '.json': + paths.append(file) + return paths + + +def get_all_tool_names() -> list[str]: + return [path.stem for path in get_all_tool_paths()] + + +def get_all_tools() -> list[ToolConfig]: + return [ToolConfig.from_json(path) for path in get_all_tool_paths()] diff --git a/agentstack/update.py b/agentstack/update.py index 041d84ec..145623fe 100644 --- a/agentstack/update.py +++ b/agentstack/update.py @@ -47,8 +47,7 @@ def get_latest_version(package: str) -> Version: import requests # defer import until we know we need it response = requests.get( - f"{ENDPOINT_URL}/{package}/", - headers={"Accept": "application/vnd.pypi.simple.v1+json"}, + f"{ENDPOINT_URL}/{package}/", headers={"Accept": "application/vnd.pypi.simple.v1+json"} ) if response.status_code != 200: raise Exception(f"Failed to fetch package data from pypi.") @@ -126,17 +125,13 @@ def check_for_updates(update_requested: bool = False): packaging.upgrade(f'{AGENTSTACK_PACKAGE}[{get_framework()}]') print( term_color( - f"{AGENTSTACK_PACKAGE} updated. Re-run your command to use the latest version.", - 'green', + f"{AGENTSTACK_PACKAGE} updated. Re-run your command to use the latest version.", 'green' ) ) sys.exit(0) else: print( - term_color( - "Skipping update. Run `agentstack update` to install the latest version.", - 'blue', - ) + term_color("Skipping update. Run `agentstack update` to install the latest version.", 'blue') ) else: print(f"{AGENTSTACK_PACKAGE} is up to date ({installed_version})") diff --git a/agentstack/utils.py b/agentstack/utils.py index b07cc4db..de008489 100644 --- a/agentstack/utils.py +++ b/agentstack/utils.py @@ -1,6 +1,7 @@ from typing import Optional import sys import json +from ruamel.yaml import YAML import re from importlib.metadata import version from pathlib import Path @@ -15,7 +16,7 @@ def get_version(package: str = 'agentstack'): return "Unknown version" -def verify_agentstack_project(path: Optional[str] = None): +def verify_agentstack_project(path: Optional[Path] = None): from agentstack.generation import ConfigFile try: @@ -78,6 +79,15 @@ def open_json_file(path) -> dict: return data +def open_yaml_file(path) -> dict: + yaml = YAML() + yaml.preserve_quotes = True # Preserve quotes in existing data + + with open(path, 'r') as f: + data = yaml.load(f) + return data + + def clean_input(input_string): special_char_pattern = re.compile(r'[^a-zA-Z0-9\s_]') return re.sub(special_char_pattern, '', input_string).lower().replace(' ', '_').replace('-', '_') diff --git a/foo.yaml b/foo.yaml new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index 9a26d4c5..1767a397 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "cookiecutter==2.6.0", "psutil==5.9.8", "astor==0.8.1", + "asttokens", "pydantic>=2.10", "packaging==23.2", "requests>=2.32", diff --git a/tests/fixtures/agents_max.yaml b/tests/fixtures/agents_max.yaml new file mode 100644 index 00000000..d532bbf4 --- /dev/null +++ b/tests/fixtures/agents_max.yaml @@ -0,0 +1,16 @@ +agent_name: + role: >- + role + goal: >- + this is a goal + backstory: >- + backstory + llm: provider/model +second_agent_name: + role: >- + role + goal: >- + this is a goal + backstory: >- + this is a backstory + llm: provider/model \ No newline at end of file diff --git a/tests/fixtures/agents_min.yaml b/tests/fixtures/agents_min.yaml new file mode 100644 index 00000000..1cbea78a --- /dev/null +++ b/tests/fixtures/agents_min.yaml @@ -0,0 +1,2 @@ +agent_name: + \ No newline at end of file diff --git a/tests/fixtures/agentstack.json b/tests/fixtures/agentstack.json index 4ca18a10..f39237b1 100644 --- a/tests/fixtures/agentstack.json +++ b/tests/fixtures/agentstack.json @@ -1,4 +1,4 @@ { "framework": "crewai", - "tools": ["tool1", "tool2"] + "tools": [] } \ No newline at end of file diff --git a/tests/fixtures/frameworks/crewai/entrypoint_max.py b/tests/fixtures/frameworks/crewai/entrypoint_max.py new file mode 100644 index 00000000..6ccb9af2 --- /dev/null +++ b/tests/fixtures/frameworks/crewai/entrypoint_max.py @@ -0,0 +1,25 @@ +from crewai import Agent, Crew, Process, Task +from crewai.project import CrewBase, agent, crew, task +import tools + + +@CrewBase +class TestCrew: + @agent + def test_agent(self) -> Agent: + return Agent(config=self.agents_config['test_agent'], tools=[], verbose=True) + + @task + def test_task(self) -> Task: + return Task( + config=self.tasks_config['test_task'], + ) + + @crew + def crew(self) -> Crew: + return Crew( + agents=self.agents, + tasks=self.tasks, + process=Process.sequential, + verbose=True, + ) diff --git a/tests/fixtures/frameworks/crewai/entrypoint_min.py b/tests/fixtures/frameworks/crewai/entrypoint_min.py new file mode 100644 index 00000000..f423807f --- /dev/null +++ b/tests/fixtures/frameworks/crewai/entrypoint_min.py @@ -0,0 +1,15 @@ +from crewai import Agent, Crew, Process, Task +from crewai.project import CrewBase, agent, crew, task +import tools + + +@CrewBase +class TestCrew: + @crew + def crew(self) -> Crew: + return Crew( + agents=self.agents, + tasks=self.tasks, + process=Process.sequential, + verbose=True, + ) diff --git a/tests/fixtures/tasks_max.yaml b/tests/fixtures/tasks_max.yaml new file mode 100644 index 00000000..355ebd98 --- /dev/null +++ b/tests/fixtures/tasks_max.yaml @@ -0,0 +1,14 @@ +task_name: + description: >- + Add your description here + expected_output: >- + Add your expected output here + agent: >- + default_agent +task_name_two: + description: >- + Add your description here + expected_output: >- + Add your expected output here + agent: >- + default_agent diff --git a/tests/fixtures/tasks_min.yaml b/tests/fixtures/tasks_min.yaml new file mode 100644 index 00000000..1fd435c0 --- /dev/null +++ b/tests/fixtures/tasks_min.yaml @@ -0,0 +1 @@ +task_name: diff --git a/tests/test_agents_config.py b/tests/test_agents_config.py new file mode 100644 index 00000000..657d9318 --- /dev/null +++ b/tests/test_agents_config.py @@ -0,0 +1,83 @@ +import json +import os, sys +import shutil +import unittest +import importlib.resources +from pathlib import Path +from agentstack.agents import AgentConfig, AGENTS_FILENAME + +BASE_PATH = Path(__file__).parent + + +class AgentConfigTest(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH / 'tmp/agent_config' + os.makedirs(self.project_dir / 'src/config') + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_empty_file(self): + config = AgentConfig("agent_name", self.project_dir) + assert config.name == "agent_name" + assert config.role is "" + assert config.goal is "" + assert config.backstory is "" + assert config.llm is "" + + def test_read_minimal_yaml(self): + shutil.copy(BASE_PATH / "fixtures/agents_min.yaml", self.project_dir / AGENTS_FILENAME) + config = AgentConfig("agent_name", self.project_dir) + assert config.name == "agent_name" + assert config.role == "" + assert config.goal == "" + assert config.backstory == "" + assert config.llm == "" + + def test_read_maximal_yaml(self): + shutil.copy(BASE_PATH / "fixtures/agents_max.yaml", self.project_dir / AGENTS_FILENAME) + config = AgentConfig("agent_name", self.project_dir) + assert config.name == "agent_name" + assert config.role == "role" + assert config.goal == "this is a goal" + assert config.backstory == "backstory" + assert config.llm == "provider/model" + + def test_write_yaml(self): + with AgentConfig("agent_name", self.project_dir) as config: + config.role = "role" + config.goal = "this is a goal" + config.backstory = "backstory" + config.llm = "provider/model" + + yaml_src = open(self.project_dir / AGENTS_FILENAME).read() + assert ( + yaml_src + == """agent_name: + role: >- + role + goal: >- + this is a goal + backstory: >- + backstory + llm: provider/model +""" + ) + + def test_write_none_values(self): + with AgentConfig("agent_name", self.project_dir) as config: + config.role = None + config.goal = None + config.backstory = None + config.llm = None + + yaml_src = open(self.project_dir / AGENTS_FILENAME).read() + assert ( + yaml_src + == """agent_name: + role: > + goal: > + backstory: > + llm: +""" + ) diff --git a/tests/test_cli_loads.py b/tests/test_cli_loads.py index 16648640..7dfaf0bf 100644 --- a/tests/test_cli_loads.py +++ b/tests/test_cli_loads.py @@ -1,16 +1,19 @@ import subprocess -import sys +import os, sys import unittest from pathlib import Path import shutil +BASE_PATH = Path(__file__).parent + class TestAgentStackCLI(unittest.TestCase): + # Replace with your actual CLI entry point if different CLI_ENTRY = [ sys.executable, "-m", "agentstack.main", - ] # Replace with your actual CLI entry point if different + ] def run_cli(self, *args): """Helper method to run the CLI with arguments.""" @@ -31,12 +34,14 @@ def test_invalid_command(self): def test_init_command(self): """Test the 'init' command to create a project directory.""" - test_dir = Path("test_project") + test_dir = Path(BASE_PATH / 'tmp/test_project') # Ensure the directory doesn't exist from previous runs if test_dir.exists(): shutil.rmtree(test_dir) + os.makedirs(test_dir) + os.chdir(test_dir) result = self.run_cli("init", str(test_dir)) self.assertEqual(result.returncode, 0) self.assertTrue(test_dir.exists()) @@ -44,6 +49,24 @@ def test_init_command(self): # Clean up shutil.rmtree(test_dir) + def test_run_command_invalid_project(self): + """Test the 'run' command on an invalid project.""" + test_dir = Path(BASE_PATH / 'tmp/test_project') + if test_dir.exists(): + shutil.rmtree(test_dir) + os.makedirs(test_dir) + + # Write a basic agentstack.json file + with (test_dir / 'agentstack.json').open('w') as f: + f.write(open(BASE_PATH / 'fixtures/agentstack.json', 'r').read()) + + os.chdir(test_dir) + result = self.run_cli('run') + self.assertNotEqual(result.returncode, 0) + self.assertIn("Project validation failed", result.stdout) + + shutil.rmtree(test_dir) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py new file mode 100644 index 00000000..8b27b84a --- /dev/null +++ b/tests/test_frameworks.py @@ -0,0 +1,119 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class + +from agentstack import ValidationError +from agentstack import frameworks +from agentstack.tools import ToolConfig + +BASE_PATH = Path(__file__).parent + + +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) +class TestFrameworks(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH / 'tmp' / self.framework + + os.makedirs(self.project_dir) + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'tools') + + (self.project_dir / 'src' / '__init__.py').touch() + (self.project_dir / 'src' / 'tools' / '__init__.py').touch() + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def _populate_min_entrypoint(self): + """This entrypoint does not have any tools or agents.""" + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_min.py", entrypoint_path) + + def _populate_max_entrypoint(self): + """This entrypoint has tools and agents.""" + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + + def _get_test_tool(self) -> ToolConfig: + return ToolConfig(name='test_tool', category='test', tools=['test_tool']) + + def _get_test_tool_starred(self) -> ToolConfig: + return ToolConfig( + name='test_tool_star', category='test', tools=['test_tool_star'], tools_bundled=True + ) + + def test_get_framework_module(self): + module = frameworks.get_framework_module(self.framework) + assert module.__name__ == f"agentstack.frameworks.{self.framework}" + + def test_get_framework_module_invalid(self): + with self.assertRaises(Exception) as context: + frameworks.get_framework_module('invalid') + + def test_validate_project(self): + self._populate_max_entrypoint() + frameworks.validate_project(self.framework, self.project_dir) + + def test_validate_project_invalid(self): + self._populate_min_entrypoint() + with self.assertRaises(ValidationError) as context: + frameworks.validate_project(self.framework, self.project_dir) + + def test_add_tool(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + # TODO these asserts are not framework agnostic + assert 'tools=[tools.test_tool' in entrypoint_src + + def test_add_tool_starred(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[*tools.test_tool_star' in entrypoint_src + + def test_add_tool_invalid(self): + self._populate_min_entrypoint() + with self.assertRaises(ValidationError) as context: + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + + def test_remove_tool(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.remove_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[tools.test_tool' not in entrypoint_src + + def test_remove_tool_starred(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + frameworks.remove_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[*tools.test_tool_star' not in entrypoint_src + + def test_add_multiple_tools(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert ( # ordering is not guaranteed + 'tools=[tools.test_tool, *tools.test_tool_star' in entrypoint_src + or 'tools=[*tools.test_tool_star, tools.test_tool' in entrypoint_src + ) + + def test_remove_one_tool_of_multiple(self): + self._populate_max_entrypoint() + frameworks.add_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + frameworks.add_tool(self.framework, self._get_test_tool_starred(), 'test_agent', self.project_dir) + frameworks.remove_tool(self.framework, self._get_test_tool(), 'test_agent', self.project_dir) + + entrypoint_src = open(frameworks.get_entrypoint_path(self.framework, self.project_dir)).read() + assert 'tools=[tools.test_tool' not in entrypoint_src + assert 'tools=[*tools.test_tool_star' in entrypoint_src diff --git a/tests/test_generation_agent.py b/tests/test_generation_agent.py new file mode 100644 index 00000000..2f836e5e --- /dev/null +++ b/tests/test_generation_agent.py @@ -0,0 +1,64 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class +import ast + +from agentstack import frameworks, ValidationError +from agentstack.generation.files import ConfigFile +from agentstack.generation.agent_generation import add_agent + +BASE_PATH = Path(__file__).parent + + +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) +class TestGenerationAgent(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH / 'tmp' / 'agent_generation' + + os.makedirs(self.project_dir) + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'config') + (self.project_dir / 'src' / '__init__.py').touch() + + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + + # set the framework in agentstack.json + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') + with ConfigFile(self.project_dir) as config: + config.framework = self.framework + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_add_agent(self): + add_agent( + 'test_agent_two', + role='role', + goal='goal', + backstory='backstory', + llm='llm', + path=self.project_dir, + ) + + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_src = open(entrypoint_path).read() + # agents.yaml is covered in test_agents_config.py + # TODO framework-specific validation for code structure + assert 'def test_agent_two' in entrypoint_src + # verify that the file's syntax is valid with ast + ast.parse(entrypoint_src) + + def test_add_agent_exists(self): + with self.assertRaises(SystemExit) as context: + add_agent( + 'test_agent', + role='role', + goal='goal', + backstory='backstory', + llm='llm', + path=self.project_dir, + ) diff --git a/tests/test_generation_files.py b/tests/test_generation_files.py index e6ea52fe..c5fae48b 100644 --- a/tests/test_generation_files.py +++ b/tests/test_generation_files.py @@ -16,17 +16,14 @@ class GenerationFilesTest(unittest.TestCase): def test_read_config(self): config = ConfigFile(BASE_PATH / "fixtures") # + agentstack.json assert config.framework == "crewai" - assert config.tools == ["tool1", "tool2"] + assert config.tools == [] assert config.telemetry_opt_out is None assert config.default_model is None def test_write_config(self): try: os.makedirs(BASE_PATH / "tmp", exist_ok=True) - shutil.copy( - BASE_PATH / "fixtures/agentstack.json", - BASE_PATH / "tmp/agentstack.json", - ) + shutil.copy(BASE_PATH / "fixtures/agentstack.json", BASE_PATH / "tmp/agentstack.json") with ConfigFile(BASE_PATH / "tmp") as config: config.framework = "crewai" diff --git a/tests/test_generation_tasks.py b/tests/test_generation_tasks.py new file mode 100644 index 00000000..430a3695 --- /dev/null +++ b/tests/test_generation_tasks.py @@ -0,0 +1,62 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class +import ast + +from agentstack import frameworks, ValidationError +from agentstack.generation.files import ConfigFile +from agentstack.generation.task_generation import add_task + +BASE_PATH = Path(__file__).parent + + +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) +class TestGenerationAgent(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH / 'tmp' / 'agent_generation' + + os.makedirs(self.project_dir) + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'config') + (self.project_dir / 'src' / '__init__.py').touch() + + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + + # set the framework in agentstack.json + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') + with ConfigFile(self.project_dir) as config: + config.framework = self.framework + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_add_task(self): + add_task( + 'task_test_two', + description='description', + expected_output='expected_output', + agent='agent', + path=self.project_dir, + ) + + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_src = open(entrypoint_path).read() + # agents.yaml is covered in test_agents_config.py + # TODO framework-specific validation for code structure + assert 'def task_test_two' in entrypoint_src + # verify that the file's syntax is valid with ast + ast.parse(entrypoint_src) + + def test_add_agent_exists(self): + with self.assertRaises(SystemExit) as context: + add_task( + 'test_task', + description='description', + expected_output='expected_output', + agent='agent', + path=self.project_dir, + ) diff --git a/tests/test_generation_tool.py b/tests/test_generation_tool.py new file mode 100644 index 00000000..d2122689 --- /dev/null +++ b/tests/test_generation_tool.py @@ -0,0 +1,70 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class +import ast + +from agentstack import frameworks +from agentstack.tools import get_all_tools, ToolConfig +from agentstack.generation.files import ConfigFile +from agentstack.generation.tool_generation import add_tool, remove_tool, TOOLS_INIT_FILENAME + + +BASE_PATH = Path(__file__).parent + + +# TODO parameterize all tools +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) +class TestGenerationTool(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH / 'tmp' / 'tool_generation' + + os.makedirs(self.project_dir) + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'tools') + (self.project_dir / 'src' / '__init__.py').touch() + (self.project_dir / TOOLS_INIT_FILENAME).touch() + + # populate the entrypoint + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path) + + # set the framework in agentstack.json + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') + with ConfigFile(self.project_dir) as config: + config.framework = self.framework + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_add_tool(self): + tool_conf = ToolConfig.from_tool_name('agent_connect') + add_tool('agent_connect', path=self.project_dir) + + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_src = open(entrypoint_path).read() + ast.parse(entrypoint_src) + tools_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() + + # TODO verify tool is added to all agents (this is covered in test_frameworks.py) + # assert 'agent_connect' in entrypoint_src + assert f'from .{tool_conf.module_name} import' in tools_init_src + assert (self.project_dir / 'src' / 'tools' / f'{tool_conf.module_name}.py').exists() + assert 'agent_connect' in open(self.project_dir / 'agentstack.json').read() + + def test_remove_tool(self): + tool_conf = ToolConfig.from_tool_name('agent_connect') + add_tool('agent_connect', path=self.project_dir) + remove_tool('agent_connect', path=self.project_dir) + + entrypoint_path = frameworks.get_entrypoint_path(self.framework, self.project_dir) + entrypoint_src = open(entrypoint_path).read() + ast.parse(entrypoint_src) + tools_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() + + # TODO verify tool is removed from all agents (this is covered in test_frameworks.py) + # assert 'agent_connect' not in entrypoint_src + assert f'from .{tool_conf.module_name} import' not in tools_init_src + assert not (self.project_dir / 'src' / 'tools' / f'{tool_conf.module_name}.py').exists() + assert 'agent_connect' not in open(self.project_dir / 'agentstack.json').read() diff --git a/tests/test_tasks_config.py b/tests/test_tasks_config.py new file mode 100644 index 00000000..c95665bc --- /dev/null +++ b/tests/test_tasks_config.py @@ -0,0 +1,76 @@ +import json +import os, sys +import shutil +import unittest +import importlib.resources +from pathlib import Path +from agentstack.tasks import TaskConfig, TASKS_FILENAME + +BASE_PATH = Path(__file__).parent + + +class AgentConfigTest(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH / 'tmp/task_config' + os.makedirs(self.project_dir / 'src/config') + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def test_empty_file(self): + config = TaskConfig("task_name", self.project_dir) + assert config.name == "task_name" + assert config.description is "" + assert config.expected_output is "" + assert config.agent is "" + + def test_read_minimal_yaml(self): + shutil.copy(BASE_PATH / "fixtures/tasks_min.yaml", self.project_dir / TASKS_FILENAME) + config = TaskConfig("task_name", self.project_dir) + assert config.name == "task_name" + assert config.description is "" + assert config.expected_output is "" + assert config.agent is "" + + def test_read_maximal_yaml(self): + shutil.copy(BASE_PATH / "fixtures/tasks_max.yaml", self.project_dir / TASKS_FILENAME) + config = TaskConfig("task_name", self.project_dir) + assert config.name == "task_name" + assert config.description == "Add your description here" + assert config.expected_output == "Add your expected output here" + assert config.agent == "default_agent" + + def test_write_yaml(self): + with TaskConfig("task_name", self.project_dir) as config: + config.description = "Add your description here" + config.expected_output = "Add your expected output here" + config.agent = "default_agent" + + yaml_src = open(self.project_dir / TASKS_FILENAME).read() + assert ( + yaml_src + == """task_name: + description: >- + Add your description here + expected_output: >- + Add your expected output here + agent: >- + default_agent +""" + ) + + def test_write_none_values(self): + with TaskConfig("task_name", self.project_dir) as config: + config.description = None + config.expected_output = None + config.agent = None + + yaml_src = open(self.project_dir / TASKS_FILENAME).read() + assert ( + yaml_src + == """task_name: + description: > + expected_output: > + agent: > +""" + ) diff --git a/tests/test_tool_config.py b/tests/test_tool_config.py index 20f820d6..5a8aad31 100644 --- a/tests/test_tool_config.py +++ b/tests/test_tool_config.py @@ -1,11 +1,7 @@ import json import unittest from pathlib import Path -from agentstack.generation.tool_generation import ( - get_all_tool_paths, - get_all_tool_names, - ToolConfig, -) +from agentstack.tools import ToolConfig, get_all_tool_paths, get_all_tool_names BASE_PATH = Path(__file__).parent diff --git a/tests/test_tool_generation_init.py b/tests/test_tool_generation_init.py new file mode 100644 index 00000000..7bb79580 --- /dev/null +++ b/tests/test_tool_generation_init.py @@ -0,0 +1,80 @@ +import os, sys +from pathlib import Path +import shutil +import unittest +from parameterized import parameterized_class + +from agentstack import ValidationError +from agentstack import frameworks +from agentstack.tools import ToolConfig +from agentstack.generation.files import ConfigFile +from agentstack.generation.tool_generation import ToolsInitFile, TOOLS_INIT_FILENAME + + +BASE_PATH = Path(__file__).parent + + +@parameterized_class([{"framework": framework} for framework in frameworks.SUPPORTED_FRAMEWORKS]) +class TestToolGenerationInit(unittest.TestCase): + def setUp(self): + self.project_dir = BASE_PATH / 'tmp' / 'tool_generation_init' + os.makedirs(self.project_dir) + os.makedirs(self.project_dir / 'src') + os.makedirs(self.project_dir / 'src' / 'tools') + (self.project_dir / 'src' / '__init__.py').touch() + (self.project_dir / 'src' / 'tools' / '__init__.py').touch() + shutil.copy(BASE_PATH / 'fixtures' / 'agentstack.json', self.project_dir / 'agentstack.json') + # set the framework in agentstack.json + with ConfigFile(self.project_dir) as config: + config.framework = self.framework + + def tearDown(self): + shutil.rmtree(self.project_dir) + + def _get_test_tool(self) -> ToolConfig: + return ToolConfig(name='test_tool', category='test', tools=['test_tool']) + + def _get_test_tool_alt(self) -> ToolConfig: + return ToolConfig(name='test_tool_alt', category='test', tools=['test_tool_alt']) + + def test_tools_init_file(self): + tools_init = ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) + # file is empty + assert tools_init.get_import_for_tool(self._get_test_tool()) == None + + def test_tools_init_file_missing(self): + with self.assertRaises(ValidationError) as context: + tools_init = ToolsInitFile(self.project_dir / 'missing') + + def test_tools_init_file_add_import(self): + tool = self._get_test_tool() + with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(self.framework, tool) + + tool_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() + assert tool.get_import_statement(self.framework) in tool_init_src + + def test_tools_init_file_add_import_multiple(self): + tool = self._get_test_tool() + tool_alt = self._get_test_tool_alt() + with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(self.framework, tool) + + with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(self.framework, tool_alt) + + # Should not be able to re-add a tool import + with self.assertRaises(ValidationError) as context: + with ToolsInitFile(self.project_dir / TOOLS_INIT_FILENAME) as tools_init: + tools_init.add_import_for_tool(self.framework, tool) + + tool_init_src = open(self.project_dir / TOOLS_INIT_FILENAME).read() + assert tool.get_import_statement(self.framework) in tool_init_src + assert tool_alt.get_import_statement(self.framework) in tool_init_src + # TODO this might be a little too strict + assert ( + tool_init_src + == """ +from .test_tool_tool import test_tool +from .test_tool_alt_tool import test_tool_alt""" + ) diff --git a/tox.ini b/tox.ini index 6ab5f4a5..6733c3ab 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ envlist = py310,py311,py312 [testenv] deps = pytest + parameterized mypy: mypy commands = pytest -v