Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #3255 (draft) #3265

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyro/distributions/score_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ def scale_and_mask(self, scale=1.0, mask=None):
:type mask: torch.BoolTensor or None
"""
log_prob = scale_and_mask(self.log_prob, scale, mask)
score_function = self.score_function # not scaled
score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point out exactly the scenario that is fixed by this one line change? IIRC, score_function would always be multiplied by another tensor that is masked, so the mask here would be redundant.

Copy link
Author

@gui11aume gui11aume Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's again for consistency. When a variable is partially observed, some parameters are created for the missing observations, and some dummy parameters are created for the non-missing ones. When has_rsample is true, the dummy parameters are not updated: they retain their initial values because they do not contribute to the gradient. When has_rsample is false, the gradient "leaks" through this line and the dummy parameters are updated during learning (but I found that inference on the non-dummy parameters was correct in the cases I checked). As above, this line does not really fix any bug, it just tries to make the behavior consistent between has_rsample = True and has_rsample = False.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining, I think I better understand now.

entropy_term = scale_and_mask(self.entropy_term, scale, mask)
return ScoreParts(log_prob, score_function, entropy_term)
2 changes: 1 addition & 1 deletion pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _compute_log_r(model_trace, guide_trace):
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
log_r_term = model_site["log_prob"]
if not model_site["is_observed"]:
if not model_site["is_observed"] and name in guide_trace.nodes:
Copy link
Member

@fritzo fritzo Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be forgetting something, but I thought Trace_ELBO requires the guide to include all model sites that are not observed. If that's the case, we wouldn't want to keep the old version where Trace_ELBO errors. Can you explain when a model site would be neither observed nor in the guide?

Copy link
Author

@gui11aume gui11aume Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you and I cannot think of useful cases of this. The point I make in the issue is that this triggers a warning when has_rsample is true and an error when it is false. I think they should both trigger a warning or both trigger an error. The suggested changes try to make the behavior consistent with the case has_rsample = True. But maybe it makes more sense to trigger an error everywhere?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for explaining, I think I better understand now.

I'd like to hear what other folks think (@martinjankowiak @eb8680 @fehiepsi). One argument for erroring more often is that there is a lot code in Pyro that tacitly assumes all sites are either observed or guided. I'm not sure what that code is, since we've only tacitly made that assumption, but it's worth thinking about: reparametrizers, Predictive, AutoGuideMessenger.

One argument for allowing "partial" guides is that it's just more general. But if we decide to support "partial" guides throughout Pyro, I think we'll need to adopt importance sampling semantics, so we'd need to replace pyro's basic Trace data structure with a weighted trace, and replace sample sets with weighted sets of samples in a number of places e.g. Predictive. This seems like a perfectly fine design choice for a PPL, but it is different from much of Pyro's existing usage, and I think we would need to make many small changes throughout the codebase including tutorials. 🤔

Copy link
Author

@gui11aume gui11aume Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. For context, this happened to me on sites that are implicitly created by Pyro in the model (and that are therefore not in the guide), and that subsequently caused a failure because they were in the case has_rsample is false. Figuring out why the code fails in this case is quite challenging.

log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
log_r.add((stacks[name], log_r_term.detach()))
return log_r
Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/trace_mean_field_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _differentiable_loss_particle(self, model_trace, guide_trace):
if model_site["type"] == "sample":
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_site["log_prob_sum"]
else:
elif name in guide_trace.nodes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: this should never happen

guide_site = guide_trace.nodes[name]
if is_validation_enabled():
check_fully_reparametrized(guide_site)
Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def _compute_elbo(model_trace, guide_trace):
# we include only downstream costs to reduce variance
# optionally include baselines to further reduce variance
for node, downstream_cost in downstream_costs.items():
if node not in guide_trace.nodes:
continue
Comment on lines +221 to +222
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: this should never happen

guide_site = guide_trace.nodes[node]
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
score_function = guide_site["score_parts"].score_function
Expand Down
97 changes: 97 additions & 0 deletions tests/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from collections import defaultdict

import numpy as np
import pytest
Expand Down Expand Up @@ -214,6 +215,102 @@ def guide(subsample):
assert_equal(actual_grads, expected_grads, prec=precision)


# Not including the unobserved site in the guide triggers a warning
# that can make the test fail if we do not deactivate UserWarning.
@pytest.mark.filterwarnings("ignore::UserWarning")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: this should never happen

@pytest.mark.parametrize(
"with_x_unobserved",
[True, False],
)
@pytest.mark.parametrize(
"mask",
[[True, True], [True, False], [False, True]],
)
@pytest.mark.parametrize(
"reparameterized,has_rsample",
[(True, None), (True, False), (True, True), (False, None)],
ids=["reparam", "reparam-False", "reparam-True", "nonreparam"],
)
@pytest.mark.parametrize(
"Elbo,local_samples",
[
(Trace_ELBO, False),
(DiffTrace_ELBO, False),
(TraceGraph_ELBO, False),
(TraceMeanField_ELBO, False),
(TraceEnum_ELBO, False),
(TraceEnum_ELBO, True),
],
)
def test_mask_gradient(
Elbo,
reparameterized,
has_rsample,
local_samples,
mask,
with_x_unobserved,
):
pyro.clear_param_store()
data = torch.tensor([-0.5, 2.0])
precision = 0.08
Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

def model(data, mask):
z = pyro.sample("z", Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", Normal(z, 1), obs=data, obs_mask=mask)

def guide(data, mask):
scale = pyro.param("scale", lambda: torch.tensor([1.0]))
loc = pyro.param("loc", lambda: torch.tensor([1.0]))
z_dist = Normal(loc, scale)
if has_rsample is not None:
z_dist.has_rsample_(has_rsample)
z = pyro.sample("z", z_dist)
if with_x_unobserved:
with pyro.plate("data", len(data)):
with pyro.poutine.mask(mask=~mask):
pyro.sample("x_unobserved", Normal(z, 1))

num_particles = 50000
accumulation = 1
if local_samples:
# One has to limit the amount of samples in this
# test because the memory footprint is large.
guide = config_enumerate(guide, num_samples=5000)
accumulation = num_particles // 5000
num_particles = 1

optim = Adam({"lr": 0.1})
elbo = Elbo(
max_plate_nesting=1, # set this to ensure rng agrees across runs
num_particles=num_particles,
vectorize_particles=True,
strict_enumeration_warning=False,
)
actual_grads = defaultdict(lambda: np.zeros(1))
for _ in range(accumulation):
inference = SVI(model, guide, optim, loss=elbo)
with xfail_if_not_implemented():
inference.loss_and_grads(model, guide, data=data, mask=torch.tensor(mask))
params = dict(pyro.get_param_store().named_parameters())
actual_grads = {
name: param.grad.detach().cpu().numpy() / accumulation
for name, param in params.items()
}

# grad(loc) = (n+1) * loc - (x1 + ... + xn)
# grad(scale) = (n+1) * scale - 1 / scale
expected_grads = {
"loc": sum(mask) + 1.0 - data[mask].sum(0, keepdim=True).numpy(),
"scale": sum(mask) + 1 - np.ones(1),
}
for name in sorted(params):
logger.info("expected {} = {}".format(name, expected_grads[name]))
logger.info("actual {} = {}".format(name, actual_grads[name]))
assert_equal(actual_grads, expected_grads, prec=precision)


@pytest.mark.parametrize(
"reparameterized", [True, False], ids=["reparam", "nonreparam"]
)
Expand Down