From bb6fe733bb9c723a2ff3a0e2bfecbc75e947c7cc Mon Sep 17 00:00:00 2001 From: nabr Date: Fri, 9 Feb 2024 21:53:18 +0100 Subject: [PATCH 1/2] Distributions: Disable conditioning on mutable variables if not conditioning variable --- cuqi/distribution/_distribution.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cuqi/distribution/_distribution.py b/cuqi/distribution/_distribution.py index 1296f0a68..1867c5324 100644 --- a/cuqi/distribution/_distribution.py +++ b/cuqi/distribution/_distribution.py @@ -284,6 +284,12 @@ def _condition(self, *args, **kwargs): new_dist = self._make_copy() #New cuqi distribution conditioned on the kwargs processed_kwargs = set() # Keep track of processed (unique) elements in kwargs + # Check if kwargs contain any mutable variables that are not conditioning variables + # If so we raise an error since these are not allowed to be specified. + for kw_key in kwargs.keys(): + if kw_key in mutable_vars and kw_key not in cond_vars: + raise ValueError(f"The mutable variable \"{kw_key}\" is not a conditioning variable of this distribution.") + # Go through every mutable variable and assign value from kwargs if present for var_key in mutable_vars: From bfdbfe367969d327391615ba587fff14e9273072 Mon Sep 17 00:00:00 2001 From: nabr Date: Fri, 9 Feb 2024 21:58:27 +0100 Subject: [PATCH 2/2] Fix unit tests after not allowing conditioning on mutable that are not cond vars --- tests/test_abstract_distribution_density.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_abstract_distribution_density.py b/tests/test_abstract_distribution_density.py index 6bf984efb..133842c61 100644 --- a/tests/test_abstract_distribution_density.py +++ b/tests/test_abstract_distribution_density.py @@ -53,8 +53,8 @@ def test_conditioning_on_main_parameter(): def test_conditioning_kwarg_as_mutable_var(): """ This checks if we allow kwargs for a distribution that has no conditioning variables. """ x = cuqi.distribution.Gaussian(mean=1, cov=1) - x = x(cov=2) #This should be ok and not throw an error - assert x.cov == 2 + with pytest.raises(ValueError): + x = x(cov=2) #This should raise error since no cond vars def test_conditioning_both_args_kwargs(): """ This tests that we throw error if we accidentally provide arg and kwarg for same variable. """ @@ -172,7 +172,7 @@ def test_cond_positional_and_kwargs(): """ Test conditioning for both positional and kwargs """ x = cuqi.distribution.Gaussian(cov=lambda s:s, geometry=1) - logd = x(mean=3, cov=7).logd(13) + logd = x(mean=3, s=7).logd(13) # Conditioning full positional assert x(3, 7, 13).value == logd @@ -181,10 +181,10 @@ def test_cond_positional_and_kwargs(): assert x(3)(7)(13).value == logd # Conditioning full kwargs - assert x(mean=3, cov=7, x=13).value == logd - assert x(mean=3, cov=7)(x=13).value == logd - assert x(mean=3)(cov=7, x=13).value == logd - assert x(mean=3)(cov=7)(x=13).value == logd + assert x(mean=3, s=7, x=13).value == logd + assert x(mean=3, s=7)(x=13).value == logd + assert x(mean=3)(s=7, x=13).value == logd + assert x(mean=3)(s=7)(x=13).value == logd # Conditioning partial positional assert x(3, s=7, x=13).value == logd