Skip to content

Commit b1c9f04

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
derivativeGP gpu support (facebookresearch#444)
Summary: Add gpu support for derivative GP. I noticed that this model isn’t actually like a normal model that can show up in a live experiment with a config, but we should still make it work for GPU. I did most of that but it did require some pretty arcane shenanigans with overriding GPyTorch’s underlying handling of train_inputs. This in turn made me do some arcane mypy stuff. Differential Revision: D65515631
1 parent 3b6bae0 commit b1c9f04

File tree

4 files changed

+53
-6
lines changed

4 files changed

+53
-6
lines changed

aepsych/means/constant_partial_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
2626
idx = input[..., -1].to(dtype=torch.long) > 0
2727
mean_fit = super(ConstantMeanPartialObsGrad, self).forward(input[..., ~idx, :])
2828
sz = mean_fit.shape[:-1] + torch.Size([input.shape[-2]])
29-
mean = torch.zeros(sz)
29+
mean = torch.zeros(sz).to(input)
3030
mean[~idx] = mean_fit
3131
return mean

aepsych/models/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class AEPsychMixin(GPyTorchModel):
116116

117117
extremum_solver = "Nelder-Mead"
118118
outcome_types: List[str] = []
119-
train_inputs: Optional[Tuple[torch.Tensor]]
119+
train_inputs: Optional[Tuple[torch.Tensor, ...]]
120120
train_targets: Optional[torch.Tensor]
121121

122122
@property
@@ -398,7 +398,7 @@ def p_below_threshold(
398398

399399

400400
class AEPsychModelDeviceMixin(AEPsychMixin):
401-
_train_inputs: Optional[Tuple[torch.Tensor]]
401+
_train_inputs: Optional[Tuple[torch.Tensor, ...]]
402402
_train_targets: Optional[torch.Tensor]
403403

404404
def set_train_data(self, inputs=None, targets=None, strict=False):
@@ -423,13 +423,17 @@ def device(self) -> torch.device:
423423
return next(self.parameters()).device
424424

425425
@property
426-
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:
426+
def train_inputs(self) -> Optional[Tuple[torch.Tensor, ...]]:
427427
if self._train_inputs is None:
428428
return None
429429

430430
# makes sure the tensors are on the right device, move in place
431+
_train_inputs = []
431432
for input in self._train_inputs:
432-
input.to(self.device)
433+
_train_inputs.append(input.to(self.device))
434+
435+
_tuple_inputs: Tuple[torch.Tensor, ...] = tuple(_train_inputs)
436+
self._train_inputs = _tuple_inputs
433437

434438
return self._train_inputs
435439

aepsych/models/derivative_gp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
1515
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
16+
from aepsych.models.base import AEPsychModelDeviceMixin
1617
from botorch.models.gpytorch import GPyTorchModel
1718
from gpytorch.distributions import MultivariateNormal
1819
from gpytorch.kernels import Kernel
@@ -22,7 +23,9 @@
2223
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
2324

2425

25-
class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, GPyTorchModel):
26+
class MixedDerivativeVariationalGP(
27+
gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel
28+
):
2629
"""A variational GP with mixed derivative observations.
2730
2831
For more on GPs with derivative observations, see e.g. Riihimaki & Vehtari 2010.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
from aepsych import Config, SequentialStrategy
10+
from aepsych.models.derivative_gp import MixedDerivativeVariationalGP
11+
from botorch.fit import fit_gpytorch_mll
12+
from botorch.utils.testing import BotorchTestCase
13+
from gpytorch.likelihoods import BernoulliLikelihood
14+
from gpytorch.mlls.variational_elbo import VariationalELBO
15+
16+
17+
class TestDerivativeGP(BotorchTestCase):
18+
19+
def test_MixedDerivativeVariationalGP_gpu(self):
20+
train_x = torch.cat(
21+
(torch.tensor([1.0, 2.0, 3.0, 4.0]).unsqueeze(1), torch.zeros(4, 1)), dim=1
22+
)
23+
train_y = torch.tensor([1.0, 2.0, 3.0, 4.0])
24+
m = MixedDerivativeVariationalGP(
25+
train_x=train_x,
26+
train_y=train_y,
27+
inducing_points=train_x,
28+
fixed_prior_mean=0.5,
29+
).cuda()
30+
31+
self.assertEqual(m.mean_module.constant.item(), 0.5)
32+
self.assertEqual(
33+
m.covar_module.base_kernel.raw_lengthscale.shape, torch.Size([1, 1])
34+
)
35+
mll = VariationalELBO(
36+
likelihood=BernoulliLikelihood(), model=m, num_data=train_y.numel()
37+
).cuda()
38+
mll = fit_gpytorch_mll(mll)
39+
test_x = torch.tensor([[1.0, 0], [3.0, 1.0]]).cuda()
40+
m(test_x)

0 commit comments

Comments
 (0)