Skip to content

Commit 44b5dd3

Browse files
committed
speedups and bugfixes
1 parent 3eb7965 commit 44b5dd3

File tree

9 files changed

+165
-164
lines changed

9 files changed

+165
-164
lines changed

tests/test_kalman_filter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,17 @@ def test_nans(self, ndim: int = 1, n_step: int = 1):
2828
data[2, 2, 0] = float('nan')
2929

3030
# test critical helper fun:
31-
get_nan_groups2 = torch.jit.script(get_nan_groups)
3231
nan_groups = {2}
3332
if ndim > 1:
3433
nan_groups.add(0)
3534
for t in range(ntimes):
36-
for group_idx, valid_idx in get_nan_groups2(torch.isnan(data[:, t])):
35+
for group_idx, masks in get_nan_groups(torch.isnan(data[:, t])):
3736
if t == 2:
38-
if valid_idx is None:
37+
if masks is None:
3938
self.assertEqual(len(group_idx), data.shape[0] - len(nan_groups))
4039
self.assertFalse(bool(set(group_idx.tolist()).intersection(nan_groups)))
4140
else:
41+
valid_idx, m1d, m2d = masks
4242
self.assertLess(len(valid_idx), ndim)
4343
self.assertGreater(len(valid_idx), 0)
4444
if len(valid_idx) == 1:
@@ -52,7 +52,7 @@ def test_nans(self, ndim: int = 1, n_step: int = 1):
5252
self.assertSetEqual(set(valid_idx.tolist()), {1, 2})
5353
self.assertSetEqual(set(group_idx.tolist()), {2})
5454
else:
55-
self.assertIsNone(valid_idx)
55+
self.assertIsNone(masks)
5656

5757
# test `update`
5858
# TODO: measure dim vs. state-dim

torchcast/exp_smooth/exp_smooth.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from torchcast.exp_smooth.smoothing_matrix import SmoothingMatrix
1212
from torchcast.covariance import Covariance
13-
from torchcast.internals.utils import update_tensor, get_meshgrids, transpose_last_dims
13+
from torchcast.internals.utils import update_tensor, transpose_last_dims
1414
from torchcast.process import Process
1515
from torchcast.state_space import StateSpaceModel
1616

@@ -43,14 +43,14 @@ def initial_covariance(self, inputs: dict, num_groups: int, num_times: int, _ign
4343

4444
def _mask_mats(self,
4545
groups: torch.Tensor,
46-
val_idx: Optional[torch.Tensor],
46+
masks: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
4747
**kwargs) -> dict[str, torch.Tensor]:
48-
out = super()._mask_mats(groups, val_idx, **kwargs)
49-
if val_idx is None:
48+
out = super()._mask_mats(groups, masks, **kwargs)
49+
if masks is None:
5050
return out
51-
m1d, _ = get_meshgrids(groups, val_idx)
51+
val_id, m1d, m2d = masks
5252
Kt = transpose_last_dims(kwargs['K'])
53-
out['K'] = Kt[m1d] # K is always a 2D matrix, so we can use m1d
53+
out['K'] = transpose_last_dims(Kt[m1d])
5454
return out
5555

5656
def _parse_kwargs(self,
@@ -70,13 +70,14 @@ def _parse_kwargs(self,
7070
if self.smoothing_matrix.expected_kwargs:
7171
smat_kwargs = {k: kwargs[k] for k in self.smoothing_matrix.expected_kwargs}
7272
used_keys |= set(smat_kwargs)
73-
Ks = self.smoothing_matrix(smat_kwargs, num_groups=num_groups, num_times=num_timesteps)
74-
update_kwargs['K'] = Ks.unbind(1)
75-
76-
if self.smoothing_matrix.expected_kwargs or self.measure_covariance.expected_kwargs:
73+
if smat_kwargs:
74+
Ks = self.smoothing_matrix(smat_kwargs, num_groups=num_groups, num_times=num_timesteps)
75+
update_kwargs['K'] = Ks.unbind(1)
7776
predict_kwargs['cov1step'] = Ks @ torch.stack(measure_covs, 1) @ Ks.transpose(-1, -2)
7877
else:
79-
K1 = update_kwargs['K'][0]
78+
# faster if not time-varying:
79+
K1 = self.smoothing_matrix(smat_kwargs, num_groups=num_groups, num_times=1).squeeze(1)
80+
update_kwargs['K'] = [K1] * num_timesteps
8081
measure_cov = measure_covs[0]
8182
cov1step = K1 @ measure_cov @ K1.transpose(-1, -2)
8283
predict_kwargs['cov1step'] = [cov1step] * num_timesteps

torchcast/internals/batch_design/measurement_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _adjust_measure_mat(self,
161161
# apply measure-wide adjustment
162162
measure_mat[i] = self.measure_funs[measure].adjust_measure_mat(measure_mat[i], measured_mean[i])
163163

164-
return torch.stack(measure_mat, dim=-1)
164+
return torch.stack(measure_mat, dim=-2)
165165

166166
@cached_property
167167
def measure2idx(self) -> dict[str, int]:

torchcast/internals/batch_design/transition_model.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,18 @@ def __init__(self,
2020
)
2121
self.measures = measures
2222

23-
zeros = torch.zeros(
23+
F = torch.zeros(
2424
(self.num_groups, self.num_timesteps, self.state_rank, self.state_rank),
2525
device=self.device,
2626
dtype=self.dtype
2727
)
28-
F = []
2928
for pid, process in self.processes.items():
3029
if process.linear_transition:
3130
pidx = self.process2slice[pid]
32-
# note: as in other parts, assuming autograd makes it more efficient to create clones then sum vs.
33-
# repeated masks on the same tensor. should verify that
34-
thisF = zeros.clone()
35-
thisF[:, :, pidx, pidx] = process.get_transition_matrix()
36-
F.append(thisF)
31+
F[:, :, pidx, pidx] = process.get_transition_matrix()
3732
else:
3833
raise NotImplementedError
39-
self._transition_mats = torch.stack(F, dim=0).sum(0)
34+
self._transition_mats = F
4035

4136
@cached_property
4237
def transition_mats(self) -> Sequence[torch.Tensor]:

torchcast/internals/utils.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,6 @@ def get_subclasses(cls: Type) -> Iterable[Type]:
1212
yield subclass
1313

1414

15-
@functools.lru_cache(maxsize=100)
16-
def get_meshgrids(groups: torch.Tensor,
17-
val_idx: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
18-
"""
19-
Returns meshgrids for the given groups and val_idx.
20-
"""
21-
m1d = torch.meshgrid(groups, val_idx, indexing='ij')
22-
m2d = torch.meshgrid(groups, val_idx, val_idx, indexing='ij')
23-
return m1d, m2d
24-
25-
26-
def mask_mats(groups: torch.Tensor,
27-
val_idx: Optional[torch.Tensor],
28-
mats: Sequence[tuple[str, torch.Tensor, Collection[int]]]) -> dict[str, torch.Tensor]:
29-
out = {}
30-
if val_idx is None:
31-
for nm, mat, _ in mats:
32-
out[nm] = mat[groups]
33-
else:
34-
m1d, m2d = get_meshgrids(groups, val_idx)
35-
for nm, mat, dim in mats:
36-
dim = set(dim)
37-
if dim == {-2}:
38-
mat = transpose_last_dims(mat)
39-
out[nm] = transpose_last_dims(mat[m1d])
40-
elif dim == {-1}:
41-
out[nm] = mat[m1d]
42-
elif dim == {-2, -1}:
43-
out[nm] = mat[m2d]
44-
else:
45-
raise ValueError(f"Invalid dim ({dim}), must be 0, 1, or 2")
46-
return out
47-
48-
4915
def normalize_index(index: tuple) -> tuple:
5016
# Special-case early check for the batched pattern
5117
if isinstance(index, tuple) and _is_special_batched_pattern(index):
@@ -182,30 +148,41 @@ def transpose_last_dims(x: torch.Tensor) -> torch.Tensor:
182148
return x.permute(*args)
183149

184150

185-
def get_nan_groups(isnan: torch.Tensor) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
151+
def get_nan_groups(isnan: torch.Tensor) -> List[Tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]]:
186152
"""
187153
Iterable of (group_idx, valid_idx) tuples that can be passed to torch.meshgrid. If no valid, then not returned; if
188154
all valid then (group_idx, None) is returned; can skip call to meshgrid.
189155
"""
190156
assert len(isnan.shape) == 2
191157
state_dim = isnan.shape[-1]
192-
out: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = []
158+
159+
out = []
193160
if state_dim == 1:
194161
# shortcut for univariate
195162
group_idx = (~isnan.squeeze(-1)).nonzero().view(-1)
196163
out.append((group_idx, None))
197164
return out
198-
for nan_combo in torch.unique(isnan, dim=0):
165+
166+
nan_combos = torch.unique(isnan, dim=0)
167+
if len(nan_combos) == 1 and nan_combos[0].sum() == 0:
168+
# shortcut for no nans
169+
out.append((torch.arange(isnan.shape[0]), None))
170+
return out
171+
172+
for nan_combo in nan_combos:
199173
num_nan = nan_combo.sum()
200174
if num_nan < state_dim:
201175
c1 = (isnan * nan_combo[None, :]).sum(1) == num_nan
202176
c2 = (~isnan * ~nan_combo[None, :]).sum(1) == (state_dim - num_nan)
203177
group_idx = (c1 & c2).nonzero().view(-1)
204178
if num_nan == 0:
205-
valid_idx = None
179+
out.append((group_idx, None))
206180
else:
207181
valid_idx = (~nan_combo).nonzero().view(-1)
208-
out.append((group_idx, valid_idx))
182+
m1d = torch.meshgrid(group_idx, valid_idx, indexing='ij')
183+
m2d = torch.meshgrid(group_idx, valid_idx, valid_idx, indexing='ij')
184+
masks = (valid_idx, m1d, m2d)
185+
out.append((group_idx, masks))
209186
return out
210187

211188

torchcast/kalman_filter/binomial_filter.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Sequence, TYPE_CHECKING, Optional, Union
66

77
from torchcast.covariance import Covariance
8-
from torchcast.internals.utils import get_meshgrids
98
from torchcast.kalman_filter import KalmanFilter
109
from torchcast.state_space import Predictions
1110
from torchcast.internals.batch_design import MeasurementModel, Sigmoid
@@ -101,20 +100,21 @@ def _generate_predictions(self,
101100
updates=updates,
102101
mc_white_noise=mc_white_noise,
103102
num_obs=num_obs,
104-
observed_counts=observed_counts
103+
observed_counts=observed_counts,
105104
)
106105

107106
def _mask_mats(self,
108107
groups: torch.Tensor,
109-
val_idx: Optional[torch.Tensor],
108+
masks: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
110109
binary_idx: Optional[Sequence[int]] = None,
111110
**kwargs) -> dict:
112-
out = super()._mask_mats(groups, val_idx, **kwargs)
113-
if val_idx is None or binary_idx is None:
111+
out = super()._mask_mats(groups, masks, **kwargs)
112+
if masks is None or binary_idx is None:
114113
return out
114+
val_idx = masks[0]
115115
out['binary_idx'] = [i for i in binary_idx if i in val_idx]
116116
_binary_subset_idx = torch.tensor([i1 for i1, i2 in enumerate(binary_idx) if i2 in val_idx], dtype=torch.long)
117-
m1d, _ = get_meshgrids(groups, _binary_subset_idx)
117+
m1d = torch.meshgrid(groups, _binary_subset_idx, indexing='ij')
118118
out['num_obs'] = kwargs['num_obs'][m1d]
119119
return out
120120

@@ -271,15 +271,13 @@ def __init__(self,
271271
measure_covs: Union[Sequence[torch.Tensor], torch.Tensor],
272272
num_obs: Sequence[torch.Tensor],
273273
observed_counts: Optional[bool] = None,
274-
updates: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
275-
mc_white_noise: Optional[torch.Tensor] = None):
274+
**kwargs):
276275

277276
super().__init__(
278277
measurement_model=measurement_model,
279278
states=states,
280279
measure_covs=measure_covs,
281-
updates=updates,
282-
mc_white_noise=mc_white_noise
280+
**kwargs
283281
)
284282

285283
self.observed_counts = observed_counts
@@ -381,17 +379,17 @@ def _get_posterior_predict_samples(self) -> torch.Tensor:
381379
return samples
382380

383381

384-
def main(num_groups: int = 100, num_timesteps: int = 100, bias: float = -1, prop_common: float = 1.):
385-
from torchcast.process import LocalLevel
382+
def main(num_groups: int = 50, num_timesteps: int = 365, bias: float = -1, prop_common: float = 1.):
383+
from torchcast.process import LocalLevel, Season
386384
from torchcast.utils import TimeSeriesDataset
387385
from scipy.special import expit
388386
import pandas as pd
389387
from plotnine import geom_line, aes, ggtitle
390388
torch.manual_seed(1234)
391389

392390
TOTAL_COUNT = 4
393-
measures = ['dim1', 'dim2', 'dim3']
394-
binary_measures = ['dim1']
391+
measures = ['dim1', 'dim2']
392+
binary_measures = []
395393
latent_common = torch.cumsum(.05 * torch.randn((num_groups, num_timesteps, 1)), dim=1)
396394
latent_ind = torch.cumsum(.05 * torch.randn((num_groups, num_timesteps, len(measures))), dim=1)
397395
assert 0 <= prop_common <= 1
@@ -424,22 +422,29 @@ def main(num_groups: int = 100, num_timesteps: int = 100, bias: float = -1, prop
424422
)
425423

426424
bf = BinomialFilter(
427-
processes=[LocalLevel(id=f'level_{m}', measure=m) for m in measures],
425+
processes=[LocalLevel(id=f'level_{m}', measure=m) for m in measures]
426+
+ [Season(id=f'season_{m}', measure=m, dt_unit='D', period=7, K=2) for m in measures],
428427
measures=measures,
429428
binary_measures=binary_measures,
430429
observed_counts=False
431430
)
432431

433432
y = dataset.tensors[0]
434433
bf.fit(y,
435-
stopping={'monitor_params': True},
436-
num_obs=TOTAL_COUNT,
437-
mc_samples=32)
434+
stopping={
435+
# 'max_iter': 10
436+
# 'monitor_params': True
437+
},
438+
start_offsets=dataset.start_offsets,
439+
mc_samples=32
440+
)
438441
_kwargs = {}
439-
if TOTAL_COUNT != 1:
440-
_kwargs['num_obs'] = TOTAL_COUNT
442+
# if TOTAL_COUNT != 1:
443+
# _kwargs['num_obs'] = TOTAL_COUNT
441444
preds = bf(
442-
dataset.tensors[0], **_kwargs,
445+
dataset.tensors[0],
446+
start_offsets=dataset.start_offsets,
447+
**_kwargs,
443448
)
444449
df_preds = preds.to_dataframe(dataset)
445450
if bf.observed_counts:
@@ -458,8 +463,8 @@ def main(num_groups: int = 100, num_timesteps: int = 100, bias: float = -1, prop
458463
+ geom_line(aes(y='latent'), color='purple')
459464
+ ggtitle(g)
460465
).show()
461-
# preds._white_noise = torch.zeros((1, len(binary_measures)))
462-
# print(preds.log_prob(y).mean())
466+
preds._white_noise = torch.zeros((1, len(binary_measures)))
467+
print(preds.log_prob(y).mean())
463468

464469

465470
if __name__ == '__main__':

torchcast/kalman_filter/kalman_filter.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,20 @@ def _parse_kwargs(self,
100100
)
101101

102102
# process-variance:
103+
measure_scaling = torch.diag_embed(self._get_measure_scaling().unsqueeze(0))
103104
pcov_kwargs = {}
104105
if self.process_covariance.expected_kwargs:
105106
pcov_kwargs = {k: kwargs[k] for k in self.process_covariance.expected_kwargs}
106107
used_keys |= set(pcov_kwargs)
107-
pcov_raw = self.process_covariance(pcov_kwargs, num_groups=num_groups, num_times=num_timesteps)
108-
measure_scaling = torch.diag_embed(self._get_measure_scaling().unsqueeze(0).unsqueeze(0))
109-
Qs = measure_scaling @ pcov_raw @ measure_scaling
110-
predict_kwargs['Q'] = Qs.unbind(1)
108+
if pcov_kwargs:
109+
measure_scaling = measure_scaling.unsqueeze(0)
110+
pcov_raw = self.process_covariance(pcov_kwargs, num_groups=num_groups, num_times=num_timesteps)
111+
Qs = measure_scaling @ pcov_raw @ measure_scaling
112+
predict_kwargs['Q'] = Qs.unbind(1)
113+
else:
114+
# faster if not time-varying
115+
pcov_raw = self.process_covariance(pcov_kwargs, num_groups=num_groups, num_times=1)
116+
Qs = measure_scaling @ pcov_raw.squeeze(1) @ measure_scaling
117+
predict_kwargs['Q'] = [Qs] * num_timesteps
111118

112119
return predict_kwargs, update_kwargs, used_keys

0 commit comments

Comments
 (0)