Skip to content

Commit a66b5ac

Browse files
authored
Merge branch 'securefederatedai:develop' into straggler_handling_update
2 parents 2aa0b22 + 3375609 commit a66b5ac

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

openfl/component/collaborator/collaborator.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,17 @@ def get_data_for_tensorkey(self, tensor_key):
382382
return nparray
383383
prior_round -= 1
384384
logger.info(f"Cannot find any prior version of tensor {tensor_name} locally...")
385-
logger.debug(
386-
"Unable to get tensor from local store..." "attempting to retrieve from client"
387-
)
388385
# Determine whether there are additional compression related
389386
# dependencies.
390387
# Typically, dependencies are only relevant to model layers
391388
tensor_dependencies = self.tensor_codec.find_dependencies(
392389
tensor_key, self.delta_updates
393390
)
391+
logger.debug(
392+
"Unable to get tensor from local store..."
393+
"attempting to retrieve from client len tensor_dependencies"
394+
f" tensor_key {tensor_key}"
395+
)
394396
if len(tensor_dependencies) > 0:
395397
# Resolve dependencies
396398
# tensor_dependencies[0] corresponds to the prior version
@@ -411,10 +413,10 @@ def get_data_for_tensorkey(self, tensor_key):
411413
self.tensor_db.cache_tensor({new_model_tk: nparray})
412414
else:
413415
logger.info(
414-
"Count not find previous model layer."
416+
"Could not find previous model layer."
415417
"Fetching latest layer from aggregator"
416418
)
417-
# The original model tensor should be fetched from client
419+
# The original model tensor should be fetched from aggregator
418420
nparray = self.get_aggregated_tensor_from_aggregator(
419421
tensor_key, require_lossless=True
420422
)
@@ -423,6 +425,18 @@ def get_data_for_tensorkey(self, tensor_key):
423425
nparray = self.get_aggregated_tensor_from_aggregator(
424426
tensor_key, require_lossless=True
425427
)
428+
else:
429+
# we should try fetching the tensor from aggregator
430+
tensor_name, origin, round_number, report, tags = tensor_key
431+
tags = (self.collaborator_name,) + tags
432+
tensor_key = (tensor_name, origin, round_number, report, tags)
433+
logger.info(
434+
"Could not find previous model layer."
435+
f"Fetching latest layer from aggregator {tensor_key}"
436+
)
437+
nparray = self.get_aggregated_tensor_from_aggregator(
438+
tensor_key, require_lossless=True
439+
)
426440
else:
427441
logger.debug("Found tensor %s in local TensorDB", tensor_key)
428442

openfl/federated/task/runner_keras.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,16 @@ def train_(self, batch_generator, metrics: list = None, **kwargs):
182182
# initialization (build_model).
183183
# If metrics are added (i.e. not a subset of what was originally
184184
# defined) then the model must be recompiled.
185-
results = self.model.get_metrics_result()
185+
try:
186+
results = self.model.get_metrics_result()
187+
except ValueError:
188+
if "batch_size" in kwargs:
189+
batch_size = kwargs["batch_size"]
190+
else:
191+
batch_size = 1
192+
# evaluation needed before metrics can be resolved
193+
self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1)
194+
results = self.model.get_metrics_result()
186195

187196
# TODO if there are new metrics in the flplan that were not included
188197
# in the originally

0 commit comments

Comments
 (0)