Skip to content

Commit 73ee0eb

Browse files
authored
[REF | KFAC] Use module names instead of data_ptr()s for mapping (#79)
* [REF] Use module names for internal mapping to parameter positions * [ADD] Load Kronecker factors to new device during `to_device` * [REF] Overwrite params using keys * [REF] Update matmat of inverse KFAC using the new mapping
1 parent 494ec00 commit 73ee0eb

File tree

3 files changed

+108
-154
lines changed

3 files changed

+108
-154
lines changed

curvlinops/inverse.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -288,18 +288,17 @@ def _matmat(self, M: ndarray) -> ndarray:
288288

289289
M_torch = self._A._preprocess(M)
290290

291-
for name in self._A.param_ids_to_hooked_modules.values():
292-
mod = self._A._model_func.get_submodule(name)
293-
291+
for mod_name, param_pos in self._A._mapping.items():
294292
# retrieve the inverses of the Kronecker factors from cache or invert them
295-
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(name)
293+
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name)
296294

297295
# bias and weights are treated jointly
298-
weight, bias = mod.weight, mod.bias
299-
if not self._A._separate_weight_and_bias and self._A.in_params(
300-
weight, bias
296+
if (
297+
not self._A._separate_weight_and_bias
298+
and "weight" in param_pos.keys()
299+
and "bias" in param_pos.keys()
301300
):
302-
w_pos, b_pos = self._A.param_pos(weight), self._A.param_pos(bias)
301+
w_pos, b_pos = param_pos["weight"], param_pos["bias"]
303302
M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)")
304303
M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2)
305304
M_joint = einsum(ggT_inv, M_joint, aaT_inv, "i j,m j k, k l -> m i l")
@@ -310,19 +309,13 @@ def _matmat(self, M: ndarray) -> ndarray:
310309
# for weights we need to multiply from the right with aaT
311310
# for weights and biases we need to multiply from the left with ggT
312311
else:
313-
for p_name in ["weight", "bias"]:
314-
p = getattr(mod, p_name)
315-
if self._A.in_params(p):
316-
pos = self._A.param_pos(p)
317-
318-
if p_name == "weight":
319-
M_w = rearrange(
320-
M_torch[pos], "m c_out ... -> m c_out (...)"
321-
)
322-
M_torch[pos] = einsum(M_w, aaT_inv, "m i j, j k -> m i k")
323-
324-
M_torch[pos] = einsum(
325-
ggT_inv, M_torch[pos], "i j, m j ... -> m i ..."
326-
)
312+
for p_name, pos in param_pos.items():
313+
if p_name == "weight":
314+
M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)")
315+
M_torch[pos] = einsum(M_w, aaT_inv, "m i j, j k -> m i k")
316+
317+
M_torch[pos] = einsum(
318+
ggT_inv, M_torch[pos], "i j, m j ... -> m i ..."
319+
)
327320

328321
return self._A._postprocess(M_torch)

0 commit comments

Comments
 (0)