Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion agentstack/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ruamel.yaml import YAML, YAMLError
from ruamel.yaml.scalarstring import FoldedScalarString
from agentstack import conf, log
from agentstack import frameworks
from agentstack.exceptions import ValidationError


Expand Down Expand Up @@ -71,10 +70,12 @@ def __init__(self, name: str):

@property
def provider(self) -> str:
from agentstack import frameworks
return frameworks.parse_llm(self.llm)[0]

@property
def model(self) -> str:
from agentstack import frameworks
return frameworks.parse_llm(self.llm)[1]

@property
Expand Down
27 changes: 24 additions & 3 deletions agentstack/frameworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from agentstack import conf
from agentstack.exceptions import ValidationError
from agentstack.utils import get_framework
from agentstack.agents import AgentConfig, get_all_agent_names
from agentstack.tasks import TaskConfig, get_all_task_names
from agentstack._tools import ToolConfig
from agentstack import graph

if TYPE_CHECKING:
from agentstack.generation import InsertionPoint
from agentstack.agents import AgentConfig
from agentstack.tasks import TaskConfig


CREWAI = 'crewai'
Expand Down Expand Up @@ -122,7 +122,28 @@ def validate_project():
"""
Validate that the user's project is ready to run.
"""
return get_framework_module(get_framework()).validate_project()
framework = get_framework()
entrypoint_path = get_entrypoint_path(framework)
_module = get_framework_module(framework)

# Run framework-specific validation
_module.validate_project()

# Verify that agents defined in agents.yaml are present in the codebase
agent_method_names = _module.get_agent_method_names()
for agent_name in get_all_agent_names():
if agent_name not in agent_method_names:
raise ValidationError(
f"Agent `{agent_name}` is defined in agents.yaml but not in {entrypoint_path}"
)

# Verify that tasks defined in tasks.yaml are present in the codebase
task_method_names = _module.get_task_method_names()
for task_name in get_all_task_names():
if task_name not in task_method_names:
raise ValidationError(
f"Task `{task_name}` is defined in tasks.yaml but not in {entrypoint_path}"
)


def parse_llm(llm: str) -> tuple[str, str]:
Expand Down
12 changes: 9 additions & 3 deletions tests/fixtures/frameworks/crewai/entrypoint_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@ class TestCrew:
def agent_name(self) -> Agent:
return Agent(config=self.agents_config['agent_name'], tools=[], verbose=True)

@agent
def second_agent_name(self) -> Agent:
return Agent(config=self.agents_config['second_agent_name'], tools=[], verbose=True)

@task
def task_name(self) -> Task:
return Task(
config=self.tasks_config['task_name'],
)
return Task(config=self.tasks_config['task_name'])

@task
def task_name_two(self) -> Task:
return Task(config=self.tasks_config['task_name_two'])

@crew
def crew(self) -> Crew:
Expand Down
31 changes: 31 additions & 0 deletions tests/fixtures/frameworks/langgraph/entrypoint_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def agent_name(self, state: State):
)
return {'messages': [response, ]}

@agentstack.agent
def second_agent_name(self, state: State):
agent_config = agentstack.get_agent('second_agent_name')
messages = ChatPromptTemplate.from_messages([
("user", agent_config.prompt),
])
messages = messages.format_messages(**state['inputs'])
agent = ChatOpenAI(model=agent_config.model)
agent = agent.bind_tools([])
response = agent.invoke(
messages + state['messages'],
)
return {'messages': [response, ]}

@agentstack.task
def task_name(self, state: State):
task_config = agentstack.get_task('task_name')
Expand All @@ -40,6 +54,15 @@ def task_name(self, state: State):
messages = messages.format_messages(**state['inputs'])
return {'messages': messages + state['messages']}

@agentstack.task
def task_name_two(self, state: State):
task_config = agentstack.get_task('task_name_two')
messages = ChatPromptTemplate.from_messages([
("user", task_config.prompt),
])
messages = messages.format_messages(**state['inputs'])
return {'messages': messages + state['messages']}

def run(self, inputs: list[str]):
self.graph = StateGraph(State)
tools = ToolNode([])
Expand All @@ -49,11 +72,19 @@ def run(self, inputs: list[str]):
self.graph.add_edge("agent_name", "tools")
self.graph.add_conditional_edges("agent_name", tools_condition)

self.graph.add_node("second_agent_name", self.agent_name)
self.graph.add_edge("second_agent_name", "tools")
self.graph.add_conditional_edges("second_agent_name", tools_condition)

self.graph.add_node("task_name", self.task_name)
self.graph.add_node("task_name_two", self.task_name)

self.graph.add_edge(START, "task_name")
self.graph.add_edge(START, "task_name_two")
self.graph.add_edge("task_name", "agent_name")
self.graph.add_edge("task_name_two", "second_agent_name")
self.graph.add_edge("agent_name", END)
self.graph.add_edge("second_agent_name", END)

app = self.graph.compile()
result = app.invoke({
Expand Down
47 changes: 45 additions & 2 deletions tests/test_frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,21 @@ def _populate_max_entrypoint(self):
"""This entrypoint has tools and agents."""
entrypoint_path = frameworks.get_entrypoint_path(self.framework)
shutil.copy(BASE_PATH / f"fixtures/frameworks/{self.framework}/entrypoint_max.py", entrypoint_path)
shutil.copy(BASE_PATH / 'fixtures/agents_max.yaml', self.project_dir / AGENTS_FILENAME)
shutil.copy(BASE_PATH / 'fixtures/tasks_max.yaml', self.project_dir / TASKS_FILENAME)

def _get_test_agent(self) -> AgentConfig:
shutil.copy(BASE_PATH / 'fixtures/agents_max.yaml', self.project_dir / AGENTS_FILENAME)
return AgentConfig('agent_name')

def _get_test_agent_alternate(self) -> AgentConfig:
return AgentConfig('second_agent_name')

def _get_test_task(self) -> TaskConfig:
shutil.copy(BASE_PATH / 'fixtures/tasks_max.yaml', self.project_dir / TASKS_FILENAME)
return TaskConfig('task_name')

def _get_test_task_alternate(self) -> TaskConfig:
return TaskConfig('task_name_two')

def _get_test_tool(self) -> ToolConfig:
return ToolConfig(name='test_tool', category='test', tools=['test_tool'])

Expand Down Expand Up @@ -88,6 +94,8 @@ def test_validate_project_invalid(self):

def test_validate_project_has_agent_no_task_invalid(self):
self._populate_min_entrypoint()
shutil.copy(BASE_PATH / 'fixtures/agents_max.yaml', self.project_dir / AGENTS_FILENAME)

frameworks.add_agent(self._get_test_agent())
with self.assertRaises(ValidationError) as context:
frameworks.validate_project()
Expand All @@ -98,6 +106,38 @@ def test_validate_project_has_task_no_agent_invalid(self):
with self.assertRaises(ValidationError) as context:
frameworks.validate_project()

def test_validate_project_missing_agent_method_invalid(self):
"""Ensure that all agents have a method defined in the entrypoint."""
self._populate_max_entrypoint()
# add an extra entry to agents.yaml
with open(self.project_dir / AGENTS_FILENAME, 'a') as f:
f.write("""\nextra_agent:
role: >-
role
goal: >-
this is a goal
backstory: >-
this is a backstory
llm: openai/gpt-4o""")
with self.assertRaises(ValidationError) as context:
frameworks.validate_project()

def test_validate_project_missing_task_method_invalid(self):
"""Ensure that all tasks have a method defined in the entrypoint."""
self._populate_max_entrypoint()
# add an extra entry to tasks.yaml
with open(self.project_dir / TASKS_FILENAME, 'a') as f:
f.write("""\nextra_task:
description: >-
Add your description here
expected_output: >-
Add your expected output here
agent: >-
default_agent""")

with self.assertRaises(ValidationError) as context:
frameworks.validate_project()

def test_get_agent_tool_names(self):
self._populate_max_entrypoint()
frameworks.add_tool(self._get_test_tool(), 'agent_name')
Expand Down Expand Up @@ -167,6 +207,9 @@ def test_get_tool_callables(self, tool_config):

def test_get_graph(self):
self._populate_max_entrypoint()
shutil.copy(BASE_PATH / 'fixtures/agents_max.yaml', self.project_dir / AGENTS_FILENAME)
shutil.copy(BASE_PATH / 'fixtures/tasks_max.yaml', self.project_dir / TASKS_FILENAME)

self._get_test_agent()
self._get_test_task()

Expand Down