Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement control channel in python host servicer #5427

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar

from autogen_core import TopicId
from autogen_core._agent_id import AgentId
from autogen_core._runtime_impl_helpers import SubscriptionManager

from ._constants import GRPC_IMPORT_ERROR_STR
Expand Down Expand Up @@ -100,6 +101,9 @@ def __init__(self) -> None:
self._data_connections: Dict[
ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message]
] = {}
self._control_connections: Dict[
ClientConnectionId, ChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage]
] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
Expand Down Expand Up @@ -140,7 +144,23 @@ async def OpenControlChannel( # type: ignore
request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage],
context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage],
) -> AsyncIterator[agent_worker_pb2.ControlMessage]:
raise NotImplementedError("Method not implemented.")
client_id = await get_client_id_or_abort(context)

async def handle_callback(message: agent_worker_pb2.ControlMessage) -> None:
await self._receive_control_message(client_id, message)

connection = CallbackChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage](
request_iterator, client_id, handle_callback=handle_callback
)
self._control_connections[client_id] = connection
logger.info(f"Client {client_id} connected.")

try:
async for message in connection:
yield message
finally:
# Clean up the client connection.
del self._control_connections[client_id]

async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None:
async with self._agent_type_to_client_id_lock:
Expand Down Expand Up @@ -182,6 +202,29 @@ async def _receive_message(self, client_id: ClientConnectionId, message: agent_w
case None:
logger.warning("Received empty message")

async def _receive_control_message(
self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage
) -> None:
logger.info(f"Received message from client {client_id}: {message}")
destination = message.destination
if destination.startswith("agentid="):
agent_id = AgentId.from_str(destination[len("agentid=") :])
target_client_id = self._agent_type_to_client_id.get(agent_id.type)
if target_client_id is None:
logger.error(f"Agent client id not found for agent type {agent_id.type}.")
return
elif destination.startswith("clientid="):
target_client_id = destination[len("clientid=") :]
else:
logger.error(f"Invalid destination {destination}")
return

target_send_queue = self._control_connections.get(target_client_id)
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.send(message)

async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None:
# Deliver the message to a client given the target agent type.
async with self._agent_type_to_client_id_lock:
Expand Down
Loading