diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 19ad6e7a..3fcc34bf 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -71,6 +71,11 @@ class KFACLinearOperator(_LinearOperator): _SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss) _SUPPORTED_MODULES = (Linear,) + _SUPPORTED_LOSS_AVERAGE: Tuple[Union[None, str], ...] = ( + None, + "batch", + "batch+sequence", + ) def __init__( self, @@ -84,6 +89,7 @@ def __init__( seed: int = 2147483647, fisher_type: str = "mc", mc_samples: int = 1, + loss_average: Union[None, str] = "batch", separate_weight_and_bias: bool = True, ): """Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN. @@ -128,6 +134,15 @@ def __init__( mc_samples: The number of Monte-Carlo samples to use per data point. Has to be set to ``1`` when ``fisher_type != 'mc'``. Defaults to ``1``. + loss_average: Whether the loss function is a mean over per-sample + losses and if yes, over which dimensions the mean is taken. + If `"batch"`, the loss function is a mean over as many terms as + the size of the mini-batch. If `"batch+sequence"`, the loss + function is a mean over as many terms as the size of the + mini-batch times the sequence length, e.g. in the case of + language modeling. If `None`, the loss function is a sum. This + argument is used to ensure that the preconditioner is scaled + consistently with the loss and the gradient. Default: `"batch"`. separate_weight_and_bias: Whether to treat weights and biases separately. Defaults to ``True``. @@ -139,6 +154,16 @@ def __init__( raise ValueError( f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}." ) + if loss_average not in self._SUPPORTED_LOSS_AVERAGE: + raise ValueError( + f"Invalid loss_average: {loss_average}. " + f"Supported: {self._SUPPORTED_LOSS_AVERAGE}." + ) + if loss_average is None and loss_func.reduction != "sum": + raise ValueError( + f"Invalid loss_average: {loss_average}. " + f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'." + ) if fisher_type != "mc" and mc_samples != 1: raise ValueError( f"Invalid mc_samples: {mc_samples}. " @@ -167,6 +192,7 @@ def __init__( self._separate_weight_and_bias = separate_weight_and_bias self._fisher_type = fisher_type self._mc_samples = mc_samples + self._loss_average = loss_average self._input_covariances: Dict[str, Tensor] = {} self._gradient_covariances: Dict[str, Tensor] = {} @@ -284,14 +310,17 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor): Raises: ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or ``'empirical'``. - NotImplementedError: If ``fisher_type`` is ``'type-2'`` and the - output is not 2d. """ + # if >2d output we convert to an equivalent 2d output + if output.ndim > 2: + if isinstance(self._loss_func, CrossEntropyLoss): + output = rearrange(output, "batch c ... -> (batch ...) c") + y = rearrange(y, "batch ... -> (batch ...)") + else: + output = rearrange(output, "batch ... c -> (batch ...) c") + y = rearrange(y, "batch ... c -> (batch ...) c") + if self._fisher_type == "type-2": - if output.ndim != 2: - raise NotImplementedError( - "Type-2 Fisher not implemented for non-2d output." - ) # Compute per-sample Hessian square root, then concatenate over samples. # Result has shape `(batch_size, num_classes, num_classes)` hessian_sqrts = stack( @@ -365,20 +394,12 @@ def draw_label(self, output: Tensor) -> Tensor: return output.clone().detach() + perturbation elif isinstance(self._loss_func, CrossEntropyLoss): - # TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting - # from these labels might be off - if output.ndim != 2: - raise NotImplementedError( - "Only 2D output is supported for CrossEntropyLoss for now." - ) - probs = output.softmax(dim=1) # each row contains a vector describing a categorical - probs_as_mat = rearrange(probs, "n c ... -> (n ...) c") - labels = probs_as_mat.multinomial( + probs = output.softmax(dim=1) + labels = probs.multinomial( num_samples=1, generator=self._generator ).squeeze(-1) - label_shape = output.shape[:1] + output.shape[2:] - return labels.reshape(label_shape) + return labels else: raise NotImplementedError @@ -422,20 +443,21 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor): NotImplementedError: If the layer is not supported. """ g = grad_output.data.detach() + sequence_length = g.shape[1:-1].numel() + if self._loss_average is not None: + num_loss_terms = g.shape[0] # batch_size + if self._loss_average == "batch+sequence": + # Number of all loss terms = batch_size * sequence_length + num_loss_terms *= sequence_length - if isinstance(module, Linear): - if g.ndim != 2: - # TODO Support weight sharing - raise NotImplementedError( - "Only 2d grad_outputs are supported for linear layers. " - + f"Got {g.ndim}d." - ) + g = rearrange(g, "batch ... d_out -> (batch ...) d_out") - batch_size = g.shape[0] + if isinstance(module, Linear): # self._mc_samples will be 1 if fisher_type != "mc" correction = { "sum": 1.0 / self._mc_samples, - "mean": batch_size**2 / (self._N_data * self._mc_samples), + "mean": num_loss_terms**2 + / (self._N_data * self._mc_samples * sequence_length), }[self._loss_func.reduction] covariance = einsum("bi,bj->ij", g, g).mul_(correction) else: @@ -468,21 +490,18 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor if len(inputs) != 1: raise ValueError("Modules with multiple inputs are not supported.") x = inputs[0].data.detach() + sequence_length = x.shape[1:-1].numel() - if isinstance(module, Linear): - if x.ndim != 2: - # TODO Support weight sharing - raise NotImplementedError( - f"Only 2d inputs are supported for linear layers. Got {x.ndim}d." - ) + x = rearrange(x, "batch ... d_in -> (batch ...) d_in") + if isinstance(module, Linear): if ( self.in_params(module.weight, module.bias) and not self._separate_weight_and_bias ): x = cat([x, x.new_ones(x.shape[0], 1)], dim=1) - covariance = einsum("bi,bj->ij", x, x).div_(self._N_data) + covariance = einsum("bi,bj->ij", x, x).div_(self._N_data * sequence_length) else: # TODO Support convolutions raise NotImplementedError(f"Layer of type {type(module)} is unsupported.")