From 7f90f42ca16ae4636d560d131abf76ee0fc0bb2e Mon Sep 17 00:00:00 2001 From: Yehonatan Buchnik Date: Tue, 28 Jan 2025 12:51:55 +0200 Subject: [PATCH] A callback for re-fetching the root ca in the aggregator (#1315) * dynamically let the aggregator learn the clients certificates on every TLS handshake Signed-off-by: Buchnik, Yehonatan * formatting Signed-off-by: Buchnik, Yehonatan * formatting Signed-off-by: Buchnik, Yehonatan --------- Signed-off-by: Buchnik, Yehonatan --- openfl/transport/grpc/aggregator_server.py | 31 +++++++++++++++++----- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index 12f658a0aa..3aa7977ebb 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -10,7 +10,12 @@ from random import random from time import sleep -from grpc import StatusCode, server, ssl_server_credentials +from grpc import ( + StatusCode, + dynamic_ssl_server_credentials, + server, + ssl_server_certificate_configuration, +) from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils from openfl.transport.grpc.grpc_channel_options import channel_options @@ -50,6 +55,7 @@ def __init__( root_certificate=None, certificate=None, private_key=None, + root_certificate_refresher_cb=None, **kwargs, ): """ @@ -68,6 +74,8 @@ def __init__( TLS connection. private_key (str): The path to the server's private key for the TLS connection. + root_certificate_refresher_cb (Callable): A callback function + that receive no arguments and return the current root certificate. **kwargs: Additional keyword arguments. """ self.aggregator = aggregator @@ -81,6 +89,7 @@ def __init__( self.server_credentials = None self.logger = logging.getLogger(__name__) + self.root_certificate_refresher_cb = root_certificate_refresher_cb def validate_collaborator(self, request, context): """Validate the collaborator. @@ -325,13 +334,23 @@ def get_server(self): if not self.require_client_auth: self.logger.warning("Client-side authentication is disabled.") - - self.server_credentials = ssl_server_credentials( - ((private_key_b, certificate_b),), - root_certificates=root_certificate_b, - require_client_auth=self.require_client_auth, + cert_config = ssl_server_certificate_configuration( + ((private_key_b, certificate_b),), root_certificates=root_certificate_b ) + def certificate_configuration_fetcher(): + root_cert = root_certificate_b + if self.root_certificate_refresher_cb is not None: + root_cert = self.root_certificate_refresher_cb() + return ssl_server_certificate_configuration( + ((private_key_b, certificate_b),), root_certificates=root_cert + ) + + self.server_credentials = dynamic_ssl_server_credentials( + cert_config, + certificate_configuration_fetcher, + require_client_authentication=self.require_client_auth, + ) self.server.add_secure_port(self.uri, self.server_credentials) return self.server