Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend plan initialize to optionally load OpenFL trained model protobuf file #1290

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,36 @@ def restore_object(self, filename):
return None
obj = serializer_plugin.restore_object(filename)
return obj

def save_model_to_state_file(self, tensor_dict, round_number, output_path):
ishaileshpant marked this conversation as resolved.
Show resolved Hide resolved
"""Save model weights to a protobuf state file.
This method serializes the model weights into a protobuf format and saves
them to a file. The serialization is done using the tensor pipe to ensure
proper compression and formatting.
Args:
tensor_dict (dict): Dictionary containing model weights and their
corresponding tensors.
round_number (int): The current federation round number.
output_path (str): Path where the serialized model state will be
saved.
Raises:
Exception: If there is an error during model proto creation or saving
to file.
"""
from openfl.protocols import utils # Import here to avoid circular imports

# Get tensor pipe to properly serialize the weights
tensor_pipe = self.get_tensor_pipe()

# Create and save the protobuf message
try:
model_proto = utils.construct_model_proto(
tensor_dict=tensor_dict, round_number=round_number, tensor_pipe=tensor_pipe
)
utils.dump_proto(model_proto=model_proto, fpath=output_path)
except Exception as e:
self.logger.error(f"Failed to create or save model proto: {e}")
raise
118 changes: 82 additions & 36 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def plan(context):
help="Install packages listed under 'requirements.txt'. True/False [Default: True]",
default=True,
)
@option(
"-i",
"--init_model_path",
required=False,
help="Path to initial model protobuf file",
type=ClickPath(exists=True),
)
def initialize(
context,
plan_config,
Expand All @@ -104,6 +111,7 @@ def initialize(
input_shape,
gandlf_config,
install_reqs,
init_model_path,
):
"""Initialize Data Science plan.
Expand All @@ -119,6 +127,7 @@ def initialize(
feature_shape (str): The input shape to the model.
gandlf_config (str): GaNDLF Configuration File Path.
install_reqs (bool): Whether to install packages listed under 'requirements.txt'.
init_model_path (str): Optional path to initialization model protobuf file.
"""

for p in [plan_config, cols_config, data_config]:
Expand All @@ -133,29 +142,8 @@ def initialize(
gandlf_config = Path(gandlf_config).absolute()

if install_reqs:
requirements_filename = "requirements.txt"
requirements_path = Path(requirements_filename).absolute()

if isfile(f"{str(requirements_path)}"):
check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-r",
f"{str(requirements_path)}",
],
shell=False,
)
echo(f"Successfully installed packages from {requirements_path}.")

# Required to restart the process for newly installed packages to be recognized
args_restart = [arg for arg in sys.argv if not arg.startswith("--install_reqs")]
args_restart.append("--install_reqs=False")
os.execv(args_restart[0], args_restart)
else:
echo("No additional requirements for workspace defined. Skipping...")
requirements_path = Path("requirements.txt").absolute()
_handle_requirements_install(requirements_path)

plan = Plan.parse(
plan_config_path=plan_config,
Expand All @@ -165,21 +153,20 @@ def initialize(
)

init_state_path = plan.config["aggregator"]["settings"]["init_state_path"]

# This is needed to bypass data being locally available
if input_shape is not None:
logger.info(
f"Attempting to generate initial model weights with custom input shape {input_shape}"
)

data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)

task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()
# Initialize tensor dictionary
init_tensor_dict, task_runner, round_number = _initialize_tensor_dict(
plan, input_shape, init_model_path
)

tensor_dict, holdout_params = split_tensor_dict_for_holdouts(
logger,
task_runner.get_tensor_dict(False),
init_tensor_dict,
**task_runner.tensor_dict_split_fn_kwargs,
)

Expand All @@ -189,13 +176,15 @@ def initialize(
f" values: {list(holdout_params.keys())}"
)

model_snap = utils.construct_model_proto(
tensor_dict=tensor_dict, round_number=0, tensor_pipe=tensor_pipe
)

logger.info("Creating Initial Weights File 🠆 %s", init_state_path)

utils.dump_proto(model_proto=model_snap, fpath=init_state_path)
# Save the model state
try:
logger.info(f"Saving model state to {init_state_path}")
plan.save_model_to_state_file(
tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path
)
except Exception as e:
logger.error(f"Failed to save model state: {e}")
raise

plan_origin = Plan.parse(
plan_config_path=plan_config,
Expand Down Expand Up @@ -223,6 +212,63 @@ def initialize(
logger.info(f"{context.obj['plans']}")


def _handle_requirements_install(requirements_path):
"""Handle the installation of requirements and process restart if needed.
This method checks if a requirements.txt file exists at the provided path.
If found, it installs the packages listed in the file using pip. After
successful installation, it restarts the current process with the same
arguments, but with the --install_reqs flag set to False to avoid
re-installing requirements.
If no requirements.txt file is found, it prints a message indicating that
no additional requirements are defined for the workspace and skips the
installation.
Args:
requirements_path (str or Path): The path to the requirements.txt file.
"""
if isfile(str(requirements_path)):
check_call(
[sys.executable, "-m", "pip", "install", "-r", str(requirements_path)],
Dismissed Show dismissed Hide dismissed
shell=False,
)
echo(f"Successfully installed packages from {requirements_path}.")

# Required to restart the process for newly installed packages to be recognized
args_restart = [arg for arg in sys.argv if not arg.startswith("--install_reqs")]
args_restart.append("--install_reqs=False")
os.execv(args_restart[0], args_restart)
Dismissed Show dismissed Hide dismissed
else:
echo("No additional requirements for workspace defined. Skipping...")


def _initialize_tensor_dict(plan, input_shape, init_model_path):
"""Initialize and return the tensor dictionary.
Args:
plan: The federation plan object
input_shape: The input shape to the model
init_model_path: Path to initial model protobuf file
Returns:
Tuple of (tensor_dict, task_runner, round_number)
"""
data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)
task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()
round_number = 0

if init_model_path and isfile(init_model_path):
logger.info(f"Loading initial model from {init_model_path}")
model_proto = utils.load_proto(init_model_path)
init_tensor_dict, round_number = utils.deconstruct_model_proto(model_proto, tensor_pipe)
else:
init_tensor_dict = task_runner.get_tensor_dict(False)

return init_tensor_dict, task_runner, round_number


# TODO: looks like Plan.method
def freeze_plan(plan_config):
"""Dump the plan to YAML file.
Expand Down
1 change: 0 additions & 1 deletion openfl/protocols/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


"""Proto utils."""

from openfl.protocols import base_pb2
Expand Down
Loading