From e74ef39f8a2737a6e4c6aeef85da347980d30984 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 29 Jul 2024 12:11:26 +0300 Subject: [PATCH 1/4] Add option to specify support for SplitReparam. --- pyro/infer/reparam/split.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/pyro/infer/reparam/split.py b/pyro/infer/reparam/split.py index d5a389bc0e..0d902a5b48 100644 --- a/pyro/infer/reparam/split.py +++ b/pyro/infer/reparam/split.py @@ -6,10 +6,31 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine +from pyro.distributions.torch_distribution import TorchDistributionMixin from .reparam import Reparam +def same_support(fn: TorchDistributionMixin): + ''' + Returns support of the `fn` distribution. + + :param fn: distribution class + :returns: distribution support + ''' + return fn.support + + +def real_support(fn: TorchDistributionMixin): + ''' + Returns real support with same event dimension as that of the `fn` distribution. + + :param fn: distribution class + :returns: distribution support + ''' + return dist.constraints.independent(dist.constraints.real, fn.event_dim) + + class SplitReparam(Reparam): """ Reparameterizer to split a random variable along a dimension, similar to @@ -28,14 +49,20 @@ class SplitReparam(Reparam): each chunk. :type: list(int) :param int dim: Dimension along which to split. Defaults to -1. + :param callable support_fn: Function which derives the split support + from the site's sampling function. Default is :func:`same_support` + as the sampling function, but in some cases such as sampling functions + which are stacked transforms, you would have to explicitly specify + the support with :func:`real_support` """ - def __init__(self, sections, dim): + def __init__(self, sections, dim, support_fn=same_support): assert isinstance(dim, int) and dim < 0 assert isinstance(sections, list) assert all(isinstance(size, int) for size in sections) self.event_dim = -dim self.sections = sections + self.support_fn = support_fn def apply(self, msg): name = msg["name"] @@ -57,7 +84,7 @@ def apply(self, msg): event_shape = left_shape + (size,) + right_shape value_split[i] = pyro.sample( f"{name}_split_{i}", - dist.ImproperUniform(fn.support, fn.batch_shape, event_shape), + dist.ImproperUniform(self.support_fn(fn), fn.batch_shape, event_shape), obs=value_split[i], infer={"is_observed": is_observed}, ) From ba596fda5ba99f3a1aa9511f92d808a10bc89ebf Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 29 Jul 2024 12:25:08 +0300 Subject: [PATCH 2/4] Fix formatting. --- pyro/infer/reparam/split.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyro/infer/reparam/split.py b/pyro/infer/reparam/split.py index 0d902a5b48..09946d5939 100644 --- a/pyro/infer/reparam/split.py +++ b/pyro/infer/reparam/split.py @@ -12,22 +12,22 @@ def same_support(fn: TorchDistributionMixin): - ''' + """ Returns support of the `fn` distribution. :param fn: distribution class :returns: distribution support - ''' + """ return fn.support def real_support(fn: TorchDistributionMixin): - ''' + """ Returns real support with same event dimension as that of the `fn` distribution. :param fn: distribution class :returns: distribution support - ''' + """ return dist.constraints.independent(dist.constraints.real, fn.event_dim) From 60008124aa3a23cc4a1dcac95a89a922f3c3f59b Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 29 Jul 2024 17:51:19 +0300 Subject: [PATCH 3/4] Handle support of stacking and concatenation transforms in SplitReparam. --- pyro/infer/reparam/split.py | 55 +++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/pyro/infer/reparam/split.py b/pyro/infer/reparam/split.py index 09946d5939..83f2224263 100644 --- a/pyro/infer/reparam/split.py +++ b/pyro/infer/reparam/split.py @@ -11,9 +11,10 @@ from .reparam import Reparam -def same_support(fn: TorchDistributionMixin): +def same_support(fn: TorchDistributionMixin, *args): """ - Returns support of the `fn` distribution. + Returns support of the `fn` distribution. Used in :class:`SplitReparam` in + order to determine the support of the split value. :param fn: distribution class :returns: distribution support @@ -21,9 +22,10 @@ def same_support(fn: TorchDistributionMixin): return fn.support -def real_support(fn: TorchDistributionMixin): +def real_support(fn: TorchDistributionMixin, *args): """ Returns real support with same event dimension as that of the `fn` distribution. + Used in :class:`SplitReparam` in order to determine the support of the split value. :param fn: distribution class :returns: distribution support @@ -31,6 +33,34 @@ def real_support(fn: TorchDistributionMixin): return dist.constraints.independent(dist.constraints.real, fn.event_dim) +def default_support(fn: TorchDistributionMixin, slice, dim): + """ + Returns support of the `fn` distribution, corrected for split stacking and + concatenation transforms. Used in :class:`SplitReparam` in + order to determine the support of the split value. + + :param fn: distribution class + :param slice: slice for which to return support + :param dim: dimension for which to return support + :returns: distribution support + """ + support = fn.support + # Unwrap support + reinterpreted_batch_ndims_vec = [] + while isinstance(support, dist.constraints.independent): + reinterpreted_batch_ndims_vec.append(support.reinterpreted_batch_ndims) + support = support.base_constraint + # Slice concatenation and stacking transforms + if isinstance(support, dist.constraints.stack) and support.dim == dim: + support = dist.constraints.stack(support.cseq[slice], dim) + elif isinstance(support, dist.constraints.cat) and support.dim == dim: + support = dist.constraints.cat(support.cseq[slice], dim, support.lengths[slice]) + # Wrap support + for reinterpreted_batch_ndims in reinterpreted_batch_ndims_vec[::-1]: + support = dist.constraints.independent(support, reinterpreted_batch_ndims) + return support + + class SplitReparam(Reparam): """ Reparameterizer to split a random variable along a dimension, similar to @@ -50,13 +80,14 @@ class SplitReparam(Reparam): :type: list(int) :param int dim: Dimension along which to split. Defaults to -1. :param callable support_fn: Function which derives the split support - from the site's sampling function. Default is :func:`same_support` - as the sampling function, but in some cases such as sampling functions - which are stacked transforms, you would have to explicitly specify - the support with :func:`real_support` + from the site's sampling function, split size, and split dimension. + Default is :func:`default_support` which correctly handles stacking + and concatenation transforms. Other options are :func:`same_support` + which returns the same support as that of the sampling function, and + :func:`real_support` which returns a real support. """ - def __init__(self, sections, dim, support_fn=same_support): + def __init__(self, sections, dim, support_fn=default_support): assert isinstance(dim, int) and dim < 0 assert isinstance(sections, list) assert all(isinstance(size, int) for size in sections) @@ -80,14 +111,20 @@ def apply(self, msg): dim = fn.event_dim - self.event_dim left_shape = fn.event_shape[:dim] right_shape = fn.event_shape[1 + dim :] + start = 0 for i, size in enumerate(self.sections): event_shape = left_shape + (size,) + right_shape value_split[i] = pyro.sample( f"{name}_split_{i}", - dist.ImproperUniform(self.support_fn(fn), fn.batch_shape, event_shape), + dist.ImproperUniform( + self.support_fn(fn, slice(start, start + size), -self.event_dim), + fn.batch_shape, + event_shape, + ), obs=value_split[i], infer={"is_observed": is_observed}, ) + start += size # Combine parts into value. if value is None: From e46f4bb0c097e59f9150990360319a398443846b Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 30 Jul 2024 12:34:12 +0300 Subject: [PATCH 4/4] Test support of transformed distributions, with stacking and concatenation transforms, in SplitReparam. --- tests/infer/reparam/test_split.py | 81 +++++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/tests/infer/reparam/test_split.py b/tests/infer/reparam/test_split.py index 6337069ea0..0167f5778d 100644 --- a/tests/infer/reparam/test_split.py +++ b/tests/infer/reparam/test_split.py @@ -13,8 +13,7 @@ from .util import check_init_reparam - -@pytest.mark.parametrize( +event_shape_splits_dim = pytest.mark.parametrize( "event_shape,splits,dim", [ ((6,), [2, 1, 3], -1), @@ -31,7 +30,13 @@ ], ids=str, ) -@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) + + +batch_shape = pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) + + +@event_shape_splits_dim +@batch_shape def test_normal(batch_shape, event_shape, splits, dim): shape = batch_shape + event_shape loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() @@ -72,24 +77,8 @@ def model(): assert_close(actual_grads, expected_grads) -@pytest.mark.parametrize( - "event_shape,splits,dim", - [ - ((6,), [2, 1, 3], -1), - ( - ( - 2, - 5, - ), - [2, 3], - -1, - ), - ((4, 2), [1, 3], -2), - ((2, 3, 1), [1, 2], -2), - ], - ids=str, -) -@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@event_shape_splits_dim +@batch_shape def test_init(batch_shape, event_shape, splits, dim): shape = batch_shape + event_shape loc = torch.empty(shape).uniform_(-1.0, 1.0) @@ -100,3 +89,53 @@ def model(): return pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape))) check_init_reparam(model, SplitReparam(splits, dim)) + + +@batch_shape +def test_transformed_distribution(batch_shape): + num_samples = 10 + + transform = dist.transforms.StackTransform( + [ + dist.transforms.OrderedTransform(), + dist.transforms.DiscreteCosineTransform(), + dist.transforms.HaarTransform(), + ], + dim=-1, + ) + + num_transforms = len(transform.transforms) + + def model(): + scale_tril = pyro.sample("scale_tril", dist.LKJCholesky(num_transforms, 1)) + with pyro.plate_stack("plates", batch_shape): + x_dist = dist.TransformedDistribution( + dist.MultivariateNormal( + torch.zeros(num_samples, num_transforms), scale_tril=scale_tril + ).to_event(1), + [transform], + ) + return pyro.sample("x", x_dist) + + assert model().shape == batch_shape + (num_samples, num_transforms) + + pyro.clear_param_store() + guide = pyro.infer.autoguide.AutoMultivariateNormal(model) + guide_sites = guide() + + assert guide_sites["x"].shape == batch_shape + (num_samples, num_transforms) + + for sections in [[1, 1, 1], [1, 2], [2, 1]]: + split_model = pyro.poutine.reparam( + model, config={"x": SplitReparam(sections, -1)} + ) + + pyro.clear_param_store() + guide = pyro.infer.autoguide.AutoMultivariateNormal(split_model) + guide_sites = guide() + + for n, section in enumerate(sections): + assert guide_sites[f"x_split_{n}"].shape == batch_shape + ( + num_samples, + section, + )