Skip to content

Commit

Permalink
[REF] Update matmat of inverse KFAC using the new mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Feb 15, 2024
1 parent 2ba3aa2 commit cd8e328
Showing 1 changed file with 15 additions and 22 deletions.
37 changes: 15 additions & 22 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,18 +288,17 @@ def _matmat(self, M: ndarray) -> ndarray:

M_torch = self._A._preprocess(M)

for name in self._A.param_ids_to_hooked_modules.values():
mod = self._A._model_func.get_submodule(name)

for mod_name, param_pos in self._A._mapping.items():
# retrieve the inverses of the Kronecker factors from cache or invert them
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(name)
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name)

# bias and weights are treated jointly
weight, bias = mod.weight, mod.bias
if not self._A._separate_weight_and_bias and self._A.in_params(
weight, bias
if (
not self._A._separate_weight_and_bias
and "weight" in param_pos.keys()
and "bias" in param_pos.keys()
):
w_pos, b_pos = self._A.param_pos(weight), self._A.param_pos(bias)
w_pos, b_pos = param_pos["weight"], param_pos["bias"]
M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)")
M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2)
M_joint = einsum(ggT_inv, M_joint, aaT_inv, "i j,m j k, k l -> m i l")
Expand All @@ -310,19 +309,13 @@ def _matmat(self, M: ndarray) -> ndarray:
# for weights we need to multiply from the right with aaT
# for weights and biases we need to multiply from the left with ggT
else:
for p_name in ["weight", "bias"]:
p = getattr(mod, p_name)
if self._A.in_params(p):
pos = self._A.param_pos(p)

if p_name == "weight":
M_w = rearrange(
M_torch[pos], "m c_out ... -> m c_out (...)"
)
M_torch[pos] = einsum(M_w, aaT_inv, "m i j, j k -> m i k")

M_torch[pos] = einsum(
ggT_inv, M_torch[pos], "i j, m j ... -> m i ..."
)
for p_name, pos in param_pos.items():
if p_name == "weight":
M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)")
M_torch[pos] = einsum(M_w, aaT_inv, "m i j, j k -> m i k")

M_torch[pos] = einsum(
ggT_inv, M_torch[pos], "i j, m j ... -> m i ..."
)

return self._A._postprocess(M_torch)

0 comments on commit cd8e328

Please sign in to comment.