Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from agent_framework import Agent, SupportsAgentRun
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware
from agent_framework._sessions import AgentSession
from agent_framework._sessions import AgentSession, BaseHistoryProvider, InMemoryHistoryProvider
from agent_framework._tools import FunctionTool, tool
from agent_framework._types import AgentResponse, Content, Message
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest
Expand Down Expand Up @@ -368,12 +368,32 @@ def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]:
# restore the original tools, in case they are shared between agents
options["tools"] = tools_from_options

# Filter out history providers to prevent duplicate messages.
# The HandoffAgentExecutor manages conversation history via _full_conversation,
# so history providers would re-inject previously stored messages on each
# agent.run() call, causing the entire conversation to appear twice.
# A no-op InMemoryHistoryProvider placeholder prevents the agent from
# auto-injecting a default one at runtime.
filtered_providers = [
p for p in agent.context_providers
if not isinstance(p, BaseHistoryProvider)
]
# Always add a no-op placeholder to prevent the agent from
# auto-injecting a default InMemoryHistoryProvider at runtime.
filtered_providers.append(
InMemoryHistoryProvider(
load_messages=False,
store_inputs=False,
store_outputs=False,
)
)

return Agent(
client=agent.client,
id=agent.id,
name=agent.name,
description=agent.description,
context_providers=agent.context_providers,
context_providers=filtered_providers,
middleware=agent.agent_middleware,
default_options=cloned_options, # type: ignore[assignment]
)
Expand Down
111 changes: 111 additions & 0 deletions python/packages/orchestrations/tests/test_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,3 +1117,114 @@ def get_session(self, *, service_session_id, **kwargs):

with pytest.raises(TypeError, match="Participants must be Agent instances"):
HandoffBuilder().participants([fake])


class CapturingChatClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
"""Mock client that records the messages it receives on each call."""

def __init__(
self,
*,
name: str = "",
handoff_to: str | None = None,
) -> None:
ChatMiddlewareLayer.__init__(self)
FunctionInvocationLayer.__init__(self)
BaseChatClient.__init__(self)
self._name = name
self._handoff_to = handoff_to
self._call_index = 0
self.captured_calls: list[list[Message]] = []

def _inner_get_response(
self,
*,
messages: Sequence[Message],
stream: bool,
options: Mapping[str, Any],
**kwargs: Any,
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
self.captured_calls.append(list(messages))
contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id())
reply = Message(role="assistant", contents=contents)
if stream:
return self._build_streaming_response(contents, dict(options))

async def _get() -> ChatResponse:
return ChatResponse(messages=reply, response_id="mock_response")

return _get()

def _build_streaming_response(
self, contents: list[Content], options: dict[str, Any]
) -> ResponseStream[ChatResponseUpdate, ChatResponse]:
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")

def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
response_format = options.get("response_format")
output_format_type = response_format if isinstance(response_format, type) else None
return ChatResponse.from_updates(updates, output_format_type=output_format_type)

return ResponseStream(_stream(), finalizer=_finalize)

def _next_call_id(self) -> str | None:
if not self._handoff_to:
return None
call_id = f"{self._name}-handoff-{self._call_index}"
self._call_index += 1
# Only handoff on first call
self._handoff_to = None
return call_id


async def test_no_duplicate_messages_after_handoff_and_resume() -> None:
"""Regression test for issue #4695: duplicate messages in Handoff workflow.

When InMemoryHistoryProvider is active on handoff agents, it re-injects
previously stored messages alongside the executor's _full_conversation,
causing the entire conversation to appear twice in the API call.
"""
triage_client = CapturingChatClient(name="triage", handoff_to="specialist")
specialist_client = CapturingChatClient(name="specialist")

triage = Agent(client=triage_client, name="triage", id="triage")
specialist = Agent(client=specialist_client, name="specialist", id="specialist")

workflow = (
HandoffBuilder(
participants=[triage, specialist],
termination_condition=lambda _: False,
)
.with_start_agent(triage)
.build()
)

# First run: triage hands off to specialist, specialist responds
first_events = await _drain(workflow.run("Hello, I need help.", stream=True))
first_request = _latest_request_info_event(first_events)

# Second run: user replies, specialist responds again
await _drain(
workflow.run(
stream=True,
responses={first_request.request_id: [Message(role="user", text="More details please.")]},
)
)

# Specialist should have been called twice
assert len(specialist_client.captured_calls) == 2

# On the second call, verify no duplicate messages.
# Use a structural fingerprint (role + content types/text) to detect duplicates
# rather than comparing .text alone, which can miss non-text duplicates
# and false-fail on legitimately identical text.
second_call_messages = specialist_client.captured_calls[1]
fingerprints = [
(m.role, tuple((c.type, c.text) for c in m.contents))
for m in second_call_messages
]
assert len(fingerprints) == len(set(fingerprints)), (
f"Duplicate messages detected in specialist's second call: {fingerprints}"
)