Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling from truncated Gaussian #790

Closed
bmazoure opened this issue Feb 20, 2018 · 12 comments
Closed

Sampling from truncated Gaussian #790

bmazoure opened this issue Feb 20, 2018 · 12 comments

Comments

@bmazoure
Copy link

Is there currently a way to sample from new distributions which are similar to the ones already implemented in Pyro?

For instance, in the case of a multivariate truncated Gaussian we would need to define how to compute the gradient as a piecewise function and I am not sure where to do so.

@martinjankowiak
Copy link
Collaborator

hello, you might take a look here:

probtorch/pytorch#121

@fritzo
Copy link
Member

fritzo commented Feb 20, 2018

Hi @bmazoure you could also define a Rejector distribution for a truncated normal. This would require you to implement the total probability of acceptance log_scale as a function of your truncation plane. For example to truncate by ensuring sample[0] > min_x0, you should be able to define

class TruncatedMVN(dist.Rejector):
    def __init__(self, loc, covariance_matrix, min_x0):
        propose = dist.MultivariateNormal(loc, covariance_matrix)

        def log_prob_accept(x):
            return (x[0] > min_x0).type_as(x).log()

        scale_0 = torch.sqrt(covariance_matrix[0, 0])
        log_scale = torch.log(1 - dist.Normal(loc[0], scale_0).cdf(min_x0))
        super(TruncatedMVN, self).__init__(propose, log_prob_accept, log_scale)

(Note this is available on Pyro dev branch, but not in the 0.1.2 release)

@bmazoure
Copy link
Author

Will try these suggestions, thanks!

@dobos
Copy link

dobos commented Mar 27, 2020

I'm trying to use a rejector to limit a Pareto distribution, which is wrapped from a torch distribution:

class TruncatedPareto(Rejector):
    def __init__(self, scale, alpha, upper_limit, validate_args=None):
        propose = Pareto(scale, alpha, validate_args=validate_args)

        def log_prob_accept(x):
            return (x < upper_limit).type_as(x).log()

        log_scale = torch.Tensor(alpha) * torch.log(torch.Tensor([scale / upper_limit]))
        super(TruncatedPareto, self).__init__(propose, log_prob_accept, log_scale)

I can sample from it by calling .sample() but inside an MCMC, I get the error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-353-ebb20659df6e> in <module>
      2 kernel = NUTS(conditioned_model)
      3 mcmc = MCMC(kernel, num_samples=10, warmup_steps=2)
----> 4 mcmc.run(meta)
      5 posterior_samples = mcmc.get_samples()

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    355         z_flat_acc = [[] for _ in range(self.num_chains)]
    356         with pyro.validation_enabled(not self.disable_validation):
--> 357             for x, chain_id in self.sampler.run(*args, **kwargs):
    358                 if num_samples[chain_id] == 0:
    359                     num_samples[chain_id] += 1

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    164             logger = initialize_logger(logger, "", progress_bar)
    165             hook_w_logging = _add_logging_hook(logger, progress_bar, self.hook)
--> 166             for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
    167                                        i if self.num_chains > 1 else None,
    168                                        *args, **kwargs):

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
    108 
    109 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 110     kernel.setup(warmup_steps, *args, **kwargs)
    111     params = kernel.initial_params
    112     # yield structure (key, value.shape) of params

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    264         self._warmup_steps = warmup_steps
    265         if self.model is not None:
--> 266             self._initialize_model_properties(args, kwargs)
    267         potential_energy = self.potential_fn(self.initial_params)
    268         self._cache(self.initial_params, potential_energy, None)

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
    229 
    230     def _initialize_model_properties(self, model_args, model_kwargs):
--> 231         init_params, potential_fn, transforms, trace = initialize_model(
    232             self.model,
    233             model_args,

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains)
    371     model = poutine.enum(config_enumerate(model),
    372                          first_available_dim=-1 - max_plate_nesting)
--> 373     model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
    374     has_enumerable_sites = False
    375     prototype_samples = {}

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    183         Calls this poutine and returns its trace instead of the function's return value.
    184         """
--> 185         self(*args, **kwargs)
    186         return self.msngr.get_trace()

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError):
    167                 exc_type, exc_value, traceback = sys.exc_info()

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

<ipython-input-349-8dfe588a1e6e> in model(meta)
---> 25         M = pyro.sample('M', TruncatedPareto(0.1, xi[g], 1.5))

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    111             msg["is_observed"] = True
    112         # apply the stack and return its return value
--> 113         apply_stack(msg)
    114         return msg["value"]
    115 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    199 
    200     for frame in stack[-pointer:]:
--> 201         frame._postprocess_message(msg)
    202 
    203     cont = msg["continuation"]

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _postprocess_message(self, msg)
    139         method_name = "_pyro_post_{}".format(msg["type"])
    140         if hasattr(self, method_name):
--> 141             return getattr(self, method_name)(msg)
    142         return None
    143 

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/poutine/enum_messenger.py in _pyro_post_sample(self, msg)
    201         if value is None:
    202             return
--> 203         shape = value.shape[:value.dim() - msg["fn"].event_dim]
    204         dim_to_id = msg["infer"].setdefault("_dim_to_id", {})
    205         dim_to_id.update(self._param_dims.get(msg["name"], {}))

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/pyro/distributions/torch_distribution.py in event_dim(self)
     51         :rtype: int
     52         """
---> 53         return len(self.event_shape)
     54 
     55     def shape(self, sample_shape=torch.Size()):

/usr/local/miniconda3/envs/astro-pyro/lib/python3.8/site-packages/torch/distributions/distribution.py in event_shape(self)
     70         Returns the shape of a single sample (without batching).
     71         """
---> 72         return self._event_shape
     73 
     74     @property

AttributeError: 'TruncatedPareto' object has no attribute '_event_shape'

@fritzo
Copy link
Member

fritzo commented Mar 27, 2020

@dobos It looks like Rejector is missing a call to super().__init__(...). This is a bug, will push a fix.

@GlastonburyC
Copy link

Were truncated distributions ever added?

@fritzo
Copy link
Member

fritzo commented Jun 8, 2020

@GlastonburyC I believe @alicanb was working on TruncatedDistribution in pytorch/pytorch#32377

@zoj613
Copy link

zoj613 commented Jun 8, 2020

Is there a way to implement a Distribution for sampling from a multivariate Gaussian truncated on a hyperplane? for a example, making sure that the sampled array sums to zero? @fritzo

@fritzo
Copy link
Member

fritzo commented Jun 8, 2020

@zoj613 You could easily implement a multivariate Gaussian truncated along a single hyperplane passing through the center by generalizing FoldedDistribution to a MultivariateFoldedDistribution; however I don't know an easy way to allow multiple truncations or to allow truncation along a single hyperplane that does not pass through the distribution center.

@GlastonburyC
Copy link

@fritzo How can I sample from a normal but only have negative support?

@fritzo
Copy link
Member

fritzo commented Jul 7, 2020

@GlastonburyC For a negative Gaussian with zero mode, you could use a transformed HalfNormal

class NegativeHalfNormal(dist.TransformedDistribution):
    support = constraints.less_than(0)
    def __init__(self, scale):
        base_dist = dist.HalfNormal(scale)
        transform = dist.transforms.AffineTransform(0., -1.)
        super().__init__(base_dist, transform)

If your loc parameter is nonzero then you would need to wait for pytorch/pytorch#32377

In practice I prefer FoldedDistribution over truncated normal, as it is more numerically stable and has qualitatively similar density. You could create a negatively-supported folded-normal via

class NegativeFoldedNormal(TransformedDistribution):
    support = constraints.less_than(0)
    def __init__(self, loc, scale):
        base_dist = dist.FoldedDistribution(dist.Normal(loc, scale))
        transform = dist.transforms.AffineTransform(0., -1.)
        super().__init__(base_dist, transform)

@GlastonburyC
Copy link

Thanks @fritzo, that's super helpful of you. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

7 participants