diff --git a/agentstack/frameworks/__init__.py b/agentstack/frameworks/__init__.py index 1a38e5eb..0da933ab 100644 --- a/agentstack/frameworks/__init__.py +++ b/agentstack/frameworks/__init__.py @@ -72,6 +72,12 @@ def add_task(self, task: TaskConfig) -> None: """ ... + def get_task_names(self) -> list[str]: + """ + Get a list of task names in the user's project. + """ + ... + def get_framework_module(framework: str) -> FrameworkModule: """ @@ -132,3 +138,9 @@ def add_task(task: TaskConfig): """ return get_framework_module(get_framework()).add_task(task) +def get_task_names() -> list[str]: + """ + Get a list of task names in the user's project. + """ + return get_framework_module(get_framework()).get_task_names() + diff --git a/agentstack/generation/task_generation.py b/agentstack/generation/task_generation.py index f15f7e50..91bee560 100644 --- a/agentstack/generation/task_generation.py +++ b/agentstack/generation/task_generation.py @@ -15,6 +15,11 @@ def add_task( ): verify_agentstack_project() + agents = frameworks.get_agent_names() + if not agent and len(agents) == 1: + # if there's only one agent, use it by default + agent = agents[0] + task = TaskConfig(task_name) with task as config: config.description = description or "Add your description here" diff --git a/tests/test_generation_tasks.py b/tests/test_generation_tasks.py index 106ec124..7c871cd2 100644 --- a/tests/test_generation_tasks.py +++ b/tests/test_generation_tasks.py @@ -8,7 +8,9 @@ from agentstack.conf import ConfigFile, set_path from agentstack.exceptions import ValidationError from agentstack import frameworks +from agentstack.tasks import TaskConfig from agentstack.generation.task_generation import add_task +from agentstack.generation.agent_generation import add_agent BASE_PATH = Path(__file__).parent @@ -60,3 +62,13 @@ def test_add_agent_exists(self): expected_output='expected_output', agent='agent', ) + + def test_add_task_selects_single_agent(self): + add_task( + 'task_test', + description='description', + expected_output='expected_output', + ) + + task_config = TaskConfig('task_test') + assert task_config.agent == 'test_agent' # defined in entrypoint_max.py