Skip to content

Commit 10b6993

Browse files
committed
- implement a new task_group filtering decorator in Assigner class
- update all the sub-classes that use task_groups to use the decorator - update fedeval sample workspace to use default assigner, tasks and aggregator - use of federated-evaluation/aggregator.yaml for FedEval specific workspace example to use round_number as 1 - removed assigner and tasks yaml from defaults/federated-evaluation, superseded by default assigner/tasks - added additional checks for assigner sub-classes that might not have task_groups - Addressing review comments - Updated existing test cases for Assigner sub-classes - Remove hard-coded setting in assigner for torch_cnn_mnist ws, refer to default as in other Workspaces - Use aggregator supplied --task_group to override the assinger selected_task_group - update existing test cases of aggregator cli - add test cases for the decorator - rebased 23-Jan.2 Signed-off-by: Shailesh Pant <[email protected]>
1 parent ecd3603 commit 10b6993

File tree

17 files changed

+156
-51
lines changed

17 files changed

+156
-51
lines changed

openfl-workspace/torch_cnn_mnist/plan/plan.yaml

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,8 @@ aggregator:
1010
rounds_to_train: 2
1111
write_logs: false
1212
template: openfl.component.aggregator.Aggregator
13-
assigner:
14-
settings:
15-
task_groups:
16-
- name: learning
17-
percentage: 1.0
18-
tasks:
19-
- aggregated_model_validation
20-
- train
21-
- locally_tuned_model_validation
22-
template: openfl.component.RandomGroupedAssigner
13+
assigner :
14+
defaults : plan/defaults/assigner.yaml
2315
collaborator:
2416
settings:
2517
db_store_rounds: 1

openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ network :
3232
defaults : plan/defaults/network.yaml
3333

3434
assigner :
35-
defaults : plan/defaults/federated-evaluation/assigner.yaml
36-
35+
defaults : plan/defaults/assigner.yaml
36+
settings :
37+
selected_task_group : evaluation
38+
3739
tasks :
38-
defaults : plan/defaults/federated-evaluation/tasks_torch.yaml
40+
defaults : plan/defaults/tasks_torch.yaml
3941

4042
compression_pipeline :
4143
defaults : plan/defaults/compression_pipeline.yaml

openfl-workspace/workspace/plan/defaults/aggregator.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ settings :
33
db_store_rounds : 2
44
persist_checkpoint: True
55
persistent_db_path: local_state/tensor.db
6+
task_group: learning

openfl-workspace/workspace/plan/defaults/assigner.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ settings :
77
- aggregated_model_validation
88
- train
99
- locally_tuned_model_validation
10+
- name : evaluation
11+
percentage : 1.0
12+
tasks :
13+
- aggregated_model_validation

openfl-workspace/workspace/plan/defaults/federated-evaluation/assigner.yaml

Lines changed: 0 additions & 7 deletions
This file was deleted.

openfl-workspace/workspace/plan/defaults/federated-evaluation/tasks_torch.yaml

Lines changed: 0 additions & 7 deletions
This file was deleted.

openfl/component/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2020-2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""OpenFL Component Module."""
45

56
from openfl.component.aggregator.aggregator import Aggregator
67
from openfl.component.assigner.assigner import Assigner

openfl/component/aggregator/aggregator.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,25 @@ def __init__(
130130
self.straggler_handling_policy = (
131131
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
132132
)
133-
self._end_of_round_check_done = [False] * rounds_to_train
134-
self.stragglers = []
135133

136134
self.rounds_to_train = rounds_to_train
135+
if self.task_group == "evaluation":
136+
self.rounds_to_train = 1
137+
logger.info(
138+
f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}"
139+
)
140+
141+
self._end_of_round_check_done = [False] * rounds_to_train
142+
self.stragglers = []
137143

138144
# if the collaborator requests a delta, this value is set to true
139145
self.authorized_cols = authorized_cols
140146
self.uuid = aggregator_uuid
141147
self.federation_uuid = federation_uuid
148+
# # override the assigner selected_task_group
149+
# # FIXME check the case of CustomAssigner as base class Assigner is redefined
150+
# # and doesn't have selected_task_group as attribute
151+
# assigner.selected_task_group = task_group
142152
self.assigner = assigner
143153
self.quit_job_sent_to = []
144154

openfl/component/assigner/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2020-2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""OpenFL Assigner Module."""
45

56
from openfl.component.assigner.assigner import Assigner
67
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner

openfl/component/assigner/assigner.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
"""Assigner module."""
66

7+
import logging
8+
from functools import wraps
9+
10+
logger = logging.getLogger(__name__)
11+
712

813
class Assigner:
914
r"""
@@ -35,18 +40,27 @@ class Assigner:
3540
\* - ``tasks`` argument is taken from ``tasks`` section of FL plan YAML file.
3641
"""
3742

38-
def __init__(self, tasks, authorized_cols, rounds_to_train, **kwargs):
43+
def __init__(
44+
self,
45+
tasks,
46+
authorized_cols,
47+
rounds_to_train,
48+
selected_task_group: str = "learning",
49+
**kwargs,
50+
):
3951
"""Initializes the Assigner.
4052
4153
Args:
4254
tasks (list of object): List of tasks to assign.
4355
authorized_cols (list of str): Collaborators.
4456
rounds_to_train (int): Number of training rounds.
57+
selected_task_group (str, optional): Selected task_group. Defaults to "learning".
4558
**kwargs: Additional keyword arguments.
4659
"""
4760
self.tasks = tasks
4861
self.authorized_cols = authorized_cols
4962
self.rounds = rounds_to_train
63+
self.selected_task_group = selected_task_group
5064
self.all_tasks_in_groups = []
5165

5266
self.task_group_collaborators = {}
@@ -93,3 +107,37 @@ def get_aggregation_type_for_task(self, task_name):
93107
if "aggregation_type" not in self.tasks[task_name]:
94108
return None
95109
return self.tasks[task_name]["aggregation_type"]
110+
111+
@classmethod
112+
def task_group_filtering(cls, func):
113+
"""Decorator to filter task groups based on selected_task_group.
114+
115+
This decorator should be applied to define_task_assignments() method
116+
in Assigner subclasses to handle task_group filtering.
117+
"""
118+
119+
@wraps(func)
120+
def wrapper(self, *args, **kwargs):
121+
# First check if selection of task_group is applicable
122+
if hasattr(self, "selected_task_group"):
123+
# Verify task_groups exists before attempting filtering
124+
if not hasattr(self, "task_groups"):
125+
logger.warning(
126+
"Task group specified for selection but no task_groups found. "
127+
"Skipping filtering. This might be intentional for custom assigners."
128+
)
129+
return func(self, *args, **kwargs)
130+
131+
assert self.task_groups, "No task_groups defined in assigner."
132+
133+
# Perform the filtering
134+
self.task_groups = [
135+
group for group in self.task_groups if group["name"] == self.selected_task_group
136+
]
137+
138+
assert self.task_groups, f"No task groups found for : {self.selected_task_group}"
139+
140+
# Call the original method
141+
return func(self, *args, **kwargs)
142+
143+
return wrapper

0 commit comments

Comments
 (0)