Skip to content

Commit

Permalink
Implement control channel in python host servicer
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Feb 7, 2025
1 parent 4c1c12d commit a79c9ec
Showing 1 changed file with 42 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar

from autogen_core import TopicId
from autogen_core._agent_id import AgentId
from autogen_core._runtime_impl_helpers import SubscriptionManager

from ._constants import GRPC_IMPORT_ERROR_STR
Expand Down Expand Up @@ -100,6 +101,9 @@ def __init__(self) -> None:
self._data_connections: Dict[
ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message]
] = {}
self._control_connections: Dict[
ClientConnectionId, ChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage]
] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
Expand Down Expand Up @@ -140,7 +144,23 @@ async def OpenControlChannel( # type: ignore
request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage],
context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage],
) -> AsyncIterator[agent_worker_pb2.ControlMessage]:
raise NotImplementedError("Method not implemented.")
client_id = await get_client_id_or_abort(context)

async def handle_callback(message: agent_worker_pb2.ControlMessage) -> None:
await self._receive_control_message(client_id, message)

connection = CallbackChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage](
request_iterator, client_id, handle_callback=handle_callback
)
self._control_connections[client_id] = connection
logger.info(f"Client {client_id} connected.")

try:
async for message in connection:
yield message
finally:
# Clean up the client connection.
del self._control_connections[client_id]

async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
async with self._agent_type_to_client_id_lock:
Expand Down Expand Up @@ -182,6 +202,27 @@ async def _receive_message(self, client_id: ClientConnectionId, message: agent_w
case None:
logger.warning("Received empty message")

async def _receive_control_message(self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage) -> None:
logger.info(f"Received message from client {client_id}: {message}")
destination = message.destination
if destination.startswith("agentid="):
agent_id = AgentId.from_str(destination[len("agentid="):])
target_client_id = self._agent_type_to_client_id.get(agent_id.type)
if target_client_id is None:
logger.error(f"Agent client id not found for agent type {agent_id.type}.")
return
elif destination.startswith("clientid="):
target_client_id = destination[len("clientid="):]
else:
logger.error(f"Invalid destination {destination}")
return

target_send_queue = self._control_connections.get(target_client_id)
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.send(message)

async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
# Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock:
Expand Down

0 comments on commit a79c9ec

Please sign in to comment.