Skip to content

Commit

Permalink
added note and assert that sbvm conc < 10k (#3412)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning authored Dec 4, 2024
1 parent b34ba6c commit d6ebf5a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
12 changes: 12 additions & 0 deletions pyro/distributions/sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class SineBivariateVonMises(TorchDistribution):
.. note:: In the context of :class:`~pyro.infer.SVI`, this distribution can be used as a likelihood but not for
latent variables.
.. note:: Normalization remains accurate up to concentrations of 10,000.
** References: **
1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)
2. Protein Bioinformatics and Mixtures of Bivariate von Mises Distributions for Angular Data,
Expand Down Expand Up @@ -108,6 +110,16 @@ def __init__(
) = broadcast_all(
phi_loc, psi_loc, phi_concentration, psi_concentration, correlation
)

max_conc = torch.maximum(
torch.max(phi_concentration), torch.max(psi_concentration)
)
assrt_hstr = (
"Normalization of SineBiviateVonMises is inaccurate for"
f"current max concentration ({max_conc} > 10,000)."
)
assert max_conc <= torch.tensor(10_000.0), assrt_hstr

self.phi_loc = phi_loc
self.psi_loc = psi_loc
self.phi_concentration = phi_concentration
Expand Down
9 changes: 8 additions & 1 deletion tests/distributions/test_sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,15 @@ def guide(data):
assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2)


@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0])
@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10_000.0, 10_001.0])
def test_sine_bivariate_von_mises_norm(conc):
if conc > 10_000.0:
try:
dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
pytest.fail()
except AssertionError:
return

dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
num_samples = 500
x = torch.linspace(-torch.pi, torch.pi, num_samples)
Expand Down

0 comments on commit d6ebf5a

Please sign in to comment.