Skip to content

Commit 2aa0b22

Browse files
authored
Merge branch 'develop' into straggler_handling_update
2 parents d117bd2 + fdad4fb commit 2aa0b22

File tree

12 files changed

+587
-50
lines changed

12 files changed

+587
-50
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ jobs:
2525
- name: Install linters
2626
run: |
2727
python -m pip install --upgrade pip
28-
pip install -r linters-requirements.txt
28+
pip install -r linters-requirements.txt
2929
- name: Lint with OpenFL-specific rules
3030
run: bash scripts/lint.sh

.github/workflows/taskrunner.yml

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,19 @@ env:
1717
jobs:
1818
build:
1919
if: github.event.pull_request.draft == false
20-
strategy:
21-
matrix:
22-
os: ['ubuntu-latest', 'windows-latest']
23-
python-version: ["3.10", "3.11", "3.12"]
24-
runs-on: ${{ matrix.os }}
20+
runs-on: ubuntu-latest
2521
timeout-minutes: 15
2622

2723
steps:
2824
- uses: actions/checkout@v3
2925
- name: Set up Python
3026
uses: actions/setup-python@v4
3127
with:
32-
python-version: ${{ matrix.python-version }}
28+
python-version: "3.10"
3329
- name: Install dependencies ubuntu
34-
if: matrix.os == 'ubuntu-latest'
3530
run: |
3631
python -m pip install --upgrade pip
3732
pip install .
38-
- name: Install dependencies windows
39-
if: matrix.os == 'windows-latest'
33+
- name: Task Runner API
4034
run: |
41-
python -m pip install --upgrade pip
42-
pip install .
43-
- name: Test TaskRunner API
44-
run: |
45-
python -m tests.github.test_hello_federation --template keras_cnn_mnist --fed_workspace aggregator --col1 col1 --col2 col2 --rounds-to-train 3 --save-model output_model
35+
python -m tests.github.test_hello_federation --template torch_cnn_mnist --fed_workspace aggregator --col1 collaborator1 --col2 collaborator2 --rounds-to-train 3 --save-model output_model

.github/workflows/ubuntu.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ env:
1313

1414
jobs:
1515
pytest-coverage: # from pytest_coverage.yml
16+
strategy:
17+
matrix:
18+
python-version: ["3.10", "3.11", "3.12"]
1619
runs-on: ubuntu-latest
1720
timeout-minutes: 15
1821

@@ -21,7 +24,7 @@ jobs:
2124
- name: Set up Python 3
2225
uses: actions/setup-python@v3
2326
with:
24-
python-version: "3.10"
27+
python-version: ${{ matrix.python-version }}
2528
- name: Install dependencies
2629
run: |
2730
python -m pip install --upgrade pip

.github/workflows/windows.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@ env:
1313

1414
jobs:
1515
pytest-coverage: # from pytest_coverage.yml
16+
strategy:
17+
matrix:
18+
python-version: ["3.10", "3.11", "3.12"]
1619
runs-on: windows-latest
1720
timeout-minutes: 15
21+
1822
steps:
1923
- uses: actions/checkout@v3
2024
- name: Set up Python 3
2125
uses: actions/setup-python@v3
2226
with:
23-
python-version: "3.10"
27+
python-version: ${{ matrix.python-version }}
2428
- name: Install dependencies
2529
run: |
2630
python -m pip install --upgrade pip

docs/about/features_index/taskrunner.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ Configurable Settings
4444
- :code:`best_state_path`: (str:path) Defines the weight protobuf file path that will be saved to for the highest accuracy model during the experiment.
4545
- :code:`last_state_path`: (str:path) Defines the weight protobuf file path that will be saved to during the last round completed in each experiment.
4646
- :code:`rounds_to_train`: (int) Specifies the number of rounds in a federation. A federated learning round is defined as one complete iteration when the collaborators train the model and send the updated model weights back to the aggregator to form a new global model. Within a round, collaborators can train the model for multiple iterations called epochs.
47-
- :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard <https://www.tensorflow.org/tensorboard/get_started>`_ but users can also use custom metric logging function for each task.
48-
47+
- :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard <https://www.tensorflow.org/tensorboard/get_started>`_ but users can also use custom metric logging function for each task.
48+
- :code:`persist_checkpoint`: (boolean) Specifies whether to enable the storage of a persistent checkpoint in non-volatile storage for recovery purposes. When enabled, the aggregator will restore its state to what it was prior to the restart, ensuring continuity after a restart.
49+
- :code:`persistent_db_path`: (str:path) Defines the persisted database path.
4950

5051
- :class:`Collaborator <openfl.component.Collaborator>`
5152
`openfl.component.Collaborator <https://github.com/intel/openfl/blob/develop/openfl/component/collaborator/collaborator.py>`_

openfl-tutorials/experimental/workflow/105_Numpy_Linear_Regression_Workflow.ipynb

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@
4242
"metadata": {},
4343
"outputs": [],
4444
"source": [
45-
"# Below code will display the print statement output on screen as well\n",
46-
"import sys\n",
47-
"sys.stdout = open('/dev/stdout', 'w')\n",
48-
"\n",
4945
"!pip install git+https://github.com/securefederatedai/openfl.git\n",
5046
"!pip install -r workflow_interface_requirements.txt\n",
5147
"!pip install matplotlib\n",
@@ -308,7 +304,7 @@
308304
" self.current_round += 1\n",
309305
" if self.current_round < self.rounds:\n",
310306
" self.next(self.aggregated_model_validation,\n",
311-
" foreach='collaborators', exclude=['private'])\n",
307+
" foreach='collaborators')\n",
312308
" else:\n",
313309
" self.next(self.end)\n",
314310
"\n",
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
template : openfl.component.Aggregator
22
settings :
33
db_store_rounds : 2
4+
persist_checkpoint: True
5+
persistent_db_path: save/tensor.db

openfl/component/aggregator/aggregator.py

Lines changed: 151 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212

1313
import openfl.callbacks as callbacks_module
1414
from openfl.component.aggregator.straggler_handling import CutoffPolicy, StragglerPolicy
15-
from openfl.databases import TensorDB
15+
from openfl.databases import PersistentTensorDB, TensorDB
1616
from openfl.interface.aggregation_functions import WeightedAverage
1717
from openfl.pipelines import NoCompressionPipeline, TensorCodec
1818
from openfl.protocols import base_pb2, utils
19+
from openfl.protocols.base_pb2 import NamedTensor
1920
from openfl.utilities import TaskResultKey, TensorKey, change_tags
2021

2122
logger = logging.getLogger(__name__)
@@ -82,6 +83,8 @@ def __init__(
8283
log_memory_usage=False,
8384
write_logs=False,
8485
callbacks: Optional[List] = None,
86+
persist_checkpoint=True,
87+
persistent_db_path=None,
8588
):
8689
"""Initializes the Aggregator.
8790
@@ -109,6 +112,7 @@ def __init__(
109112
callbacks: List of callbacks to be used during the experiment.
110113
"""
111114
self.round_number = 0
115+
self.next_model_round_number = 0
112116

113117
if single_col_cert_common_name:
114118
logger.warning(
@@ -135,6 +139,16 @@ def __init__(
135139
self.quit_job_sent_to = []
136140

137141
self.tensor_db = TensorDB()
142+
if persist_checkpoint:
143+
persistent_db_path = persistent_db_path or "tensor.db"
144+
logger.info(
145+
"Persistent checkpoint is enabled, setting persistent db at path %s",
146+
persistent_db_path,
147+
)
148+
self.persistent_db = PersistentTensorDB(persistent_db_path)
149+
else:
150+
logger.info("Persistent checkpoint is disabled")
151+
self.persistent_db = None
138152
# FIXME: I think next line generates an error on the second round
139153
# if it is set to 1 for the aggregator.
140154
self.db_store_rounds = db_store_rounds
@@ -152,8 +166,25 @@ def __init__(
152166
# TODO: Remove. Used in deprecated interactive and native APIs
153167
self.best_tensor_dict: dict = {}
154168
self.last_tensor_dict: dict = {}
169+
# these enable getting all tensors for a task
170+
self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}
171+
self.collaborator_task_weight = {} # {TaskResultKey: data_size}
155172

156-
if initial_tensor_dict:
173+
# maintain a list of collaborators that have completed task and
174+
# reported results in a given round
175+
self.collaborators_done = []
176+
# Initialize a lock for thread safety
177+
self.lock = Lock()
178+
self.use_delta_updates = use_delta_updates
179+
180+
self.model = None # Initialize the model attribute to None
181+
if self.persistent_db and self._recover():
182+
logger.info("recovered state of aggregator")
183+
184+
# The model is built by recovery if at least one round has finished
185+
if self.model:
186+
logger.info("Model was loaded by recovery")
187+
elif initial_tensor_dict:
157188
self._load_initial_tensors_from_dict(initial_tensor_dict)
158189
self.model = utils.construct_model_proto(
159190
tensor_dict=initial_tensor_dict,
@@ -166,20 +197,6 @@ def __init__(
166197

167198
self.collaborator_tensor_results = {} # {TensorKey: nparray}}
168199

169-
# these enable getting all tensors for a task
170-
self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}
171-
172-
self.collaborator_task_weight = {} # {TaskResultKey: data_size}
173-
174-
# maintain a list of collaborators that have completed task and
175-
# reported results in a given round
176-
self.collaborators_done = []
177-
178-
# Initialize a lock for thread safety
179-
self.lock = Lock()
180-
181-
self.use_delta_updates = use_delta_updates
182-
183200
# Callbacks
184201
self.callbacks = callbacks_module.CallbackList(
185202
callbacks,
@@ -193,6 +210,79 @@ def __init__(
193210
self.callbacks.on_experiment_begin()
194211
self.callbacks.on_round_begin(self.round_number)
195212

213+
def _recover(self):
214+
"""Populates the aggregator state to the state it was prior a restart"""
215+
recovered = False
216+
# load tensors persistent DB
217+
tensor_key_dict = self.persistent_db.load_tensors(
218+
self.persistent_db.get_tensors_table_name()
219+
)
220+
if len(tensor_key_dict) > 0:
221+
logger.info(f"Recovering {len(tensor_key_dict)} model tensors")
222+
recovered = True
223+
self.tensor_db.cache_tensor(tensor_key_dict)
224+
committed_round_number, self.best_model_score = (
225+
self.persistent_db.get_round_and_best_score()
226+
)
227+
logger.info("Recovery - Setting model proto")
228+
to_proto_tensor_dict = {}
229+
for tk in tensor_key_dict:
230+
tk_name, _, _, _, _ = tk
231+
to_proto_tensor_dict[tk_name] = tensor_key_dict[tk]
232+
self.model = utils.construct_model_proto(
233+
to_proto_tensor_dict, committed_round_number, self.compression_pipeline
234+
)
235+
# round number is the current round which is still in process
236+
# i.e. committed_round_number + 1
237+
self.round_number = committed_round_number + 1
238+
logger.info(
239+
"Recovery - loaded round number %s and best score %s",
240+
self.round_number,
241+
self.best_model_score,
242+
)
243+
244+
next_round_tensor_key_dict = self.persistent_db.load_tensors(
245+
self.persistent_db.get_next_round_tensors_table_name()
246+
)
247+
if len(next_round_tensor_key_dict) > 0:
248+
logger.info(f"Recovering {len(next_round_tensor_key_dict)} next round model tensors")
249+
recovered = True
250+
self.tensor_db.cache_tensor(next_round_tensor_key_dict)
251+
252+
logger.debug("Recovery - this is the tensor_db after recovery: %s", self.tensor_db)
253+
254+
if self.persistent_db.is_task_table_empty():
255+
logger.debug("task table is empty")
256+
return recovered
257+
258+
logger.info("Recovery - Replaying saved task results")
259+
task_id = 1
260+
while True:
261+
task_result = self.persistent_db.get_task_result_by_id(task_id)
262+
if not task_result:
263+
break
264+
recovered = True
265+
collaborator_name = task_result["collaborator_name"]
266+
round_number = task_result["round_number"]
267+
task_name = task_result["task_name"]
268+
data_size = task_result["data_size"]
269+
serialized_tensors = task_result["named_tensors"]
270+
named_tensors = [
271+
NamedTensor.FromString(serialized_tensor)
272+
for serialized_tensor in serialized_tensors
273+
]
274+
logger.info(
275+
"Recovery - Replaying task results %s %s %s",
276+
collaborator_name,
277+
round_number,
278+
task_name,
279+
)
280+
self.process_task_results(
281+
collaborator_name, round_number, task_name, data_size, named_tensors
282+
)
283+
task_id += 1
284+
return recovered
285+
196286
def _load_initial_tensors(self):
197287
"""Load all of the tensors required to begin federated learning.
198288
@@ -253,9 +343,12 @@ def _save_model(self, round_number, file_path):
253343
for k, v in og_tensor_dict.items()
254344
]
255345
tensor_dict = {}
346+
tensor_tuple_dict = {}
256347
for tk in tensor_keys:
257348
tk_name, _, _, _, _ = tk
258-
tensor_dict[tk_name] = self.tensor_db.get_tensor_from_cache(tk)
349+
tensor_value = self.tensor_db.get_tensor_from_cache(tk)
350+
tensor_dict[tk_name] = tensor_value
351+
tensor_tuple_dict[tk] = tensor_value
259352
if tensor_dict[tk_name] is None:
260353
logger.info(
261354
"Cannot save model for round %s. Continuing...",
@@ -265,6 +358,19 @@ def _save_model(self, round_number, file_path):
265358
if file_path == self.best_state_path:
266359
self.best_tensor_dict = tensor_dict
267360
if file_path == self.last_state_path:
361+
# Transaction to persist/delete all data needed to increment the round
362+
if self.persistent_db:
363+
if self.next_model_round_number > 0:
364+
next_round_tensors = self.tensor_db.get_tensors_by_round_and_tags(
365+
self.next_model_round_number, ("model",)
366+
)
367+
self.persistent_db.finalize_round(
368+
tensor_tuple_dict, next_round_tensors, self.round_number, self.best_model_score
369+
)
370+
logger.info(
371+
"Persist model and clean task result for round %s",
372+
round_number,
373+
)
268374
self.last_tensor_dict = tensor_dict
269375
self.model = utils.construct_model_proto(
270376
tensor_dict, round_number, self.compression_pipeline
@@ -364,7 +470,7 @@ def get_tasks(self, collaborator_name):
364470
# if no tasks, tell the collaborator to sleep
365471
if len(tasks) == 0:
366472
tasks = None
367-
sleep_time = self._get_sleep_time()
473+
sleep_time = Aggregator._get_sleep_time()
368474

369475
return tasks, self.round_number, sleep_time, time_to_quit
370476

@@ -394,7 +500,7 @@ def get_tasks(self, collaborator_name):
394500
# been completed
395501
if len(tasks) == 0:
396502
tasks = None
397-
sleep_time = self._get_sleep_time()
503+
sleep_time = Aggregator._get_sleep_time()
398504

399505
return tasks, self.round_number, sleep_time, time_to_quit
400506

@@ -604,6 +710,31 @@ def send_local_task_results(
604710
Returns:
605711
None
606712
"""
713+
# Save task and its metadata for recovery
714+
serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors]
715+
if self.persistent_db:
716+
self.persistent_db.save_task_results(
717+
collaborator_name, round_number, task_name, data_size, serialized_tensors
718+
)
719+
logger.debug(
720+
f"Persisting task results {task_name} from {collaborator_name} round {round_number}"
721+
)
722+
logger.info(
723+
f"Collaborator {collaborator_name} is sending task results "
724+
f"for {task_name}, round {round_number}"
725+
)
726+
self.process_task_results(
727+
collaborator_name, round_number, task_name, data_size, named_tensors
728+
)
729+
730+
def process_task_results(
731+
self,
732+
collaborator_name,
733+
round_number,
734+
task_name,
735+
data_size,
736+
named_tensors,
737+
):
607738
if self._time_to_quit() or collaborator_name in self.stragglers:
608739
logger.warning(
609740
f"STRAGGLER: Collaborator {collaborator_name} is reporting results "
@@ -618,11 +749,6 @@ def send_local_task_results(
618749
)
619750
return
620751

621-
logger.info(
622-
f"Collaborator {collaborator_name} is sending task results "
623-
f"for {task_name}, round {round_number}"
624-
)
625-
626752
task_key = TaskResultKey(task_name, collaborator_name, round_number)
627753

628754
# we mustn't have results already
@@ -862,7 +988,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result
862988
new_model_report,
863989
("model",),
864990
)
865-
991+
self.next_model_round_number = new_model_round_number
866992
# Finally, cache the updated model tensor
867993
self.tensor_db.cache_tensor({final_model_tk: new_model_nparray})
868994

0 commit comments

Comments
 (0)