-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for running single-agent workflows within the BaseWorkflo…
…wAgent class (#18038)
- Loading branch information
Showing
5 changed files
with
206 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
llama-index-core/llama_index/core/agent/workflow/single_agent_workflow.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from abc import ABC | ||
from typing import Any, List, Optional, Union, TypeVar | ||
|
||
from llama_index.core.llms import ChatMessage | ||
from llama_index.core.memory import BaseMemory | ||
from llama_index.core.workflow import ( | ||
Context, | ||
) | ||
from llama_index.core.workflow.checkpointer import CheckpointCallback | ||
from llama_index.core.workflow.handler import WorkflowHandler | ||
|
||
T = TypeVar("T", bound="BaseWorkflowAgent") # type: ignore[name-defined] | ||
|
||
|
||
class SingleAgentRunnerMixin(ABC): | ||
"""Mixin class for executing a single agent within a workflow system. | ||
This class provides the necessary interface for running a single agent. | ||
""" | ||
|
||
def run( | ||
self: T, | ||
user_msg: Optional[Union[str, ChatMessage]] = None, | ||
chat_history: Optional[List[ChatMessage]] = None, | ||
memory: Optional[BaseMemory] = None, | ||
ctx: Optional[Context] = None, | ||
stepwise: bool = False, | ||
checkpoint_callback: Optional[CheckpointCallback] = None, | ||
**workflow_kwargs: Any, | ||
) -> WorkflowHandler: | ||
"""Run the agent.""" | ||
from llama_index.core.agent.workflow import AgentWorkflow | ||
|
||
workflow = AgentWorkflow(agents=[self], **workflow_kwargs) | ||
return workflow.run( | ||
user_msg=user_msg, | ||
chat_history=chat_history, | ||
memory=memory, | ||
ctx=ctx, | ||
stepwise=stepwise, | ||
checkpoint_callback=checkpoint_callback, | ||
) |
145 changes: 145 additions & 0 deletions
145
llama-index-core/tests/agent/workflow/test_single_agent_workflow.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from typing import List, Any | ||
|
||
import pytest | ||
|
||
from llama_index.core.agent.workflow import FunctionAgent, ReActAgent | ||
from llama_index.core.base.llms.types import ( | ||
ChatMessage, | ||
LLMMetadata, | ||
ChatResponseAsyncGen, | ||
ChatResponse, | ||
MessageRole, | ||
) | ||
from llama_index.core.llms import MockLLM | ||
from llama_index.core.llms.llm import ToolSelection | ||
from llama_index.core.memory import ChatMemoryBuffer | ||
from llama_index.core.tools import FunctionTool | ||
|
||
|
||
class MockLLM(MockLLM): | ||
def __init__(self, responses: List[ChatMessage]): | ||
super().__init__() | ||
self._responses = responses | ||
self._response_index = 0 | ||
|
||
@property | ||
def metadata(self) -> LLMMetadata: | ||
return LLMMetadata(is_function_calling_model=True) | ||
|
||
async def astream_chat( | ||
self, messages: List[ChatMessage], **kwargs: Any | ||
) -> ChatResponseAsyncGen: | ||
response_msg = None | ||
if self._responses: | ||
response_msg = self._responses[self._response_index] | ||
self._response_index = (self._response_index + 1) % len(self._responses) | ||
|
||
async def _gen(): | ||
if response_msg: | ||
yield ChatResponse( | ||
message=response_msg, | ||
delta=response_msg.content, | ||
raw={"content": response_msg.content}, | ||
) | ||
|
||
return _gen() | ||
|
||
async def astream_chat_with_tools( | ||
self, tools: List[Any], chat_history: List[ChatMessage], **kwargs: Any | ||
) -> ChatResponseAsyncGen: | ||
response_msg = None | ||
if self._responses: | ||
response_msg = self._responses[self._response_index] | ||
self._response_index = (self._response_index + 1) % len(self._responses) | ||
|
||
async def _gen(): | ||
if response_msg: | ||
yield ChatResponse( | ||
message=response_msg, | ||
delta=response_msg.content, | ||
raw={"content": response_msg.content}, | ||
) | ||
|
||
return _gen() | ||
|
||
def get_tool_calls_from_response( | ||
self, response: ChatResponse, **kwargs: Any | ||
) -> List[ToolSelection]: | ||
return response.message.additional_kwargs.get("tool_calls", []) | ||
|
||
|
||
@pytest.fixture() | ||
def function_agent(): | ||
return FunctionAgent( | ||
name="retriever", | ||
description="Manages data retrieval", | ||
system_prompt="You are a retrieval assistant.", | ||
llm=MockLLM( | ||
responses=[ | ||
ChatMessage( | ||
role=MessageRole.ASSISTANT, content="Success with the FunctionAgent" | ||
) | ||
], | ||
), | ||
) | ||
|
||
|
||
def add(a: int, b: int) -> int: | ||
"""Add two numbers.""" | ||
return a + b | ||
|
||
|
||
def subtract(a: int, b: int) -> int: | ||
"""Subtract two numbers.""" | ||
return a - b | ||
|
||
|
||
@pytest.fixture() | ||
def calculator_agent(): | ||
return ReActAgent( | ||
name="calculator", | ||
description="Performs basic arithmetic operations", | ||
system_prompt="You are a calculator assistant.", | ||
tools=[ | ||
FunctionTool.from_defaults(fn=add), | ||
FunctionTool.from_defaults(fn=subtract), | ||
], | ||
llm=MockLLM( | ||
responses=[ | ||
ChatMessage( | ||
role=MessageRole.ASSISTANT, | ||
content='Thought: I need to add these numbers\nAction: add\nAction Input: {"a": 5, "b": 3}\n', | ||
), | ||
ChatMessage( | ||
role=MessageRole.ASSISTANT, | ||
content=r"Thought: The result is 8\Answer: The sum is 8", | ||
), | ||
] | ||
), | ||
) | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_single_function_agent(function_agent): | ||
"""Test single agent with state management.""" | ||
handler = function_agent.run(user_msg="test") | ||
async for _ in handler.stream_events(): | ||
pass | ||
|
||
response = await handler | ||
assert "Success with the FunctionAgent" in str(response.response) | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_single_react_agent(calculator_agent): | ||
"""Verify execution of basic ReAct single agent.""" | ||
memory = ChatMemoryBuffer.from_defaults() | ||
handler = calculator_agent.run(user_msg="Can you add 5 and 3?", memory=memory) | ||
|
||
events = [] | ||
async for event in handler.stream_events(): | ||
events.append(event) | ||
|
||
response = await handler | ||
|
||
assert "8" in str(response.response) |