-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wishart / InverseWishart / LKJ priors #1692
Comments
You actually don't need a |
I was trying to use HMC (NUTS), which I think requires sampling (I do get an NotImplementedError). On second thought, though this might belong in |
That's just for setting the initial value - you can hack around that for now by defining a dummy sample method for your prior that just returns an arbitrary appropriately shaped value that's in the prior's support. |
Got it. I just tried it with the patched method, and ran into two issues:
which makes sense. Actually modeling-wise the LKJ prior would be more useful, but I tried that and it also runs into the same error. Here's my attempt: class LKJCorr(Distribution, TorchDistributionMixin):
arg_constraints = {"n": constraints.positive_integer, "eta": constraints.positive}
support = constraints.positive_definite
_validate_args = True
def __init__(self, n, eta, validate_args=False):
if not isinstance(n, int) or n < 1:
raise ValueError("n must be a positive integer")
if isinstance(eta, Number):
eta = torch.tensor(float(eta))
self.n = torch.tensor(n, dtype=torch.long, device=eta.device)
batch_shape = eta.shape
event_shape = torch.Size([n, n])
i = torch.arange(n, dtype=eta.dtype, device=eta.device)
C = (((2 * eta.view(-1, 1) - 2 + i) * i).sum(1) * math.log(2)).view_as(eta)
C += n * torch.sum(2 * torch.lgamma(i / 2 + 1) - torch.lgamma(i + 2))
self.eta = eta
self.C = C
super(LKJCorr, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def log_prob(self, X):
if any(s != self.n for s in X.shape[-2:]):
raise ValueError("Correlation matrix is not of size n={}".format(self.n.item()))
if not _is_valid_correlation_matrix(X):
raise ValueError("Input is not a valid correlation matrix")
log_diag_sum = torch.stack([p.cholesky(upper=True).diag().log().sum() for p in X.view(-1, *X.shape[-2:])])
return self.C + (self.eta - 1) * 2 * log_diag_sum
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
return torch.eye(self.n.item()).expand(shape)
class LKJCov(Distribution, TorchDistributionMixin):
arg_constraints = {"n": constraints.positive_integer, "eta": constraints.positive}
support = constraints.positive_definite
_validate_args = True
def __init__(self, n, eta, sd_prior, validate_args=False):
correlation_prior = LKJCorr(n=n, eta=eta, validate_args=validate_args)
self.correlation_prior = correlation_prior
self.sd_prior = sd_prior
super(LKJCov, self).__init__(self.correlation_prior._batch_shape,
self.correlation_prior._event_shape,
self.correlation_prior._validate_args)
def log_prob(self, X):
marginal_var = torch.diagonal(X, dim1=-2, dim2=-1)
if not torch.all(marginal_var >= 0):
raise ValueError("Variance(s) cannot be negative")
marginal_sd = marginal_var.sqrt()
sd_diag_mat = _batch_form_diag(1 / marginal_sd)
correlations = torch.matmul(torch.matmul(sd_diag_mat, X), sd_diag_mat)
log_prob_corr = self.correlation_prior.log_prob(correlations)
log_prob_sd = self.sd_prior.log_prob(marginal_sd)
return log_prob_corr + log_prob_sd
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
return torch.eye(self.correlation_prior.n.item()).expand(shape) |
How about using LKJ prior for Cholesky? From Stan reference LKJ prior for cov will make "the code to run slower and consume more memory with more risk of numerical errors". I think that LowerCholeskyTransform has inverse method. There still needs an implementation for |
Huh, that sounded right, but I think I'm missing something. Shouldn't the domain of that transform be When that's figured out, the implementation looks like: class LKJCholesky(LKJCorr):
support = constraints.lower_cholesky
def log_prob(self, L):
log_diag_sum = torch.diagonal(L, dim1=-2, dim2=-1).log().sum(-1)
return self.C + (self.eta - 1) * 2 * log_diag_sum
class LKJCholeskyCov(Distribution, TorchDistributionMixin):
arg_constraints = {"n": constraints.positive_integer, "eta": constraints.positive}
support = constraints.lower_cholesky
def __init__(self, n, eta, sd_prior, validate_args=False):
self.correlation_prior = LKJCholesky(n=n, eta=eta, validate_args=validate_args)
self.sd_prior = sd_prior
super(LKJCholeskyCov, self).__init__(self.correlation_prior._batch_shape,
self.correlation_prior._event_shape,
self.correlation_prior._validate_args)
def log_prob(self, L):
# we essentially have the (LD^{1/2}) part of the LDL decomposition
marginal_std = torch.diagonal(X, dim1=-2, dim2=-1)
sd_diag_mat = _batch_form_diag(1 / marginal_std)
correlation_L = torch.matmul(L, sd_diag_mat) # or is that backwards??
log_prob_corr = self.correlation_prior.log_prob(correlation_L)
log_prob_sd = self.sd_prior.log_prob(marginal_std)
return log_prob_corr + log_prob_sd
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
return torch.eye(self.correlation_prior.n.item()).expand(shape) |
https://github.com/pytorch/pytorch/blob/master/torch/distributions/transforms.py#L514 transforms a matrix with garbage in the upper triangle to a matrix with zeros in the upper triangle, hence the use of |
At one day, we should find a clean way to define unconstrained domain (maybe vector instead of matrix) for LowerCholesky Transform. It will reduce the number of parameters to optimize for these priors. |
Just tossing in a +1 on this... I'm experimenting with pyro by converting a Stan model, and the lack of priors for covariance matrices is kinda an impediment. |
I guess the most complicated work for LKJ prior (which is more numerically stable than Wishart/InverseWishart) is to define a transform from unconstrained space to the space of correlation matrices. Stan reference gives a nice derivation for such transform, which based on the paper: https://www.sciencedirect.com/science/article/pii/S0047259X09000876. The tricky part (which requires loops) is to transform the canonical partial correlation to the Cholesky of correlation (transform from |
Does it need to be coded to efficiently support batches? If someone has a template (i.e., that shows what functions need to be filled in), I can help. |
@elbamos Here is a template which I come up with: Step 1: define constraint as in this script
Step 2: define transform as in this script
Step 3: register bijective as in this scriptStep 4: define distribution (as in PyTorch/Pyro distributions)
Hope that help! |
Maybe I'm missing something here, but it seems to me we don't need to define a constraint on correlation matrices or a transformation from z to x.
Don't we only need the transformation from an unconstrained vector to the lower Cholesky factor of the matrix? This can follow the algorithm in the Stan manual (and we also need its inverse and log abs det jacobian). The constraint on the domain should be `lower_cholesky`, which is already defined in pytorch.
If someone wants to turn it into a correlation matrix (I guess to look at it after the fact), they can use x'x.
No?
Also, I think the relevant Stan page is this one: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
(Also, if you do try to implement the algorithm on https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html, be aware that the equation for the reverse transformation is missing a factor.)
|
@elbamos I think what you're describing are arbitrary covariance matrices. Correlation matrices have the additional constraint that the diagonal is all-ones and off diagonal entries are all in EDIT While I don't think we can use PyTorch's |
@fritzo and @fehiepsi: The most convenient way to work with covariance matrices, in practice, is usually by separately generating a scale vector (theta), and the lower cholesky factor of a correlation matrix (Omega). If you multiply diag(sqrt(theta)) * Omega, you get the lower cholesky factor of a covariance matrix, which is the most efficient parameterization for multivariate distributions. So what's missing from pyro is are distributions for generating the lower cholesky factor of a correlation matrix (probably by the LKJ prior). The question is what Transforms and Constraints need to be coded. Ordinarily in HMC, sampling takes place in an unconstrained space. Variables are then transformed into a constrained space. If that's how Pyro works too, what we'd need is an (invertible) transformation from the unconstrained space of Since the Stan manual helpfully provides an algorithm for that transformation, I'm not sure what Constraint actually needs to be coded. It seems that all we need is the Transform. Does this help explain it? If I'm right about what's required here, then I've already got a prototype implementation, based on the Stan source code, of everything except log_prob. (If its necessary to backprop through the Transform, this becomes tricky. It is probably actually easier to implement the Transform as a pytorch function so we can provide a custom grad function.) |
@elbamos thanks for explaining, yes that makes sense. I was confused because I think you're right, we can follow Stan code to develop a new @elbamos For context PyTorch's transforms differ from Stan's and Tensorflow's in that they also include non-bijective transforms including projections. These are often cheaper and more stable than bijective transforms. While HMC and NUTS require bijections, SVI and MAP inference can allow such overparameterized transforms. This is why |
@elbamos It is great to hear that you already come up with a prototype (I was intending to make this implementation)! I agree that LKJ prior for Cholesky is enough. LKJ prior for correlation (which is more popular for small models) will be mostly based on its Cholesky version (except the log_prob method, where we have to do Cholesky transform for a correlation matrix, which is ineffective for large matrix), so we can do it later if necessary. To make our discussion consistence, I will only discuss about LKJ prior for Cholesky. Here is the template of LKJ prior for Cholesky (quite similar to the correlation version):
You are right about how Pyro's HMC works. About constraint, it is required to define support for each distribution. Under the hood, HMC will first see what is the constraint of the support of the distribution of a latent variable. Then it will see if there is a bijective (step 3 in my template) to that constraint. In this case, if it sees that the constraint is Of course, we can just define a transform and specify it in a HMC instance through @fritzo I don't have a trick to do that. And I think that we can just keep event_shape for the distribution is As I mentioned above, we can make a "loop" version as a first step. If we only need to convert a vector to a lower triangular matrix then I know the following way.
But I don't think that it will help in this situation. I don't know how to make a non-loop version to transform from |
I already have code for unrolling the vector. The issue is whether there will ever be a need to backprop through that code, because there doesn't appear to be a purely vectorized way of doing it. If backprop'ing is going to be necessary, then we will have to write a pytorch function implementing a custom grad. The code that needs to be written is for the If I'm understanding you correctly: The function of the constraint is that during sampling, Pyro will automatically select a transform that matches the domains, which is the Transform(s) I've written. (I've actually coded-up two of them, one from the unconstrained space, and one from [-1, 1], the difference being application of tanh.) I setup a branch, https://github.com/elbamos/pyro/blob/lkj/pyro/distributions/lkj.py |
Also - I took a look at the code for the |
@elbamos Could you let me know which lines in your code which you worry about grad's backpropagation? I can't identify it. Overall, your code looks great. It will be better if we modify it to support "batching".
Please let me know which part you need me to add in your code base (to avoid duplicated work). I'm happy to work on this with you. |
@fehiepsi I've tested the algorithm code in there separately, but I haven't tested any of it in-place in those classes. What I think will cause backprop problems, is the in-place modification to the tensors when transforming the vector into a matrix. You're right about the batching... Maybe someone else can modify the code to do that? I just find it very hard to imagine someone generating correlation matrices in batches that way, but I guess other people have uses cases very different from mine. Regarding the log prob, I think our best source for it is actually the Stan source code... https://github.com/stan-dev/math/blob/master/stan/math/prim/mat/prob/lkj_corr_cholesky_lpdf.hpp Actually, its very interesting to try to implement the same model in Stan and Pyro. Stan is currently running the model about 100x faster than Pyro, and GPU (mine is a 1080Ti) actually makes things worse rather than better. Presumably this is because my Pyro implementation isn't efficiently written yet. |
@elbamos Yes, we can use Stan source code to verify our implementation. I'll take care of "batching". If you want to convert vector to matrix, you can use the trick:
Back-propagation should be fine with this version. About performance of HMC, we did profiling for various models and observe that most of time is spent for computing potential energy which calls GPU is suitable for large vectors/batches. It seems do not give computation advantages for models we get from statistical textbooks. So you don't have to worry about it for now. Could you please allow me pushing commit to your branch and add some tests so I can expand it to verify "batching" work correctly? Thanks! |
@fritzo Should I start the PR now or wait until we're further along? I'm pretty sure you can open PRs against the branch now. I'm not sure what the correct github etiquette is these days. @fehiepsi Yeah... The model that I'm porting over does a lot of matrix slicing and reassembling, which I had hoped would be faster in pyro because it has more powerful vectorized functions. But its turning out that simple Interestingly, the per-iteration performance seems to decline over the course of inference. I suspect this is a combination of things. One thing I learned when I was spending a lot of time building pytorch neural networks, is that to get good performance out of it, you have to optimize your code to re-use buffers. Otherwise pytorch spends a lot of its time reallocating and destroying memory, especially on the gpu. I'm wondering if pyro isn't optimizing for that well? The other thing that I suspect is going on is the limited support for constrained and truncated distributions in pyro. For example, I don't see a way to tell pyro that one of my variables has to be constrained to be positive, and another to be positive-ordered, etc. See the discussion at the bottom here: https://mc-stan.org/docs/2_18/reference-manual/reject-statements-section.html. There's some additional discussion about it in the Stan mailing list. There are two problems. One is that when a distribution is truncated, it doesn't integrate to 1, and if that isn't taken into account it'll confound the posterior. The second, related problem, is that because of this, a Hamiltonian sampler will tend to keep pushing against the improperly enforced constraint. This slows sampling down considerably, and you end up with a posterior bunched-up around the constraint because the sampler keeps trying to explore that part of the space and it can't. |
@elbamos It is fine to me to make PRs to your repo. Could you please open Issues tab, so I can write to-do list there? About performance, we didn't care about reallocating/destroying stuffs. @neerajprad might have better ideas about it than me. About truncated distribution, I am not sure if I understand what you mean correctly. We have a PR at probtorch/pytorch#121, in case you want to follow it. |
@fehiepsi I've enabled the issues tab. I'll have more time on this project this weekend. |
Regarding performance, for most models I would expect the GPU to be slower since HMC/NUTS is heavily sequential - we need to take many steps in the integrator, and unless the time saved from parallelizing the gradient computation within each step is significant, we are unlikely to realize any benefits.
Do you see a difference even after warmup? During warmup, the performance might vary as we adapt the step size.
You are right about NN training, but for HMC, this shouldn't be an issue since we cannot deal with mini-batches and transfer the data all at once to the GPU. |
Just to update folks on the current status of this - I have code up at https://github.com/elbamos/pyro/blob/lkj/pyro/distributions/lkj.py, which runs when isolated, and seems to be correct, passes tests (although I'm working on tests for the derivative and log_prob), etc. But when I try to run, I get an error that I'd appreciate some advice/suggestions/help on tracking it down - my understanding of the innards of pyro is quite basic, so I'm not quite sure where to start. @fritzo Could I trouble you to take a peek? Thanks. |
@elbamos Hope that you don't mind if I take care of this issue separately. |
This is an error that is thrown when HMC tries to find an initial trace to begin sampling from. It repeatedly samples from the prior until it finds a trace with a non-nan value for potential energy, and throws this error if it doesn't succeed in 100 trials. My guess is that the |
Yeah, I forgot about it. This should be the case. |
Thanks, @neerajprad, I was able to trade it to an issue in the conversion that I've resolved. Now dealing with an issue being thrown by mutli_normal while building a good example. I should have it worked through soon. |
+1 |
Ok, I've tracked-down the issue that I'm seeing... It relates to an intersection between Pyro's sampler and precision limits on the To build the L Cholesky of a correlation matrix, we have to go from an unconstrained vector of appropriate size to a lower matrix where each row has a unit norm, and each entry on the diagonal is positive. The transformation works on a row-by-row basis. It fills in the row up to the diagonal. The diagonal element is then filled-in by whatever is necessary so the row has unit norm. This works fine as long as the input is in the range The first step in the transformation is therefore to take the input and pass it through The problem that is arising is that when Pyro samples from the unconstrained space, it apparently begins by sampling from a very wide range. For example, in the test I just ran, the first unconstrained values provided by Pyro were On values in this range, at double precision, pytorch's In fact, in every test I've run so far, Pyro's initial unconstrained sample produced vectors that, after Advice? Suggestions? |
Not sure if it helps, but you can clamp the output of |
@fehiepsi It will run if I do that. But I'd think that would cause two problems:
@neerajprad Do you think @fehiepsi 's solution will cause problems in the HMC/NUTS geometry? How big a problem is it if the inverse transform is broken? If the solution works, then I think this is basically done, and I'll make the PR. If not, we need another solution. |
@elbamos - Great job debugging this! I think at @fehiepsi's solution should be fine. We have had to add these epsilon factors inside many distributions for numerical stability reasons around boundaries. The only thing I would suggest is to choose
My guess is that we will be sampling extreme values during the initial stages when we take large steps and adjust the step size and mass matrix, and not when we enter the typical set. I think that this should be safe because the Metropolis correction step would end up correcting for this (by rejecting most such proposals at the boundaries), but it might also be the case that we are stuck with values at the boundaries and don't end up in the typical set at all. Why don't you open up a PR, and we can discuss this in more detail? @fritzo - Should we open a PR directly in |
I think that we should move this distribution to pytorch (as mentioned by @ssnl at probtorch/pytorch#150) later, after testing the correctness in jit/cpu/gpu. In the mean time, I'll open a parallel PR with tests to address the correctness (should be top priority) and performance (not important right now) of the implementations. |
@neerajprad Ok. And good point on the inverse mapping. I'm curious about the decision to start sampling from wide values rather than narrow ones as Stan does. Is there any discussion about the rationale somewhere online that I can read through? |
@fehiepsi and @elbamos - I looked at the PR in question and mostly what I saw was the sort of reasonable but inevitable miscommunication that happens when smart, well-intentioned strangers discuss complicated math on the internet. I've removed off-topic comments and @neerajprad and others will review #1746 . Please do not derail this issue. |
I think the best place for verbose code review of statistical functions is either Pyro or https://github.com/probtorch/pytorch . I'm fine starting in Pyro and later moving to PyTorch |
Would it be possible to implement Wishart / InverseWishart / LKJ priors?
gpytorch has them already, but when I tried mixing in
TorchDistributionMixin
to get something useable in Pyro, I realized that they don't have a.sample
method.I don't think it's easy to get efficient samplers for the InverseWishart (I think trying to build it via a TransformedDistribution might be too slow), but curious to see other approaches. There's a Tensorflow Probability tutorial on this here, which explains some of the underlying ideas, as well as a note here.
Great work by the way! Pyro is so easy to use it's incredible.
The text was updated successfully, but these errors were encountered: