Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO #2893

Draft
wants to merge 62 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
b00948c
trace_elbo
Jul 3, 2021
da2f887
lint
Jul 3, 2021
f6c95e4
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Jul 5, 2021
3ec076d
test_gradient
Jul 5, 2021
a22ff4e
copy traceenum_elbo and add test model with poisson dist
Jul 16, 2021
d551fa2
lint
Jul 16, 2021
b68bb3f
use constant funsor
Jul 21, 2021
bfb13bf
working version
Jul 28, 2021
ca1a1fe
pass second test
Jul 28, 2021
6d6a9ed
clean up trace_elbo
Jul 29, 2021
0f23b42
add another test
Aug 8, 2021
91384ed
lazy eval
Aug 20, 2021
c18a8bd
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Sep 18, 2021
34d9a3c
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Sep 30, 2021
b0182c0
vectorize particles; update tests
Sep 30, 2021
dc31767
minor fixes; pin to funsor@normalize-logaddexp
Sep 30, 2021
5c0fe75
update docs/requirements
Sep 30, 2021
2b15fe1
combine Trace_ELBO and TraceEnum_ELBO
Sep 30, 2021
351090b
eager evaluation
Oct 1, 2021
7d029c7
rm file
Oct 1, 2021
1bb7380
lazy
Oct 1, 2021
42ad4fa
remove memoize
Oct 1, 2021
5b6afdb
merge TraceEnum_ELBO
Oct 10, 2021
33628aa
skip test
Oct 11, 2021
18a973b
fixes
Oct 12, 2021
2c3ead3
convert Tensor to Categorical
Oct 12, 2021
5fb1522
restore docs/requirements.txt
Oct 12, 2021
f907f93
pin funsor in docs/requirements
Oct 12, 2021
902e445
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Oct 12, 2021
0042f85
use funsor.optimizer.apply_optimizer; higher precision in the test
Oct 12, 2021
ee5a5ad
pin funsor to the latest commit
Oct 12, 2021
e4c6760
optimize logzq
Oct 12, 2021
aba300a
optimize logzq
Oct 13, 2021
d823153
restore TraceEnum_ELBO
Oct 13, 2021
c06e9e4
revert hmm changes
Oct 13, 2021
eee297d
_tensor_to_categorical helper function
Oct 13, 2021
d748efa
lazy to_funsor
Oct 13, 2021
a1970d6
reduce over particle_var
Oct 13, 2021
4c1ee9e
address comment in tests
Oct 13, 2021
5df30c8
import pyroapi
Oct 13, 2021
46ff6f4
compute expected grads using dice factors
Oct 14, 2021
d7ee7ee
add test with guide enumeration
Oct 15, 2021
49553c3
add two more tests
Oct 15, 2021
835f815
pin funsor
Oct 15, 2021
760eeb0
lint
Oct 15, 2021
ab3831c
remove breakpoint
Oct 15, 2021
0b46f3a
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Oct 29, 2021
b6ff8e0
Approximate(ops.sample, ...) based approach
Nov 3, 2021
b5bece7
Importance funsor based approach
Nov 4, 2021
d6e246e
fixes
Nov 4, 2021
6582d7d
Merge branch 'dev' into fix-funsor-traceelbo
Apr 6, 2022
714fd62
fix funsor model enumeration
Apr 9, 2022
2d2210e
Merge branch 'fix-model-enumeration-funsor' into fix-funsor-traceelbo
Apr 9, 2022
29bad7a
use Sampled funsor
Apr 11, 2022
9144be1
fixes
Apr 11, 2022
e4c8a47
git fixes
Apr 11, 2022
c147ad9
Merge branch 'dev' into fix-funsor-traceelbo
Apr 11, 2022
703a2fa
use Provenance funsor
Apr 11, 2022
3137b1b
clean up
Apr 12, 2022
88713f6
fixes
May 5, 2022
99a0647
Merge branch 'dev' into fix-funsor-traceelbo
May 5, 2022
14131ad
use provenance
Jun 22, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ observations>=0.1.4
opt_einsum>=2.3.2
pyro-api>=0.1.1
tqdm>=4.36
funsor[torch]
funsor[torch] @ git+https://github.com/pyro-ppl/funsor.git@sampled-funsor
3 changes: 2 additions & 1 deletion pyro/contrib/funsor/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from pyro.poutine.handlers import _make_handler

from .enum_messenger import EnumMessenger, queue # noqa: F401
from .enum_messenger import EnumMessenger, ProvenanceMessenger, queue # noqa: F401
from .named_messenger import MarkovMessenger, NamedMessenger
from .plate_messenger import PlateMessenger, VectorizedMarkovMessenger
from .replay_messenger import ReplayMessenger
Expand All @@ -26,6 +26,7 @@
MarkovMessenger,
NamedMessenger,
PlateMessenger,
ProvenanceMessenger,
ReplayMessenger,
TraceMessenger,
VectorizedMarkovMessenger,
Expand Down
57 changes: 55 additions & 2 deletions pyro/contrib/funsor/handlers/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from pyro.contrib.funsor.handlers.primitives import to_data, to_funsor
from pyro.contrib.funsor.handlers.replay_messenger import ReplayMessenger
from pyro.contrib.funsor.handlers.trace_messenger import TraceMessenger
from pyro.ops.provenance import detach_provenance, extract_provenance
from pyro.poutine.escape_messenger import EscapeMessenger
from pyro.poutine.reentrant_messenger import ReentrantMessenger
from pyro.poutine.subsample_messenger import _Subsample

funsor.set_backend("torch")
Expand Down Expand Up @@ -58,6 +60,13 @@ def _get_support_value_tensor(funsor_dist, name, **kwargs):
)


@_get_support_value.register(funsor.Provenance)
def _get_support_value_sampled(funsor_dist, name, **kwargs):
assert name in funsor_dist.inputs
value = _get_support_value(funsor_dist.term, name, **kwargs)
return funsor.Provenance(value, funsor_dist.provenance)


@_get_support_value.register(funsor.distribution.Distribution)
def _get_support_value_distribution(funsor_dist, name, expand=False):
assert name == funsor_dist.value.name
Expand Down Expand Up @@ -179,6 +188,49 @@ def enumerate_site(dist, msg):
raise ValueError("{} not valid enum strategy".format(msg))


@extract_provenance.register(funsor.Provenance)
def _extract_provenance_funsor(x):
return x.term, x.provenance


class ProvenanceMessenger(ReentrantMessenger):
"""
Adds provenance information for all sample sites that are not enumerated.
"""

def _pyro_sample(self, msg):
if (
msg["done"]
or msg["is_observed"]
or msg["infer"].get("enumerate") == "parallel"
or isinstance(msg["fn"], _Subsample)
):
return

if "funsor" not in msg:
msg["funsor"] = {}

with funsor.terms.lazy:
ordabayevy marked this conversation as resolved.
Show resolved Hide resolved
unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)(
value=msg["name"]
)
# TODO delegate to enumerate_site
log_measure = _enum_strategy_default(unsampled_log_measure, msg)
msg["funsor"]["log_measure"] = detach_provenance(log_measure)
support_value = _get_support_value(
log_measure,
msg["name"],
expand=msg["infer"].get("expand", False),
)
# TODO delegate to _get_support_value
msg["funsor"]["value"] = funsor.Provenance(
support_value,
frozenset([(msg["name"], detach_provenance(support_value))]),
)
msg["value"] = to_data(msg["funsor"]["value"])
msg["done"] = True


class EnumMessenger(NamedMessenger):
"""
This version of :class:`~EnumMessenger` uses :func:`~pyro.contrib.funsor.to_data`
Expand All @@ -200,9 +252,10 @@ def _pyro_sample(self, msg):
unsampled_log_measure = to_funsor(msg["fn"], output=funsor.Real)(
value=msg["name"]
)
msg["funsor"]["log_measure"] = enumerate_site(unsampled_log_measure, msg)
log_measure = enumerate_site(unsampled_log_measure, msg)
msg["funsor"]["log_measure"] = detach_provenance(log_measure)
msg["funsor"]["value"] = _get_support_value(
msg["funsor"]["log_measure"],
log_measure,
msg["name"],
expand=msg["infer"].get("expand", False),
)
Expand Down
7 changes: 6 additions & 1 deletion pyro/contrib/funsor/handlers/named_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from collections import OrderedDict
from contextlib import ExitStack

import funsor

from pyro.contrib.funsor.handlers.runtime import (
_DIM_STACK,
DimRequest,
Expand Down Expand Up @@ -64,7 +66,10 @@ def _pyro_to_data(msg):
name_to_dim = msg["kwargs"].setdefault("name_to_dim", OrderedDict())
dim_type = msg["kwargs"].setdefault("dim_type", DimType.LOCAL)

batch_names = tuple(funsor_value.inputs.keys())
if isinstance(funsor_value, funsor.Provenance):
batch_names = tuple(funsor_value.term.inputs.keys())
else:
batch_names = tuple(funsor_value.inputs.keys())

# interpret all names/dims as requests since we only run this function once
name_to_dim_request = name_to_dim.copy()
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/funsor/infer/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs):
log_prob = funsor.sum_product.sum_product(
sum_op,
prod_op,
terms["log_factors"] + terms["log_measures"],
terms["log_factors"] + list(terms["log_measures"].values()),
eliminate=terms["measure_vars"] | terms["plate_vars"],
plates=terms["plate_vars"],
)
Expand Down
115 changes: 95 additions & 20 deletions pyro/contrib/funsor/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import contextlib

import funsor
from funsor.sum_product import _partition

from pyro.contrib.funsor import to_data, to_funsor
from pyro.contrib.funsor.handlers import enum, plate, replay, trace
from pyro.contrib.funsor.infer import config_enumerate
from pyro.contrib.funsor.handlers import enum, plate, provenance, replay, trace
from pyro.distributions.util import copy_docs_from
from pyro.infer import Trace_ELBO as _OrigTrace_ELBO

Expand All @@ -18,32 +18,107 @@
@copy_docs_from(_OrigTrace_ELBO)
class Trace_ELBO(ELBO):
def differentiable_loss(self, model, guide, *args, **kwargs):
with enum(), plate(
size=self.num_particles
with enum(
first_available_dim=(-self.max_plate_nesting - 1)
if self.max_plate_nesting is not None
and self.max_plate_nesting != float("inf")
else None
), provenance(), plate(
name="num_particles_vectorized",
size=self.num_particles,
dim=-self.max_plate_nesting,
) if self.num_particles > 1 else contextlib.ExitStack():
guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace(
*args, **kwargs
)
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)

model_terms = terms_from_trace(model_tr)
guide_terms = terms_from_trace(guide_tr)

log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
log_factors = model_terms["log_factors"] + [
-f for f in guide_terms["log_factors"]
]
plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"]
measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"]

elbo = funsor.Integrate(
sum(log_measures, to_funsor(0.0)),
sum(log_factors, to_funsor(0.0)),
measure_vars,
particle_var = (
frozenset({"num_particles_vectorized"})
if self.num_particles > 1
else frozenset()
)
elbo = elbo.reduce(funsor.ops.add, plate_vars)
plate_vars = (
guide_terms["plate_vars"] | model_terms["plate_vars"]
) - particle_var

return -to_data(elbo)
model_measure_vars = model_terms["measure_vars"] - guide_terms["measure_vars"]
with funsor.terms.lazy:
# identify and contract out auxiliary variables in the model with partial_sum_product
contracted_factors, uncontracted_factors = [], []
for f in model_terms["log_factors"]:
if model_measure_vars.intersection(f.inputs):
contracted_factors.append(f)
else:
uncontracted_factors.append(f)
contracted_costs = []
# incorporate the effects of subsampling and handlers.scale through a common scale factor
for group_factors, group_vars in _partition(
list(model_terms["log_measures"].values()) + contracted_factors,
model_terms["measure_vars"],
):
group_factor_vars = frozenset().union(
*[f.inputs for f in group_factors]
)
group_plates = model_terms["plate_vars"] & group_factor_vars
outermost_plates = frozenset.intersection(
*(frozenset(f.inputs) & group_plates for f in group_factors)
)
elim_plates = group_plates - outermost_plates
for f in funsor.sum_product.partial_sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
group_factors,
plates=group_plates,
eliminate=group_vars | elim_plates,
):
contracted_costs.append(model_terms["scale"] * f)

# accumulate costs from model (logp) and guide (-logq)
costs = contracted_costs + uncontracted_factors # model costs: logp
costs += [-f for f in guide_terms["log_factors"]] # guide costs: -logq

# compute log_measures corresponding to each cost term
# the goal is to achieve fine-grained Rao-Blackwellization
log_measures = dict()
for cost in costs:
if cost.input_vars not in log_measures:
log_probs = [
f
for name, f in guide_terms["log_measures"].items()
if name in cost.inputs
]
log_prob = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
log_probs,
plates=plate_vars,
eliminate=(plate_vars | guide_terms["measure_vars"])
- frozenset(cost.inputs),
)
log_measures[cost.input_vars] = funsor.optimizer.apply_optimizer(
log_prob
)

with funsor.terms.lazy:
# finally, integrate out guide variables in the elbo and all plates
elbo = to_funsor(0, output=funsor.Real)
for cost in costs:
log_measure = log_measures[cost.input_vars]
measure_vars = (frozenset(cost.inputs) - plate_vars) - particle_var
elbo_term = funsor.Integrate(
log_measure,
cost,
measure_vars,
)
elbo += elbo_term.reduce(
funsor.ops.add, plate_vars & frozenset(cost.inputs)
)
# average over Monte-Carlo particles
elbo = elbo.reduce(funsor.ops.mean, particle_var)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reducing over particle_var.


return -to_data(funsor.optimizer.apply_optimizer(elbo))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using funsor.optimizer.apply_optimizer instead of .traceenum_elbo.apply_optimizer since no issues are found using it.



class JitTrace_ELBO(Jit_ELBO, Trace_ELBO):
Expand Down
12 changes: 6 additions & 6 deletions pyro/contrib/funsor/infer/traceenum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def terms_from_trace(tr):
# of free variables as either product (plate) variables or sum (measure) variables
terms = {
"log_factors": [],
"log_measures": [],
"log_measures": {},
"scale": to_funsor(1.0),
"plate_vars": frozenset(),
"measure_vars": frozenset(),
Expand Down Expand Up @@ -62,7 +62,7 @@ def terms_from_trace(tr):
)
# grab the log-measure, found only at sites that are not replayed or observed
if node["funsor"].get("log_measure", None) is not None:
terms["log_measures"].append(node["funsor"]["log_measure"])
terms["log_measures"][name] = node["funsor"]["log_measure"]
# sum (measure) variables: the fresh non-plate variables at a site
terms["measure_vars"] |= (
frozenset(node["funsor"]["value"].inputs) | {name}
Expand Down Expand Up @@ -132,7 +132,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
for f in funsor.sum_product.dynamic_partial_sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
model_terms["log_measures"] + contracted_factors,
list(model_terms["log_measures"].values()) + contracted_factors,
plate_to_step=model_terms["plate_to_step"],
eliminate=model_terms["measure_vars"] | markov_dims,
)
Expand All @@ -149,7 +149,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
log_prob = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
guide_terms["log_measures"],
list(guide_terms["log_measures"].values()),
plates=plate_vars,
eliminate=(plate_vars | guide_terms["measure_vars"])
- frozenset(cost.inputs),
Expand Down Expand Up @@ -198,7 +198,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
contracted_costs = []
# incorporate the effects of subsampling and handlers.scale through a common scale factor
for group_factors, group_vars in _partition(
model_terms["log_measures"] + contracted_factors,
list(model_terms["log_measures"].values()) + contracted_factors,
model_terms["measure_vars"],
):
group_factor_vars = frozenset().union(
Expand Down Expand Up @@ -244,7 +244,7 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
logzq = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
guide_terms["log_measures"] + list(targets.values()),
list(guide_terms["log_measures"].values()) + list(targets.values()),
plates=plate_vars,
eliminate=(plate_vars | guide_terms["measure_vars"]),
)
Expand Down
4 changes: 3 additions & 1 deletion pyro/contrib/funsor/infer/tracetmc_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
model_terms = terms_from_trace(model_tr)
guide_terms = terms_from_trace(guide_tr)

log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
log_measures = list(guide_terms["log_measures"].values()) + list(
model_terms["log_measures"].values()
)
log_factors = model_terms["log_factors"] + [
-f for f in guide_terms["log_factors"]
]
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@
"horovod": ["horovod[pytorch]>=0.19"],
"funsor": [
# This must be a released version when Pyro is released.
# "funsor[torch] @ git+git://github.com/pyro-ppl/funsor.git@7bb52d0eae3046d08a20d1b288544e1a21b4f461",
"funsor[torch]==0.4.3",
"funsor[torch] @ git+https://github.com/pyro-ppl/funsor.git@sampled-funsor",
# "funsor[torch]==0.4.3",
],
},
python_requires=">=3.7",
Expand Down
Loading