diff --git a/tests/unit/models/gpflux/test_sampler.py b/tests/unit/models/gpflux/test_sampler.py index 001f1586b..032b1f491 100644 --- a/tests/unit/models/gpflux/test_sampler.py +++ b/tests/unit/models/gpflux/test_sampler.py @@ -24,7 +24,7 @@ from __future__ import annotations -from typing import Callable, Tuple +from typing import Callable, Sequence, Tuple from unittest.mock import patch import gpflow.kernels @@ -508,8 +508,15 @@ def test_dgp_decoupled_layer_update_updates( evals_1 = decoupled_layer(xs) - original_W = decoupled_layer._feature_functions.W.value().numpy() - original_b = decoupled_layer._feature_functions.b.value().numpy() + def get_values(x: tf.Variable | Sequence[tf.Variable]) -> Sequence[tf.Tensor]: + # weights and biases are either a single variable or a list of variables + if isinstance(x, tf.Variable): + x = [x] + return [x.value().numpy() for x in x] + + original_W = get_values(decoupled_layer._feature_functions.W) + original_b = get_values(decoupled_layer._feature_functions.b) + for _ in range(5): x_train = tf.random.uniform([20, 2], minval=-10.0, maxval=10.0, dtype=tf.float64) y_train = tf.random.normal([20, 1], dtype=tf.float64) @@ -522,9 +529,7 @@ def test_dgp_decoupled_layer_update_updates( npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(evals_1 - evals_new))) # Check that RFF weights change - npt.assert_array_less( - 1e-2, tf.reduce_sum(tf.abs(original_b - decoupled_layer._feature_functions.b)) - ) - npt.assert_array_less( - 1e-2, tf.reduce_sum(tf.abs(original_W - decoupled_layer._feature_functions.W)) - ) + for old_b, new_b in zip(original_b, get_values(decoupled_layer._feature_functions.b)): + npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(old_b - new_b))) + for old_W, new_W in zip(original_W, get_values(decoupled_layer._feature_functions.W)): + npt.assert_array_less(1e-2, tf.reduce_sum(tf.abs(old_W - new_W))) diff --git a/trieste/models/gpflow/sampler.py b/trieste/models/gpflow/sampler.py index f3521d695..6370bdf59 100644 --- a/trieste/models/gpflow/sampler.py +++ b/trieste/models/gpflow/sampler.py @@ -811,11 +811,11 @@ def resample(self) -> None: b.assign(self._bias_init(tf.shape(b), dtype=self._dtype)) if isinstance(self.W, tf.Variable): - self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), dtype=self._dtype)) + self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), self._dtype)) else: tf.debugging.Assert(isinstance(self.W, list), []) for W, k in zip(self.W, cycle(self.sub_kernels)): - W.assign(self._weights_init(k)(tf.shape(W), dtype=self._dtype)) + W.assign(self._weights_init(k)(tf.shape(W), self._dtype)) class ResampleableDecoupledFeatureFunctions(ResampleableRandomFourierFeatureFunctions): diff --git a/trieste/models/gpflux/sampler.py b/trieste/models/gpflux/sampler.py index b724beaf1..5d5bffdca 100644 --- a/trieste/models/gpflux/sampler.py +++ b/trieste/models/gpflux/sampler.py @@ -453,11 +453,11 @@ def resample(self) -> None: b.assign(self._bias_init(tf.shape(b), dtype=self._dtype)) if isinstance(self.W, tf.Variable): - self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), dtype=self._dtype)) + self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), self._dtype)) else: tf.debugging.Assert(isinstance(self.W, list), []) for W, k in zip(self.W, cycle(self.sub_kernels)): - W.assign(self._weights_init(k)(tf.shape(W), dtype=self._dtype)) + W.assign(self._weights_init(k)(tf.shape(W), self._dtype)) def __call__(self, x: TensorType) -> TensorType: # [N, D] -> [N, L + M] or [P, N, L + M] """