Skip to content

Commit

Permalink
- extend plan initialize to have additional optional argument to take…
Browse files Browse the repository at this point in the history
… init model path (pbuf format)

- added new function to utils.py
- rebased 21.Jan.1
- reduce cyclo-complexity of initialize function
- address review comments
- removed unused imports/vars
Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 23, 2025
1 parent a19f869 commit 5c8ddf6
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 37 deletions.
33 changes: 33 additions & 0 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,36 @@ def restore_object(self, filename):
return None
obj = serializer_plugin.restore_object(filename)
return obj

def save_model_to_state_file(self, tensor_dict, round_number, output_path):
"""Save model weights to a protobuf state file.
This method serializes the model weights into a protobuf format and saves
them to a file. The serialization is done using the tensor pipe to ensure
proper compression and formatting.
Args:
tensor_dict (dict): Dictionary containing model weights and their
corresponding tensors.
round_number (int): The current federation round number.
output_path (str): Path where the serialized model state will be
saved.
Raises:
Exception: If there is an error during model proto creation or saving
to file.
"""
from openfl.protocols import utils # Import here to avoid circular imports

# Get tensor pipe to properly serialize the weights
tensor_pipe = self.get_tensor_pipe()

# Create and save the protobuf message
try:
model_proto = utils.construct_model_proto(
tensor_dict=tensor_dict, round_number=round_number, tensor_pipe=tensor_pipe
)
utils.dump_proto(model_proto=model_proto, fpath=output_path)
except Exception as e:
self.logger.error(f"Failed to create or save model proto: {e}")
raise
118 changes: 82 additions & 36 deletions openfl/interface/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def plan(context):
help="Install packages listed under 'requirements.txt'. True/False [Default: True]",
default=True,
)
@option(
"-i",
"--init_model_path",
required=False,
help="Path to initial model protobuf file",
type=ClickPath(exists=True),
)
def initialize(
context,
plan_config,
Expand All @@ -104,6 +111,7 @@ def initialize(
input_shape,
gandlf_config,
install_reqs,
init_model_path,
):
"""Initialize Data Science plan.
Expand All @@ -119,6 +127,7 @@ def initialize(
feature_shape (str): The input shape to the model.
gandlf_config (str): GaNDLF Configuration File Path.
install_reqs (bool): Whether to install packages listed under 'requirements.txt'.
init_model_path (str): Optional path to initialization model protobuf file.
"""

for p in [plan_config, cols_config, data_config]:
Expand All @@ -133,29 +142,8 @@ def initialize(
gandlf_config = Path(gandlf_config).absolute()

if install_reqs:
requirements_filename = "requirements.txt"
requirements_path = Path(requirements_filename).absolute()

if isfile(f"{str(requirements_path)}"):
check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-r",
f"{str(requirements_path)}",
],
shell=False,
)
echo(f"Successfully installed packages from {requirements_path}.")

# Required to restart the process for newly installed packages to be recognized
args_restart = [arg for arg in sys.argv if not arg.startswith("--install_reqs")]
args_restart.append("--install_reqs=False")
os.execv(args_restart[0], args_restart)
else:
echo("No additional requirements for workspace defined. Skipping...")
requirements_path = Path("requirements.txt").absolute()
_handle_requirements_install(requirements_path)

plan = Plan.parse(
plan_config_path=plan_config,
Expand All @@ -165,21 +153,20 @@ def initialize(
)

init_state_path = plan.config["aggregator"]["settings"]["init_state_path"]

# This is needed to bypass data being locally available
if input_shape is not None:
logger.info(
f"Attempting to generate initial model weights with custom input shape {input_shape}"
)

data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)

task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()
# Initialize tensor dictionary
init_tensor_dict, task_runner, round_number = _initialize_tensor_dict(
plan, input_shape, init_model_path
)

tensor_dict, holdout_params = split_tensor_dict_for_holdouts(
logger,
task_runner.get_tensor_dict(False),
init_tensor_dict,
**task_runner.tensor_dict_split_fn_kwargs,
)

Expand All @@ -189,13 +176,15 @@ def initialize(
f" values: {list(holdout_params.keys())}"
)

model_snap = utils.construct_model_proto(
tensor_dict=tensor_dict, round_number=0, tensor_pipe=tensor_pipe
)

logger.info("Creating Initial Weights File 🠆 %s", init_state_path)

utils.dump_proto(model_proto=model_snap, fpath=init_state_path)
# Save the model state
try:
logger.info(f"Saving model state to {init_state_path}")
plan.save_model_to_state_file(
tensor_dict=tensor_dict, round_number=round_number, output_path=init_state_path
)
except Exception as e:
logger.error(f"Failed to save model state: {e}")
raise

plan_origin = Plan.parse(
plan_config_path=plan_config,
Expand Down Expand Up @@ -223,6 +212,63 @@ def initialize(
logger.info(f"{context.obj['plans']}")


def _handle_requirements_install(requirements_path):
"""Handle the installation of requirements and process restart if needed.
This method checks if a requirements.txt file exists at the provided path.
If found, it installs the packages listed in the file using pip. After
successful installation, it restarts the current process with the same
arguments, but with the --install_reqs flag set to False to avoid
re-installing requirements.
If no requirements.txt file is found, it prints a message indicating that
no additional requirements are defined for the workspace and skips the
installation.
Args:
requirements_path (str or Path): The path to the requirements.txt file.
"""
if isfile(str(requirements_path)):
check_call(
[sys.executable, "-m", "pip", "install", "-r", str(requirements_path)],
shell=False,
)
echo(f"Successfully installed packages from {requirements_path}.")

# Required to restart the process for newly installed packages to be recognized
args_restart = [arg for arg in sys.argv if not arg.startswith("--install_reqs")]
args_restart.append("--install_reqs=False")
os.execv(args_restart[0], args_restart)
else:
echo("No additional requirements for workspace defined. Skipping...")


def _initialize_tensor_dict(plan, input_shape, init_model_path):
"""Initialize and return the tensor dictionary.
Args:
plan: The federation plan object
input_shape: The input shape to the model
init_model_path: Path to initial model protobuf file
Returns:
Tuple of (tensor_dict, task_runner, round_number)
"""
data_loader = get_dataloader(plan, prefer_minimal=True, input_shape=input_shape)
task_runner = plan.get_task_runner(data_loader)
tensor_pipe = plan.get_tensor_pipe()
round_number = 0

if init_model_path and isfile(init_model_path):
logger.info(f"Loading initial model from {init_model_path}")
model_proto = utils.load_proto(init_model_path)
init_tensor_dict, round_number = utils.deconstruct_model_proto(model_proto, tensor_pipe)
else:
init_tensor_dict = task_runner.get_tensor_dict(False)

return init_tensor_dict, task_runner, round_number


# TODO: looks like Plan.method
def freeze_plan(plan_config):
"""Dump the plan to YAML file.
Expand Down
1 change: 0 additions & 1 deletion openfl/protocols/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


"""Proto utils."""

from openfl.protocols import base_pb2
Expand Down

0 comments on commit 5c8ddf6

Please sign in to comment.