Skip to content

Commit

Permalink
- implement a new task_group filtering decorator in Assigner class
Browse files Browse the repository at this point in the history
- update all the sub-classes that use task_groups to use the decorator
- update fedeval sample workspace to use default assigner, tasks and aggregator
- use of federated-evaluation/aggregator.yaml for FedEval specific workspace example to use round_number as 1
- removed assigner and tasks yaml from defaults/federated-evaluation, superseded by default assigner/tasks
- Rebase 21-Jan-2025.2
- added additional checks for assigner sub-classes that might not have task_groups
- Addressing review comments
- Updated existing test cases for Assigner sub-classes
- Remove hard-coded setting in assigner for torch_cnn_mnist ws, refer to default as in other Workspaces
- Use aggregator supplied --task_group to override the assinger selected_task_group
- update existing test cases of aggregator cli
- add test cases for the decorator
Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 23, 2025
1 parent ecd3603 commit 455b691
Show file tree
Hide file tree
Showing 17 changed files with 156 additions and 51 deletions.
12 changes: 2 additions & 10 deletions openfl-workspace/torch_cnn_mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,8 @@ aggregator:
rounds_to_train: 2
write_logs: false
template: openfl.component.aggregator.Aggregator
assigner:
settings:
task_groups:
- name: learning
percentage: 1.0
tasks:
- aggregated_model_validation
- train
- locally_tuned_model_validation
template: openfl.component.RandomGroupedAssigner
assigner :
defaults : plan/defaults/assigner.yaml
collaborator:
settings:
db_store_rounds: 1
Expand Down
8 changes: 5 additions & 3 deletions openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/federated-evaluation/assigner.yaml

defaults : plan/defaults/assigner.yaml
settings :
selected_task_group : evaluation

tasks :
defaults : plan/defaults/federated-evaluation/tasks_torch.yaml
defaults : plan/defaults/tasks_torch.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: local_state/tensor.db
task_group: learning
4 changes: 4 additions & 0 deletions openfl-workspace/workspace/plan/defaults/assigner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ settings :
- aggregated_model_validation
- train
- locally_tuned_model_validation
- name : evaluation
percentage : 1.0
tasks :
- aggregated_model_validation

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Component Module."""

from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.assigner.assigner import Assigner
Expand Down
14 changes: 12 additions & 2 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,25 @@ def __init__(
self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

self.rounds_to_train = rounds_to_train
if self.task_group == "evaluation":
self.rounds_to_train = 1
logger.info(
f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}"
)

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

# if the collaborator requests a delta, this value is set to true
self.authorized_cols = authorized_cols
self.uuid = aggregator_uuid
self.federation_uuid = federation_uuid
# # override the assigner selected_task_group
# # FIXME check the case of CustomAssigner as base class Assigner is redefined
# # and doesn't have selected_task_group as attribute
# assigner.selected_task_group = task_group
self.assigner = assigner
self.quit_job_sent_to = []

Expand Down
1 change: 1 addition & 0 deletions openfl/component/assigner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Assigner Module."""

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
Expand Down
50 changes: 49 additions & 1 deletion openfl/component/assigner/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

"""Assigner module."""

import logging
from functools import wraps

logger = logging.getLogger(__name__)


class Assigner:
r"""
Expand Down Expand Up @@ -35,18 +40,27 @@ class Assigner:
\* - ``tasks`` argument is taken from ``tasks`` section of FL plan YAML file.
"""

def __init__(self, tasks, authorized_cols, rounds_to_train, **kwargs):
def __init__(
self,
tasks,
authorized_cols,
rounds_to_train,
selected_task_group: str = "learning",
**kwargs,
):
"""Initializes the Assigner.
Args:
tasks (list of object): List of tasks to assign.
authorized_cols (list of str): Collaborators.
rounds_to_train (int): Number of training rounds.
selected_task_group (str, optional): Selected task_group. Defaults to "learning".
**kwargs: Additional keyword arguments.
"""
self.tasks = tasks
self.authorized_cols = authorized_cols
self.rounds = rounds_to_train
self.selected_task_group = selected_task_group
self.all_tasks_in_groups = []

self.task_group_collaborators = {}
Expand Down Expand Up @@ -93,3 +107,37 @@ def get_aggregation_type_for_task(self, task_name):
if "aggregation_type" not in self.tasks[task_name]:
return None
return self.tasks[task_name]["aggregation_type"]

@classmethod
def task_group_filtering(cls, func):
"""Decorator to filter task groups based on selected_task_group.
This decorator should be applied to define_task_assignments() method
in Assigner subclasses to handle task_group filtering.
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# First check if selection of task_group is applicable
if hasattr(self, "selected_task_group"):
# Verify task_groups exists before attempting filtering
if not hasattr(self, "task_groups"):
logger.warning(
"Task group specified for selection but no task_groups found. "
"Skipping filtering. This might be intentional for custom assigners."
)
return func(self, *args, **kwargs)

assert self.task_groups, "No task_groups defined in assigner."

# Perform the filtering
self.task_groups = [
group for group in self.task_groups if group["name"] == self.selected_task_group
]

assert self.task_groups, f"No task groups found for : {self.selected_task_group}"

# Call the original method
return func(self, *args, **kwargs)

return wrapper
7 changes: 5 additions & 2 deletions openfl/component/assigner/random_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner import Assigner


class RandomGroupedAssigner(Assigner):
Expand All @@ -33,16 +33,19 @@ class RandomGroupedAssigner(Assigner):
\* - Plan setting.
"""

task_group_filtering = Assigner.task_group_filtering

def __init__(self, task_groups, **kwargs):
"""Initializes the RandomGroupedAssigner.
Args:
task_groups (list of object): Task groups to assign.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments, including mode.
"""
self.task_groups = task_groups
super().__init__(**kwargs)

@task_group_filtering
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.
Expand Down
3 changes: 3 additions & 0 deletions openfl/component/assigner/static_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class StaticGroupedAssigner(Assigner):
\* - Plan setting.
"""

task_group_filtering = Assigner.task_group_filtering

def __init__(self, task_groups, **kwargs):
"""Initializes the StaticGroupedAssigner.
Expand All @@ -42,6 +44,7 @@ def __init__(self, task_groups, **kwargs):
self.task_groups = task_groups
super().__init__(**kwargs)

@task_group_filtering
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.
Expand Down
1 change: 1 addition & 0 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def start_(plan, authorized_cols, task_group):
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

logger.info("🧿 Starting the Aggregator Service.")
Expand Down
28 changes: 18 additions & 10 deletions tests/openfl/component/assigner/test_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def assigner():

def test_get_aggregation_type_for_task_none(assigner):
"""Assert that aggregation type of custom task is None."""
task_name = 'test_name'
task_name = "test_name"
tasks = {task_name: {}}

assigner = assigner(tasks, None, None)
Expand All @@ -31,11 +31,9 @@ def test_get_aggregation_type_for_task_none(assigner):

def test_get_aggregation_type_for_task(assigner):
"""Assert that aggregation type of task is getting correctly."""
task_name = 'test_name'
test_aggregation_type = 'test_aggregation_type'
tasks = {task_name: {
'aggregation_type': test_aggregation_type
}}
task_name = "test_name"
test_aggregation_type = "test_aggregation_type"
tasks = {task_name: {"aggregation_type": test_aggregation_type}}
assigner = assigner(tasks, None, None)

aggregation_type = assigner.get_aggregation_type_for_task(task_name)
Expand All @@ -46,13 +44,23 @@ def test_get_aggregation_type_for_task(assigner):
def test_get_all_tasks_for_round(assigner):
"""Assert that assigner tasks object is list."""
assigner = Assigner(None, None, None)
tasks = assigner.get_all_tasks_for_round('test')
tasks = assigner.get_all_tasks_for_round("test")

assert isinstance(tasks, list)

def test_default_task_group(assigner):
"""Assert that by default learning task_group is assigned."""
assigner = Assigner(None,None,None)
assert assigner.selected_task_group == 'learning'

class TestNotImplError(TestCase):
def test_task_group_filtering_no_task_groups(assigner):
"""Assert that task_group_filtering does not filter when no task_groups are defined."""
assigner = Assigner(None,None,None)
assigner.selected_task_group = "test_group"
assigner.define_task_assignments()
assert not hasattr(assigner, "task_groups")

class TestNotImplError(TestCase):
def test_define_task_assignments(self):
# TODO: define_task_assignments is defined as a mock in multiple fixtures,
# which leads the function to behave as a mock here and other tests.
Expand All @@ -61,9 +69,9 @@ def test_define_task_assignments(self):
def test_get_tasks_for_collaborator(self):
with self.assertRaises(NotImplementedError):
assigner = Assigner(None, None, None)
assigner.get_tasks_for_collaborator('col1', 0)
assigner.get_tasks_for_collaborator("col1", 0)

def test_get_collaborators_for_task(self):
with self.assertRaises(NotImplementedError):
assigner = Assigner(None, None, None)
assigner.get_collaborators_for_task('task_name', 0)
assigner.get_collaborators_for_task("task_name", 0)
46 changes: 38 additions & 8 deletions tests/openfl/component/assigner/test_random_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from openfl.component.assigner import RandomGroupedAssigner
from openfl.component.assigner import RandomGroupedAssigner, Assigner

ROUNDS_TO_TRAIN = 10

Expand All @@ -14,13 +14,20 @@ def task_groups():
"""Initialize task groups."""
task_groups = [
{
'name': 'train_and_validate',
'name': 'learning',
'percentage': 1.0,
'tasks': [
'aggregated_model_validation',
'train',
'locally_tuned_model_validation'
]
},
{
'name': 'evaluation',
'percentage': 1.0,
'tasks': [
'aggregated_model_validation'
]
}
]
return task_groups
Expand All @@ -35,19 +42,42 @@ def authorized_cols():
@pytest.fixture
def assigner(task_groups, authorized_cols):
"""Initialize assigner."""
assigner = RandomGroupedAssigner

assigner = assigner(task_groups,
tasks=None,
authorized_cols=authorized_cols,
rounds_to_train=ROUNDS_TO_TRAIN)
assigner = RandomGroupedAssigner(
task_groups=task_groups, # Pass task_groups here
tasks=None,
authorized_cols=authorized_cols,
rounds_to_train=ROUNDS_TO_TRAIN
)
return assigner


def test_define_task_assignments(assigner):
"""Test `define_task_assignments` is working."""
assigner.define_task_assignments()

def test_check_default_task_group():
"""Assert that by default learning task_group is assigned."""
assigner = Assigner(None, None, None)
assert assigner.selected_task_group == 'learning'

@pytest.mark.parametrize('round_number', range(ROUNDS_TO_TRAIN))
def test_get_default_tasks_for_collaborator(assigner, task_groups,
authorized_cols, round_number):
"""Test that assigner tasks correspond to task groups defined."""
tasks = assigner.get_tasks_for_collaborator(
authorized_cols[0], round_number)
assert tasks == task_groups[0]['tasks']
assert assigner.selected_task_group == task_groups[0]['name']

# @pytest.mark.parametrize('round_number', range(ROUNDS_TO_TRAIN))
# def test_get_filtered_tasks_for_collaborator(assigner, task_groups,
# authorized_cols, round_number):
# """Test that assigner tasks correspond to task groups defined."""
# assigner.selected_task_group=task_groups[1]['name']
# assigner.define_task_assignments()
# tasks = assigner.get_tasks_for_collaborator(
# authorized_cols[0], round_number)
# assert tasks == task_groups[1]['tasks']

@pytest.mark.parametrize('round_number', range(ROUNDS_TO_TRAIN))
def test_get_tasks_for_collaborator(assigner, task_groups,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def task_groups(authorized_cols):
"""Initialize task groups."""
task_groups = [
{
'name': 'train_and_validate',
'name': 'learning',
'percentage': 1.0,
'collaborators': authorized_cols,
'tasks': [
Expand Down
Loading

0 comments on commit 455b691

Please sign in to comment.