Skip to content

Commit

Permalink
Straggler handling Follow-up (#1097)
Browse files Browse the repository at this point in the history
* Renamed Straggler Handling package

Signed-off-by: Ishant Thakare <[email protected]>

* Incorporated review comments

Signed-off-by: Ishant Thakare <[email protected]>

* Incorporated karan's review comments

Signed-off-by: Ishant Thakare <[email protected]>

* Resolving merge conflicts

Signed-off-by: Ishant Thakare <[email protected]>

* Fix code format

Signed-off-by: Ishant Thakare <[email protected]>

* Incorporated review comments

Signed-off-by: Ishant Thakare <[email protected]>

* Review comment incorporated

Signed-off-by: Ishant Thakare <[email protected]>

* Incorporated review comments

Signed-off-by: Ishant Thakare <[email protected]>

---------

Signed-off-by: Ishant Thakare <[email protected]>
  • Loading branch information
ishant162 authored Jan 30, 2025
1 parent 0ed1338 commit c71080e
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ The Open Federated Learning (OpenFL) framework supports straggler handling inter

The following are the straggler handling algorithms supported in OpenFL:

``CutoffTimeBasedStragglerHandling``
``CutoffTimePolicy``
Identifies stragglers based on the cutoff time specified in the settings. Arguments to the function are:
- *Cutoff Time* (straggler_cutoff_time), specifies the cutoff time by which the aggregator should end the round early.
- *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model.

For example, in a federation of 5 collaborators, if :code:`straggler_cutoff_time` (in seconds) is set to 20 and :code:`minimum_reporting` is set to 2, atleast 2 collaborators (or more) would be included in the round, provided that the time limit of 20 seconds is not exceeded.
In an event where :code:`minimum_reporting` collaborators don't make it within the :code:`straggler_cutoff_time`, the straggler handling policy is disregarded.

``PercentageBasedStragglerHandling``
``PercentagePolicy``
Identifies stragglers based on the percetage specified. Arguments to the function are:
- *Percentage of collaborators* (percent_collaborators_needed), specifies a percentage of collaborators enough to end the round early.
- *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model.
Expand All @@ -29,12 +29,12 @@ The following are the straggler handling algorithms supported in OpenFL:
Demonstration of adding the straggler handling interface
=========================================================

The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentageBasedStragglerHandling``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffTimeBasedStragglerHandling** function instead:
The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentagePolicy``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffTimePolicy** function instead:

.. code-block:: yaml
straggler_handling_policy :
template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling
template : openfl.component.aggregator.straggler_handling.CutoffTimePolicy
settings :
straggler_cutoff_time : 20
minimum_reporting : 1
1 change: 0 additions & 1 deletion docs/openfl.component.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@
collaborator
director
envoy
straggler_handling_functions

.. TODO(MasterSkepticista) Shrink API namespace
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml

straggler_handling_policy :
template : openfl.component.straggler_handling_functions.PercentageBasedStragglerHandling
template : openfl.component.aggregator.straggler_handling.PercentagePolicy
settings :
percent_collaborators_needed : 0.5
minimum_reporting : 1
14 changes: 5 additions & 9 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
"""OpenFL Component Module."""

from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.aggregator.straggler_handling import (
CutoffTimePolicy,
PercentagePolicy,
StragglerPolicy,
)
from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner
from openfl.component.collaborator.collaborator import Collaborator
from openfl.component.straggler_handling_functions.cutoff_time_based_straggler_handling import (
CutoffTimeBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.percentage_based_straggler_handling import (
PercentageBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingPolicy,
)
5 changes: 5 additions & 0 deletions openfl/component/aggregator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@


from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.aggregator.straggler_handling import (
CutoffTimePolicy,
PercentagePolicy,
StragglerPolicy,
)
9 changes: 3 additions & 6 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List, Optional

import openfl.callbacks as callbacks_module
from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling
from openfl.component.aggregator.straggler_handling import CutoffTimePolicy, StragglerPolicy
from openfl.databases import PersistentTensorDB, TensorDB
from openfl.interface.aggregation_functions import WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
last_state_path,
assigner,
use_delta_updates=True,
straggler_handling_policy=None,
straggler_handling_policy: StragglerPolicy = CutoffTimePolicy,
rounds_to_train=256,
single_col_cert_common_name=None,
compression_pipeline=None,
Expand All @@ -100,7 +100,6 @@ def __init__(
weight.
assigner: Assigner object.
straggler_handling_policy (optional): Straggler handling policy.
Defaults to CutoffTimeBasedStragglerHandling.
rounds_to_train (int, optional): Number of rounds to train.
Defaults to 256.
single_col_cert_common_name (str, optional): Common name for single
Expand All @@ -127,9 +126,7 @@ def __init__(
# FIXME: "" instead of None is for protobuf compatibility.
self.single_col_cert_common_name = single_col_cert_common_name or ""

self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
self.straggler_handling_policy = straggler_handling_policy()

self.rounds_to_train = rounds_to_train
if self.task_group == "evaluation":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,77 @@
# SPDX-License-Identifier: Apache-2.0


"""Cutoff time based Straggler Handling function."""
"""Straggler handling module."""

import threading
import time
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Callable

import numpy as np

from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingPolicy,
)
logger = getLogger(__name__)


class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy):
class StragglerPolicy(ABC):
"""Federated Learning straggler handling interface."""

@abstractmethod
def start_policy(self, **kwargs) -> None:
"""
Start straggler handling policy for collaborator for a particular round.
NOTE: Refer CutoffTimePolicy class for reference.
Args:
**kwargs
"""
raise NotImplementedError

@abstractmethod
def reset_policy_for_round(self) -> None:
"""Reset policy for the next round."""
raise NotImplementedError

@abstractmethod
def straggler_cutoff_check(
self, num_collaborators_done: int, num_all_collaborators: int, **kwargs
) -> bool:
"""
Determines whether the round should end early when straggler policy conditions are met.
Args:
num_collaborators_done: int
Number of collaborators finished.
num_all_collaborators: int
Total number of collaborators.
Returns:
bool: True if it is time to end the round early, False otherwise.
Raises:
NotImplementedError: This method must be implemented by a subclass.
"""
raise NotImplementedError


class CutoffTimePolicy(StragglerPolicy):
"""Cutoff time based Straggler Handling function."""

def __init__(
self, round_start_time=None, straggler_cutoff_time=np.inf, minimum_reporting=1, **kwargs
):
"""
Initialize a CutoffTimeBasedStragglerHandling object.
Initialize a CutoffTimePolicy object.
Args:
round_start_time (optional): The start time of the round. Defaults
to None.
straggler_cutoff_time (float, optional): The cutoff time for
stragglers. Defaults to np.inf.
minimum_reporting (int, optional): The minimum number of
collaborators that should report. Defaults to 1.
collaborators that should report before moving to the next round.
Defaults to 1.
**kwargs: Variable length argument list.
"""
if minimum_reporting <= 0:
Expand All @@ -40,21 +81,18 @@ def __init__(
self.round_start_time = round_start_time
self.straggler_cutoff_time = straggler_cutoff_time
self.minimum_reporting = minimum_reporting
self.logger = getLogger(__name__)
self.is_timer_started = False

if self.straggler_cutoff_time == np.inf:
self.logger.warning(
"CutoffTimeBasedStragglerHandling is disabled as straggler_cutoff_time "
"is set to np.inf."
logger.warning(
"CutoffTimePolicy is disabled as straggler_cutoff_time is set to np.inf."
)

def reset_policy_for_round(self) -> None:
"""
Reset timer for the next round.
"""
"""Reset timer for the next round."""
if hasattr(self, "timer"):
self.timer.cancel()
delattr(self, "timer")
self.is_timer_started = False

def start_policy(self, callback: Callable) -> None:
"""
Expand All @@ -64,22 +102,21 @@ def start_policy(self, callback: Callable) -> None:
Args:
callback: Callable
Callback function for when straggler_cutoff_time elapses
Returns:
None
"""
# If straggler_cutoff_time is set to infinity
# or if the timer is already running,
# do not start the policy.
if self.straggler_cutoff_time == np.inf or hasattr(self, "timer"):
if self.straggler_cutoff_time == np.inf or self.is_timer_started:
return

self.round_start_time = time.time()
self.timer = threading.Timer(
self.straggler_cutoff_time,
callback,
)
self.timer.daemon = True
self.timer.start()
self.is_timer_started = True

def straggler_cutoff_check(
self,
Expand Down Expand Up @@ -108,13 +145,13 @@ def straggler_cutoff_check(
# Time has expired
# Check if minimum_reporting collaborators have reported results
elif self.__minimum_collaborators_reported(num_collaborators_done):
self.logger.info(
logger.info(
f"{num_collaborators_done} collaborators have reported results. "
"Applying cutoff policy and proceeding with end of round."
)
return True
else:
self.logger.info(
logger.info(
f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results."
)
return False
Expand All @@ -141,3 +178,66 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool:
False otherwise.
"""
return num_collaborators_done >= self.minimum_reporting


class PercentagePolicy(StragglerPolicy):
"""Percentage based Straggler Handling function."""

def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs):
"""Initialize a PercentagePolicy object.
Args:
percent_collaborators_needed (float, optional): The percentage of
collaborators needed. Defaults to 1.0.
minimum_reporting (int, optional): The minimum number of
collaborators that should report. Defaults to 1.
**kwargs: Variable length argument list.
"""
if minimum_reporting <= 0:
raise ValueError("minimum_reporting must be >0")

self.percent_collaborators_needed = percent_collaborators_needed
self.minimum_reporting = minimum_reporting

def reset_policy_for_round(self) -> None:
"""Not required in PercentagePolicy."""
pass

def start_policy(self, **kwargs) -> None:
"""Not required in PercentagePolicy."""
pass

def straggler_cutoff_check(
self,
num_collaborators_done: int,
num_all_collaborators: int,
) -> bool:
"""
If percent_collaborators_needed and minimum_reporting collaborators have
reported results, then it is time to end round early.
Args:
num_collaborators_done (int): The number of collaborators that
have reported.
all_collaborators (list): All the collaborators.
Returns:
bool: True if the straggler cutoff conditions are met, False
otherwise.
"""
return (
num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators
) and self.__minimum_collaborators_reported(num_collaborators_done)

def __minimum_collaborators_reported(self, num_collaborators_done) -> bool:
"""Check if the minimum number of collaborators have reported.
Args:
num_collaborators_done (int): The number of collaborators that
have reported.
Returns:
bool: True if the minimum number of collaborators have reported,
False otherwise.
"""
return num_collaborators_done >= self.minimum_reporting
13 changes: 0 additions & 13 deletions openfl/component/straggler_handling_functions/__init__.py

This file was deleted.

Loading

0 comments on commit c71080e

Please sign in to comment.