@@ -382,15 +382,17 @@ def get_data_for_tensorkey(self, tensor_key):
382
382
return nparray
383
383
prior_round -= 1
384
384
logger .info (f"Cannot find any prior version of tensor { tensor_name } locally..." )
385
- logger .debug (
386
- "Unable to get tensor from local store..." "attempting to retrieve from client"
387
- )
388
385
# Determine whether there are additional compression related
389
386
# dependencies.
390
387
# Typically, dependencies are only relevant to model layers
391
388
tensor_dependencies = self .tensor_codec .find_dependencies (
392
389
tensor_key , self .delta_updates
393
390
)
391
+ logger .debug (
392
+ "Unable to get tensor from local store..."
393
+ "attempting to retrieve from client len tensor_dependencies"
394
+ f" tensor_key { tensor_key } "
395
+ )
394
396
if len (tensor_dependencies ) > 0 :
395
397
# Resolve dependencies
396
398
# tensor_dependencies[0] corresponds to the prior version
@@ -411,10 +413,10 @@ def get_data_for_tensorkey(self, tensor_key):
411
413
self .tensor_db .cache_tensor ({new_model_tk : nparray })
412
414
else :
413
415
logger .info (
414
- "Count not find previous model layer."
416
+ "Could not find previous model layer."
415
417
"Fetching latest layer from aggregator"
416
418
)
417
- # The original model tensor should be fetched from client
419
+ # The original model tensor should be fetched from aggregator
418
420
nparray = self .get_aggregated_tensor_from_aggregator (
419
421
tensor_key , require_lossless = True
420
422
)
@@ -423,6 +425,18 @@ def get_data_for_tensorkey(self, tensor_key):
423
425
nparray = self .get_aggregated_tensor_from_aggregator (
424
426
tensor_key , require_lossless = True
425
427
)
428
+ else :
429
+ # we should try fetching the tensor from aggregator
430
+ tensor_name , origin , round_number , report , tags = tensor_key
431
+ tags = (self .collaborator_name ,) + tags
432
+ tensor_key = (tensor_name , origin , round_number , report , tags )
433
+ logger .info (
434
+ "Could not find previous model layer."
435
+ f"Fetching latest layer from aggregator { tensor_key } "
436
+ )
437
+ nparray = self .get_aggregated_tensor_from_aggregator (
438
+ tensor_key , require_lossless = True
439
+ )
426
440
else :
427
441
logger .debug ("Found tensor %s in local TensorDB" , tensor_key )
428
442
0 commit comments