Skip to content

Commit

Permalink
[ADD] Option to specify number of points in data set
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jan 31, 2024
1 parent cb40294 commit 9caa6bb
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
progressbar: bool = False,
check_deterministic: bool = True,
shape: Optional[Tuple[int, int]] = None,
num_data: Optional[int] = None,
):
"""Linear operator for DNN matrices.
Expand All @@ -64,6 +65,8 @@ def __init__(
safeguard, only turn it off if you know what you are doing.
shape: Shape of the represented matrix. If ``None`` assumes ``(D, D)``
where ``D`` is the total number of parameters
num_data: Number of data points. If ``None``, it is inferred from the data
at the cost of one traversal through the data loader.
Raises:
RuntimeError: If the check for deterministic behavior fails.
Expand All @@ -80,8 +83,10 @@ def __init__(
self._device = self._infer_device(self._params)
self._progressbar = progressbar

self._N_data = sum(
X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data")
self._N_data = (
num_data
if num_data is not None
else sum(X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data"))
)

if check_deterministic:
Expand Down

0 comments on commit 9caa6bb

Please sign in to comment.