Skip to content

Commit

Permalink
- Fix the leaky abstraction of task_group to Aggregator
Browse files Browse the repository at this point in the history
- Added is_task_group_evaluation function in Assigner class

Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 23, 2025
1 parent 57f5094 commit 662f997
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: local_state/tensor.db
task_group: learning

20 changes: 6 additions & 14 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(
callbacks: Optional[List] = None,
persist_checkpoint=True,
persistent_db_path=None,
task_group: str = "learning",
):
"""Initializes the Aggregator.
Expand All @@ -111,9 +110,7 @@ def __init__(
Defaults to 1.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
task_group (str, optional): Selected task_group for assignment.
"""
self.task_group = task_group
self.round_number = 0
self.next_model_round_number = 0

Expand All @@ -132,11 +129,10 @@ def __init__(
)

self.rounds_to_train = rounds_to_train
if self.task_group == "evaluation":
self.assigner = assigner
if self.assigner.is_task_group_evaluation():
self.rounds_to_train = 1
logger.info(
f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}"
)
logger.info(f"For evaluation tasks setting rounds_to_train = {self.rounds_to_train}")

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []
Expand All @@ -145,11 +141,7 @@ def __init__(
self.authorized_cols = authorized_cols
self.uuid = aggregator_uuid
self.federation_uuid = federation_uuid
# # override the assigner selected_task_group
# # FIXME check the case of CustomAssigner as base class Assigner is redefined
# # and doesn't have selected_task_group as attribute
# assigner.selected_task_group = task_group
self.assigner = assigner

self.quit_job_sent_to = []

self.tensor_db = TensorDB()
Expand Down Expand Up @@ -311,8 +303,8 @@ def _load_initial_tensors(self):
)

# Check selected task_group before updating round number
if self.task_group == "evaluation":
logger.info(f"Skipping round_number check for {self.task_group} task_group")
if self.assigner.is_task_group_evaluation():
logger.info("Skipping round_number check for evaluation run")
elif round_number > self.round_number:
logger.info(f"Starting training from round {round_number} of previously saved model")
self.round_number = round_number
Expand Down
10 changes: 10 additions & 0 deletions openfl/component/assigner/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ def get_collaborators_for_task(self, task_name, round_number):
"""Abstract method."""
raise NotImplementedError

def is_task_group_evaluation(self):
"""Check if the selected task group is for 'evaluation' run.
Returns:
bool: True if the selected task group is 'evaluation', False otherwise.
"""
if hasattr(self, "selected_task_group"):
return self.selected_task_group == "evaluation"
return False

def get_all_tasks_for_round(self, round_number):
"""Return tasks for the current round.
Expand Down
1 change: 0 additions & 1 deletion openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def start_(plan, authorized_cols, task_group):
# Set task_group in aggregator settings
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

Expand Down
5 changes: 1 addition & 4 deletions tests/openfl/component/aggregator/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,13 @@ def agg(mocker, model, assigner):
'some_uuid',
'federation_uuid',
['col1', 'col2'],

'init_state_path',
'best_state_path',
'last_state_path',

assigner,
)
)
return agg


@pytest.mark.parametrize(
'cert_common_name,collaborator_common_name,authorized_cols,single_cccn,expected_is_valid', [
('col1', 'col1', ['col1', 'col2'], '', True),
Expand Down

0 comments on commit 662f997

Please sign in to comment.