From ea1541bc9763c38bbf7894a4ecafd700d7091639 Mon Sep 17 00:00:00 2001 From: Shailesh Pant Date: Thu, 23 Jan 2025 06:36:01 +0000 Subject: [PATCH] - Fix the leaky abstraction of task_group to Aggregator - Added is_task_group_evaluation function in Assigner class - Fix existing aggregator interface test cases Signed-off-by: Shailesh Pant --- .../workspace/plan/defaults/aggregator.yaml | 2 +- openfl/component/aggregator/aggregator.py | 20 ++++++------------- openfl/component/assigner/assigner.py | 10 ++++++++++ openfl/interface/aggregator.py | 7 +++---- .../component/aggregator/test_aggregator.py | 5 +---- tests/openfl/interface/test_aggregator_api.py | 15 -------------- 6 files changed, 21 insertions(+), 38 deletions(-) diff --git a/openfl-workspace/workspace/plan/defaults/aggregator.yaml b/openfl-workspace/workspace/plan/defaults/aggregator.yaml index 2a233287eb..5ef44847f6 100644 --- a/openfl-workspace/workspace/plan/defaults/aggregator.yaml +++ b/openfl-workspace/workspace/plan/defaults/aggregator.yaml @@ -3,4 +3,4 @@ settings : db_store_rounds : 2 persist_checkpoint: True persistent_db_path: local_state/tensor.db - task_group: learning + diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 8a9cd757e8..07c1d38ac1 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -84,7 +84,6 @@ def __init__( callbacks: Optional[List] = None, persist_checkpoint=True, persistent_db_path=None, - task_group: str = "learning", ): """Initializes the Aggregator. @@ -111,9 +110,7 @@ def __init__( Defaults to 1. initial_tensor_dict (dict, optional): Initial tensor dictionary. callbacks: List of callbacks to be used during the experiment. - task_group (str, optional): Selected task_group for assignment. """ - self.task_group = task_group self.round_number = 0 self.next_model_round_number = 0 @@ -132,11 +129,10 @@ def __init__( ) self.rounds_to_train = rounds_to_train - if self.task_group == "evaluation": + self.assigner = assigner + if self.assigner.is_task_group_evaluation(): self.rounds_to_train = 1 - logger.info( - f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}" - ) + logger.info(f"For evaluation tasks setting rounds_to_train = {self.rounds_to_train}") self._end_of_round_check_done = [False] * rounds_to_train self.stragglers = [] @@ -145,11 +141,7 @@ def __init__( self.authorized_cols = authorized_cols self.uuid = aggregator_uuid self.federation_uuid = federation_uuid - # # override the assigner selected_task_group - # # FIXME check the case of CustomAssigner as base class Assigner is redefined - # # and doesn't have selected_task_group as attribute - # assigner.selected_task_group = task_group - self.assigner = assigner + self.quit_job_sent_to = [] self.tensor_db = TensorDB() @@ -311,8 +303,8 @@ def _load_initial_tensors(self): ) # Check selected task_group before updating round number - if self.task_group == "evaluation": - logger.info(f"Skipping round_number check for {self.task_group} task_group") + if self.assigner.is_task_group_evaluation(): + logger.info("Skipping round_number check for evaluation run") elif round_number > self.round_number: logger.info(f"Starting training from round {round_number} of previously saved model") self.round_number = round_number diff --git a/openfl/component/assigner/assigner.py b/openfl/component/assigner/assigner.py index 0b5fc36e88..d49a68ffd6 100644 --- a/openfl/component/assigner/assigner.py +++ b/openfl/component/assigner/assigner.py @@ -81,6 +81,16 @@ def get_collaborators_for_task(self, task_name, round_number): """Abstract method.""" raise NotImplementedError + def is_task_group_evaluation(self): + """Check if the selected task group is for 'evaluation' run. + + Returns: + bool: True if the selected task group is 'evaluation', False otherwise. + """ + if hasattr(self, "selected_task_group"): + return self.selected_task_group == "evaluation" + return False + def get_all_tasks_for_round(self, round_number): """Return tasks for the current round. diff --git a/openfl/interface/aggregator.py b/openfl/interface/aggregator.py index 8c922eee19..73f25311f7 100644 --- a/openfl/interface/aggregator.py +++ b/openfl/interface/aggregator.py @@ -94,10 +94,9 @@ def start_(plan, authorized_cols, task_group): cols_config_path=Path(authorized_cols).absolute(), ) - # Set task_group in aggregator settings - if "settings" not in parsed_plan.config["aggregator"]: - parsed_plan.config["aggregator"]["settings"] = {} - parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group + # Set selected_task_group in assigner settings + if "settings" not in parsed_plan.config["assigner"]: + parsed_plan.config["assigner"]["settings"] = {} parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group logger.info(f"Setting aggregator to assign: {task_group} task_group") diff --git a/tests/openfl/component/aggregator/test_aggregator.py b/tests/openfl/component/aggregator/test_aggregator.py index f90b457925..f9883fd7b7 100644 --- a/tests/openfl/component/aggregator/test_aggregator.py +++ b/tests/openfl/component/aggregator/test_aggregator.py @@ -48,16 +48,13 @@ def agg(mocker, model, assigner): 'some_uuid', 'federation_uuid', ['col1', 'col2'], - 'init_state_path', 'best_state_path', 'last_state_path', - assigner, - ) + ) return agg - @pytest.mark.parametrize( 'cert_common_name,collaborator_common_name,authorized_cols,single_cccn,expected_is_valid', [ ('col1', 'col1', ['col1', 'col2'], '', True), diff --git a/tests/openfl/interface/test_aggregator_api.py b/tests/openfl/interface/test_aggregator_api.py index 0e5aa963dc..9a010e3f10 100644 --- a/tests/openfl/interface/test_aggregator_api.py +++ b/tests/openfl/interface/test_aggregator_api.py @@ -23,11 +23,6 @@ def test_aggregator_start(mock_parse): mock_plan.get = {'task_group': 'learning'}.get # Add the config attribute with proper nesting mock_plan.config = { - 'aggregator': { - 'settings': { - 'task_group': 'learning' - } - }, 'assigner': { 'settings': { 'selected_task_group': 'learning' @@ -55,11 +50,6 @@ def test_aggregator_start_illegal_plan(mock_parse, mock_is_directory_traversal): mock_plan.get = {'task_group': 'learning'}.get # Add the config attribute with proper nesting mock_plan.config = { - 'aggregator': { - 'settings': { - 'task_group': 'learning' - } - }, 'assigner': { 'settings': { 'selected_task_group': 'learning' @@ -89,11 +79,6 @@ def test_aggregator_start_illegal_cols(mock_parse, mock_is_directory_traversal): mock_plan.get = {'task_group': 'learning'}.get # Add the config attribute with proper nesting mock_plan.config = { - 'aggregator': { - 'settings': { - 'task_group': 'learning' - } - }, 'assigner': { 'settings': { 'selected_task_group': 'learning'