Skip to content

Commit 0dd5348

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
derivativeGP gpu support (#444)
Summary: Pull Request resolved: #444 Add gpu support for derivative GP. I noticed that this model isn’t actually like a normal model that can show up in a live experiment with a config, but we should still make it work for GPU. I did most of that but it did require some pretty arcane shenanigans with overriding GPyTorch’s underlying handling of train_inputs. This in turn made me do some arcane mypy stuff. Differential Revision: D65515631
1 parent 1f21c9e commit 0dd5348

File tree

6 files changed

+59
-15
lines changed

6 files changed

+59
-15
lines changed

aepsych/kernels/pairwisekernel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ class PairwiseKernel(Kernel):
1616
"""
1717

1818
def __init__(
19-
self, latent_kernel: Kernel, is_partial_obs: bool=False, **kwargs
19+
self, latent_kernel: Kernel, is_partial_obs: bool = False, **kwargs
2020
) -> None:
2121
"""
22-
Args:
23-
latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP.
24-
is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False.
22+
Args:
23+
latent_kernel (Kernel): The underlying kernel used to compute the covariance for the GP.
24+
is_partial_obs (bool): If the kernel should handle partial observations. Defaults to False.
2525
"""
2626
super(PairwiseKernel, self).__init__(**kwargs)
2727

@@ -40,11 +40,11 @@ def forward(
4040
x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
4141
x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
4242
diag (bool): Should the Kernel compute the whole covariance matrix or just the diagonal? Defaults to False.
43-
43+
4444
4545
Returns:
4646
torch.Tensor (or :class:`gpytorch.lazy.LazyTensor`) : A `b x n x m` or `n x m` tensor representing
47-
the covariance matrix between `x1` and `x2`.
47+
the covariance matrix between `x1` and `x2`.
4848
The exact size depends on the kernel's evaluation mode:
4949
* `full_covar`: `n x m` or `b x n x m`
5050
* `diag`: `n` or `b x n`

aepsych/kernels/rbf_partial_grad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ def forward(
3131
self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params: Any
3232
) -> torch.Tensor:
3333
"""Computes the covariance matrix between x1 and x2 based on the RBF
34-
34+
3535
Args:
3636
x1 (torch.Tensor): A `b x n x d` or `n x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
3737
x2 (torch.Tensor): A `b x m x d` or `m x d` tensor, where `d = 2k` and `k` is the dimension of the latent space.
3838
diag (bool): Should the Kernel compute the whole covariance matrix (False) or just the diagonal (True)? Defaults to False.
3939
40-
41-
40+
41+
4242
Returns:
4343
torch.Tensor: A `b x n x m` or `n x m` tensor representing the covariance matrix between `x1` and `x2`.
4444
The exact size depends on the kernel's evaluation mode:

aepsych/models/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class AEPsychMixin(GPyTorchModel):
116116

117117
extremum_solver = "Nelder-Mead"
118118
outcome_types: List[str] = []
119-
train_inputs: Optional[Tuple[torch.Tensor]]
119+
train_inputs: Optional[Tuple[torch.Tensor, ...]]
120120
train_targets: Optional[torch.Tensor]
121121

122122
@property
@@ -393,7 +393,7 @@ def p_below_threshold(
393393

394394

395395
class AEPsychModelDeviceMixin(AEPsychMixin):
396-
_train_inputs: Optional[Tuple[torch.Tensor]]
396+
_train_inputs: Optional[Tuple[torch.Tensor, ...]]
397397
_train_targets: Optional[torch.Tensor]
398398

399399
def set_train_data(self, inputs=None, targets=None, strict=False):
@@ -423,13 +423,17 @@ def device(self) -> torch.device:
423423
return torch.device("cpu")
424424

425425
@property
426-
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:
426+
def train_inputs(self) -> Optional[Tuple[torch.Tensor, ...]]:
427427
if self._train_inputs is None:
428428
return None
429429

430430
# makes sure the tensors are on the right device, move in place
431+
_train_inputs = []
431432
for input in self._train_inputs:
432-
input.to(self.device)
433+
_train_inputs.append(input.to(self.device))
434+
435+
_tuple_inputs: Tuple[torch.Tensor, ...] = tuple(_train_inputs)
436+
self._train_inputs = _tuple_inputs
433437

434438
return self._train_inputs
435439

aepsych/models/derivative_gp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
2424

2525

26-
class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel):
26+
class MixedDerivativeVariationalGP(
27+
gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel
28+
):
2729
"""A variational GP with mixed derivative observations.
2830
2931
For more on GPs with derivative observations, see e.g. Riihimaki & Vehtari 2010.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
from aepsych import Config, SequentialStrategy
10+
from aepsych.models.derivative_gp import MixedDerivativeVariationalGP
11+
from botorch.fit import fit_gpytorch_mll
12+
from botorch.utils.testing import BotorchTestCase
13+
from gpytorch.likelihoods import BernoulliLikelihood
14+
from gpytorch.mlls.variational_elbo import VariationalELBO
15+
16+
17+
class TestDerivativeGP(BotorchTestCase):
18+
def test_MixedDerivativeVariationalGP_gpu(self):
19+
train_x = torch.cat(
20+
(torch.tensor([1.0, 2.0, 3.0, 4.0]).unsqueeze(1), torch.zeros(4, 1)), dim=1
21+
)
22+
train_y = torch.tensor([1.0, 2.0, 3.0, 4.0])
23+
m = MixedDerivativeVariationalGP(
24+
train_x=train_x,
25+
train_y=train_y,
26+
inducing_points=train_x,
27+
fixed_prior_mean=0.5,
28+
).cuda()
29+
30+
self.assertEqual(m.mean_module.constant.item(), 0.5)
31+
self.assertEqual(
32+
m.covar_module.base_kernel.raw_lengthscale.shape, torch.Size([1, 1])
33+
)
34+
mll = VariationalELBO(
35+
likelihood=BernoulliLikelihood(), model=m, num_data=train_y.numel()
36+
).cuda()
37+
mll = fit_gpytorch_mll(mll)
38+
test_x = torch.tensor([[1.0, 0], [3.0, 1.0]]).cuda()
39+
m(test_x)

tests_gpu/test_strategy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,5 @@ def test_gpu_no_model_generator_warn(self):
2727
)
2828

2929

30-
3130
if __name__ == "__main__":
3231
unittest.main()

0 commit comments

Comments
 (0)