diff --git a/curvlinops/experimental/activation_hessian.py b/curvlinops/experimental/activation_hessian.py index aaf9de32..039786c0 100644 --- a/curvlinops/experimental/activation_hessian.py +++ b/curvlinops/experimental/activation_hessian.py @@ -150,6 +150,14 @@ def _matvec_batch( return hessian_vector_product(loss, [activation], x_list) def _preprocess(self, x: ndarray) -> List[Tensor]: + """Reshape the incoming vector into the activation shape and convert to PyTorch. + + Args: + x: Vector in NumPy format onto which the linear operator is applied. + + Returns: + Vector in PyTorch format. Has same shape as the activation. + """ return [from_numpy(x).to(self._device).reshape(self._activation_shape)]