Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Port Gamma to CUDA #102

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open

WIP Port Gamma to CUDA #102

wants to merge 8 commits into from

Conversation

fritzo
Copy link

@fritzo fritzo commented Jan 20, 2018

PR is for discussion only

cc @rachtsingh

@fritzo fritzo added the WIP label Jan 20, 2018
CPU_tensor_apply2<scalar, double>(ret, alpha,
[generator](scalar& ret_val, const double& alpha){
auto sample = sample_gamma(alpha, generator);
ret_val = sample > 0 ? sample : FLT_MIN;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace FLT_MIN with std::numeric_limits<scalar>::min()

if multivariate:
# Project onto a random axis.
axis = np.random.normal(size=torch_samples.shape[-1])
axis /= np.linalg.norm(axis)
torch_samples = np.dot(torch_samples, axis)
ref_samples = np.dot(ref_samples, axis)
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
samples.sort()
shuffle(samples) # necessary to prevent stable sort from making uneven bins for discrete
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this wasn't intended to work for discrete distributions. I see you're implementing a different test for those.

@fritzo
Copy link
Author

fritzo commented Jan 20, 2018

@rachtsingh Where are you seeing NANs?

@rachtsingh
Copy link

When I run the tests, the Dirichlet tests fail because of bad samples. I'll post a log in a few hours (I'm away from my computer).

Thanks for the numeric comment, that's exactly what I was looking for!

@fritzo
Copy link
Author

fritzo commented Jan 20, 2018

Could you be more specific and paste Dirichlet test failure output? (There are like 10 Dirichlet tests 😉).

@rachtsingh
Copy link

Yes, will do asap, sorry!

@rachtsingh
Copy link

Sorry for the long delay. Here's the output (after fixing std::numeric_limits):

............................/datadrive/build/pytorch/torch/distributions/distribution.py:70: UserWarning: sample_n will be deprecated. Use .sample((n,)) instead
  warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning)
....................................................../home/rachitsingh/venv/local/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in greater
  return (self.a < x) & (x < self.b)
/home/rachitsingh/venv/local/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in less
  return (self.a < x) & (x < self.b)
/home/rachitsingh/venv/local/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py:1738: RuntimeWarning: invalid value encountered in greater_equal
  cond2 = (x >= self.b) & cond0
/home/rachitsingh/venv/local/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py:876: RuntimeWarning: invalid value encountered in greater_equal
  return (self.a <= x) & (x <= self.b)
/home/rachitsingh/venv/local/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py:876: RuntimeWarning: invalid value encountered in less_equal
  return (self.a <= x) & (x <= self.b)
FF.....
======================================================================
FAIL: test_beta_wrt_alpha (__main__.TestRsample)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_distributions.py", line 1142, in test_beta_wrt_alpha
    'at x = {}'.format(x[rel_error.argmax()]),
AssertionError: Bad gradient dx/dalpha for x ~ Beta(0.01, 0.01)
x [  1.19208998e-07   1.19208998e-07   1.19208998e-07              nan
   1.19208998e-07   1.19208998e-07   1.19208998e-07   1.19208998e-07
   8.55534563e-06   9.99999881e-01   9.99999881e-01   9.99999881e-01
   9.99999881e-01   9.99999881e-01   9.99999881e-01   9.99999881e-01
              nan   9.99999881e-01   9.99999881e-01              nan]
expected [ 0.00078592  0.00078592  0.00078592         nan  0.00078592  0.00078592
  0.00078592  0.00078592  0.05274682  0.00059625  0.00059625  0.00059625
  0.00059625  0.00059625  0.00059625  0.00059625         nan  0.00059625
  0.00059625         nan]
actual [ 0.0007859   0.0007859   0.0007859          nan  0.0007859   0.0007859
  0.0007859   0.0007859   0.0527457   0.00059624  0.00059624  0.00059624
  0.00059624  0.00059624  0.00059624  0.00059624         nan  0.00059624
  0.00059624         nan]
rel error [  2.16805036e-05   2.16805036e-05   2.16805036e-05              nan
   2.16805036e-05   2.16805036e-05   2.16805036e-05   2.16805036e-05
   2.12543943e-05   2.05788282e-05   2.05788282e-05   2.05788282e-05
   2.05788282e-05   2.05788282e-05   2.05788282e-05   2.05788282e-05
              nan   2.05788282e-05   2.05788282e-05              nan]
max error nan
at x = nan

======================================================================
FAIL: test_beta_wrt_beta (__main__.TestRsample)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_distributions.py", line 1172, in test_beta_wrt_beta
    'at x = {!r}'.format(x[rel_error.argmax()]),
AssertionError: Bad gradient dx/dbeta for x ~ Beta(0.01, 0.01)
x [  1.19208998e-07   1.19208998e-07   1.19208998e-07              nan
   1.19208998e-07   1.19208998e-07   1.19208998e-07   1.19208998e-07
   8.55534563e-06   9.99999881e-01   9.99999881e-01   9.99999881e-01
   9.99999881e-01   9.99999881e-01   9.99999881e-01   9.99999881e-01
              nan   9.99999881e-01   9.99999881e-01              nan]
expected [-0.00059625 -0.00059625 -0.00059625         nan -0.00059625 -0.00059625
 -0.00059625 -0.00059625 -0.04279102 -0.00078592 -0.00078592 -0.00078592
 -0.00078592 -0.00078592 -0.00078592 -0.00078592         nan -0.00078592
 -0.00078592         nan]
actual [-0.00059624 -0.00059624 -0.00059624         nan -0.00059624 -0.00059624
 -0.00059624 -0.00059624 -0.04279014 -0.0007859  -0.0007859  -0.0007859
 -0.0007859  -0.0007859  -0.0007859  -0.0007859          nan -0.0007859
 -0.0007859          nan]
rel error [ -2.05756588e-05  -2.05756588e-05  -2.05756588e-05              nan
  -2.05756588e-05  -2.05756588e-05  -2.05756588e-05  -2.05756588e-05
  -2.06265899e-05  -2.40871728e-05  -2.40871728e-05  -2.40871728e-05
  -2.40871728e-05  -2.40871728e-05  -2.40871728e-05  -2.40871728e-05
              nan  -2.40871728e-05  -2.40871728e-05              nan]
max error nan
at x = nan

----------------------------------------------------------------------
Ran 89 tests in 8.604s

FAILED (failures=2)

Based on the error message, it looks like it's sampling nan at some point. I'll look more into it as well - I was mostly wondering if you knew a quick reason why this might fail (e.g. like the check for a positive sample).

@rachtsingh
Copy link

Ah, I figured this out. It's a casting issue - will upload the fix in a second.

@rachtsingh
Copy link

Ok, yep, it's fixed. I will make the real PR to port this to CUDA (really, just a few lines of changes now) after the CUDA RNG changes are merged.

@rachtsingh
Copy link

Ok, CUDA changes are in this branch now - I'm waiting on review for pytorch#4556 and then I can turn this into a PR?

Unfortunately it looks like CUDA samplers will be lower accuracy in general (because of the Tesla / double thing).

Copy link
Author

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked any Q-Q plots or probability for the single-precision Gamma sampler? I'm curious how small alpha can be before samples start being clamped to zero.

return scale * d * v;
}
}
/* double THRandom_standard_gamma(THGenerator *_generator, double alpha) { */
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out? Have you moved the CPU implementation?

Copy link

@rachtsingh rachtsingh Jan 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's been moved to https://github.com/probtorch/pytorch/pull/102/files#diff-6f5adabe13d89ad314ae10947a7f524aR250 - @apaszke brought up code duplication as an issue so I'm combining the implementations here (this is one attempt that at least works; very open to other ideas that are cleaner implementations).

Basically, by specifying precision_t you can sort of control the level of accuracy of the implementation used. For GPUs, that's float, and for CPU that's double right now.

@rachtsingh
Copy link

I haven't checked the q-q plots for this sampler (I checked mine a few months ago), but I've found that they can be hard to interpret. I suspect that there's a good opportunity here for someone to improve accuracy, but I'm not sure how.

@rachtsingh
Copy link

rachtsingh commented Jan 25, 2018

@fritzo should I move this to pytorch? Or did you want to take another pass on it? Also, there's some submodule gunk on this commit, but I'll clean it before upstreaming.

Copy link
Author

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for implementing this!

Ready to move upstream.

apaszke and others added 8 commits February 24, 2018 11:15
Additionally:
- add support for calling functions that are not methods in the Python frontend
- add an end-to-end test for the Python frontend
- add a capture_stdout helper for checking that `print` actually works
…h#5380)

* Ignore FileNotFoundError when shutting down in data_queue.get

* Address @apaszke comments
I know this works because I had to squelch a bunch of ASAN
errors in multiprocessing.

Signed-off-by: Edward Z. Yang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants