WIP implementation of multivariate normal distribution#52
WIP implementation of multivariate normal distribution#52
Conversation
…ed covariance matrices at all.
fritzo
left a comment
There was a problem hiding this comment.
Nice! Could you also add a section to docs/source/distributions.rst and maybe cd docs; make html to ensure docs still build (I've caught my own typos this way).
You can also take a look a the recent OneHotCategorical tests, since that is also a "multivariate" distribution with nontrivial event_shape.
test/test_distributions.py
Outdated
There was a problem hiding this comment.
Could you also add some example parameters in EXAMPLES below, ideally one that specifies cov and another that specifies scale_tril?
There was a problem hiding this comment.
Is it possible to support batched .rsample() by using torch.bmm() here instead of torch.matmul()? I'm not sure what's blocking batched covariance.
There was a problem hiding this comment.
That might work for distributions specified via scale_tril as opposed to covariance_matrix. The primary blocker is a batched torch.potrf. We also need a batched solver (batched torch.gesv or otherwise) to compute the log probability.
I think we can maybe do this by using torch.btrifact and torch.btrisolve instead of potrf and gesv. Haven't looked into it yet.
There was a problem hiding this comment.
Unfortunately, I didn't realize that torch.btrifact doesn't actually support .backward() calls.
There was a problem hiding this comment.
nit: It would be nice to stay maximally compatible with Tensorflow.distributions and name this covariance_matrix.
There was a problem hiding this comment.
Yeah, I just left this matching the scipy mean and cov because they were so much shorter. Happy to change to loc and covariance_matrix if that is what we've settled on.
There was a problem hiding this comment.
Yeah I'd like to keep the interfaces similar if possible, but I defer to your judgement here.
|
@tbrx How's this going? I might have time this weekend to try add some bits of our Pyro implementation into this branch. I think it's fine to provide a batching-complete interface even if we might need to do some python iteration under the hood for now, until pytorch#4612 merges. |
|
I haven't touched it since the last push — was waiting to see if a batched At the moment this actually works fine, with the caveat that a batch shape on the Adding a version of this which handles batches by using python loops or list comprehensions shouldn't be too difficult… |
|
Oh great, if it already works could we merge it now, then add full batched support in a follow-up PR? It would be nice to help motivate batched linear algebra work in PyTorch by claiming that "if xxx operation were batched then torch.distributions.MultivariateNormal would get batched covariance support for free". |
|
Okay — I think for that, all we need to do is update / expand the tests. If there are any other updates that happened to the Pyro version it would be nice to merge them in too. Maybe it is worth implementing a slow version with batched covariance matrices first just to the the API correct, though. If the batch size is reasonably small it shouldn't be too slow. |
|
@tbrx It would make @neerajprad 's and my job easier if you could merge this PR soon, simply adding tests and pushing further enhancements to follow-up PRs. The Pyro team has already migrated to PyTorch distributions and we're working around lack of |
|
That makes sense — the lack of a batch dimension on the covariance_matrix doesn't cause issues for you in Pyro if I understand correctly? I can update this PR and add the remaining tests Monday morning my time. |
|
What sort of constraint should we use here for the |
|
Yeah, I've been thinking about that. I think we should introduce new constraints:
Does that seem reasonable? They're simply symbolic placeholders, but we'll use them to register |
|
BTW I've added an issue #99 for implementing a |
| self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,7)).size(), (2, 7, 6, 5, 3)) | ||
| self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2,7)).size(), (2, 7, 5, 3)) | ||
|
|
||
| # check gradients |
There was a problem hiding this comment.
@jwvdm noted that we could generically retrieve params for a distribution if we specified a canonical set of parameters. I've been trying to do this by putting only a single canonical parameterization in the .params dict (e.g. either loc,covariance_matrix or loc,scale_tril but not all three).
But I like what you've done here by adding them all. Maybe we should do that for all distributions and specify canonical_params or something in another field, or just let higher level libraries like Pyro or ProbTorch do that. WDYT?
There was a problem hiding this comment.
nit: You could simplify via
if (covariance_matrix is None) == (scale_tril is None):
raise ValueError(...)| raise ValueError("Either covariance matrix or scale_tril may be specified, not both.") | ||
| if covariance_matrix is None and scale_tril is None: | ||
| raise ValueError("One of either covariance matrix or scale_tril must be specified") | ||
| if scale_tril is None: |
There was a problem hiding this comment.
Neeraj made this cool decorator called @lazy_property that could allow you to create scale_tril only if it does not exist, which would avoid unnecessary work in some cases. You could use it as follows:
class MultivariateNormal(Distribution):
def __init__(...):
...
if scale_tril is not None:
self.scale_tril = scale_tril
# leave .covariance_matrix unset
else:
self.covariance_matrix = covariance_matrix
# leave .scale_tril unset
...
@lazy_property
def scale_tril(self):
return torch.potrf(self.covariance_matrix, upper=False)
@lazy_property
def covariance_matrix(self):
return torch.mm(scale_tril, scale_tril.t())|
@tbrx In pytorch#4771 I've replaced |
|
@fritzo Actually… I like leaving it just as The main (potential) problem would be in computing the determinant of the covariance matrix. But that's actually fine. Here's an example lower-triangular matrix with a negative entry on the diagonal: We can use this to get a covariance matrix, whose Cholesky decomposition is of course different: The determinant of this covariance matrix is 0.25. We can get this from the Cholesky decomposition by but this is the same as That said, I believe the current MVN code actually handles It seems to me one nice use case of |
In my very limited experience, it is important to ensure positive definiteness rather than merely semidefiniteness (sorry if I've messed this up in u = Variable(torch.Tensor(4, 4).normal_(), requires_grad=True) # optimize this
scale_tril = u.tril(-1) + u.diag().exp().diag()If you merely instead define scale_tril = u.tril()then optimization will often pass though a hyperplane of singular matrices, i.e. where one of the I'm happy to add |
|
I'm not sure we should actually change it back! Just wanted to discuss. I actually agree with you that the PSD vs PD bit is probably more crucial. In that case I want to confirm I guess that I agree it is nice (generally) to have the And actually, your code snippet may have convinced me that this isn't a problem. The My one remaining concern though is what happens when we (eventually) update the MVN to support batching for |
There is no batch support yet, and def cholesky_lower_transform(x):
if x.dim() == 2:
return x.tril(-1) + x.diag().exp().diag()
else:
n = x.size(-1)
diag = torch.eye(n, out=x.new(n, n))
arange = torch.arange(n, out=x.new(n))
tril = (arange.unsqueeze(-1) > arange.unsqueeze(0)).float()
return x * tril + x.exp() * diagThis probably suffers from NAN issues, but the general idea should work. |
|
BTW The u = Variable(torch.Tensor(100, 100).normal_(), requires_grad=True)
scale_tril = to_constrained(constraints.cholesky_lower)(u)or even scale_tril = to_constrained(dist.params['scale_tril'])(u)I'm really looking forward to using this in Pyro 😄 |
…ay computation until after init
|
I believe the primary blocker to moving upstream at this point is the constraints, and a decision on the Alternatively, we could wait for pytorch#4771 and then include both this and |
|
I'd recommend adding an implementation of |
|
Is there anything I'm missing here (particularly in terms of test coverage…)? Otherwise, I'd be up for sending this PR upstream. |
fritzo
left a comment
There was a problem hiding this comment.
Looks ready to send upstream after one minor doc fix.
Re: testing, I think the strongest tests will be provided once we have a "by hand" bivariate normal distribution. We'll also be using this in Pyro right away; this should give us a little time to look for bugs and weird behavior before PyTorch release.
| :members: | ||
|
|
||
| :hidden:`MultivariateNormal` | ||
| ~~~~~~~~~~~~~~~~~~~~~~~ |
There was a problem hiding this comment.
nit: underline is too short and will break docs. You can run make -C docs html and open docs/build/html/index.html to check docs.
|
After merging in the latest Visually the results look "okay" for most entries — the max error reported is 0.349, which is on a value of 58.xxx. |
|
@tbrx That failure is not expected. Can you make sure you've rebuilt with |
|
So, it seems that the monte carlo test for Pareto entropy is just very sensitive… If I change the ordering of |
…ers for multivariate normal.
|
In working on the BivariateNormal #99 I started writing helpers for working with torch linear algebra functions, and realized that actually it would be hardly more work to port and implement these here. So, I updated this to support actual batching on covariance matrices and the Would appreciate feedback, particularly on whether I handled the "batch-friendly" matrix constraints correctly, and whether I am missing anything with the linear algebra helpers I added to Obviously the current implementation is not ideal, speed-wise:
|
fritzo
left a comment
There was a problem hiding this comment.
The helpers look reasonable, and I like that they abstract out the mess and make MultivariateNormal methods more readable.
I'd love to have this in master soon so we can "kick its tires" and get any fixes into PyTorch 0.4 release. E.g. it would help to have other multivariate distributions for testing batch shapes of Transforms.
There was a problem hiding this comment.
nit: Use r""" rather than """ to open docstrings that contain backslashes
| dims = torch.arange(n, out=bmat.new(n)).long() | ||
| if isinstance(dims, Variable): | ||
| dims = dims.data # TODO: why can't I index with a Variable? | ||
| return bmat[...,dims,dims] |
| return -0.5*(M + self.loc.size(-1)*math.log(2*math.pi)) - log_det | ||
|
|
||
| def entropy(self): | ||
| log_det = _batch_diag(self.scale_tril).abs().log().sum(-1) |
There was a problem hiding this comment.
Hmm, shouldn't this already have the correct shape? Why do you need to H.expand(self._batch_shape) below?
There was a problem hiding this comment.
Before you send upstream, consider replacing with something more diplomatic 😉
conform to torch.bmm which requires .dim() == 3
|
Great, thanks @fritzo ! I'll (finally!) make a new pull request upstream. |
For issue #1. @fritzo
meanargument, plus (either)covorscale_triltorch.gesvfor computing log_prob; ifrequires_grad=Falsethen we could do a (cheaper)torch.potrs… probably worth using a solver-helper here like @dwd31415 has in the Pyro PR.Argument naming convention at the moment is:
meanandcovto matchscipy.stats.multivariate_normal, andscale_trilto match the Pyro PR.Test coverage is spotty at the moment (in particular I had some issue with the
_gradcheck_log_probhelper), but shapes seem okay and logprob values match scipy.One question is when we should compute the Cholesky decomposition if passed a
covargument instead ofscale_tril. I opted to call it initially up front in the constructor -- we're ultimately going to need it no matter what, either for sampling, or for computing the log determinant in the log_prob or entropy.