-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #3255 (draft) #3265
base: dev
Are you sure you want to change the base?
Fix #3255 (draft) #3265
Conversation
The purpose of this pull request is to harmonize the behavior of masking between the different ways of estimating the gradient of the ELBO (most notably when I added some tests inspired from those that were already present. Just ignoring the fact that some sites of the model are not present in the guide gave the correct results right away. The main difficulties were:
I had to disable user warnings for The changes to |
@gui11aume can you please fix lint issues from |
Hi @ordabayevy! Thanks for taking the time to help me with this. There were issues in the file |
@@ -23,7 +23,7 @@ def _compute_log_r(model_trace, guide_trace): | |||
for name, model_site in model_trace.nodes.items(): | |||
if model_site["type"] == "sample": | |||
log_r_term = model_site["log_prob"] | |||
if not model_site["is_observed"]: | |||
if not model_site["is_observed"] and name in guide_trace.nodes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be forgetting something, but I thought Trace_ELBO
requires the guide to include all model sites that are not observed. If that's the case, we wouldn't want to keep the old version where Trace_ELBO
errors. Can you explain when a model site would be neither observed nor in the guide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you and I cannot think of useful cases of this. The point I make in the issue is that this triggers a warning when has_rsample
is true and an error when it is false. I think they should both trigger a warning or both trigger an error. The suggested changes try to make the behavior consistent with the case has_rsample = True
. But maybe it makes more sense to trigger an error everywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, thanks for explaining, I think I better understand now.
I'd like to hear what other folks think (@martinjankowiak @eb8680 @fehiepsi). One argument for erroring more often is that there is a lot code in Pyro that tacitly assumes all sites are either observed or guided. I'm not sure what that code is, since we've only tacitly made that assumption, but it's worth thinking about: reparametrizers, Predictive
, AutoGuideMessenger
.
One argument for allowing "partial" guides is that it's just more general. But if we decide to support "partial" guides throughout Pyro, I think we'll need to adopt importance sampling semantics, so we'd need to replace pyro's basic Trace
data structure with a weighted trace, and replace sample sets with weighted sets of samples in a number of places e.g. Predictive
. This seems like a perfectly fine design choice for a PPL, but it is different from much of Pyro's existing usage, and I think we would need to make many small changes throughout the codebase including tutorials. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. For context, this happened to me on sites that are implicitly created by Pyro in the model (and that are therefore not in the guide), and that subsequently caused a failure because they were in the case has_rsample
is false. Figuring out why the code fails in this case is quite challenging.
@@ -25,6 +25,6 @@ def scale_and_mask(self, scale=1.0, mask=None): | |||
:type mask: torch.BoolTensor or None | |||
""" | |||
log_prob = scale_and_mask(self.log_prob, scale, mask) | |||
score_function = self.score_function # not scaled | |||
score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you point out exactly the scenario that is fixed by this one line change? IIRC, score_function
would always be multiplied by another tensor that is masked, so the mask here would be redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here it's again for consistency. When a variable is partially observed, some parameters are created for the missing observations, and some dummy parameters are created for the non-missing ones. When has_rsample
is true, the dummy parameters are not updated: they retain their initial values because they do not contribute to the gradient. When has_rsample
is false, the gradient "leaks" through this line and the dummy parameters are updated during learning (but I found that inference on the non-dummy parameters was correct in the cases I checked). As above, this line does not really fix any bug, it just tries to make the behavior consistent between has_rsample = True
and has_rsample = False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining, I think I better understand now.
if node not in guide_trace.nodes: | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: this should never happen
@@ -108,7 +108,7 @@ def _differentiable_loss_particle(self, model_trace, guide_trace): | |||
if model_site["type"] == "sample": | |||
if model_site["is_observed"]: | |||
elbo_particle = elbo_particle + model_site["log_prob_sum"] | |||
else: | |||
elif name in guide_trace.nodes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: this should never happen
@@ -214,6 +215,102 @@ def guide(subsample): | |||
assert_equal(actual_grads, expected_grads, prec=precision) | |||
|
|||
|
|||
# Not including the unobserved site in the guide triggers a warning | |||
# that can make the test fail if we do not deactivate UserWarning. | |||
@pytest.mark.filterwarnings("ignore::UserWarning") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: this should never happen
No description provided.