Skip to content

Commit 3985744

Browse files
committed
speedups and bugfixes
1 parent 98f3b99 commit 3985744

File tree

4 files changed

+89
-65
lines changed

4 files changed

+89
-65
lines changed

torchcast/internals/batch_design/measurement_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _adjust_measure_mat(self,
161161
# apply measure-wide adjustment
162162
measure_mat[i] = self.measure_funs[measure].adjust_measure_mat(measure_mat[i], measured_mean[i])
163163

164-
return torch.stack(measure_mat, dim=-1)
164+
return torch.stack(measure_mat, dim=-2)
165165

166166
@cached_property
167167
def measure2idx(self) -> dict[str, int]:

torchcast/internals/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def get_subclasses(cls: Type) -> Iterable[Type]:
1212
yield subclass
1313

1414

15-
@functools.lru_cache(maxsize=100)
1615
def get_meshgrids(groups: torch.Tensor,
1716
val_idx: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
1817
"""
@@ -189,13 +188,21 @@ def get_nan_groups(isnan: torch.Tensor) -> List[Tuple[torch.Tensor, Optional[tor
189188
"""
190189
assert len(isnan.shape) == 2
191190
state_dim = isnan.shape[-1]
192-
out: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = []
191+
192+
out = []
193193
if state_dim == 1:
194194
# shortcut for univariate
195195
group_idx = (~isnan.squeeze(-1)).nonzero().view(-1)
196196
out.append((group_idx, None))
197197
return out
198-
for nan_combo in torch.unique(isnan, dim=0):
198+
199+
nan_combos = torch.unique(isnan, dim=0)
200+
if len(nan_combos) == 1 and nan_combos[0].sum() == 0:
201+
# shortcut for no nans
202+
out.append((torch.arange(isnan.shape[0]), None))
203+
return out
204+
205+
for nan_combo in nan_combos:
199206
num_nan = nan_combo.sum()
200207
if num_nan < state_dim:
201208
c1 = (isnan * nan_combo[None, :]).sum(1) == num_nan

torchcast/state_space/predictions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,18 @@ def measure_covs_flat(self) -> torch.Tensor:
425425
self._state_means_flat, self._state_covs_flat, self._mcovs_flat = self._flatten()
426426
return self._mcovs_flat
427427

428-
def log_prob(self, obs: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
428+
def log_prob(self,
429+
obs: torch.Tensor,
430+
weights: Optional[torch.Tensor] = None,
431+
nan_groups_flat: Optional[Sequence[tuple[torch.Tensor, Optional[torch.Tensor]]]] = None
432+
) -> torch.Tensor:
429433
"""
430434
Compute the log-probability of data (e.g. data that was originally fed into the ``StateSpaceModel``).
431435
432436
:param obs: A Tensor that could be used in the ``StateSpaceModel`` forward pass.
433437
:param weights: If specified, will be used to weight the log-probability of each group X timestep.
438+
:param nan_groups_flat: used by StateSpaceModel.fit() for speeding up computations, pre-computing nan-masks at
439+
the start of fitting rather than doing so on each call to log_prob().
434440
:return: A tensor with one element for each group X timestep indicating the log-probability.
435441
"""
436442
assert len(obs.shape) == 3
@@ -447,7 +453,11 @@ def log_prob(self, obs: torch.Tensor, weights: Optional[torch.Tensor] = None) ->
447453
measure_covs_flat = self.measure_covs.view(-1, measure_rank, measure_rank)
448454

449455
lp_flat = torch.zeros(obs_flat.shape[0], dtype=self.state_means.dtype, device=self.state_means.device)
450-
for gt_idx, valid_idx in get_nan_groups(torch.isnan(obs_flat)):
456+
457+
if nan_groups_flat is None:
458+
nan_groups_flat = get_nan_groups(torch.isnan(obs_flat))
459+
460+
for gt_idx, valid_idx in nan_groups_flat:
451461
if valid_idx is None:
452462
gt_obs = obs_flat[gt_idx]
453463
gt_mcov = measure_covs_flat[gt_idx]

torchcast/state_space/state_space.py

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from torchcast.internals.batch_design import TransitionModel, MeasurementModel, MeasureFun
1010
from torchcast.internals.hessian import hessian
11-
from torchcast.internals.utils import repeat, true1d_idx, get_nan_groups, mask_mats, get_meshgrids
11+
from torchcast.internals.utils import repeat, true1d_idx, get_nan_groups, get_meshgrids
1212
from torchcast.covariance import Covariance
1313
from torchcast.state_space.predictions import Predictions
1414
from torchcast.process.regression import Process
@@ -74,11 +74,6 @@ def __init__(self,
7474
else:
7575
self.dt_unit = process.dt_unit
7676

77-
# @property
78-
# def is_nonlinear(self) -> bool:
79-
# return any(not p.linear_measurement for p in self.processes.values()) or self.measure_funs
80-
81-
@torch.jit.ignore()
8277
def forward(self,
8378
y: Optional[torch.Tensor] = None,
8479
n_step: Union[int, float] = 1,
@@ -88,7 +83,6 @@ def forward(self,
8883
every_step: bool = True,
8984
include_updates_in_output: bool = False,
9085
simulate: Optional[int] = None,
91-
last_measured_per_group: Optional[torch.Tensor] = None,
9286
prediction_kwargs: Optional[dict] = None,
9387
**kwargs) -> 'Predictions':
9488
"""
@@ -119,14 +113,6 @@ def forward(self,
119113
:param include_updates_in_output: If False, only the ``n_step`` ahead predictions are included in the output.
120114
This means that we cannot use this output to generate the ``initial_state`` for subsequent forward-passes. Set
121115
to True to allow this -- False by default to reduce memory.
122-
:param last_measured_per_group: This provides a method to reduce unused computations in training. On each call
123-
to forward in training, you can supply to this argument a tensor indicating the last measured timestep for
124-
each group in the batch (this can be computed with ``last_measured_per_group=batch.get_durations()``, where
125-
``batch`` is a :class:`TimeSeriesDataset`). In this case, predictions will not be generated after the
126-
specified timestep for each group; these can be discarded in training because, without any measurements, they
127-
wouldn't have been used in loss calculations anyways. Naturally this should never be set for
128-
inference/forecasting. This will automatically be set when calling ``fit()``, but if you're instread using a
129-
custom training loop, you can pass this manually.
130116
:param simulate: If specified, will generate `simulate` samples from the model.
131117
:param prediction_kwargs: A dictionary of kwargs to pass to initialize ``Predictions()``.
132118
:param kwargs: Further arguments passed to the `processes`. For example, the :class:`.LinearModel` expects an
@@ -176,8 +162,14 @@ def forward(self,
176162
out_timesteps=out_timesteps
177163
)
178164

165+
# used by fit() to reduce unneeded computations:
166+
last_measured_per_group = kwargs.pop('last_measured_per_group', None)
179167
if last_measured_per_group is None:
180168
last_measured_per_group = torch.full((num_groups,), out_timesteps, dtype=torch.int, device=meanu.device)
169+
nan_groups = kwargs.pop('nan_groups', None)
170+
if nan_groups is None:
171+
nan_groups = [None] * out_timesteps
172+
# /
181173

182174
# todo: update Covariance class to make this less hacky:
183175
mcov_kwargs = {}
@@ -242,6 +234,7 @@ def forward(self,
242234
measured_mean=measured_mean,
243235
measure_mat=measure_mat,
244236
measure_cov=measure_covs[t],
237+
nan_groups=nan_groups[t],
245238
**{k: v[t] for k, v in update_kwargs.items()}
246239
)
247240
if self.adaptive_measure_var and t < len(measure_covs) - 1:
@@ -302,7 +295,13 @@ def forward(self,
302295
device=meanu.device,
303296
dtype=meanu.dtype
304297
)
305-
preds = self._generate_predictions(preds, updates, measure_covs, measurement_model, **prediction_kwargs)
298+
preds = self._generate_predictions(
299+
preds=preds,
300+
updates=updates,
301+
measure_covs=measure_covs,
302+
measurement_model=measurement_model,
303+
**prediction_kwargs
304+
)
306305
return preds.set_metadata(
307306
start_offsets=start_offsets if start_offsets is not None else np.zeros(num_groups, dtype='int'),
308307
dt_unit=self.dt_unit
@@ -357,9 +356,6 @@ def fit(self,
357356
if set_initial_values:
358357
self._set_initial_values(y, verbose=verbose > 1, **kwargs)
359358

360-
if not get_loss:
361-
get_loss = lambda _pred, _y: -_pred.log_prob(_y).mean()
362-
363359
_deprecated = {k: kwargs.pop(k) for k in ['tol', 'patience', 'max_iter'] if k in kwargs}
364360
_dmsg = f"The following are deprecated, use `stopping` arg instead:\n{set(_deprecated)}"
365361
if stopping is None:
@@ -381,6 +377,11 @@ def fit(self,
381377

382378
kwargs = self._prepare_fit_kwargs(y, **kwargs)
383379

380+
if get_loss is None:
381+
# precompute nan-groups instead of doing it on each call to log_prob:
382+
nan_groups_flat = get_nan_groups(torch.isnan(y).reshape(-1, y.shape[-1]))
383+
get_loss = lambda _pred, _y: -_pred.log_prob(_y, nan_groups_flat=nan_groups_flat).mean()
384+
384385
closure = _OptimizerClosure(
385386
ss_model=self,
386387
y=y,
@@ -415,21 +416,25 @@ def is_nonlinear(self) -> bool:
415416
return any(not p.linear_measurement for p in self.processes.values()) or self.measure_funs
416417

417418
def _prepare_fit_kwargs(self, y: torch.Tensor, **kwargs) -> dict:
418-
mc_samples = kwargs.pop('mc_samples', None)
419+
# precompute nan-groups for forward pass
420+
isnan = torch.isnan(y)
421+
kwargs['nan_groups'] = [get_nan_groups(isnan_t) for isnan_t in isnan.unbind(1)]
419422

423+
#
424+
prediction_kwargs = kwargs.pop('prediction_kwargs', None) or {}
425+
# monte-carlo for Predictions.log_prob:
426+
mc_samples = kwargs.pop('mc_samples', None)
420427
if self.is_nonlinear and not mc_samples:
421428
raise ValueError("Nonlinear state-space models require `mc_samples` to be set.")
422-
423429
if mc_samples:
424-
prediction_kwargs = kwargs.pop('prediction_kwargs', None) or {}
425430
if 'mc_white_noise' not in prediction_kwargs:
426431
emmat_rank = MeasurementModel.get_extended_mmat_rank(self.processes.values(), self.measures)
427432
prediction_kwargs['mc_white_noise'] = torch.randn(
428433
(mc_samples, emmat_rank),
429434
device=y.device,
430435
dtype=y.dtype
431436
)
432-
kwargs['prediction_kwargs'] = prediction_kwargs
437+
kwargs['prediction_kwargs'] = prediction_kwargs
433438

434439
# see `last_measured_per_group` in forward docstring
435440
# todo: duplicate code in ``TimeSeriesDataset.get_durations()``
@@ -447,6 +452,7 @@ def _generate_predictions(self,
447452
updates: Optional[tuple[list[torch.Tensor], list[torch.Tensor]]],
448453
measure_covs: torch.Tensor,
449454
measurement_model: 'MeasurementModel',
455+
nan_groups: Optional[List[Sequence[tuple[torch.Tensor, Optional[torch.Tensor]]]]] = None,
450456
mc_white_noise: Optional[torch.Tensor] = None,
451457
**kwargs
452458
) -> 'Predictions':
@@ -523,41 +529,48 @@ def _update_step_with_nans(self,
523529
measured_mean: torch.Tensor,
524530
measure_mat: torch.Tensor,
525531
measure_cov: torch.Tensor,
532+
nan_groups: Optional[Sequence[tuple[torch.Tensor, Optional[torch.Tensor]]]] = None,
526533
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
527-
isnan = torch.isnan(input)
528-
if isnan.all():
529-
return mean, cov
530-
if isnan.any():
531-
new_mean = mean.clone()
532-
new_cov = cov.clone()
533-
for groups, val_idx in get_nan_groups(isnan):
534-
masked = self._mask_mats(
535-
groups,
536-
val_idx,
534+
if nan_groups is None:
535+
nan_groups = get_nan_groups(torch.isnan(input))
536+
if len(nan_groups) == 1:
537+
group_idx, val_idx = nan_groups[0]
538+
if len(group_idx) == len(input) and val_idx is None:
539+
# no nans, no masking:
540+
return self._update_step(
537541
input=input,
542+
mean=mean,
543+
cov=cov,
538544
measured_mean=measured_mean,
539545
measure_mat=measure_mat,
540546
measure_cov=measure_cov,
541547
**kwargs
542548
)
543-
new_mean[groups], new_cov[groups] = self._update_step(
544-
mean=mean[groups],
545-
cov=cov[groups],
546-
**masked,
547-
**{k: v for k, v in kwargs.items() if k not in masked}
548-
)
549-
return new_mean, new_cov
550-
else:
551-
return self._update_step(
549+
elif not len(nan_groups):
550+
# all nans, nothing to do:
551+
return mean, cov
552+
553+
new_mean = mean.clone()
554+
new_cov = cov.clone()
555+
for groups, val_idx in nan_groups:
556+
masked = self._mask_mats(
557+
groups,
558+
val_idx,
552559
input=input,
553-
mean=mean,
554-
cov=cov,
555560
measured_mean=measured_mean,
556561
measure_mat=measure_mat,
557562
measure_cov=measure_cov,
558563
**kwargs
559564
)
560565

566+
new_mean[groups], new_cov[groups] = self._update_step(
567+
mean=mean[groups],
568+
cov=cov[groups],
569+
**masked,
570+
**{k: v for k, v in kwargs.items() if k not in masked}
571+
)
572+
return new_mean, new_cov
573+
561574
def _mask_mats(self,
562575
groups: torch.Tensor,
563576
val_idx: Optional[torch.Tensor],
@@ -677,14 +690,15 @@ def state_rank(self) -> int:
677690

678691
def _get_measure_scaling(self) -> torch.Tensor:
679692
mcov = self.measure_covariance({}, num_groups=1, num_times=1, _ignore_input=True)[0, 0]
680-
measure_var = mcov.diagonal(dim1=-2, dim2=-1).unbind()
693+
measure_var = list(mcov.diagonal(dim1=-2, dim2=-1).unbind())
694+
for idx in self.measure_covariance.empty_idx:
695+
measure_var[idx] = torch.ones_like(measure_var[idx]) # empty measures have no variance, so set to 1
681696

682697
multi = [
683698
measure_var[self.measures.index(process.measure)].expand(process.rank).sqrt()
684699
for process in self.processes.values()
685700
]
686-
for idx in self.measure_covariance.empty_idx:
687-
multi[idx] = torch.ones_like(multi[idx]) # empty measures have no variance, so set to 1
701+
688702
multi = torch.cat(multi)
689703
if (multi <= 0).any():
690704
raise RuntimeError(f"measure-cov diagonal is not positive:{measure_var}")
@@ -787,19 +801,12 @@ def simulate(self,
787801
)
788802

789803

804+
def default_get_loss(pred: 'Predictions', y: torch.Tensor, **kwargs) -> torch.Tensor:
805+
return -pred.log_prob(y, **kwargs).mean()
806+
807+
790808
class _OptimizerClosure:
791-
"""
792-
closure = _OptimizerClosure(
793-
ss_model=self,
794-
y=y,
795-
get_loss=get_loss,
796-
prog=prog,
797-
callable_kwargs=callable_kwargs,
798-
optimizer=optimizer,
799-
stopping=stopping,
800-
kwargs=kwargs,
801-
)
802-
"""
809+
803810
def __init__(self,
804811
ss_model: StateSpaceModel,
805812
y: torch.Tensor,

0 commit comments

Comments
 (0)