55from typing import Sequence , TYPE_CHECKING , Optional , Union
66
77from torchcast .covariance import Covariance
8- from torchcast .internals .utils import get_meshgrids
98from torchcast .kalman_filter import KalmanFilter
109from torchcast .state_space import Predictions
1110from torchcast .internals .batch_design import MeasurementModel , Sigmoid
@@ -101,20 +100,21 @@ def _generate_predictions(self,
101100 updates = updates ,
102101 mc_white_noise = mc_white_noise ,
103102 num_obs = num_obs ,
104- observed_counts = observed_counts
103+ observed_counts = observed_counts ,
105104 )
106105
107106 def _mask_mats (self ,
108107 groups : torch .Tensor ,
109- val_idx : Optional [torch .Tensor ],
108+ masks : Optional [tuple [ torch .Tensor , torch . Tensor , torch . Tensor ] ],
110109 binary_idx : Optional [Sequence [int ]] = None ,
111110 ** kwargs ) -> dict :
112- out = super ()._mask_mats (groups , val_idx , ** kwargs )
113- if val_idx is None or binary_idx is None :
111+ out = super ()._mask_mats (groups , masks , ** kwargs )
112+ if masks is None or binary_idx is None :
114113 return out
114+ val_idx = masks [0 ]
115115 out ['binary_idx' ] = [i for i in binary_idx if i in val_idx ]
116116 _binary_subset_idx = torch .tensor ([i1 for i1 , i2 in enumerate (binary_idx ) if i2 in val_idx ], dtype = torch .long )
117- m1d , _ = get_meshgrids (groups , _binary_subset_idx )
117+ m1d = torch . meshgrid (groups , _binary_subset_idx , indexing = 'ij' )
118118 out ['num_obs' ] = kwargs ['num_obs' ][m1d ]
119119 return out
120120
@@ -271,15 +271,13 @@ def __init__(self,
271271 measure_covs : Union [Sequence [torch .Tensor ], torch .Tensor ],
272272 num_obs : Sequence [torch .Tensor ],
273273 observed_counts : Optional [bool ] = None ,
274- updates : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
275- mc_white_noise : Optional [torch .Tensor ] = None ):
274+ ** kwargs ):
276275
277276 super ().__init__ (
278277 measurement_model = measurement_model ,
279278 states = states ,
280279 measure_covs = measure_covs ,
281- updates = updates ,
282- mc_white_noise = mc_white_noise
280+ ** kwargs
283281 )
284282
285283 self .observed_counts = observed_counts
@@ -381,17 +379,17 @@ def _get_posterior_predict_samples(self) -> torch.Tensor:
381379 return samples
382380
383381
384- def main (num_groups : int = 100 , num_timesteps : int = 100 , bias : float = - 1 , prop_common : float = 1. ):
385- from torchcast .process import LocalLevel
382+ def main (num_groups : int = 50 , num_timesteps : int = 365 , bias : float = - 1 , prop_common : float = 1. ):
383+ from torchcast .process import LocalLevel , Season
386384 from torchcast .utils import TimeSeriesDataset
387385 from scipy .special import expit
388386 import pandas as pd
389387 from plotnine import geom_line , aes , ggtitle
390388 torch .manual_seed (1234 )
391389
392390 TOTAL_COUNT = 4
393- measures = ['dim1' , 'dim2' , 'dim3' ]
394- binary_measures = ['dim1' ]
391+ measures = ['dim1' , 'dim2' ]
392+ binary_measures = []
395393 latent_common = torch .cumsum (.05 * torch .randn ((num_groups , num_timesteps , 1 )), dim = 1 )
396394 latent_ind = torch .cumsum (.05 * torch .randn ((num_groups , num_timesteps , len (measures ))), dim = 1 )
397395 assert 0 <= prop_common <= 1
@@ -424,22 +422,29 @@ def main(num_groups: int = 100, num_timesteps: int = 100, bias: float = -1, prop
424422 )
425423
426424 bf = BinomialFilter (
427- processes = [LocalLevel (id = f'level_{ m } ' , measure = m ) for m in measures ],
425+ processes = [LocalLevel (id = f'level_{ m } ' , measure = m ) for m in measures ]
426+ + [Season (id = f'season_{ m } ' , measure = m , dt_unit = 'D' , period = 7 , K = 2 ) for m in measures ],
428427 measures = measures ,
429428 binary_measures = binary_measures ,
430429 observed_counts = False
431430 )
432431
433432 y = dataset .tensors [0 ]
434433 bf .fit (y ,
435- stopping = {'monitor_params' : True },
436- num_obs = TOTAL_COUNT ,
437- mc_samples = 32 )
434+ stopping = {
435+ # 'max_iter': 10
436+ # 'monitor_params': True
437+ },
438+ start_offsets = dataset .start_offsets ,
439+ mc_samples = 32
440+ )
438441 _kwargs = {}
439- if TOTAL_COUNT != 1 :
440- _kwargs ['num_obs' ] = TOTAL_COUNT
442+ # if TOTAL_COUNT != 1:
443+ # _kwargs['num_obs'] = TOTAL_COUNT
441444 preds = bf (
442- dataset .tensors [0 ], ** _kwargs ,
445+ dataset .tensors [0 ],
446+ start_offsets = dataset .start_offsets ,
447+ ** _kwargs ,
443448 )
444449 df_preds = preds .to_dataframe (dataset )
445450 if bf .observed_counts :
@@ -458,8 +463,8 @@ def main(num_groups: int = 100, num_timesteps: int = 100, bias: float = -1, prop
458463 + geom_line (aes (y = 'latent' ), color = 'purple' )
459464 + ggtitle (g )
460465 ).show ()
461- # preds._white_noise = torch.zeros((1, len(binary_measures)))
462- # print(preds.log_prob(y).mean())
466+ preds ._white_noise = torch .zeros ((1 , len (binary_measures )))
467+ print (preds .log_prob (y ).mean ())
463468
464469
465470if __name__ == '__main__' :
0 commit comments