Skip to content

Commit

Permalink
Add support for running single-agent workflows within the BaseWorkflo…
Browse files Browse the repository at this point in the history
…wAgent class (#18038)
  • Loading branch information
dbuos authored Mar 6, 2025
1 parent 8b3e456 commit 3dee964
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 3 deletions.
17 changes: 16 additions & 1 deletion llama-index-core/llama_index/core/agent/workflow/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Sequence, Optional, Union
from typing import Callable, List, Sequence, Optional, Union, Any

from llama_index.core.agent.workflow.workflow_events import (
AgentOutput,
Expand All @@ -18,6 +18,8 @@
from llama_index.core.workflow import Context
from llama_index.core.objects import ObjectRetriever
from llama_index.core.settings import Settings
from llama_index.core.workflow.checkpointer import CheckpointCallback
from llama_index.core.workflow.handler import WorkflowHandler


def get_default_llm() -> LLM:
Expand Down Expand Up @@ -109,3 +111,16 @@ async def finalize(
self, ctx: Context, output: AgentOutput, memory: BaseMemory
) -> AgentOutput:
"""Finalize the agent's execution."""

@abstractmethod
def run(
self,
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
ctx: Optional[Context] = None,
stepwise: bool = False,
checkpoint_callback: Optional[CheckpointCallback] = None,
**workflow_kwargs: Any,
) -> WorkflowHandler:
"""Run the agent."""
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Sequence

from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent
from llama_index.core.agent.workflow.single_agent_workflow import SingleAgentRunnerMixin
from llama_index.core.agent.workflow.workflow_events import (
AgentInput,
AgentOutput,
Expand All @@ -15,7 +16,7 @@
from llama_index.core.workflow import Context


class FunctionAgent(BaseWorkflowAgent):
class FunctionAgent(SingleAgentRunnerMixin, BaseWorkflowAgent):
"""Function calling agent implementation."""

scratchpad_key: str = "scratchpad"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ResponseReasoningStep,
)
from llama_index.core.agent.workflow.base_agent import BaseWorkflowAgent
from llama_index.core.agent.workflow.single_agent_workflow import SingleAgentRunnerMixin
from llama_index.core.agent.workflow.workflow_events import (
AgentInput,
AgentOutput,
Expand All @@ -32,7 +33,7 @@ def default_formatter() -> ReActChatFormatter:
return ReActChatFormatter.from_defaults(context="some context")


class ReActAgent(BaseWorkflowAgent):
class ReActAgent(SingleAgentRunnerMixin, BaseWorkflowAgent):
"""React agent implementation."""

reasoning_key: str = "current_reasoning"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC
from typing import Any, List, Optional, Union, TypeVar

from llama_index.core.llms import ChatMessage
from llama_index.core.memory import BaseMemory
from llama_index.core.workflow import (
Context,
)
from llama_index.core.workflow.checkpointer import CheckpointCallback
from llama_index.core.workflow.handler import WorkflowHandler

T = TypeVar("T", bound="BaseWorkflowAgent") # type: ignore[name-defined]


class SingleAgentRunnerMixin(ABC):
"""Mixin class for executing a single agent within a workflow system.
This class provides the necessary interface for running a single agent.
"""

def run(
self: T,
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
memory: Optional[BaseMemory] = None,
ctx: Optional[Context] = None,
stepwise: bool = False,
checkpoint_callback: Optional[CheckpointCallback] = None,
**workflow_kwargs: Any,
) -> WorkflowHandler:
"""Run the agent."""
from llama_index.core.agent.workflow import AgentWorkflow

workflow = AgentWorkflow(agents=[self], **workflow_kwargs)
return workflow.run(
user_msg=user_msg,
chat_history=chat_history,
memory=memory,
ctx=ctx,
stepwise=stepwise,
checkpoint_callback=checkpoint_callback,
)
145 changes: 145 additions & 0 deletions llama-index-core/tests/agent/workflow/test_single_agent_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from typing import List, Any

import pytest

from llama_index.core.agent.workflow import FunctionAgent, ReActAgent
from llama_index.core.base.llms.types import (
ChatMessage,
LLMMetadata,
ChatResponseAsyncGen,
ChatResponse,
MessageRole,
)
from llama_index.core.llms import MockLLM
from llama_index.core.llms.llm import ToolSelection
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool


class MockLLM(MockLLM):
def __init__(self, responses: List[ChatMessage]):
super().__init__()
self._responses = responses
self._response_index = 0

@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(is_function_calling_model=True)

async def astream_chat(
self, messages: List[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
response_msg = None
if self._responses:
response_msg = self._responses[self._response_index]
self._response_index = (self._response_index + 1) % len(self._responses)

async def _gen():
if response_msg:
yield ChatResponse(
message=response_msg,
delta=response_msg.content,
raw={"content": response_msg.content},
)

return _gen()

async def astream_chat_with_tools(
self, tools: List[Any], chat_history: List[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
response_msg = None
if self._responses:
response_msg = self._responses[self._response_index]
self._response_index = (self._response_index + 1) % len(self._responses)

async def _gen():
if response_msg:
yield ChatResponse(
message=response_msg,
delta=response_msg.content,
raw={"content": response_msg.content},
)

return _gen()

def get_tool_calls_from_response(
self, response: ChatResponse, **kwargs: Any
) -> List[ToolSelection]:
return response.message.additional_kwargs.get("tool_calls", [])


@pytest.fixture()
def function_agent():
return FunctionAgent(
name="retriever",
description="Manages data retrieval",
system_prompt="You are a retrieval assistant.",
llm=MockLLM(
responses=[
ChatMessage(
role=MessageRole.ASSISTANT, content="Success with the FunctionAgent"
)
],
),
)


def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b


def subtract(a: int, b: int) -> int:
"""Subtract two numbers."""
return a - b


@pytest.fixture()
def calculator_agent():
return ReActAgent(
name="calculator",
description="Performs basic arithmetic operations",
system_prompt="You are a calculator assistant.",
tools=[
FunctionTool.from_defaults(fn=add),
FunctionTool.from_defaults(fn=subtract),
],
llm=MockLLM(
responses=[
ChatMessage(
role=MessageRole.ASSISTANT,
content='Thought: I need to add these numbers\nAction: add\nAction Input: {"a": 5, "b": 3}\n',
),
ChatMessage(
role=MessageRole.ASSISTANT,
content=r"Thought: The result is 8\Answer: The sum is 8",
),
]
),
)


@pytest.mark.asyncio()
async def test_single_function_agent(function_agent):
"""Test single agent with state management."""
handler = function_agent.run(user_msg="test")
async for _ in handler.stream_events():
pass

response = await handler
assert "Success with the FunctionAgent" in str(response.response)


@pytest.mark.asyncio()
async def test_single_react_agent(calculator_agent):
"""Verify execution of basic ReAct single agent."""
memory = ChatMemoryBuffer.from_defaults()
handler = calculator_agent.run(user_msg="Can you add 5 and 3?", memory=memory)

events = []
async for event in handler.stream_events():
events.append(event)

response = await handler

assert "8" in str(response.response)

0 comments on commit 3dee964

Please sign in to comment.