Skip to content

Commit

Permalink
Impl remove and get subscription APIs for python xlang (#5365)
Browse files Browse the repository at this point in the history
Closes #5297

---------

Co-authored-by: Ryan Sweet <[email protected]>
Co-authored-by: Jacob Alber <[email protected]>
Co-authored-by: Jacob Alber <[email protected]>
Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
5 people authored Feb 11, 2025
1 parent 392aa14 commit dc877d5
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Awaitable, Callable, DefaultDict, List, Set
from typing import Awaitable, Callable, DefaultDict, List, Set, Sequence

from ._agent import Agent
from ._agent_id import AgentId
Expand Down Expand Up @@ -35,6 +35,10 @@ def __init__(self) -> None:
self._seen_topics: Set[TopicId] = set()
self._subscribed_recipients: DefaultDict[TopicId, List[AgentId]] = defaultdict(list)

@property
def subscriptions(self) -> Sequence[Subscription]:
return self._subscriptions

async def add_subscription(self, subscription: Subscription) -> None:
# Check if the subscription already exists
if any(sub == subscription for sub in self._subscriptions):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,15 @@ async def add_subscription(self, subscription: Subscription) -> None:
await self._subscription_manager.add_subscription(subscription)

async def remove_subscription(self, id: str) -> None:
raise NotImplementedError("Subscriptions cannot be removed while using distributed runtime currently.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")

message = agent_worker_pb2.RemoveSubscriptionRequest(id=id)
_response: agent_worker_pb2.RemoveSubscriptionResponse = await self._host_connection.stub.RemoveSubscription(
message, metadata=self._host_connection.metadata
)

await self._subscription_manager.remove_subscription(id)

async def get(
self, id_or_type: AgentId | AgentType | str, /, key: str = "default", *, lazy: bool = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from autogen_core._runtime_impl_helpers import SubscriptionManager

from ._constants import GRPC_IMPORT_ERROR_STR
from ._utils import subscription_from_proto
from ._utils import subscription_from_proto, subscription_to_proto

try:
import grpc
Expand Down Expand Up @@ -170,7 +170,11 @@ async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
del self._agent_type_to_client_id[agent_type]
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
await self._subscription_manager.remove_subscription(sub_id)
try:
await self._subscription_manager.remove_subscription(sub_id)
# Catch and ignore if the subscription does not exist.
except ValueError:
continue
logger.info(f"Client {client_id} disconnected successfully")

def _raise_on_exception(self, task: Task[Any]) -> None:
Expand Down Expand Up @@ -327,7 +331,8 @@ async def RemoveSubscription( # type: ignore
],
) -> agent_worker_pb2.RemoveSubscriptionResponse:
_client_id = await get_client_id_or_abort(context)
raise NotImplementedError("Method not implemented.")
await self._subscription_manager.remove_subscription(request.id)
return agent_worker_pb2.RemoveSubscriptionResponse()

async def GetSubscriptions( # type: ignore
self,
Expand All @@ -337,4 +342,23 @@ async def GetSubscriptions( # type: ignore
],
) -> agent_worker_pb2.GetSubscriptionsResponse:
_client_id = await get_client_id_or_abort(context)
raise NotImplementedError("Method not implemented.")
subscriptions = self._subscription_manager.subscriptions
return agent_worker_pb2.GetSubscriptionsResponse(
subscriptions=[subscription_to_proto(sub) for sub in subscriptions]
)

# async def GetState( # type: ignore
# self,
# request: agent_worker_pb2.AgentId,
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse],
# ) -> agent_worker_pb2.GetStateResponse:
# _client_id = await get_client_id_or_abort(context)
# raise NotImplementedError("Method not implemented!")

# async def SaveState( # type: ignore
# self,
# request: agent_worker_pb2.AgentState,
# context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse],
# ) -> agent_worker_pb2.SaveStateResponse:
# _client_id = await get_client_id_or_abort(context)
# raise NotImplementedError("Method not implemented!")
42 changes: 42 additions & 0 deletions python/packages/autogen-ext/tests/test_worker_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PROTOBUF_DATA_CONTENT_TYPE,
AgentId,
AgentType,
DefaultSubscription,
DefaultTopicId,
MessageContext,
RoutedAgent,
Expand Down Expand Up @@ -129,6 +130,47 @@ async def test_register_receives_publish() -> None:


@pytest.mark.grpc
@pytest.mark.asyncio
async def test_register_doesnt_receive_after_removing_subscription() -> None:
host_address = "localhost:50053"
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()

worker1 = GrpcWorkerAgentRuntime(host_address=host_address)
worker1.start()
worker1.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await worker1.register_factory(
type=AgentType("name1"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
sub = DefaultSubscription(agent_type="name1")
await worker1.add_subscription(sub)

agent_1_instance = await worker1.try_get_underlying_agent_instance(AgentId("name1", "default"), LoopbackAgent)
# Publish message from worker1
await worker1.publish_message(MessageType(), topic_id=DefaultTopicId())

# Let the agent run for a bit.
await agent_1_instance.event.wait()
agent_1_instance.event.clear()

# Agents in default topic source should have received the message.
assert agent_1_instance.num_calls == 1

await worker1.remove_subscription(sub.id)

# Publish message from worker1
await worker1.publish_message(MessageType(), topic_id=DefaultTopicId())

# Let the agent run for a bit.
await asyncio.sleep(2)

# Agent should not have received the message.
assert agent_1_instance.num_calls == 1

await worker1.stop()
await host.stop()


@pytest.mark.asyncio
async def test_register_receives_publish_cascade_single_worker() -> None:
host_address = "localhost:50054"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from pydantic import BaseModel

from asyncio import Event


@dataclass
class MessageType: ...
Expand All @@ -36,13 +38,15 @@ def __init__(self) -> None:
super().__init__("A loop back agent.")
self.num_calls = 0
self.received_messages: list[Any] = []
self.event = Event()

@message_handler
async def on_new_message(
self, message: MessageType | ContentMessage, ctx: MessageContext
) -> MessageType | ContentMessage:
self.num_calls += 1
self.received_messages.append(message)
self.event.set()
return message


Expand Down

0 comments on commit dc877d5

Please sign in to comment.