From f88a16ea903347426aab747a0345faab52f4df25 Mon Sep 17 00:00:00 2001 From: gui11aume Date: Wed, 30 Aug 2023 18:53:22 -0400 Subject: [PATCH 1/2] Fix #3255 (draft) --- pyro/distributions/score_parts.py | 2 +- pyro/infer/trace_elbo.py | 2 +- pyro/infer/trace_mean_field_elbo.py | 2 +- pyro/infer/tracegraph_elbo.py | 2 + tests/infer/test_gradient.py | 95 ++++++++++++++++++++++++++++- 5 files changed, 99 insertions(+), 4 deletions(-) diff --git a/pyro/distributions/score_parts.py b/pyro/distributions/score_parts.py index 15d39156d7..ca4b7929a7 100644 --- a/pyro/distributions/score_parts.py +++ b/pyro/distributions/score_parts.py @@ -25,6 +25,6 @@ def scale_and_mask(self, scale=1.0, mask=None): :type mask: torch.BoolTensor or None """ log_prob = scale_and_mask(self.log_prob, scale, mask) - score_function = self.score_function # not scaled + score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled entropy_term = scale_and_mask(self.entropy_term, scale, mask) return ScoreParts(log_prob, score_function, entropy_term) diff --git a/pyro/infer/trace_elbo.py b/pyro/infer/trace_elbo.py index 93041c92cc..457e64dba8 100644 --- a/pyro/infer/trace_elbo.py +++ b/pyro/infer/trace_elbo.py @@ -23,7 +23,7 @@ def _compute_log_r(model_trace, guide_trace): for name, model_site in model_trace.nodes.items(): if model_site["type"] == "sample": log_r_term = model_site["log_prob"] - if not model_site["is_observed"]: + if not model_site["is_observed"] and name in guide_trace.nodes: log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"] log_r.add((stacks[name], log_r_term.detach())) return log_r diff --git a/pyro/infer/trace_mean_field_elbo.py b/pyro/infer/trace_mean_field_elbo.py index 5d1f38c89f..33bf17b833 100644 --- a/pyro/infer/trace_mean_field_elbo.py +++ b/pyro/infer/trace_mean_field_elbo.py @@ -108,7 +108,7 @@ def _differentiable_loss_particle(self, model_trace, guide_trace): if model_site["type"] == "sample": if model_site["is_observed"]: elbo_particle = elbo_particle + model_site["log_prob_sum"] - else: + elif name in guide_trace.nodes: guide_site = guide_trace.nodes[name] if is_validation_enabled(): check_fully_reparametrized(guide_site) diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index fe817d020f..ac00587e80 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -218,6 +218,8 @@ def _compute_elbo(model_trace, guide_trace): # we include only downstream costs to reduce variance # optionally include baselines to further reduce variance for node, downstream_cost in downstream_costs.items(): + if node not in guide_trace.nodes: + continue guide_site = guide_trace.nodes[node] downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"]) score_function = guide_site["score_parts"].score_function diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index 69501cf561..fb53507e31 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging +from collections import defaultdict import numpy as np import pytest @@ -30,7 +31,6 @@ logger = logging.getLogger(__name__) - def DiffTrace_ELBO(*args, **kwargs): return Trace_ELBO(*args, **kwargs).differentiable_loss @@ -214,6 +214,99 @@ def guide(subsample): assert_equal(actual_grads, expected_grads, prec=precision) +# Not including the unobserved site in the guide triggers a warning +# that can make the test fail if we do not deactivate UserWarning. +@pytest.mark.filterwarnings("ignore::UserWarning") +@pytest.mark.parametrize( + "with_x_unobserved", + [True, False], +) +@pytest.mark.parametrize( + "mask", + [[True, True], [True, False], [False, True]], +) +@pytest.mark.parametrize( + "reparameterized,has_rsample", + [(True, None), (True, False), (True, True), (False, None)], + ids=["reparam", "reparam-False", "reparam-True", "nonreparam"], + ) +@pytest.mark.parametrize( + "Elbo,local_samples", + [ + (Trace_ELBO, False), + (DiffTrace_ELBO, False), + (TraceGraph_ELBO, False), + (TraceMeanField_ELBO, False), + (TraceEnum_ELBO, False), + (TraceEnum_ELBO, True), + ], +) +def test_mask_gradient( + Elbo, reparameterized, has_rsample, local_samples, mask, with_x_unobserved, +): + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + precision = 0.08 + Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + + def model(data, mask): + z = pyro.sample("z", Normal(0, 1)) + with pyro.plate("data", len(data)): + pyro.sample("x", Normal(z, 1), obs=data, obs_mask=mask) + + def guide(data, mask): + scale = pyro.param("scale", lambda: torch.tensor([1.0])) + loc = pyro.param("loc", lambda: torch.tensor([1.0])) + z_dist = Normal(loc, scale) + if has_rsample is not None: + z_dist.has_rsample_(has_rsample) + z = pyro.sample("z", z_dist) + if with_x_unobserved: + with pyro.plate("data", len(data)): + with pyro.poutine.mask(mask=~mask): + pyro.sample("x_unobserved", Normal(z, 1)) + + num_particles = 50000 + accumulation = 1 + if local_samples: + # One has to limit the amount of samples in this + # test because the memory footprint is large. + guide = config_enumerate(guide, num_samples=5000) + accumulation = num_particles // 5000 + num_particles = 1 + + optim = Adam({"lr": 0.1}) + elbo = Elbo( + max_plate_nesting=1, # set this to ensure rng agrees across runs + num_particles=num_particles, + vectorize_particles=True, + strict_enumeration_warning=False, + ) + actual_grads = defaultdict(lambda: np.zeros(1)) + for _ in range(accumulation): + inference = SVI(model, guide, optim, loss=elbo) + with xfail_if_not_implemented(): + inference.loss_and_grads( + model, guide, data=data, mask=torch.tensor(mask) + ) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = { + name: param.grad.detach().cpu().numpy() / accumulation + for name, param in params.items() + } + + # grad(loc) = (n+1) * loc - (x1 + ... + xn) + # grad(scale) = (n+1) * scale - 1 / scale + expected_grads = { + "loc": sum(mask) + 1. - data[mask].sum(0, keepdim=True).numpy(), + "scale": sum(mask) + 1 - np.ones(1) + } + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + assert_equal(actual_grads, expected_grads, prec=precision) + + @pytest.mark.parametrize( "reparameterized", [True, False], ids=["reparam", "nonreparam"] ) From 930e32ac254eac0eb3927c55475335a762152d1d Mon Sep 17 00:00:00 2001 From: gui11aume Date: Wed, 4 Oct 2023 18:19:36 -0400 Subject: [PATCH 2/2] Fix linting issues --- tests/infer/test_gradient.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index fb53507e31..00d9df5218 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -31,6 +31,7 @@ logger = logging.getLogger(__name__) + def DiffTrace_ELBO(*args, **kwargs): return Trace_ELBO(*args, **kwargs).differentiable_loss @@ -229,7 +230,7 @@ def guide(subsample): "reparameterized,has_rsample", [(True, None), (True, False), (True, True), (False, None)], ids=["reparam", "reparam-False", "reparam-True", "nonreparam"], - ) +) @pytest.mark.parametrize( "Elbo,local_samples", [ @@ -242,7 +243,12 @@ def guide(subsample): ], ) def test_mask_gradient( - Elbo, reparameterized, has_rsample, local_samples, mask, with_x_unobserved, + Elbo, + reparameterized, + has_rsample, + local_samples, + mask, + with_x_unobserved, ): pyro.clear_param_store() data = torch.tensor([-0.5, 2.0]) @@ -286,9 +292,7 @@ def guide(data, mask): for _ in range(accumulation): inference = SVI(model, guide, optim, loss=elbo) with xfail_if_not_implemented(): - inference.loss_and_grads( - model, guide, data=data, mask=torch.tensor(mask) - ) + inference.loss_and_grads(model, guide, data=data, mask=torch.tensor(mask)) params = dict(pyro.get_param_store().named_parameters()) actual_grads = { name: param.grad.detach().cpu().numpy() / accumulation @@ -298,8 +302,8 @@ def guide(data, mask): # grad(loc) = (n+1) * loc - (x1 + ... + xn) # grad(scale) = (n+1) * scale - 1 / scale expected_grads = { - "loc": sum(mask) + 1. - data[mask].sum(0, keepdim=True).numpy(), - "scale": sum(mask) + 1 - np.ones(1) + "loc": sum(mask) + 1.0 - data[mask].sum(0, keepdim=True).numpy(), + "scale": sum(mask) + 1 - np.ones(1), } for name in sorted(params): logger.info("expected {} = {}".format(name, expected_grads[name]))