88
99from torchcast .internals .batch_design import TransitionModel , MeasurementModel , MeasureFun
1010from torchcast .internals .hessian import hessian
11- from torchcast .internals .utils import repeat , true1d_idx , get_nan_groups , mask_mats , get_meshgrids
11+ from torchcast .internals .utils import repeat , true1d_idx , get_nan_groups , get_meshgrids
1212from torchcast .covariance import Covariance
1313from torchcast .state_space .predictions import Predictions
1414from torchcast .process .regression import Process
@@ -74,11 +74,6 @@ def __init__(self,
7474 else :
7575 self .dt_unit = process .dt_unit
7676
77- # @property
78- # def is_nonlinear(self) -> bool:
79- # return any(not p.linear_measurement for p in self.processes.values()) or self.measure_funs
80-
81- @torch .jit .ignore ()
8277 def forward (self ,
8378 y : Optional [torch .Tensor ] = None ,
8479 n_step : Union [int , float ] = 1 ,
@@ -88,7 +83,6 @@ def forward(self,
8883 every_step : bool = True ,
8984 include_updates_in_output : bool = False ,
9085 simulate : Optional [int ] = None ,
91- last_measured_per_group : Optional [torch .Tensor ] = None ,
9286 prediction_kwargs : Optional [dict ] = None ,
9387 ** kwargs ) -> 'Predictions' :
9488 """
@@ -119,14 +113,6 @@ def forward(self,
119113 :param include_updates_in_output: If False, only the ``n_step`` ahead predictions are included in the output.
120114 This means that we cannot use this output to generate the ``initial_state`` for subsequent forward-passes. Set
121115 to True to allow this -- False by default to reduce memory.
122- :param last_measured_per_group: This provides a method to reduce unused computations in training. On each call
123- to forward in training, you can supply to this argument a tensor indicating the last measured timestep for
124- each group in the batch (this can be computed with ``last_measured_per_group=batch.get_durations()``, where
125- ``batch`` is a :class:`TimeSeriesDataset`). In this case, predictions will not be generated after the
126- specified timestep for each group; these can be discarded in training because, without any measurements, they
127- wouldn't have been used in loss calculations anyways. Naturally this should never be set for
128- inference/forecasting. This will automatically be set when calling ``fit()``, but if you're instread using a
129- custom training loop, you can pass this manually.
130116 :param simulate: If specified, will generate `simulate` samples from the model.
131117 :param prediction_kwargs: A dictionary of kwargs to pass to initialize ``Predictions()``.
132118 :param kwargs: Further arguments passed to the `processes`. For example, the :class:`.LinearModel` expects an
@@ -176,8 +162,14 @@ def forward(self,
176162 out_timesteps = out_timesteps
177163 )
178164
165+ # used by fit() to reduce unneeded computations:
166+ last_measured_per_group = kwargs .pop ('last_measured_per_group' , None )
179167 if last_measured_per_group is None :
180168 last_measured_per_group = torch .full ((num_groups ,), out_timesteps , dtype = torch .int , device = meanu .device )
169+ nan_groups = kwargs .pop ('nan_groups' , None )
170+ if nan_groups is None :
171+ nan_groups = [None ] * out_timesteps
172+ # /
181173
182174 # todo: update Covariance class to make this less hacky:
183175 mcov_kwargs = {}
@@ -242,6 +234,7 @@ def forward(self,
242234 measured_mean = measured_mean ,
243235 measure_mat = measure_mat ,
244236 measure_cov = measure_covs [t ],
237+ nan_groups = nan_groups [t ],
245238 ** {k : v [t ] for k , v in update_kwargs .items ()}
246239 )
247240 if self .adaptive_measure_var and t < len (measure_covs ) - 1 :
@@ -302,7 +295,13 @@ def forward(self,
302295 device = meanu .device ,
303296 dtype = meanu .dtype
304297 )
305- preds = self ._generate_predictions (preds , updates , measure_covs , measurement_model , ** prediction_kwargs )
298+ preds = self ._generate_predictions (
299+ preds = preds ,
300+ updates = updates ,
301+ measure_covs = measure_covs ,
302+ measurement_model = measurement_model ,
303+ ** prediction_kwargs
304+ )
306305 return preds .set_metadata (
307306 start_offsets = start_offsets if start_offsets is not None else np .zeros (num_groups , dtype = 'int' ),
308307 dt_unit = self .dt_unit
@@ -357,9 +356,6 @@ def fit(self,
357356 if set_initial_values :
358357 self ._set_initial_values (y , verbose = verbose > 1 , ** kwargs )
359358
360- if not get_loss :
361- get_loss = lambda _pred , _y : - _pred .log_prob (_y ).mean ()
362-
363359 _deprecated = {k : kwargs .pop (k ) for k in ['tol' , 'patience' , 'max_iter' ] if k in kwargs }
364360 _dmsg = f"The following are deprecated, use `stopping` arg instead:\n { set (_deprecated )} "
365361 if stopping is None :
@@ -381,6 +377,11 @@ def fit(self,
381377
382378 kwargs = self ._prepare_fit_kwargs (y , ** kwargs )
383379
380+ if get_loss is None :
381+ # precompute nan-groups instead of doing it on each call to log_prob:
382+ nan_groups_flat = get_nan_groups (torch .isnan (y ).reshape (- 1 , y .shape [- 1 ]))
383+ get_loss = lambda _pred , _y : - _pred .log_prob (_y , nan_groups_flat = nan_groups_flat ).mean ()
384+
384385 closure = _OptimizerClosure (
385386 ss_model = self ,
386387 y = y ,
@@ -415,21 +416,25 @@ def is_nonlinear(self) -> bool:
415416 return any (not p .linear_measurement for p in self .processes .values ()) or self .measure_funs
416417
417418 def _prepare_fit_kwargs (self , y : torch .Tensor , ** kwargs ) -> dict :
418- mc_samples = kwargs .pop ('mc_samples' , None )
419+ # precompute nan-groups for forward pass
420+ isnan = torch .isnan (y )
421+ kwargs ['nan_groups' ] = [get_nan_groups (isnan_t ) for isnan_t in isnan .unbind (1 )]
419422
423+ #
424+ prediction_kwargs = kwargs .pop ('prediction_kwargs' , None ) or {}
425+ # monte-carlo for Predictions.log_prob:
426+ mc_samples = kwargs .pop ('mc_samples' , None )
420427 if self .is_nonlinear and not mc_samples :
421428 raise ValueError ("Nonlinear state-space models require `mc_samples` to be set." )
422-
423429 if mc_samples :
424- prediction_kwargs = kwargs .pop ('prediction_kwargs' , None ) or {}
425430 if 'mc_white_noise' not in prediction_kwargs :
426431 emmat_rank = MeasurementModel .get_extended_mmat_rank (self .processes .values (), self .measures )
427432 prediction_kwargs ['mc_white_noise' ] = torch .randn (
428433 (mc_samples , emmat_rank ),
429434 device = y .device ,
430435 dtype = y .dtype
431436 )
432- kwargs ['prediction_kwargs' ] = prediction_kwargs
437+ kwargs ['prediction_kwargs' ] = prediction_kwargs
433438
434439 # see `last_measured_per_group` in forward docstring
435440 # todo: duplicate code in ``TimeSeriesDataset.get_durations()``
@@ -447,6 +452,7 @@ def _generate_predictions(self,
447452 updates : Optional [tuple [list [torch .Tensor ], list [torch .Tensor ]]],
448453 measure_covs : torch .Tensor ,
449454 measurement_model : 'MeasurementModel' ,
455+ nan_groups : Optional [List [Sequence [tuple [torch .Tensor , Optional [torch .Tensor ]]]]] = None ,
450456 mc_white_noise : Optional [torch .Tensor ] = None ,
451457 ** kwargs
452458 ) -> 'Predictions' :
@@ -523,41 +529,48 @@ def _update_step_with_nans(self,
523529 measured_mean : torch .Tensor ,
524530 measure_mat : torch .Tensor ,
525531 measure_cov : torch .Tensor ,
532+ nan_groups : Optional [Sequence [tuple [torch .Tensor , Optional [torch .Tensor ]]]] = None ,
526533 ** kwargs ) -> tuple [torch .Tensor , torch .Tensor ]:
527- isnan = torch .isnan (input )
528- if isnan .all ():
529- return mean , cov
530- if isnan .any ():
531- new_mean = mean .clone ()
532- new_cov = cov .clone ()
533- for groups , val_idx in get_nan_groups (isnan ):
534- masked = self ._mask_mats (
535- groups ,
536- val_idx ,
534+ if nan_groups is None :
535+ nan_groups = get_nan_groups (torch .isnan (input ))
536+ if len (nan_groups ) == 1 :
537+ group_idx , val_idx = nan_groups [0 ]
538+ if len (group_idx ) == len (input ) and val_idx is None :
539+ # no nans, no masking:
540+ return self ._update_step (
537541 input = input ,
542+ mean = mean ,
543+ cov = cov ,
538544 measured_mean = measured_mean ,
539545 measure_mat = measure_mat ,
540546 measure_cov = measure_cov ,
541547 ** kwargs
542548 )
543- new_mean [groups ], new_cov [groups ] = self ._update_step (
544- mean = mean [groups ],
545- cov = cov [groups ],
546- ** masked ,
547- ** {k : v for k , v in kwargs .items () if k not in masked }
548- )
549- return new_mean , new_cov
550- else :
551- return self ._update_step (
549+ elif not len (nan_groups ):
550+ # all nans, nothing to do:
551+ return mean , cov
552+
553+ new_mean = mean .clone ()
554+ new_cov = cov .clone ()
555+ for groups , val_idx in nan_groups :
556+ masked = self ._mask_mats (
557+ groups ,
558+ val_idx ,
552559 input = input ,
553- mean = mean ,
554- cov = cov ,
555560 measured_mean = measured_mean ,
556561 measure_mat = measure_mat ,
557562 measure_cov = measure_cov ,
558563 ** kwargs
559564 )
560565
566+ new_mean [groups ], new_cov [groups ] = self ._update_step (
567+ mean = mean [groups ],
568+ cov = cov [groups ],
569+ ** masked ,
570+ ** {k : v for k , v in kwargs .items () if k not in masked }
571+ )
572+ return new_mean , new_cov
573+
561574 def _mask_mats (self ,
562575 groups : torch .Tensor ,
563576 val_idx : Optional [torch .Tensor ],
@@ -677,14 +690,15 @@ def state_rank(self) -> int:
677690
678691 def _get_measure_scaling (self ) -> torch .Tensor :
679692 mcov = self .measure_covariance ({}, num_groups = 1 , num_times = 1 , _ignore_input = True )[0 , 0 ]
680- measure_var = mcov .diagonal (dim1 = - 2 , dim2 = - 1 ).unbind ()
693+ measure_var = list (mcov .diagonal (dim1 = - 2 , dim2 = - 1 ).unbind ())
694+ for idx in self .measure_covariance .empty_idx :
695+ measure_var [idx ] = torch .ones_like (measure_var [idx ]) # empty measures have no variance, so set to 1
681696
682697 multi = [
683698 measure_var [self .measures .index (process .measure )].expand (process .rank ).sqrt ()
684699 for process in self .processes .values ()
685700 ]
686- for idx in self .measure_covariance .empty_idx :
687- multi [idx ] = torch .ones_like (multi [idx ]) # empty measures have no variance, so set to 1
701+
688702 multi = torch .cat (multi )
689703 if (multi <= 0 ).any ():
690704 raise RuntimeError (f"measure-cov diagonal is not positive:{ measure_var } " )
@@ -787,19 +801,12 @@ def simulate(self,
787801 )
788802
789803
804+ def default_get_loss (pred : 'Predictions' , y : torch .Tensor , ** kwargs ) -> torch .Tensor :
805+ return - pred .log_prob (y , ** kwargs ).mean ()
806+
807+
790808class _OptimizerClosure :
791- """
792- closure = _OptimizerClosure(
793- ss_model=self,
794- y=y,
795- get_loss=get_loss,
796- prog=prog,
797- callable_kwargs=callable_kwargs,
798- optimizer=optimizer,
799- stopping=stopping,
800- kwargs=kwargs,
801- )
802- """
809+
803810 def __init__ (self ,
804811 ss_model : StateSpaceModel ,
805812 y : torch .Tensor ,
0 commit comments