Skip to content

Commit fc89bec

Browse files
committed
- implement a new mode_based_filtering decorator in Assigner class
- update all the sub-classes that use task_groups to use the decorator - update fedeval sample workspace to use default assigner, tasks and aggregator - use of federated-evaluation/aggregator.yaml for FedEval specific workspace example to use round_number as 1 - removed assigner and tasks yaml from defaults/federated-evaluation, superseded by default assigner/tasks - Rebase 21-Jan-2025.2 - added additional checks for assigner sub-classes that might not have task_groups - Addressing review comments Signed-off-by: Shailesh Pant <[email protected]>
1 parent 8104144 commit fc89bec

File tree

9 files changed

+68
-19
lines changed

9 files changed

+68
-19
lines changed

openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ network :
3232
defaults : plan/defaults/network.yaml
3333

3434
assigner :
35-
defaults : plan/defaults/federated-evaluation/assigner.yaml
36-
35+
defaults : plan/defaults/assigner.yaml
36+
settings :
37+
mode : evaluate
38+
3739
tasks :
38-
defaults : plan/defaults/federated-evaluation/tasks_torch.yaml
40+
defaults : plan/defaults/tasks_torch.yaml
3941

4042
compression_pipeline :
4143
defaults : plan/defaults/compression_pipeline.yaml

openfl-workspace/workspace/plan/defaults/assigner.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ settings :
77
- aggregated_model_validation
88
- train
99
- locally_tuned_model_validation
10+
- name : evaluation
11+
percentage : 1.0
12+
tasks :
13+
- aggregated_model_validation
14+
selected_task_group: learning

openfl-workspace/workspace/plan/defaults/federated-evaluation/assigner.yaml

Lines changed: 0 additions & 7 deletions
This file was deleted.

openfl-workspace/workspace/plan/defaults/federated-evaluation/tasks_torch.yaml

Lines changed: 0 additions & 7 deletions
This file was deleted.

openfl/component/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2020-2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""OpenFL Component Module."""
45

56
from openfl.component.aggregator.aggregator import Aggregator
67
from openfl.component.assigner.assigner import Assigner

openfl/component/assigner/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2020-2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""OpenFL Assigner Module."""
45

56
from openfl.component.assigner.assigner import Assigner
67
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner

openfl/component/assigner/assigner.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
"""Assigner module."""
66

7+
import logging
8+
from functools import wraps
9+
10+
logger = logging.getLogger(__name__)
11+
712

813
class Assigner:
914
r"""
@@ -35,18 +40,27 @@ class Assigner:
3540
\* - ``tasks`` argument is taken from ``tasks`` section of FL plan YAML file.
3641
"""
3742

38-
def __init__(self, tasks, authorized_cols, rounds_to_train, **kwargs):
43+
def __init__(
44+
self,
45+
tasks,
46+
authorized_cols,
47+
rounds_to_train,
48+
selected_task_group: str = "learning",
49+
**kwargs,
50+
):
3951
"""Initializes the Assigner.
4052
4153
Args:
4254
tasks (list of object): List of tasks to assign.
4355
authorized_cols (list of str): Collaborators.
4456
rounds_to_train (int): Number of training rounds.
57+
selected_task_group (str, optional): Selected task_group. Defaults to "learning".
4558
**kwargs: Additional keyword arguments.
4659
"""
4760
self.tasks = tasks
4861
self.authorized_cols = authorized_cols
4962
self.rounds = rounds_to_train
63+
self.selected_task_group = selected_task_group
5064
self.all_tasks_in_groups = []
5165

5266
self.task_group_collaborators = {}
@@ -93,3 +107,41 @@ def get_aggregation_type_for_task(self, task_name):
93107
if "aggregation_type" not in self.tasks[task_name]:
94108
return None
95109
return self.tasks[task_name]["aggregation_type"]
110+
111+
@classmethod
112+
def task_group_filtering(cls, func):
113+
"""Decorator to filter task groups based on selected_task_group.
114+
115+
This decorator should be applied to define_task_assignments() method
116+
in Assigner subclasses to handle task_group filtering.
117+
"""
118+
119+
@wraps(func)
120+
def wrapper(self, *args, **kwargs):
121+
# First check if selection of task_group is applicable
122+
if hasattr(self, "selected_task_group"):
123+
# Verify task_groups exists before attempting filtering
124+
if not hasattr(self, "task_groups"):
125+
logger.warning(
126+
"Task group specified for selection but no task_groups found. "
127+
"Skipping filtering. This might be intentional for custom assigners."
128+
)
129+
return func(self, *args, **kwargs)
130+
131+
assert self.task_groups, "No task_groups defined in assigner."
132+
133+
# Perform the filtering
134+
self.task_groups = [
135+
group for group in self.task_groups if group["name"] == self.selected_task_group
136+
]
137+
138+
assert self.task_groups, f"No task groups found for : {self.selected_task_group}"
139+
140+
# Mode-specific validations
141+
if self.selected_task_group == "evaluation":
142+
assert self.rounds == 1, "Number of rounds should be 1 for evaluation"
143+
144+
# Call the original method
145+
return func(self, *args, **kwargs)
146+
147+
return wrapper

openfl/component/assigner/random_grouped_assigner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ def __init__(self, task_groups, **kwargs):
3838
3939
Args:
4040
task_groups (list of object): Task groups to assign.
41-
**kwargs: Additional keyword arguments.
41+
**kwargs: Additional keyword arguments, including mode.
4242
"""
4343
self.task_groups = task_groups
4444
super().__init__(**kwargs)
4545

46+
@Assigner.task_group_filtering
4647
def define_task_assignments(self):
4748
"""Define task assignments for each round and collaborator.
4849

openfl/component/assigner/static_grouped_assigner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, task_groups, **kwargs):
4242
self.task_groups = task_groups
4343
super().__init__(**kwargs)
4444

45+
@Assigner.task_group_filtering
4546
def define_task_assignments(self):
4647
"""Define task assignments for each round and collaborator.
4748

0 commit comments

Comments
 (0)