Skip to content

Commit

Permalink
refactor: Simplify posterior shrinkage tests with pytest marks (#1197)
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls authored Jul 22, 2024
1 parent 515b40e commit ba19688
Showing 1 changed file with 58 additions and 38 deletions.
96 changes: 58 additions & 38 deletions tests/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,43 +195,63 @@ def test_mmd_squared_distance(test, sigma):
assert estimate > threshold, "Accepting 0-hypothesis even though q!=p."


def test_posterior_shrinkage():
prior_samples = np.array([2])
post_samples = np.array([3])
assert torch.isnan(posterior_shrinkage(prior_samples, post_samples)[0])

prior_samples = np.array([[1, 2], [2, 3]])
post_samples = np.array([[2, 3], [3, 4]])
expected_shrinkage = torch.tensor([0.0, 0.0])
assert torch.allclose(
posterior_shrinkage(prior_samples, post_samples), expected_shrinkage
)

prior_samples = torch.tensor([[1.0, 2.0], [2.0, 3.0]])
post_samples = torch.tensor([[2.0, 3.0], [3.0, 4.0]])
expected_shrinkage = torch.tensor([0.0, 0.0])
assert torch.allclose(
posterior_shrinkage(prior_samples, post_samples), expected_shrinkage
)

prior_samples = np.array([])
post_samples = np.array([])
with pytest.raises(ValueError):
posterior_shrinkage(prior_samples, post_samples)


def test_posterior_zscore():
true_theta = np.array([2, 3])
post_samples = np.array([[1, 2], [2, 3], [3, 4]])
expected_zscore = torch.tensor([0.0, 0.0])
assert torch.allclose(posterior_zscore(true_theta, post_samples), expected_zscore)
@pytest.mark.parametrize(
"prior_samples, post_samples, expected_shrinkage, raises_exception",
[
(np.array([2]), np.array([3]), None, False),
(
np.array([[1, 2], [2, 3]]),
np.array([[2, 3], [3, 4]]),
torch.tensor([0.0, 0.0]),
False,
),
(
torch.tensor([[1.0, 2.0], [2.0, 3.0]]),
torch.tensor([[2.0, 3.0], [3.0, 4.0]]),
torch.tensor([0.0, 0.0]),
False,
),
(np.array([]), np.array([]), None, True),
],
)
def test_posterior_shrinkage(
prior_samples, post_samples, expected_shrinkage, raises_exception
):
if raises_exception:
with pytest.raises(ValueError):
posterior_shrinkage(prior_samples, post_samples)
else:
if expected_shrinkage is not None:
assert torch.allclose(
posterior_shrinkage(prior_samples, post_samples), expected_shrinkage
)
else:
assert torch.isnan(posterior_shrinkage(prior_samples, post_samples)[0])

true_theta = torch.tensor([2.0, 3.0])
post_samples = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
expected_zscore = torch.tensor([0.0, 0.0])
assert torch.allclose(posterior_zscore(true_theta, post_samples), expected_zscore)

true_theta = np.array([])
post_samples = np.array([])
with pytest.raises(ValueError):
posterior_zscore(true_theta, post_samples)
@pytest.mark.parametrize(
"true_theta, post_samples, expected_zscore, raises_exception",
[
(
np.array([2, 3]),
np.array([[1, 2], [2, 3], [3, 4]]),
torch.tensor([0.0, 0.0]),
False,
),
(
torch.tensor([2.0, 3.0]),
torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]),
torch.tensor([0.0, 0.0]),
False,
),
(np.array([]), np.array([]), None, True),
],
)
def test_posterior_zscore(true_theta, post_samples, expected_zscore, raises_exception):
if raises_exception:
with pytest.raises(ValueError):
posterior_zscore(true_theta, post_samples)
else:
assert torch.allclose(
posterior_zscore(true_theta, post_samples), expected_zscore
)

0 comments on commit ba19688

Please sign in to comment.