20
20
21
21
from functools import partial
22
22
from math import sqrt
23
- from typing import Dict , Iterable , List , Optional , Set , Tuple , Union
23
+ from typing import Dict , Iterable , List , Optional , Tuple , Union
24
24
25
25
from einops import einsum , rearrange , reduce
26
26
from numpy import ndarray
@@ -211,10 +211,6 @@ def __init__(
211
211
f"Supported: { self ._SUPPORTED_KFAC_APPROX } ."
212
212
)
213
213
214
- self .param_ids , self .param_ids_to_hooked_modules = (
215
- self .parameter_to_module_mapping (params , model_func )
216
- )
217
-
218
214
self ._seed = seed
219
215
self ._generator : Union [None , Generator ] = None
220
216
self ._separate_weight_and_bias = separate_weight_and_bias
@@ -224,6 +220,7 @@ def __init__(
224
220
self ._loss_average = loss_average
225
221
self ._input_covariances : Dict [str , Tensor ] = {}
226
222
self ._gradient_covariances : Dict [str , Tensor ] = {}
223
+ self ._mapping = self .compute_parameter_mapping (params , model_func )
227
224
228
225
super ().__init__ (
229
226
model_func ,
@@ -244,75 +241,47 @@ def _matmat(self, M: ndarray) -> ndarray:
244
241
245
242
Returns:
246
243
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
247
-
248
- Raises:
249
- RuntimeError: If the incoming matrix was not fully processed, indicating an
250
- error due to the internal mapping from parameters to modules.
251
244
"""
252
- # Need to update parameter mapping if they have changed (e.g. device
253
- # transfer), and reset caches
254
- if self .param_ids != [p .data_ptr () for p in self ._params ]:
255
- print ("Invalidated parameter mapping detected" )
256
- self .param_ids , self .param_ids_to_hooked_modules = (
257
- self .parameter_to_module_mapping (self ._params , self ._model_func )
258
- )
259
- self ._input_covariances , self ._gradient_covariances = {}, {}
260
-
261
245
if not self ._input_covariances and not self ._gradient_covariances :
262
246
self ._compute_kfac ()
263
247
264
248
M_torch = super ()._preprocess (M )
265
- processed = set ()
266
-
267
- for name in self .param_ids_to_hooked_modules .values ():
268
- mod = self ._model_func .get_submodule (name )
269
249
250
+ for mod_name , param_pos in self ._mapping .items ():
270
251
# bias and weights are treated jointly
271
- if not self ._separate_weight_and_bias and self .in_params (
272
- mod .weight , mod .bias
252
+ if (
253
+ not self ._separate_weight_and_bias
254
+ and "weight" in param_pos .keys ()
255
+ and "bias" in param_pos .keys ()
273
256
):
274
- w_pos , b_pos = self . param_pos ( mod . weight ), self . param_pos ( mod . bias )
257
+ w_pos , b_pos = param_pos [ " weight" ], param_pos [ " bias" ]
275
258
# v denotes the free dimension for treating multiple vectors in parallel
276
259
M_w = rearrange (M_torch [w_pos ], "v c_out ... -> v c_out (...)" )
277
260
M_joint = cat ([M_w , M_torch [b_pos ].unsqueeze (- 1 )], dim = 2 )
278
- aaT = self ._input_covariances [name ]
279
- ggT = self ._gradient_covariances [name ]
261
+ aaT = self ._input_covariances [mod_name ]
262
+ ggT = self ._gradient_covariances [mod_name ]
280
263
M_joint = einsum (ggT , M_joint , aaT , "i j,v j k,k l -> v i l" )
281
264
282
265
w_cols = M_w .shape [2 ]
283
266
M_torch [w_pos ], M_torch [b_pos ] = M_joint .split ([w_cols , 1 ], dim = 2 )
284
- processed .update ([w_pos , b_pos ])
285
267
286
268
# for weights we need to multiply from the right with aaT
287
269
# for weights and biases we need to multiply from the left with ggT
288
270
else :
289
- for p_name in ["weight" , "bias" ]:
290
- p = getattr (mod , p_name )
291
- if self .in_params (p ):
292
- pos = self .param_pos (p )
293
-
294
- if p_name == "weight" :
295
- M_w = rearrange (
296
- M_torch [pos ], "v c_out ... -> v c_out (...)"
297
- )
298
- M_torch [pos ] = einsum (
299
- M_w ,
300
- self ._input_covariances [name ],
301
- "v c_out j,j k -> v c_out k" ,
302
- )
303
-
271
+ for p_name , pos in param_pos .items ():
272
+ if p_name == "weight" :
273
+ M_w = rearrange (M_torch [pos ], "v c_out ... -> v c_out (...)" )
304
274
M_torch [pos ] = einsum (
305
- self . _gradient_covariances [ name ] ,
306
- M_torch [ pos ],
307
- "j k,v k ... -> v j ... " ,
275
+ M_w ,
276
+ self . _input_covariances [ mod_name ],
277
+ "v c_out j,j k -> v c_out k " ,
308
278
)
309
- processed .add (pos )
310
279
311
- if processed != set ( range ( len ( M_torch ))):
312
- raise RuntimeError (
313
- "Some entries of the matrix were not modified."
314
- + f" Out of { len ( M_torch ) } , the following entries were processed: { processed } ."
315
- )
280
+ M_torch [ pos ] = einsum (
281
+ self . _gradient_covariances [ mod_name ],
282
+ M_torch [ pos ],
283
+ "j k,v k ... -> v j ..." ,
284
+ )
316
285
317
286
return self ._postprocess (M_torch )
318
287
@@ -331,21 +300,26 @@ def _compute_kfac(self):
331
300
# install forward and backward hooks
332
301
hook_handles : List [RemovableHandle ] = []
333
302
334
- for name in self .param_ids_to_hooked_modules . values ():
335
- module = self ._model_func .get_submodule (name )
303
+ for mod_name , param_pos in self ._mapping . items ():
304
+ module = self ._model_func .get_submodule (mod_name )
336
305
337
306
# input covariance only required for weights
338
- if self . in_params ( module . weight ):
307
+ if "weight" in param_pos . keys ( ):
339
308
hook_handles .append (
340
309
module .register_forward_pre_hook (
341
- self ._hook_accumulate_input_covariance
310
+ partial (
311
+ self ._hook_accumulate_input_covariance , module_name = mod_name
312
+ )
342
313
)
343
314
)
344
315
345
316
# gradient covariance required for weights and biases
346
317
hook_handles .append (
347
318
module .register_forward_hook (
348
- self ._register_tensor_hook_on_output_to_accumulate_gradient_covariance
319
+ partial (
320
+ self ._register_tensor_hook_on_output_to_accumulate_gradient_covariance ,
321
+ module_name = mod_name ,
322
+ )
349
323
)
350
324
)
351
325
@@ -471,7 +445,7 @@ def draw_label(self, output: Tensor) -> Tensor:
471
445
raise NotImplementedError
472
446
473
447
def _register_tensor_hook_on_output_to_accumulate_gradient_covariance (
474
- self , module : Module , inputs : Tuple [Tensor ], output : Tensor
448
+ self , module : Module , inputs : Tuple [Tensor ], output : Tensor , module_name : str
475
449
):
476
450
"""Register tensor hook on layer's output to accumulate the grad. covariance.
477
451
@@ -491,18 +465,24 @@ def _register_tensor_hook_on_output_to_accumulate_gradient_covariance(
491
465
covariance will be installed.
492
466
inputs: The layer's input tensors.
493
467
output: The layer's output tensor.
468
+ module_name: The name of the layer in the neural network.
494
469
"""
495
- tensor_hook = partial (self ._accumulate_gradient_covariance , module )
470
+ tensor_hook = partial (
471
+ self ._accumulate_gradient_covariance , module = module , module_name = module_name
472
+ )
496
473
output .register_hook (tensor_hook )
497
474
498
- def _accumulate_gradient_covariance (self , module : Module , grad_output : Tensor ):
475
+ def _accumulate_gradient_covariance (
476
+ self , grad_output : Tensor , module : Module , module_name : str
477
+ ):
499
478
"""Accumulate the gradient covariance for a layer's output.
500
479
501
480
Updates ``self._gradient_covariances``.
502
481
503
482
Args:
504
- module: The layer whose output's gradient covariance will be accumulated.
505
483
grad_output: The gradient w.r.t. the output.
484
+ module: The layer whose output's gradient covariance will be accumulated.
485
+ module_name: The name of the layer in the neural network.
506
486
"""
507
487
g = grad_output .data .detach ()
508
488
batch_size = g .shape [0 ]
@@ -531,20 +511,22 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
531
511
}[self ._loss_average ]
532
512
covariance = einsum (g , g , "b i,b j->i j" ).mul_ (correction )
533
513
534
- name = self .get_module_name (module )
535
- if name not in self ._gradient_covariances :
536
- self ._gradient_covariances [name ] = covariance
514
+ if module_name not in self ._gradient_covariances :
515
+ self ._gradient_covariances [module_name ] = covariance
537
516
else :
538
- self ._gradient_covariances [name ].add_ (covariance )
517
+ self ._gradient_covariances [module_name ].add_ (covariance )
539
518
540
- def _hook_accumulate_input_covariance (self , module : Module , inputs : Tuple [Tensor ]):
519
+ def _hook_accumulate_input_covariance (
520
+ self , module : Module , inputs : Tuple [Tensor ], module_name : str
521
+ ):
541
522
"""Pre-forward hook that accumulates the input covariance of a layer.
542
523
543
524
Updates ``self._input_covariances``.
544
525
545
526
Args:
546
527
module: Module on which the hook is called.
547
528
inputs: Inputs to the module.
529
+ module_name: Name of the module in the neural network.
548
530
549
531
Raises:
550
532
ValueError: If the module has multiple inputs.
@@ -576,88 +558,58 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
576
558
scale = 1.0 # since we use a mean reduction
577
559
x = reduce (x , "batch ... d_in -> batch d_in" , "mean" )
578
560
561
+ params = self ._mapping [module_name ]
579
562
if (
580
- self .in_params (module .weight , module .bias )
563
+ "weight" in params .keys ()
564
+ and "bias" in params .keys ()
581
565
and not self ._separate_weight_and_bias
582
566
):
583
567
x = cat ([x , x .new_ones (x .shape [0 ], 1 )], dim = 1 )
584
568
585
569
covariance = einsum (x , x , "b i,b j -> i j" ).div_ (self ._N_data * scale )
586
570
587
- name = self .get_module_name (module )
588
- if name not in self ._input_covariances :
589
- self ._input_covariances [name ] = covariance
571
+ if module_name not in self ._input_covariances :
572
+ self ._input_covariances [module_name ] = covariance
590
573
else :
591
- self ._input_covariances [name ].add_ (covariance )
592
-
593
- def get_module_name (self , module : Module ) -> str :
594
- """Get the name of a module.
595
-
596
- Args:
597
- module: The module.
598
-
599
- Returns:
600
- The name of the module.
601
- """
602
- p_ids = tuple (p .data_ptr () for p in module .parameters ())
603
- return self .param_ids_to_hooked_modules [p_ids ]
604
-
605
- def in_params (self , * params : Union [Parameter , Tensor , None ]) -> bool :
606
- """Check if all parameters are used in KFAC.
607
-
608
- Args:
609
- params: Parameters to check.
610
-
611
- Returns:
612
- Whether all parameters are used in KFAC.
613
- """
614
- return all (p is not None and p .data_ptr () in self .param_ids for p in params )
615
-
616
- def param_pos (self , param : Union [Parameter , Tensor ]) -> int :
617
- """Get the position of a parameter in the list of parameters used in KFAC.
618
-
619
- Args:
620
- param: The parameter.
621
-
622
- Returns:
623
- The parameter's position in the parameter list.
624
- """
625
- return self .param_ids .index (param .data_ptr ())
574
+ self ._input_covariances [module_name ].add_ (covariance )
626
575
627
576
@classmethod
628
- def parameter_to_module_mapping (
629
- cls , params : List [Tensor ], model_func : Module
630
- ) -> Tuple [ List [ int ] , Dict [Tuple [ int , ...], str ]]:
631
- """Construct the mapping between parameters and modules .
577
+ def compute_parameter_mapping (
578
+ cls , params : List [Union [ Tensor , Parameter ] ], model_func : Module
579
+ ) -> Dict [ str , Dict [str , int ]]:
580
+ """Construct the mapping between layers, their parameters, and positions .
632
581
633
582
Args:
634
583
params: List of parameters.
635
584
model_func: The model function.
636
585
637
586
Returns:
638
- A tuple containing:
639
- - A list of parameter data pointers.
640
- - A dictionary mapping from tuples of parameter data pointers in a module
641
- to its name.
587
+ A dictionary of dictionaries. The outer dictionary's keys are the names of
588
+ the layers that contain parameters. The interior dictionary's keys are the
589
+ parameter names, and the values their respective positions.
642
590
643
591
Raises:
644
592
NotImplementedError: If parameters are found outside supported layers.
645
593
"""
646
594
param_ids = [p .data_ptr () for p in params ]
647
- # mapping from tuples of parameter data pointers in a module to its name
648
- param_ids_to_hooked_modules : Dict [ Tuple [ int , ...], str ] = {}
595
+ positions = {}
596
+ processed = set ()
649
597
650
- hooked_param_ids : Set [int ] = set ()
651
- for name , mod in model_func .named_modules ():
652
- p_ids = tuple (p .data_ptr () for p in mod .parameters ())
598
+ for mod_name , mod in model_func .named_modules ():
653
599
if isinstance (mod , cls ._SUPPORTED_MODULES ) and any (
654
- p_id in param_ids for p_id in p_ids
600
+ p . data_ptr () in param_ids for p in mod . parameters ()
655
601
):
656
- param_ids_to_hooked_modules [p_ids ] = name
657
- hooked_param_ids .update (set (p_ids ))
658
-
659
- # check that all parameters are in hooked modules
660
- if not set (param_ids ).issubset (hooked_param_ids ):
661
- raise NotImplementedError ("Found parameters outside supported layers." )
662
-
663
- return param_ids , param_ids_to_hooked_modules
602
+ param_positions = {}
603
+ for p_name , p in mod .named_parameters ():
604
+ p_id = p .data_ptr ()
605
+ if p_id in param_ids :
606
+ pos = param_ids .index (p_id )
607
+ param_positions [p_name ] = pos
608
+ processed .add (p_id )
609
+ positions [mod_name ] = param_positions
610
+
611
+ # check that all parameters are in known modules
612
+ if len (processed ) != len (param_ids ):
613
+ raise NotImplementedError ("Found parameters in un-supported layers." )
614
+
615
+ return positions
0 commit comments