diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index d77b42dd17f2..21862559e91d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -517,6 +517,7 @@ async def save_state(self) -> Mapping[str, Any]: "remaining": {target: dict(counter) for target, counter in self._remaining.items()}, "enqueued_any": dict(self._enqueued_any), "ready": list(self._ready), + "triggered_activation_groups": {target: list(groups) for target, groups in self._triggered_activation_groups.items()}, } return state @@ -527,6 +528,7 @@ async def load_state(self, state: Mapping[str, Any]) -> None: self._remaining = {target: Counter(groups) for target, groups in state["remaining"].items()} self._enqueued_any = state["enqueued_any"] self._ready = deque(state["ready"]) + self._triggered_activation_groups = {target: set(groups) for target, groups in state["triggered_activation_groups"].items()} async def reset(self) -> None: """Reset execution state to the start of the graph.""" diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py index b3a82e9ee203..ef77ec570c30 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py @@ -13,7 +13,7 @@ PerSourceFilter, ) from autogen_agentchat.base import Response, TaskResult -from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination +from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination, StopMessageTermination from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage from autogen_agentchat.teams import ( DiGraphBuilder, @@ -1766,3 +1766,62 @@ async def test_digraph_group_chat_resume_with_termination_condition(runtime: Age assert agent_a.total_messages == 1 # Still only ran once assert agent_b.total_messages == 1 # Still only ran once assert agent_c.total_messages == 1 # Now ran once + +@pytest.mark.asyncio +async def test_digraph_group_chat_resumes_from_state_with_triggered_activation_groups(runtime: AgentRuntime | None) -> None: + class _NoOpAgent(BaseChatAgent): + def __init__(self, name: str, description: str = "", target: str = "", stop_when: str = "") -> None: + super().__init__(name, description) + self._target = target + self._stop_when = stop_when + + @property + def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: + return (TextMessage,) + + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: + + # Force the stop message when specific string is matched + if self.name == self._stop_when: + return Response(chat_message=StopMessage(content=self.name, source=self.name)) + + return Response(chat_message=TextMessage(content=self._target, source=self.name)) + + async def on_reset(self, cancellation_token: CancellationToken) -> None: + pass + + def create_flow() -> GraphFlow: + agent_a = _NoOpAgent(name="A", target="B", stop_when="A") + agent_b = _NoOpAgent(name="B", target="A") + + builder = DiGraphBuilder() + builder.add_node(agent_a).add_node(agent_b) + + # Agent A will loop forever + builder.add_edge(agent_a, agent_a, "A", "loopback") + builder.add_edge(agent_a, agent_b, "B") + + builder.set_entry_point(agent_a) + + return GraphFlow( + participants=builder.get_participants(), + graph=builder.build(), + runtime=runtime, + termination_condition=StopMessageTermination(), + ) + + # Run the graph flow until termination condition is reached + team_one = create_flow() + result_one: TaskResult = await team_one.run(task="Start") + assert result_one.stop_reason == "Stop message received" + + # Export state. + state = await team_one.save_state() + + # Load team 1's state into team 2. + team_two = create_flow() + await team_two.load_state(state) + + # Team 2 should resume and immediately hit the stop condition again + result_two: TaskResult = await team_two.run(task="Continue") + assert result_two.stop_reason == "Stop message received" \ No newline at end of file