diff --git a/python/packages/autogen-core/src/autogen_core/_runtime_impl_helpers.py b/python/packages/autogen-core/src/autogen_core/_runtime_impl_helpers.py index dcd0bb6c991c..3c025f0e1c05 100644 --- a/python/packages/autogen-core/src/autogen_core/_runtime_impl_helpers.py +++ b/python/packages/autogen-core/src/autogen_core/_runtime_impl_helpers.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Awaitable, Callable, DefaultDict, List, Set +from typing import Awaitable, Callable, DefaultDict, List, Set, Sequence from ._agent import Agent from ._agent_id import AgentId @@ -35,6 +35,10 @@ def __init__(self) -> None: self._seen_topics: Set[TopicId] = set() self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list) + @property + def subscriptions(self) -> Sequence[Subscription]: + return self._subscriptions + async def add_subscription(self, subscription: Subscription) -> None: # Check if the subscription already exists if any(sub == subscription for sub in self._subscriptions): diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 4a87793ffb7c..41f08e432973 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -790,7 +790,15 @@ async def add_subscription(self, subscription: Subscription) -> None: await self._subscription_manager.add_subscription(subscription) async def remove_subscription(self, id: str) -> None: - raise NotImplementedError("Subscriptions cannot be removed while using distributed runtime currently.") + if self._host_connection is None: + raise RuntimeError("Host connection is not set.") + + message = agent_worker_pb2.RemoveSubscriptionRequest(id=id) + _response: agent_worker_pb2.RemoveSubscriptionResponse = await self._host_connection.stub.RemoveSubscription( + message, metadata=self._host_connection.metadata + ) + + await self._subscription_manager.remove_subscription(id) async def get( self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index 621ed9511eb4..1c0b57a440ed 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -11,7 +11,7 @@ from autogen_core._runtime_impl_helpers import SubscriptionManager from ._constants import GRPC_IMPORT_ERROR_STR -from ._utils import subscription_from_proto +from ._utils import subscription_from_proto, subscription_to_proto try: import grpc @@ -170,7 +170,11 @@ async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None: del self._agent_type_to_client_id[agent_type] for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()): logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}") - await self._subscription_manager.remove_subscription(sub_id) + try: + await self._subscription_manager.remove_subscription(sub_id) + # Catch and ignore if the subscription does not exist. + except ValueError: + continue logger.info(f"Client {client_id} disconnected successfully") def _raise_on_exception(self, task: Task[Any]) -> None: @@ -327,7 +331,8 @@ async def RemoveSubscription( # type: ignore ], ) -> agent_worker_pb2.RemoveSubscriptionResponse: _client_id = await get_client_id_or_abort(context) - raise NotImplementedError("Method not implemented.") + await self._subscription_manager.remove_subscription(request.id) + return agent_worker_pb2.RemoveSubscriptionResponse() async def GetSubscriptions( # type: ignore self, @@ -337,4 +342,23 @@ async def GetSubscriptions( # type: ignore ], ) -> agent_worker_pb2.GetSubscriptionsResponse: _client_id = await get_client_id_or_abort(context) - raise NotImplementedError("Method not implemented.") + subscriptions = self._subscription_manager.subscriptions + return agent_worker_pb2.GetSubscriptionsResponse( + subscriptions=[subscription_to_proto(sub) for sub in subscriptions] + ) + + # async def GetState( # type: ignore + # self, + # request: agent_worker_pb2.AgentId, + # context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse], + # ) -> agent_worker_pb2.GetStateResponse: + # _client_id = await get_client_id_or_abort(context) + # raise NotImplementedError("Method not implemented!") + + # async def SaveState( # type: ignore + # self, + # request: agent_worker_pb2.AgentState, + # context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse], + # ) -> agent_worker_pb2.SaveStateResponse: + # _client_id = await get_client_id_or_abort(context) + # raise NotImplementedError("Method not implemented!") diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index 88c9d1e0bf02..7849dd54768f 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -8,6 +8,7 @@ PROTOBUF_DATA_CONTENT_TYPE, AgentId, AgentType, + DefaultSubscription, DefaultTopicId, MessageContext, RoutedAgent, @@ -129,6 +130,47 @@ async def test_register_receives_publish() -> None: @pytest.mark.grpc +@pytest.mark.asyncio +async def test_register_doesnt_receive_after_removing_subscription() -> None: + host_address = "localhost:50053" + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + worker1 = GrpcWorkerAgentRuntime(host_address=host_address) + worker1.start() + worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType)) + await worker1.register_factory( + type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent + ) + sub = DefaultSubscription(agent_type="name1") + await worker1.add_subscription(sub) + + agent_1_instance = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent) + # Publish message from worker1 + await worker1.publish_message(MessageType(), topic_id=DefaultTopicId()) + + # Let the agent run for a bit. + await agent_1_instance.event.wait() + agent_1_instance.event.clear() + + # Agents in default topic source should have received the message. + assert agent_1_instance.num_calls == 1 + + await worker1.remove_subscription(sub.id) + + # Publish message from worker1 + await worker1.publish_message(MessageType(), topic_id=DefaultTopicId()) + + # Let the agent run for a bit. + await asyncio.sleep(2) + + # Agent should not have received the message. + assert agent_1_instance.num_calls == 1 + + await worker1.stop() + await host.stop() + + @pytest.mark.asyncio async def test_register_receives_publish_cascade_single_worker() -> None: host_address = "localhost:50054" diff --git a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py index b917194b1d82..0dfb1ece4a07 100644 --- a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py +++ b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py @@ -16,6 +16,8 @@ ) from pydantic import BaseModel +from asyncio import Event + @dataclass class MessageType: ... @@ -36,6 +38,7 @@ def __init__(self) -> None: super().__init__("A loop back agent.") self.num_calls = 0 self.received_messages: list[Any] = [] + self.event = Event() @message_handler async def on_new_message( @@ -43,6 +46,7 @@ async def on_new_message( ) -> MessageType | ContentMessage: self.num_calls += 1 self.received_messages.append(message) + self.event.set() return message