Skip to content

Commit

Permalink
fix: do not count agent event in MaxMessageTermination condition (#5436)
Browse files Browse the repository at this point in the history
Resolves #5425
  • Loading branch information
ekzhu authored Feb 7, 2025
1 parent 707c3cf commit 0008c9c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self

from ..base import TerminatedException, TerminationCondition
from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage
from ..messages import AgentEvent, BaseChatMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage


class StopMessageTerminationConfig(BaseModel):
Expand Down Expand Up @@ -48,21 +48,25 @@ def _from_config(cls, config: StopMessageTerminationConfig) -> Self:

class MaxMessageTerminationConfig(BaseModel):
max_messages: int
include_agent_event: bool = False


class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminationConfig]):
"""Terminate the conversation after a maximum number of messages have been exchanged.
Args:
max_messages: The maximum number of messages allowed in the conversation.
include_agent_event: If True, include :class:`~autogen_agentchat.messages.AgentEvent` in the message count.
Otherwise, only include :class:`~autogen_agentchat.messages.ChatMessage`. Defaults to False.
"""

component_config_schema = MaxMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.MaxMessageTermination"

def __init__(self, max_messages: int) -> None:
def __init__(self, max_messages: int, include_agent_event: bool = False) -> None:
self._max_messages = max_messages
self._message_count = 0
self._include_agent_event = include_agent_event

@property
def terminated(self) -> bool:
Expand All @@ -71,7 +75,7 @@ def terminated(self) -> bool:
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached")
self._message_count += len(messages)
self._message_count += len([m for m in messages if self._include_agent_event or isinstance(m, BaseChatMessage)])
if self._message_count >= self._max_messages:
return StopMessage(
content=f"Maximum number of messages {self._max_messages} reached, current message count: {self._message_count}",
Expand All @@ -83,11 +87,13 @@ async def reset(self) -> None:
self._message_count = 0

def _to_config(self) -> MaxMessageTerminationConfig:
return MaxMessageTerminationConfig(max_messages=self._max_messages)
return MaxMessageTerminationConfig(
max_messages=self._max_messages, include_agent_event=self._include_agent_event
)

@classmethod
def _from_config(cls, config: MaxMessageTerminationConfig) -> Self:
return cls(max_messages=config.max_messages)
return cls(max_messages=config.max_messages, include_agent_event=config.include_agent_event)


class TextMentionTerminationConfig(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
TimeoutTermination,
TokenUsageTermination,
)
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage, UserInputRequestedEvent
from autogen_core.models import RequestUsage


Expand Down Expand Up @@ -74,6 +74,18 @@ async def test_max_message_termination() -> None:
is not None
)

termination = MaxMessageTermination(2, include_agent_event=True)
assert await termination([]) is None
await termination.reset()
assert await termination([TextMessage(content="Hello", source="user")]) is None
await termination.reset()
assert (
await termination(
[TextMessage(content="Hello", source="user"), UserInputRequestedEvent(request_id="1", source="agent")]
)
is not None
)


@pytest.mark.asyncio
async def test_mention_termination() -> None:
Expand Down

0 comments on commit 0008c9c

Please sign in to comment.