Skip to content

Commit 37c3f9b

Browse files
committed
[REF] Use module names for internal mapping to parameter positions
1 parent 494ec00 commit 37c3f9b

File tree

2 files changed

+80
-131
lines changed

2 files changed

+80
-131
lines changed

curvlinops/kfac.py

Lines changed: 79 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from functools import partial
2222
from math import sqrt
23-
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
23+
from typing import Dict, Iterable, List, Optional, Tuple, Union
2424

2525
from einops import einsum, rearrange, reduce
2626
from numpy import ndarray
@@ -211,10 +211,6 @@ def __init__(
211211
f"Supported: {self._SUPPORTED_KFAC_APPROX}."
212212
)
213213

214-
self.param_ids, self.param_ids_to_hooked_modules = (
215-
self.parameter_to_module_mapping(params, model_func)
216-
)
217-
218214
self._seed = seed
219215
self._generator: Union[None, Generator] = None
220216
self._separate_weight_and_bias = separate_weight_and_bias
@@ -224,6 +220,7 @@ def __init__(
224220
self._loss_average = loss_average
225221
self._input_covariances: Dict[str, Tensor] = {}
226222
self._gradient_covariances: Dict[str, Tensor] = {}
223+
self._mapping = self.compute_parameter_mapping(params, model_func)
227224

228225
super().__init__(
229226
model_func,
@@ -244,75 +241,47 @@ def _matmat(self, M: ndarray) -> ndarray:
244241
245242
Returns:
246243
Matrix-multiplication result ``KFAC @ M``. Has shape ``[D, K]``.
247-
248-
Raises:
249-
RuntimeError: If the incoming matrix was not fully processed, indicating an
250-
error due to the internal mapping from parameters to modules.
251244
"""
252-
# Need to update parameter mapping if they have changed (e.g. device
253-
# transfer), and reset caches
254-
if self.param_ids != [p.data_ptr() for p in self._params]:
255-
print("Invalidated parameter mapping detected")
256-
self.param_ids, self.param_ids_to_hooked_modules = (
257-
self.parameter_to_module_mapping(self._params, self._model_func)
258-
)
259-
self._input_covariances, self._gradient_covariances = {}, {}
260-
261245
if not self._input_covariances and not self._gradient_covariances:
262246
self._compute_kfac()
263247

264248
M_torch = super()._preprocess(M)
265-
processed = set()
266-
267-
for name in self.param_ids_to_hooked_modules.values():
268-
mod = self._model_func.get_submodule(name)
269249

250+
for mod_name, param_pos in self._mapping.items():
270251
# bias and weights are treated jointly
271-
if not self._separate_weight_and_bias and self.in_params(
272-
mod.weight, mod.bias
252+
if (
253+
not self._separate_weight_and_bias
254+
and "weight" in param_pos.keys()
255+
and "bias" in param_pos.keys()
273256
):
274-
w_pos, b_pos = self.param_pos(mod.weight), self.param_pos(mod.bias)
257+
w_pos, b_pos = param_pos["weight"], param_pos["bias"]
275258
# v denotes the free dimension for treating multiple vectors in parallel
276259
M_w = rearrange(M_torch[w_pos], "v c_out ... -> v c_out (...)")
277260
M_joint = cat([M_w, M_torch[b_pos].unsqueeze(-1)], dim=2)
278-
aaT = self._input_covariances[name]
279-
ggT = self._gradient_covariances[name]
261+
aaT = self._input_covariances[mod_name]
262+
ggT = self._gradient_covariances[mod_name]
280263
M_joint = einsum(ggT, M_joint, aaT, "i j,v j k,k l -> v i l")
281264

282265
w_cols = M_w.shape[2]
283266
M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2)
284-
processed.update([w_pos, b_pos])
285267

286268
# for weights we need to multiply from the right with aaT
287269
# for weights and biases we need to multiply from the left with ggT
288270
else:
289-
for p_name in ["weight", "bias"]:
290-
p = getattr(mod, p_name)
291-
if self.in_params(p):
292-
pos = self.param_pos(p)
293-
294-
if p_name == "weight":
295-
M_w = rearrange(
296-
M_torch[pos], "v c_out ... -> v c_out (...)"
297-
)
298-
M_torch[pos] = einsum(
299-
M_w,
300-
self._input_covariances[name],
301-
"v c_out j,j k -> v c_out k",
302-
)
303-
271+
for p_name, pos in param_pos.items():
272+
if p_name == "weight":
273+
M_w = rearrange(M_torch[pos], "v c_out ... -> v c_out (...)")
304274
M_torch[pos] = einsum(
305-
self._gradient_covariances[name],
306-
M_torch[pos],
307-
"j k,v k ... -> v j ...",
275+
M_w,
276+
self._input_covariances[mod_name],
277+
"v c_out j,j k -> v c_out k",
308278
)
309-
processed.add(pos)
310279

311-
if processed != set(range(len(M_torch))):
312-
raise RuntimeError(
313-
"Some entries of the matrix were not modified."
314-
+ f" Out of {len(M_torch)}, the following entries were processed: {processed}."
315-
)
280+
M_torch[pos] = einsum(
281+
self._gradient_covariances[mod_name],
282+
M_torch[pos],
283+
"j k,v k ... -> v j ...",
284+
)
316285

317286
return self._postprocess(M_torch)
318287

@@ -331,21 +300,26 @@ def _compute_kfac(self):
331300
# install forward and backward hooks
332301
hook_handles: List[RemovableHandle] = []
333302

334-
for name in self.param_ids_to_hooked_modules.values():
335-
module = self._model_func.get_submodule(name)
303+
for mod_name, param_pos in self._mapping.items():
304+
module = self._model_func.get_submodule(mod_name)
336305

337306
# input covariance only required for weights
338-
if self.in_params(module.weight):
307+
if "weight" in param_pos.keys():
339308
hook_handles.append(
340309
module.register_forward_pre_hook(
341-
self._hook_accumulate_input_covariance
310+
partial(
311+
self._hook_accumulate_input_covariance, module_name=mod_name
312+
)
342313
)
343314
)
344315

345316
# gradient covariance required for weights and biases
346317
hook_handles.append(
347318
module.register_forward_hook(
348-
self._register_tensor_hook_on_output_to_accumulate_gradient_covariance
319+
partial(
320+
self._register_tensor_hook_on_output_to_accumulate_gradient_covariance,
321+
module_name=mod_name,
322+
)
349323
)
350324
)
351325

@@ -471,7 +445,7 @@ def draw_label(self, output: Tensor) -> Tensor:
471445
raise NotImplementedError
472446

473447
def _register_tensor_hook_on_output_to_accumulate_gradient_covariance(
474-
self, module: Module, inputs: Tuple[Tensor], output: Tensor
448+
self, module: Module, inputs: Tuple[Tensor], output: Tensor, module_name: str
475449
):
476450
"""Register tensor hook on layer's output to accumulate the grad. covariance.
477451
@@ -491,18 +465,24 @@ def _register_tensor_hook_on_output_to_accumulate_gradient_covariance(
491465
covariance will be installed.
492466
inputs: The layer's input tensors.
493467
output: The layer's output tensor.
468+
module_name: The name of the layer in the neural network.
494469
"""
495-
tensor_hook = partial(self._accumulate_gradient_covariance, module)
470+
tensor_hook = partial(
471+
self._accumulate_gradient_covariance, module=module, module_name=module_name
472+
)
496473
output.register_hook(tensor_hook)
497474

498-
def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
475+
def _accumulate_gradient_covariance(
476+
self, grad_output: Tensor, module: Module, module_name: str
477+
):
499478
"""Accumulate the gradient covariance for a layer's output.
500479
501480
Updates ``self._gradient_covariances``.
502481
503482
Args:
504-
module: The layer whose output's gradient covariance will be accumulated.
505483
grad_output: The gradient w.r.t. the output.
484+
module: The layer whose output's gradient covariance will be accumulated.
485+
module_name: The name of the layer in the neural network.
506486
"""
507487
g = grad_output.data.detach()
508488
batch_size = g.shape[0]
@@ -531,20 +511,22 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
531511
}[self._loss_average]
532512
covariance = einsum(g, g, "b i,b j->i j").mul_(correction)
533513

534-
name = self.get_module_name(module)
535-
if name not in self._gradient_covariances:
536-
self._gradient_covariances[name] = covariance
514+
if module_name not in self._gradient_covariances:
515+
self._gradient_covariances[module_name] = covariance
537516
else:
538-
self._gradient_covariances[name].add_(covariance)
517+
self._gradient_covariances[module_name].add_(covariance)
539518

540-
def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor]):
519+
def _hook_accumulate_input_covariance(
520+
self, module: Module, inputs: Tuple[Tensor], module_name: str
521+
):
541522
"""Pre-forward hook that accumulates the input covariance of a layer.
542523
543524
Updates ``self._input_covariances``.
544525
545526
Args:
546527
module: Module on which the hook is called.
547528
inputs: Inputs to the module.
529+
module_name: Name of the module in the neural network.
548530
549531
Raises:
550532
ValueError: If the module has multiple inputs.
@@ -576,88 +558,58 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
576558
scale = 1.0 # since we use a mean reduction
577559
x = reduce(x, "batch ... d_in -> batch d_in", "mean")
578560

561+
params = self._mapping[module_name]
579562
if (
580-
self.in_params(module.weight, module.bias)
563+
"weight" in params.keys()
564+
and "bias" in params.keys()
581565
and not self._separate_weight_and_bias
582566
):
583567
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)
584568

585569
covariance = einsum(x, x, "b i,b j -> i j").div_(self._N_data * scale)
586570

587-
name = self.get_module_name(module)
588-
if name not in self._input_covariances:
589-
self._input_covariances[name] = covariance
571+
if module_name not in self._input_covariances:
572+
self._input_covariances[module_name] = covariance
590573
else:
591-
self._input_covariances[name].add_(covariance)
592-
593-
def get_module_name(self, module: Module) -> str:
594-
"""Get the name of a module.
595-
596-
Args:
597-
module: The module.
598-
599-
Returns:
600-
The name of the module.
601-
"""
602-
p_ids = tuple(p.data_ptr() for p in module.parameters())
603-
return self.param_ids_to_hooked_modules[p_ids]
604-
605-
def in_params(self, *params: Union[Parameter, Tensor, None]) -> bool:
606-
"""Check if all parameters are used in KFAC.
607-
608-
Args:
609-
params: Parameters to check.
610-
611-
Returns:
612-
Whether all parameters are used in KFAC.
613-
"""
614-
return all(p is not None and p.data_ptr() in self.param_ids for p in params)
615-
616-
def param_pos(self, param: Union[Parameter, Tensor]) -> int:
617-
"""Get the position of a parameter in the list of parameters used in KFAC.
618-
619-
Args:
620-
param: The parameter.
621-
622-
Returns:
623-
The parameter's position in the parameter list.
624-
"""
625-
return self.param_ids.index(param.data_ptr())
574+
self._input_covariances[module_name].add_(covariance)
626575

627576
@classmethod
628-
def parameter_to_module_mapping(
629-
cls, params: List[Tensor], model_func: Module
630-
) -> Tuple[List[int], Dict[Tuple[int, ...], str]]:
631-
"""Construct the mapping between parameters and modules.
577+
def compute_parameter_mapping(
578+
cls, params: List[Union[Tensor, Parameter]], model_func: Module
579+
) -> Dict[str, Dict[str, int]]:
580+
"""Construct the mapping between layers, their parameters, and positions.
632581
633582
Args:
634583
params: List of parameters.
635584
model_func: The model function.
636585
637586
Returns:
638-
A tuple containing:
639-
- A list of parameter data pointers.
640-
- A dictionary mapping from tuples of parameter data pointers in a module
641-
to its name.
587+
A dictionary of dictionaries. The outer dictionary's keys are the names of
588+
the layers that contain parameters. The interior dictionary's keys are the
589+
parameter names, and the values their respective positions.
642590
643591
Raises:
644592
NotImplementedError: If parameters are found outside supported layers.
645593
"""
646594
param_ids = [p.data_ptr() for p in params]
647-
# mapping from tuples of parameter data pointers in a module to its name
648-
param_ids_to_hooked_modules: Dict[Tuple[int, ...], str] = {}
595+
positions = {}
596+
processed = set()
649597

650-
hooked_param_ids: Set[int] = set()
651-
for name, mod in model_func.named_modules():
652-
p_ids = tuple(p.data_ptr() for p in mod.parameters())
598+
for mod_name, mod in model_func.named_modules():
653599
if isinstance(mod, cls._SUPPORTED_MODULES) and any(
654-
p_id in param_ids for p_id in p_ids
600+
p.data_ptr() in param_ids for p in mod.parameters()
655601
):
656-
param_ids_to_hooked_modules[p_ids] = name
657-
hooked_param_ids.update(set(p_ids))
658-
659-
# check that all parameters are in hooked modules
660-
if not set(param_ids).issubset(hooked_param_ids):
661-
raise NotImplementedError("Found parameters outside supported layers.")
662-
663-
return param_ids, param_ids_to_hooked_modules
602+
param_positions = {}
603+
for p_name, p in mod.named_parameters():
604+
p_id = p.data_ptr()
605+
if p_id in param_ids:
606+
pos = param_ids.index(p_id)
607+
param_positions[p_name] = pos
608+
processed.add(p_id)
609+
positions[mod_name] = param_positions
610+
611+
# check that all parameters are in known modules
612+
if len(processed) != len(param_ids):
613+
raise NotImplementedError("Found parameters in un-supported layers.")
614+
615+
return positions

test/test_kfac.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,7 @@ def test_bug_device_change_invalidates_parameter_mapping():
422422
x = rand(kfac.shape[1]).numpy()
423423
kfac_x_gpu = kfac @ x
424424

425-
kfac.to_device(cpu) # invalidates internal mapping
426-
assert kfac.param_ids != [p.data_ptr() for p in kfac._params]
425+
kfac.to_device(cpu)
427426
kfac_x_cpu = kfac @ x
428-
# make sure invalidation is detected and fixed inside ``matmat``
429-
assert kfac.param_ids == [p.data_ptr() for p in kfac._params]
430427

431428
report_nonclose(kfac_x_gpu, kfac_x_cpu)

0 commit comments

Comments
 (0)