12
12
13
13
import openfl .callbacks as callbacks_module
14
14
from openfl .component .aggregator .straggler_handling import CutoffPolicy , StragglerPolicy
15
- from openfl .databases import TensorDB
15
+ from openfl .databases import PersistentTensorDB , TensorDB
16
16
from openfl .interface .aggregation_functions import WeightedAverage
17
17
from openfl .pipelines import NoCompressionPipeline , TensorCodec
18
18
from openfl .protocols import base_pb2 , utils
19
+ from openfl .protocols .base_pb2 import NamedTensor
19
20
from openfl .utilities import TaskResultKey , TensorKey , change_tags
20
21
21
22
logger = logging .getLogger (__name__ )
@@ -82,6 +83,8 @@ def __init__(
82
83
log_memory_usage = False ,
83
84
write_logs = False ,
84
85
callbacks : Optional [List ] = None ,
86
+ persist_checkpoint = True ,
87
+ persistent_db_path = None ,
85
88
):
86
89
"""Initializes the Aggregator.
87
90
@@ -109,6 +112,7 @@ def __init__(
109
112
callbacks: List of callbacks to be used during the experiment.
110
113
"""
111
114
self .round_number = 0
115
+ self .next_model_round_number = 0
112
116
113
117
if single_col_cert_common_name :
114
118
logger .warning (
@@ -135,6 +139,16 @@ def __init__(
135
139
self .quit_job_sent_to = []
136
140
137
141
self .tensor_db = TensorDB ()
142
+ if persist_checkpoint :
143
+ persistent_db_path = persistent_db_path or "tensor.db"
144
+ logger .info (
145
+ "Persistent checkpoint is enabled, setting persistent db at path %s" ,
146
+ persistent_db_path ,
147
+ )
148
+ self .persistent_db = PersistentTensorDB (persistent_db_path )
149
+ else :
150
+ logger .info ("Persistent checkpoint is disabled" )
151
+ self .persistent_db = None
138
152
# FIXME: I think next line generates an error on the second round
139
153
# if it is set to 1 for the aggregator.
140
154
self .db_store_rounds = db_store_rounds
@@ -152,8 +166,25 @@ def __init__(
152
166
# TODO: Remove. Used in deprecated interactive and native APIs
153
167
self .best_tensor_dict : dict = {}
154
168
self .last_tensor_dict : dict = {}
169
+ # these enable getting all tensors for a task
170
+ self .collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}
171
+ self .collaborator_task_weight = {} # {TaskResultKey: data_size}
155
172
156
- if initial_tensor_dict :
173
+ # maintain a list of collaborators that have completed task and
174
+ # reported results in a given round
175
+ self .collaborators_done = []
176
+ # Initialize a lock for thread safety
177
+ self .lock = Lock ()
178
+ self .use_delta_updates = use_delta_updates
179
+
180
+ self .model = None # Initialize the model attribute to None
181
+ if self .persistent_db and self ._recover ():
182
+ logger .info ("recovered state of aggregator" )
183
+
184
+ # The model is built by recovery if at least one round has finished
185
+ if self .model :
186
+ logger .info ("Model was loaded by recovery" )
187
+ elif initial_tensor_dict :
157
188
self ._load_initial_tensors_from_dict (initial_tensor_dict )
158
189
self .model = utils .construct_model_proto (
159
190
tensor_dict = initial_tensor_dict ,
@@ -166,20 +197,6 @@ def __init__(
166
197
167
198
self .collaborator_tensor_results = {} # {TensorKey: nparray}}
168
199
169
- # these enable getting all tensors for a task
170
- self .collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}
171
-
172
- self .collaborator_task_weight = {} # {TaskResultKey: data_size}
173
-
174
- # maintain a list of collaborators that have completed task and
175
- # reported results in a given round
176
- self .collaborators_done = []
177
-
178
- # Initialize a lock for thread safety
179
- self .lock = Lock ()
180
-
181
- self .use_delta_updates = use_delta_updates
182
-
183
200
# Callbacks
184
201
self .callbacks = callbacks_module .CallbackList (
185
202
callbacks ,
@@ -193,6 +210,79 @@ def __init__(
193
210
self .callbacks .on_experiment_begin ()
194
211
self .callbacks .on_round_begin (self .round_number )
195
212
213
+ def _recover (self ):
214
+ """Populates the aggregator state to the state it was prior a restart"""
215
+ recovered = False
216
+ # load tensors persistent DB
217
+ tensor_key_dict = self .persistent_db .load_tensors (
218
+ self .persistent_db .get_tensors_table_name ()
219
+ )
220
+ if len (tensor_key_dict ) > 0 :
221
+ logger .info (f"Recovering { len (tensor_key_dict )} model tensors" )
222
+ recovered = True
223
+ self .tensor_db .cache_tensor (tensor_key_dict )
224
+ committed_round_number , self .best_model_score = (
225
+ self .persistent_db .get_round_and_best_score ()
226
+ )
227
+ logger .info ("Recovery - Setting model proto" )
228
+ to_proto_tensor_dict = {}
229
+ for tk in tensor_key_dict :
230
+ tk_name , _ , _ , _ , _ = tk
231
+ to_proto_tensor_dict [tk_name ] = tensor_key_dict [tk ]
232
+ self .model = utils .construct_model_proto (
233
+ to_proto_tensor_dict , committed_round_number , self .compression_pipeline
234
+ )
235
+ # round number is the current round which is still in process
236
+ # i.e. committed_round_number + 1
237
+ self .round_number = committed_round_number + 1
238
+ logger .info (
239
+ "Recovery - loaded round number %s and best score %s" ,
240
+ self .round_number ,
241
+ self .best_model_score ,
242
+ )
243
+
244
+ next_round_tensor_key_dict = self .persistent_db .load_tensors (
245
+ self .persistent_db .get_next_round_tensors_table_name ()
246
+ )
247
+ if len (next_round_tensor_key_dict ) > 0 :
248
+ logger .info (f"Recovering { len (next_round_tensor_key_dict )} next round model tensors" )
249
+ recovered = True
250
+ self .tensor_db .cache_tensor (next_round_tensor_key_dict )
251
+
252
+ logger .debug ("Recovery - this is the tensor_db after recovery: %s" , self .tensor_db )
253
+
254
+ if self .persistent_db .is_task_table_empty ():
255
+ logger .debug ("task table is empty" )
256
+ return recovered
257
+
258
+ logger .info ("Recovery - Replaying saved task results" )
259
+ task_id = 1
260
+ while True :
261
+ task_result = self .persistent_db .get_task_result_by_id (task_id )
262
+ if not task_result :
263
+ break
264
+ recovered = True
265
+ collaborator_name = task_result ["collaborator_name" ]
266
+ round_number = task_result ["round_number" ]
267
+ task_name = task_result ["task_name" ]
268
+ data_size = task_result ["data_size" ]
269
+ serialized_tensors = task_result ["named_tensors" ]
270
+ named_tensors = [
271
+ NamedTensor .FromString (serialized_tensor )
272
+ for serialized_tensor in serialized_tensors
273
+ ]
274
+ logger .info (
275
+ "Recovery - Replaying task results %s %s %s" ,
276
+ collaborator_name ,
277
+ round_number ,
278
+ task_name ,
279
+ )
280
+ self .process_task_results (
281
+ collaborator_name , round_number , task_name , data_size , named_tensors
282
+ )
283
+ task_id += 1
284
+ return recovered
285
+
196
286
def _load_initial_tensors (self ):
197
287
"""Load all of the tensors required to begin federated learning.
198
288
@@ -253,9 +343,12 @@ def _save_model(self, round_number, file_path):
253
343
for k , v in og_tensor_dict .items ()
254
344
]
255
345
tensor_dict = {}
346
+ tensor_tuple_dict = {}
256
347
for tk in tensor_keys :
257
348
tk_name , _ , _ , _ , _ = tk
258
- tensor_dict [tk_name ] = self .tensor_db .get_tensor_from_cache (tk )
349
+ tensor_value = self .tensor_db .get_tensor_from_cache (tk )
350
+ tensor_dict [tk_name ] = tensor_value
351
+ tensor_tuple_dict [tk ] = tensor_value
259
352
if tensor_dict [tk_name ] is None :
260
353
logger .info (
261
354
"Cannot save model for round %s. Continuing..." ,
@@ -265,6 +358,19 @@ def _save_model(self, round_number, file_path):
265
358
if file_path == self .best_state_path :
266
359
self .best_tensor_dict = tensor_dict
267
360
if file_path == self .last_state_path :
361
+ # Transaction to persist/delete all data needed to increment the round
362
+ if self .persistent_db :
363
+ if self .next_model_round_number > 0 :
364
+ next_round_tensors = self .tensor_db .get_tensors_by_round_and_tags (
365
+ self .next_model_round_number , ("model" ,)
366
+ )
367
+ self .persistent_db .finalize_round (
368
+ tensor_tuple_dict , next_round_tensors , self .round_number , self .best_model_score
369
+ )
370
+ logger .info (
371
+ "Persist model and clean task result for round %s" ,
372
+ round_number ,
373
+ )
268
374
self .last_tensor_dict = tensor_dict
269
375
self .model = utils .construct_model_proto (
270
376
tensor_dict , round_number , self .compression_pipeline
@@ -364,7 +470,7 @@ def get_tasks(self, collaborator_name):
364
470
# if no tasks, tell the collaborator to sleep
365
471
if len (tasks ) == 0 :
366
472
tasks = None
367
- sleep_time = self ._get_sleep_time ()
473
+ sleep_time = Aggregator ._get_sleep_time ()
368
474
369
475
return tasks , self .round_number , sleep_time , time_to_quit
370
476
@@ -394,7 +500,7 @@ def get_tasks(self, collaborator_name):
394
500
# been completed
395
501
if len (tasks ) == 0 :
396
502
tasks = None
397
- sleep_time = self ._get_sleep_time ()
503
+ sleep_time = Aggregator ._get_sleep_time ()
398
504
399
505
return tasks , self .round_number , sleep_time , time_to_quit
400
506
@@ -604,6 +710,31 @@ def send_local_task_results(
604
710
Returns:
605
711
None
606
712
"""
713
+ # Save task and its metadata for recovery
714
+ serialized_tensors = [tensor .SerializeToString () for tensor in named_tensors ]
715
+ if self .persistent_db :
716
+ self .persistent_db .save_task_results (
717
+ collaborator_name , round_number , task_name , data_size , serialized_tensors
718
+ )
719
+ logger .debug (
720
+ f"Persisting task results { task_name } from { collaborator_name } round { round_number } "
721
+ )
722
+ logger .info (
723
+ f"Collaborator { collaborator_name } is sending task results "
724
+ f"for { task_name } , round { round_number } "
725
+ )
726
+ self .process_task_results (
727
+ collaborator_name , round_number , task_name , data_size , named_tensors
728
+ )
729
+
730
+ def process_task_results (
731
+ self ,
732
+ collaborator_name ,
733
+ round_number ,
734
+ task_name ,
735
+ data_size ,
736
+ named_tensors ,
737
+ ):
607
738
if self ._time_to_quit () or collaborator_name in self .stragglers :
608
739
logger .warning (
609
740
f"STRAGGLER: Collaborator { collaborator_name } is reporting results "
@@ -618,11 +749,6 @@ def send_local_task_results(
618
749
)
619
750
return
620
751
621
- logger .info (
622
- f"Collaborator { collaborator_name } is sending task results "
623
- f"for { task_name } , round { round_number } "
624
- )
625
-
626
752
task_key = TaskResultKey (task_name , collaborator_name , round_number )
627
753
628
754
# we mustn't have results already
@@ -862,7 +988,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result
862
988
new_model_report ,
863
989
("model" ,),
864
990
)
865
-
991
+ self . next_model_round_number = new_model_round_number
866
992
# Finally, cache the updated model tensor
867
993
self .tensor_db .cache_tensor ({final_model_tk : new_model_nparray })
868
994
0 commit comments