Skip to content

Commit

Permalink
- extend plan initialize to have additional optional argument to take…
Browse files Browse the repository at this point in the history
… init model path (pbuf format)

- added clean_round flag to reset round_number while reading supplied pbuf file
- added new function to utils.py
- added debug logs & added a sample pre-trained pbuf file
- rebased 21.Jan.1
Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 21, 2025
1 parent cc3f12d commit 2672ec8
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 10 deletions.
43 changes: 43 additions & 0 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,46 @@ def restore_object(self, filename):
return None
obj = serializer_plugin.restore_object(filename)
return obj

def save_model_to_state_file(self, tensor_dict, round_number, output_path):
"""Save model weights to a protobuf state file."""
from openfl.protocols import utils # Import here to avoid circular imports

# Get tensor pipe to properly serialize the weights
tensor_pipe = self.get_tensor_pipe()

# Create and save the protobuf message
try:
model_proto = utils.construct_model_proto(
tensor_dict=tensor_dict, round_number=round_number, tensor_pipe=tensor_pipe
)
utils.dump_proto(model_proto=model_proto, fpath=output_path)
except Exception as e:
self.logger.error(f"Failed to create or save model proto: {e}")
raise

def load_model_proto_from_file(self, path: str, force_clean: bool = False):
"""Load model protobuf from file.
Args:
path: Path to protobuf file
force_clean: If True, reset round numbers to 0
Returns:
ModelProto: Loaded model protobuf message
"""
from openfl.protocols import utils # Import here to avoid circular imports

# Load the protobuf model
model_proto = utils.load_proto(path)

# If force_clean, clear round_number from protobuf before conversion
if force_clean:
for tensor in model_proto.tensors:
Plan.logger.info("Resetting tensor %s", tensor)
# For scalar fields like round_number, just check if it's set to non-default
if tensor.round_number != 0:
Plan.logger.info("Resetting round_number %d to 0", tensor.round_number)
tensor.round_number = 0

return model_proto
50 changes: 40 additions & 10 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from shutil import copyfile, rmtree
from subprocess import check_call # nosec

pass

from click import Path as ClickPath
from click import echo, group, option, pass_context
from yaml import FullLoader, dump, load
Expand Down Expand Up @@ -95,6 +97,19 @@ def plan(context):
help="Install packages listed under 'requirements.txt'. True/False [Default: True]",
default=True,
)
@option(
"-i",
"--init_model_path",
required=False,
help="Path to initialization model protobuf file",
type=ClickPath(exists=False),
)
@option(
"--clean_round",
is_flag=True,
default=False,
help="Clean round information when initializing from model",
)
def initialize(
context,
plan_config,
Expand All @@ -104,6 +119,8 @@ def initialize(
input_shape,
gandlf_config,
install_reqs,
init_model_path,
clean_round,
):
"""Initialize Data Science plan.
Expand All @@ -119,6 +136,8 @@ def initialize(
feature_shape (str): The input shape to the model.
gandlf_config (str): GaNDLF Configuration File Path.
install_reqs (bool): Whether to install packages listed under 'requirements.txt'.
init_model_path (str): Optional path to initialization model protobuf file.
clean_round (bool): Whether to clean round information when initializing from model.
"""

for p in [plan_config, cols_config, data_config]:
Expand Down Expand Up @@ -165,21 +184,30 @@ def initialize(
)

init_state_path = plan.config["aggregator"]["settings"]["init_state_path"]

# This is needed to bypass data being locally available
if input_shape is not None:
logger.info(
f"Attempting to generate initial model weights with custom input shape {input_shape}"
)

data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)

task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()
round_number = 0

if init_model_path and isfile(init_model_path):
logger.info(f"Loading initialization model from {init_model_path}")
if clean_round:
round_number = 0
logger.info("Cleaning round information as requested")
model_proto = plan.load_model_proto_from_file(init_model_path, force_clean=clean_round)
init_tensor_dict, round_number = utils.deconstruct_model_proto(model_proto, tensor_pipe)
else:
init_tensor_dict = task_runner.get_tensor_dict(False)

tensor_dict, holdout_params = split_tensor_dict_for_holdouts(
logger,
task_runner.get_tensor_dict(False),
init_tensor_dict,
**task_runner.tensor_dict_split_fn_kwargs,
)

Expand All @@ -189,13 +217,15 @@ def initialize(
f" values: {list(holdout_params.keys())}"
)

model_snap = utils.construct_model_proto(
tensor_dict=tensor_dict, round_number=0, tensor_pipe=tensor_pipe
)

logger.info("Creating Initial Weights File 🠆 %s", init_state_path)

utils.dump_proto(model_proto=model_snap, fpath=init_state_path)
# Save the model state
try:
logger.info(f"Saving model state to {init_state_path}")
plan.save_model_to_state_file(
tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path
)
except Exception as e:
logger.error(f"Failed to save model state: {e}")
raise

plan_origin = Plan.parse(
plan_config_path=plan_config,
Expand Down
44 changes: 44 additions & 0 deletions openfl/protocols/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

"""Proto utils."""

import logging

from openfl.protocols import base_pb2
from openfl.utilities import TensorKey

logger = logging.getLogger(__name__)


def model_proto_to_bytes_and_metadata(model_proto):
"""Convert the model protobuf to bytes and metadata.
Expand Down Expand Up @@ -356,3 +360,43 @@ def get_headers(context) -> dict:
values are the corresponding header values.
"""
return {header[0]: header[1] for header in context.invocation_metadata()}


def construct_tensor_dict_from_proto(model_proto, tensor_pipe):
"""Convert a model protobuf message to a tensor dictionary."""
tensor_dict = {}
logger.info("\n=== Processing Proto Message ===")
logger.info(f"Number of tensors in proto: {len(model_proto.tensors)}")

for tensor in model_proto.tensors:
logger.info(f"\nProcessing proto tensor: {tensor.name}")
logger.info("-" * 50)
try:
# Extract metadata from the tensor proto
transformer_metadata = [
{
"int_to_float": proto.int_to_float,
"int_list": proto.int_list,
"bool_list": proto.bool_list,
}
for proto in tensor.transformer_metadata
]

# Decompress the tensor value using the compression pipeline
logger.info("Decompressing tensor...")
decompressed_tensor = tensor_pipe.backward(
data=tensor.data_bytes, transformer_metadata=transformer_metadata
)

# Store in dictionary
tensor_dict[tensor.name] = decompressed_tensor

except Exception as e:
logger.error(f"Failed to process tensor {tensor.name}")
logger.error(f"Error: {str(e)}")
raise

logger.info("\n=== Finished Processing Proto Message ===")
logger.info(f"Successfully processed {len(tensor_dict)} tensors")

return tensor_dict
Binary file added torch_cnn_mnist_init.pbuf
Binary file not shown.

0 comments on commit 2672ec8

Please sign in to comment.