Skip to content

Commit

Permalink
Fix models not rendering after application of the equalize effect han…
Browse files Browse the repository at this point in the history
…dler (#3387)

* Fix models not rendering after application of the equalize effect handler.

* Ignore newly added mypy callability check.

---------

Co-authored-by: Ben Zickel <[email protected]>
  • Loading branch information
BenZickel and Ben Zickel authored Jul 23, 2024
1 parent b450623 commit e3091e3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __get__(
if name not in obj.__dict__["_pyro_params"]:
init_value, constraint, event_dim = self
# bind method's self arg
init_value = functools.partial(init_value, obj) # type: ignore[arg-type]
init_value = functools.partial(init_value, obj) # type: ignore[arg-type,misc,operator]
setattr(obj, name, PyroParam(init_value, constraint, event_dim))
value: PyroParam = obj.__getattr__(name)
return value
Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/equalize_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,6 @@ def _process_message(self, msg: Message) -> None:
if self.value is not None and self._is_matching(msg): # type: ignore[unreachable]
msg["value"] = self.value # type: ignore[unreachable]
if msg["type"] == "sample":
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim)
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False)
msg["infer"] = {"_deterministic": True}
msg["is_observed"] = True
6 changes: 6 additions & 0 deletions tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,12 @@ def test_param_equalization(self):
assert_equal(tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"])
assert_not_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"])

def test_render_model(self):
pyro.set_rng_seed(20240616)
pyro.clear_param_store()
model = poutine.equalize(self.model, ".+_std")
pyro.render_model(model)


@pytest.mark.parametrize("first_available_dim", [-1, -2, -3])
@pytest.mark.parametrize("depth", [0, 1, 2])
Expand Down

0 comments on commit e3091e3

Please sign in to comment.