Skip to content

Commit

Permalink
Add support for >2d linear layer inputs (expand) and model outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jan 9, 2024
1 parent 5ae9553 commit 2acd800
Showing 1 changed file with 52 additions and 33 deletions.
85 changes: 52 additions & 33 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class KFACLinearOperator(_LinearOperator):

_SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss)
_SUPPORTED_MODULES = (Linear,)
_SUPPORTED_LOSS_AVERAGE: Tuple[Union[None, str], ...] = (
None,
"batch",
"batch+sequence",
)

def __init__(
self,
Expand All @@ -84,6 +89,7 @@ def __init__(
seed: int = 2147483647,
fisher_type: str = "mc",
mc_samples: int = 1,
loss_average: Union[None, str] = "batch",
separate_weight_and_bias: bool = True,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand Down Expand Up @@ -128,6 +134,15 @@ def __init__(
mc_samples: The number of Monte-Carlo samples to use per data point.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
loss_average: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If `"batch"`, the loss function is a mean over as many terms as
the size of the mini-batch. If `"batch+sequence"`, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If `None`, the loss function is a sum. This
argument is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: `"batch"`.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
Expand All @@ -139,6 +154,16 @@ def __init__(
raise ValueError(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)
if loss_average not in self._SUPPORTED_LOSS_AVERAGE:
raise ValueError(
f"Invalid loss_average: {loss_average}. "
f"Supported: {self._SUPPORTED_LOSS_AVERAGE}."
)
if loss_average is None and loss_func.reduction != "sum":
raise ValueError(
f"Invalid loss_average: {loss_average}. "
f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'."
)
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
Expand Down Expand Up @@ -167,6 +192,7 @@ def __init__(
self._separate_weight_and_bias = separate_weight_and_bias
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._loss_average = loss_average
self._input_covariances: Dict[str, Tensor] = {}
self._gradient_covariances: Dict[str, Tensor] = {}

Expand Down Expand Up @@ -284,14 +310,17 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
Raises:
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
NotImplementedError: If ``fisher_type`` is ``'type-2'`` and the
output is not 2d.
"""
# if >2d output we convert to an equivalent 2d output
if output.ndim > 2:
if isinstance(self._loss_func, CrossEntropyLoss):
output = rearrange(output, "batch c ... -> (batch ...) c")
y = rearrange(y, "batch ... -> (batch ...)")
else:
output = rearrange(output, "batch ... c -> (batch ...) c")
y = rearrange(y, "batch ... c -> (batch ...) c")

if self._fisher_type == "type-2":
if output.ndim != 2:
raise NotImplementedError(
"Type-2 Fisher not implemented for non-2d output."
)
# Compute per-sample Hessian square root, then concatenate over samples.
# Result has shape `(batch_size, num_classes, num_classes)`
hessian_sqrts = stack(
Expand Down Expand Up @@ -365,20 +394,12 @@ def draw_label(self, output: Tensor) -> Tensor:
return output.clone().detach() + perturbation

elif isinstance(self._loss_func, CrossEntropyLoss):
# TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting
# from these labels might be off
if output.ndim != 2:
raise NotImplementedError(
"Only 2D output is supported for CrossEntropyLoss for now."
)
probs = output.softmax(dim=1)
# each row contains a vector describing a categorical
probs_as_mat = rearrange(probs, "n c ... -> (n ...) c")
labels = probs_as_mat.multinomial(
probs = output.softmax(dim=1)
labels = probs.multinomial(
num_samples=1, generator=self._generator
).squeeze(-1)
label_shape = output.shape[:1] + output.shape[2:]
return labels.reshape(label_shape)
return labels

else:
raise NotImplementedError
Expand Down Expand Up @@ -422,20 +443,21 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
NotImplementedError: If the layer is not supported.
"""
g = grad_output.data.detach()
sequence_length = g.shape[1:-1].numel()
if self._loss_average is not None:
num_loss_terms = g.shape[0] # batch_size
if self._loss_average == "batch+sequence":
# Number of all loss terms = batch_size * sequence_length
num_loss_terms *= sequence_length

if isinstance(module, Linear):
if g.ndim != 2:
# TODO Support weight sharing
raise NotImplementedError(
"Only 2d grad_outputs are supported for linear layers. "
+ f"Got {g.ndim}d."
)
g = rearrange(g, "batch ... d_out -> (batch ...) d_out")

batch_size = g.shape[0]
if isinstance(module, Linear):
# self._mc_samples will be 1 if fisher_type != "mc"
correction = {
"sum": 1.0 / self._mc_samples,
"mean": batch_size**2 / (self._N_data * self._mc_samples),
"mean": num_loss_terms**2
/ (self._N_data * self._mc_samples * sequence_length),
}[self._loss_func.reduction]
covariance = einsum("bi,bj->ij", g, g).mul_(correction)
else:
Expand Down Expand Up @@ -468,21 +490,18 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
if len(inputs) != 1:
raise ValueError("Modules with multiple inputs are not supported.")
x = inputs[0].data.detach()
sequence_length = x.shape[1:-1].numel()

if isinstance(module, Linear):
if x.ndim != 2:
# TODO Support weight sharing
raise NotImplementedError(
f"Only 2d inputs are supported for linear layers. Got {x.ndim}d."
)
x = rearrange(x, "batch ... d_in -> (batch ...) d_in")

if isinstance(module, Linear):
if (
self.in_params(module.weight, module.bias)
and not self._separate_weight_and_bias
):
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)

covariance = einsum("bi,bj->ij", x, x).div_(self._N_data)
covariance = einsum("bi,bj->ij", x, x).div_(self._N_data * sequence_length)
else:
# TODO Support convolutions
raise NotImplementedError(f"Layer of type {type(module)} is unsupported.")
Expand Down

0 comments on commit 2acd800

Please sign in to comment.