diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 13d446e145..953425f90c 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -777,3 +777,46 @@ def restore_object(self, filename): return None obj = serializer_plugin.restore_object(filename) return obj + + def save_model_to_state_file(self, tensor_dict, round_number, output_path): + """Save model weights to a protobuf state file.""" + from openfl.protocols import utils # Import here to avoid circular imports + + # Get tensor pipe to properly serialize the weights + tensor_pipe = self.get_tensor_pipe() + + # Create and save the protobuf message + try: + model_proto = utils.construct_model_proto( + tensor_dict=tensor_dict, round_number=round_number, tensor_pipe=tensor_pipe + ) + utils.dump_proto(model_proto=model_proto, fpath=output_path) + except Exception as e: + self.logger.error(f"Failed to create or save model proto: {e}") + raise + + def load_model_proto_from_file(self, path: str, force_clean: bool = False): + """Load model protobuf from file. + + Args: + path: Path to protobuf file + force_clean: If True, reset round numbers to 0 + + Returns: + ModelProto: Loaded model protobuf message + """ + from openfl.protocols import utils # Import here to avoid circular imports + + # Load the protobuf model + model_proto = utils.load_proto(path) + + # If force_clean, clear round_number from protobuf before conversion + if force_clean: + for tensor in model_proto.tensors: + Plan.logger.info("Resetting tensor %s", tensor) + # For scalar fields like round_number, just check if it's set to non-default + if tensor.round_number != 0: + Plan.logger.info("Resetting round_number %d to 0", tensor.round_number) + tensor.round_number = 0 + + return model_proto diff --git a/openfl/interface/plan.py b/openfl/interface/plan.py index 503693e581..78624c081e 100644 --- a/openfl/interface/plan.py +++ b/openfl/interface/plan.py @@ -13,6 +13,8 @@ from shutil import copyfile, rmtree from subprocess import check_call # nosec +pass + from click import Path as ClickPath from click import echo, group, option, pass_context from yaml import FullLoader, dump, load @@ -95,6 +97,19 @@ def plan(context): help="Install packages listed under 'requirements.txt'. True/False [Default: True]", default=True, ) +@option( + "-i", + "--init_model_path", + required=False, + help="Path to initialization model protobuf file", + type=ClickPath(exists=False), +) +@option( + "--clean_round", + is_flag=True, + default=False, + help="Clean round information when initializing from model", +) def initialize( context, plan_config, @@ -104,6 +119,8 @@ def initialize( input_shape, gandlf_config, install_reqs, + init_model_path, + clean_round, ): """Initialize Data Science plan. @@ -119,6 +136,8 @@ def initialize( feature_shape (str): The input shape to the model. gandlf_config (str): GaNDLF Configuration File Path. install_reqs (bool): Whether to install packages listed under 'requirements.txt'. + init_model_path (str): Optional path to initialization model protobuf file. + clean_round (bool): Whether to clean round information when initializing from model. """ for p in [plan_config, cols_config, data_config]: @@ -165,7 +184,6 @@ def initialize( ) init_state_path = plan.config["aggregator"]["settings"]["init_state_path"] - # This is needed to bypass data being locally available if input_shape is not None: logger.info( @@ -173,13 +191,23 @@ def initialize( ) data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape) - task_runner = plan.get_task_runner(data_loader) tensor_pipe = plan.get_tensor_pipe() + round_number = 0 + + if init_model_path and isfile(init_model_path): + logger.info(f"Loading initialization model from {init_model_path}") + if clean_round: + round_number = 0 + logger.info("Cleaning round information as requested") + model_proto = plan.load_model_proto_from_file(init_model_path, force_clean=clean_round) + init_tensor_dict, round_number = utils.deconstruct_model_proto(model_proto, tensor_pipe) + else: + init_tensor_dict = task_runner.get_tensor_dict(False) tensor_dict, holdout_params = split_tensor_dict_for_holdouts( logger, - task_runner.get_tensor_dict(False), + init_tensor_dict, **task_runner.tensor_dict_split_fn_kwargs, ) @@ -189,13 +217,15 @@ def initialize( f" values: {list(holdout_params.keys())}" ) - model_snap = utils.construct_model_proto( - tensor_dict=tensor_dict, round_number=0, tensor_pipe=tensor_pipe - ) - - logger.info("Creating Initial Weights File 🠆 %s", init_state_path) - - utils.dump_proto(model_proto=model_snap, fpath=init_state_path) + # Save the model state + try: + logger.info(f"Saving model state to {init_state_path}") + plan.save_model_to_state_file( + tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path + ) + except Exception as e: + logger.error(f"Failed to save model state: {e}") + raise plan_origin = Plan.parse( plan_config_path=plan_config, diff --git a/openfl/protocols/utils.py b/openfl/protocols/utils.py index e1d3da888a..3ccdd0ab05 100644 --- a/openfl/protocols/utils.py +++ b/openfl/protocols/utils.py @@ -4,9 +4,13 @@ """Proto utils.""" +import logging + from openfl.protocols import base_pb2 from openfl.utilities import TensorKey +logger = logging.getLogger(__name__) + def model_proto_to_bytes_and_metadata(model_proto): """Convert the model protobuf to bytes and metadata. @@ -356,3 +360,43 @@ def get_headers(context) -> dict: values are the corresponding header values. """ return {header[0]: header[1] for header in context.invocation_metadata()} + + +def construct_tensor_dict_from_proto(model_proto, tensor_pipe): + """Convert a model protobuf message to a tensor dictionary.""" + tensor_dict = {} + logger.info("\n=== Processing Proto Message ===") + logger.info(f"Number of tensors in proto: {len(model_proto.tensors)}") + + for tensor in model_proto.tensors: + logger.info(f"\nProcessing proto tensor: {tensor.name}") + logger.info("-" * 50) + try: + # Extract metadata from the tensor proto + transformer_metadata = [ + { + "int_to_float": proto.int_to_float, + "int_list": proto.int_list, + "bool_list": proto.bool_list, + } + for proto in tensor.transformer_metadata + ] + + # Decompress the tensor value using the compression pipeline + logger.info("Decompressing tensor...") + decompressed_tensor = tensor_pipe.backward( + data=tensor.data_bytes, transformer_metadata=transformer_metadata + ) + + # Store in dictionary + tensor_dict[tensor.name] = decompressed_tensor + + except Exception as e: + logger.error(f"Failed to process tensor {tensor.name}") + logger.error(f"Error: {str(e)}") + raise + + logger.info("\n=== Finished Processing Proto Message ===") + logger.info(f"Successfully processed {len(tensor_dict)} tensors") + + return tensor_dict diff --git a/torch_cnn_mnist_init.pbuf b/torch_cnn_mnist_init.pbuf new file mode 100644 index 0000000000..36a3bdb969 Binary files /dev/null and b/torch_cnn_mnist_init.pbuf differ