Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Aggregator / Assigner leaky abstraction #1301

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions openfl-workspace/torch_cnn_mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,8 @@ aggregator:
rounds_to_train: 2
write_logs: false
template: openfl.component.aggregator.Aggregator
assigner:
settings:
task_groups:
- name: learning
percentage: 1.0
tasks:
- aggregated_model_validation
- train
- locally_tuned_model_validation
template: openfl.component.RandomGroupedAssigner
assigner :
defaults : plan/defaults/assigner.yaml
collaborator:
settings:
db_store_rounds: 1
Expand Down
8 changes: 5 additions & 3 deletions openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/federated-evaluation/assigner.yaml

defaults : plan/defaults/assigner.yaml
settings :
selected_task_group : evaluation

tasks :
defaults : plan/defaults/federated-evaluation/tasks_torch.yaml
defaults : plan/defaults/tasks_torch.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: local_state/tensor.db

4 changes: 4 additions & 0 deletions openfl-workspace/workspace/plan/defaults/assigner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ settings :
- aggregated_model_validation
- train
- locally_tuned_model_validation
- name : evaluation
percentage : 1.0
tasks :
- aggregated_model_validation

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Component Module."""

from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.assigner.assigner import Assigner
Expand Down
18 changes: 10 additions & 8 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(
callbacks: Optional[List] = None,
persist_checkpoint=True,
persistent_db_path=None,
task_group: str = "learning",
):
"""Initializes the Aggregator.

Expand All @@ -111,9 +110,7 @@ def __init__(
Defaults to 1.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
task_group (str, optional): Selected task_group for assignment.
"""
self.task_group = task_group
self.round_number = 0
self.next_model_round_number = 0

Expand All @@ -130,16 +127,21 @@ def __init__(
self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

self.rounds_to_train = rounds_to_train
self.assigner = assigner
if self.assigner.is_task_group_evaluation():
self.rounds_to_train = 1
logger.info(f"For evaluation tasks setting rounds_to_train = {self.rounds_to_train}")

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

# if the collaborator requests a delta, this value is set to true
self.authorized_cols = authorized_cols
self.uuid = aggregator_uuid
self.federation_uuid = federation_uuid
self.assigner = assigner

self.quit_job_sent_to = []

self.tensor_db = TensorDB()
Expand Down Expand Up @@ -301,8 +303,8 @@ def _load_initial_tensors(self):
)

# Check selected task_group before updating round number
if self.task_group == "evaluation":
logger.info(f"Skipping round_number check for {self.task_group} task_group")
if self.assigner.is_task_group_evaluation():
logger.info("Skipping round_number check for evaluation run")
elif round_number > self.round_number:
logger.info(f"Starting training from round {round_number} of previously saved model")
self.round_number = round_number
Expand Down
1 change: 1 addition & 0 deletions openfl/component/assigner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Assigner Module."""

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
Expand Down
60 changes: 59 additions & 1 deletion openfl/component/assigner/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

"""Assigner module."""

import logging
from functools import wraps

logger = logging.getLogger(__name__)


class Assigner:
r"""
Expand Down Expand Up @@ -35,18 +40,27 @@ class Assigner:
\* - ``tasks`` argument is taken from ``tasks`` section of FL plan YAML file.
"""

def __init__(self, tasks, authorized_cols, rounds_to_train, **kwargs):
def __init__(
self,
tasks,
authorized_cols,
rounds_to_train,
selected_task_group: str = "learning",
**kwargs,
):
"""Initializes the Assigner.

Args:
tasks (list of object): List of tasks to assign.
authorized_cols (list of str): Collaborators.
rounds_to_train (int): Number of training rounds.
selected_task_group (str, optional): Selected task_group. Defaults to "learning".
**kwargs: Additional keyword arguments.
"""
self.tasks = tasks
self.authorized_cols = authorized_cols
self.rounds = rounds_to_train
self.selected_task_group = selected_task_group
self.all_tasks_in_groups = []

self.task_group_collaborators = {}
Expand All @@ -67,6 +81,16 @@ def get_collaborators_for_task(self, task_name, round_number):
"""Abstract method."""
raise NotImplementedError

def is_task_group_evaluation(self):
"""Check if the selected task group is for 'evaluation' run.

Returns:
bool: True if the selected task group is 'evaluation', False otherwise.
"""
if hasattr(self, "selected_task_group"):
return self.selected_task_group == "evaluation"
return False

def get_all_tasks_for_round(self, round_number):
"""Return tasks for the current round.

Expand All @@ -93,3 +117,37 @@ def get_aggregation_type_for_task(self, task_name):
if "aggregation_type" not in self.tasks[task_name]:
return None
return self.tasks[task_name]["aggregation_type"]

@classmethod
def task_group_filtering(cls, func):
"""Decorator to filter task groups based on selected_task_group.

This decorator should be applied to define_task_assignments() method
in Assigner subclasses to handle task_group filtering.
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# First check if selection of task_group is applicable
if hasattr(self, "selected_task_group"):
# Verify task_groups exists before attempting filtering
if not hasattr(self, "task_groups"):
logger.warning(
"Task group specified for selection but no task_groups found. "
"Skipping filtering. This might be intentional for custom assigners."
)
return func(self, *args, **kwargs)

assert self.task_groups, "No task_groups defined in assigner."

# Perform the filtering
self.task_groups = [
group for group in self.task_groups if group["name"] == self.selected_task_group
]

assert self.task_groups, f"No task groups found for : {self.selected_task_group}"

# Call the original method
return func(self, *args, **kwargs)

return wrapper
7 changes: 5 additions & 2 deletions openfl/component/assigner/random_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner import Assigner


class RandomGroupedAssigner(Assigner):
Expand All @@ -33,16 +33,19 @@ class RandomGroupedAssigner(Assigner):
\* - Plan setting.
"""

task_group_filtering = Assigner.task_group_filtering

def __init__(self, task_groups, **kwargs):
"""Initializes the RandomGroupedAssigner.

Args:
task_groups (list of object): Task groups to assign.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments, including mode.
"""
self.task_groups = task_groups
super().__init__(**kwargs)

@task_group_filtering
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.

Expand Down
3 changes: 3 additions & 0 deletions openfl/component/assigner/static_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class StaticGroupedAssigner(Assigner):
\* - Plan setting.
"""

task_group_filtering = Assigner.task_group_filtering

def __init__(self, task_groups, **kwargs):
"""Initializes the StaticGroupedAssigner.

Expand All @@ -42,6 +44,7 @@ def __init__(self, task_groups, **kwargs):
self.task_groups = task_groups
super().__init__(**kwargs)

@task_group_filtering
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.

Expand Down
8 changes: 4 additions & 4 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def start_(plan, authorized_cols, task_group):
cols_config_path=Path(authorized_cols).absolute(),
)

# Set task_group in aggregator settings
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
# Set selected_task_group in assigner settings
if "settings" not in parsed_plan.config["assigner"]:
parsed_plan.config["assigner"]["settings"] = {}
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

logger.info("🧿 Starting the Aggregator Service.")
Expand Down
5 changes: 1 addition & 4 deletions tests/openfl/component/aggregator/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,13 @@ def agg(mocker, model, assigner):
'some_uuid',
'federation_uuid',
['col1', 'col2'],

'init_state_path',
'best_state_path',
'last_state_path',

assigner,
)
)
return agg


@pytest.mark.parametrize(
'cert_common_name,collaborator_common_name,authorized_cols,single_cccn,expected_is_valid', [
('col1', 'col1', ['col1', 'col2'], '', True),
Expand Down
28 changes: 18 additions & 10 deletions tests/openfl/component/assigner/test_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def assigner():

def test_get_aggregation_type_for_task_none(assigner):
"""Assert that aggregation type of custom task is None."""
task_name = 'test_name'
task_name = "test_name"
tasks = {task_name: {}}

assigner = assigner(tasks, None, None)
Expand All @@ -31,11 +31,9 @@ def test_get_aggregation_type_for_task_none(assigner):

def test_get_aggregation_type_for_task(assigner):
"""Assert that aggregation type of task is getting correctly."""
task_name = 'test_name'
test_aggregation_type = 'test_aggregation_type'
tasks = {task_name: {
'aggregation_type': test_aggregation_type
}}
task_name = "test_name"
test_aggregation_type = "test_aggregation_type"
tasks = {task_name: {"aggregation_type": test_aggregation_type}}
assigner = assigner(tasks, None, None)

aggregation_type = assigner.get_aggregation_type_for_task(task_name)
Expand All @@ -46,13 +44,23 @@ def test_get_aggregation_type_for_task(assigner):
def test_get_all_tasks_for_round(assigner):
"""Assert that assigner tasks object is list."""
assigner = Assigner(None, None, None)
tasks = assigner.get_all_tasks_for_round('test')
tasks = assigner.get_all_tasks_for_round("test")

assert isinstance(tasks, list)

def test_default_task_group(assigner):
"""Assert that by default learning task_group is assigned."""
assigner = Assigner(None,None,None)
assert assigner.selected_task_group == 'learning'

class TestNotImplError(TestCase):
def test_task_group_filtering_no_task_groups(assigner):
"""Assert that task_group_filtering does not filter when no task_groups are defined."""
assigner = Assigner(None,None,None)
assigner.selected_task_group = "test_group"
assigner.define_task_assignments()
assert not hasattr(assigner, "task_groups")

class TestNotImplError(TestCase):
def test_define_task_assignments(self):
# TODO: define_task_assignments is defined as a mock in multiple fixtures,
# which leads the function to behave as a mock here and other tests.
Expand All @@ -61,9 +69,9 @@ def test_define_task_assignments(self):
def test_get_tasks_for_collaborator(self):
with self.assertRaises(NotImplementedError):
assigner = Assigner(None, None, None)
assigner.get_tasks_for_collaborator('col1', 0)
assigner.get_tasks_for_collaborator("col1", 0)

def test_get_collaborators_for_task(self):
with self.assertRaises(NotImplementedError):
assigner = Assigner(None, None, None)
assigner.get_collaborators_for_task('task_name', 0)
assigner.get_collaborators_for_task("task_name", 0)
Loading
Loading