From 9caa6bb48787fc97e8b396105c592977c8331f47 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 31 Jan 2024 18:48:46 -0500 Subject: [PATCH] [ADD] Option to specify number of points in data set --- curvlinops/_base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/curvlinops/_base.py b/curvlinops/_base.py index d6df534b..7b13b6b7 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -38,6 +38,7 @@ def __init__( progressbar: bool = False, check_deterministic: bool = True, shape: Optional[Tuple[int, int]] = None, + num_data: Optional[int] = None, ): """Linear operator for DNN matrices. @@ -64,6 +65,8 @@ def __init__( safeguard, only turn it off if you know what you are doing. shape: Shape of the represented matrix. If ``None`` assumes ``(D, D)`` where ``D`` is the total number of parameters + num_data: Number of data points. If ``None``, it is inferred from the data + at the cost of one traversal through the data loader. Raises: RuntimeError: If the check for deterministic behavior fails. @@ -80,8 +83,10 @@ def __init__( self._device = self._infer_device(self._params) self._progressbar = progressbar - self._N_data = sum( - X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data") + self._N_data = ( + num_data + if num_data is not None + else sum(X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data")) ) if check_deterministic: