@@ -35,28 +35,34 @@ def aggregate(self, agg_info):
35
35
def _aggre_with_normbounding (self , models ):
36
36
models_temp = []
37
37
for each_model in models :
38
- param = self ._flatten_updates (each_model [1 ])
38
+ param , ignore_keys = self ._flatten_updates (each_model [1 ])
39
39
if torch .norm (param , p = 2 ) > self .norm_bound :
40
40
scaling_rate = self .norm_bound / torch .norm (param , p = 2 )
41
41
scaled_param = scaling_rate * param
42
42
models_temp .append (
43
- (each_model [0 ], self ._reconstruct_updates (scaled_param )))
43
+ (each_model [0 ],
44
+ self ._reconstruct_updates (scaled_param , ignore_keys )))
44
45
else :
45
46
models_temp .append (each_model )
46
47
return self ._para_weighted_avg (models_temp )
47
48
48
49
def _flatten_updates (self , model ):
49
- model_update = []
50
+ model_update , ignore_keys = [], []
50
51
init_model = self .model .state_dict ()
51
52
for key in init_model :
53
+ if key not in model :
54
+ ignore_keys .append (key )
55
+ continue
52
56
model_update .append (model [key ].view (- 1 ))
53
- return torch .cat (model_update , dim = 0 )
57
+ return torch .cat (model_update , dim = 0 ), ignore_keys
54
58
55
- def _reconstruct_updates (self , flatten_updates ):
59
+ def _reconstruct_updates (self , flatten_updates , ignore_keys ):
56
60
start_idx = 0
57
61
init_model = self .model .state_dict ()
58
62
reconstructed_model = copy .deepcopy (init_model )
59
63
for key in init_model :
64
+ if key in ignore_keys :
65
+ continue
60
66
reconstructed_model [key ] = flatten_updates [
61
67
start_idx :start_idx + len (init_model [key ].view (- 1 ))].reshape (
62
68
init_model [key ].shape )
0 commit comments