Skip to content

Commit

Permalink
Add a few simple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jan 3, 2025
1 parent e284c87 commit 7fd07e6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ def test_independent_reparametrization_sampler_sample_raises_for_negative_jitter
sampler.sample(tf.constant([[0.0]]), jitter=-1e-6)


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("var", [0.0, 0.1, 1.0])
def test_independent_reparametrization_sampler_sample_caps_jitter(qmc: bool, var: float) -> None:
sampler = IndependentReparametrizationSampler(
100, QuadraticMeanAndRBFKernel(kernel_amplitude=var), qmc=qmc
)

def sample_var(var: float) -> float:
return tf.math.reduce_variance(sampler.sample(tf.constant([[1.0]]), jitter=float(var)))

npt.assert_allclose(sample_var(var), sample_var(0) * 2, rtol=1e-5)
assert sample_var(0) <= sample_var(var / 2) <= sample_var(var)
npt.assert_allclose(sample_var(var), sample_var(var + 1), rtol=1e-5) # capped


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("sample_size", [0, -2])
def test_independent_reparametrization_sampler_raises_for_invalid_sample_size(
Expand Down Expand Up @@ -426,6 +441,21 @@ def test_batch_reparametrization_sampler_sample_raises_for_negative_jitter(qmc:
sampler.sample(tf.constant([[0.0]]), jitter=-1e-6)


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("var", [0.1, 1.0])
def test_batch_reparametrization_sampler_sample_caps_jitter(qmc: bool, var: float) -> None:
sampler = BatchReparametrizationSampler(
100, QuadraticMeanAndRBFKernel(kernel_amplitude=var), qmc=qmc
)

def sample_var(var: float) -> float:
return tf.math.reduce_variance(sampler.sample(tf.constant([[1.0]]), jitter=float(var)))

npt.assert_allclose(sample_var(var), sample_var(0) * 2, rtol=1e-5)
assert sample_var(0) <= sample_var(var / 2) <= sample_var(var)
npt.assert_allclose(sample_var(var), sample_var(var + 1), rtol=1e-5) # capped


@pytest.mark.parametrize("qmc", [True, False])
def test_batch_reparametrization_sampler_sample_raises_for_inconsistent_batch_size(
qmc: bool,
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/models/gpflux/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ def test_dgp_reparam_sampler_sample_is_continuous(
npt.assert_array_less(tf.abs(sampler.sample(xs + 1e-20) - sampler.sample(xs)), 1e-20)


@random_seed
def test_dgp_reparam_sampler_sample_caps_jitter() -> None:
_, model = _build_dataset_and_train_deep_gp(simple_two_layer_dgp_model)

sampler = DeepGaussianProcessReparamSampler(100, model)
xs = tf.random.uniform([100, 2], minval=-10.0, maxval=10.0, dtype=tf.float64)[:, None, :]
sample_var_0 = tf.math.reduce_variance(sampler.sample(xs, jitter=0.0))
sample_var_10 = tf.math.reduce_variance(sampler.sample(xs, jitter=10.0))
sample_var_1000 = tf.math.reduce_variance(sampler.sample(xs, jitter=1000.0))
assert sample_var_0 < sample_var_10
npt.assert_allclose(sample_var_10, sample_var_1000)


def test_dgp_reparam_sampler_sample_is_repeatable(
two_layer_model: Callable[[TensorType], DeepGP]
) -> None:
Expand Down

0 comments on commit 7fd07e6

Please sign in to comment.