-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP Port Gamma to CUDA #102
base: master
Are you sure you want to change the base?
Conversation
CPU_tensor_apply2<scalar, double>(ret, alpha, | ||
[generator](scalar& ret_val, const double& alpha){ | ||
auto sample = sample_gamma(alpha, generator); | ||
ret_val = sample > 0 ? sample : FLT_MIN; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace FLT_MIN
with std::numeric_limits<scalar>::min()
test/test_distributions.py
Outdated
if multivariate: | ||
# Project onto a random axis. | ||
axis = np.random.normal(size=torch_samples.shape[-1]) | ||
axis /= np.linalg.norm(axis) | ||
torch_samples = np.dot(torch_samples, axis) | ||
ref_samples = np.dot(ref_samples, axis) | ||
samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples] | ||
samples.sort() | ||
shuffle(samples) # necessary to prevent stable sort from making uneven bins for discrete |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this wasn't intended to work for discrete distributions. I see you're implementing a different test for those.
@rachtsingh Where are you seeing NANs? |
When I run the tests, the Dirichlet tests fail because of bad samples. I'll post a log in a few hours (I'm away from my computer). Thanks for the numeric comment, that's exactly what I was looking for! |
Could you be more specific and paste Dirichlet test failure output? (There are like 10 Dirichlet tests 😉). |
Yes, will do asap, sorry! |
Sorry for the long delay. Here's the output (after fixing
Based on the error message, it looks like it's sampling |
Ah, I figured this out. It's a casting issue - will upload the fix in a second. |
Ok, yep, it's fixed. I will make the real PR to port this to CUDA (really, just a few lines of changes now) after the CUDA RNG changes are merged. |
Ok, CUDA changes are in this branch now - I'm waiting on review for pytorch#4556 and then I can turn this into a PR? Unfortunately it looks like CUDA samplers will be lower accuracy in general (because of the Tesla / double thing). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you checked any Q-Q plots or probability for the single-precision Gamma sampler? I'm curious how small alpha
can be before samples start being clamped to zero.
aten/src/TH/THRandom.c
Outdated
return scale * d * v; | ||
} | ||
} | ||
/* double THRandom_standard_gamma(THGenerator *_generator, double alpha) { */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out? Have you moved the CPU implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's been moved to https://github.com/probtorch/pytorch/pull/102/files#diff-6f5adabe13d89ad314ae10947a7f524aR250 - @apaszke brought up code duplication as an issue so I'm combining the implementations here (this is one attempt that at least works; very open to other ideas that are cleaner implementations).
Basically, by specifying precision_t
you can sort of control the level of accuracy of the implementation used. For GPUs, that's float
, and for CPU that's double
right now.
I haven't checked the q-q plots for this sampler (I checked mine a few months ago), but I've found that they can be hard to interpret. I suspect that there's a good opportunity here for someone to improve accuracy, but I'm not sure how. |
4aff41b
to
2128818
Compare
@fritzo should I move this to pytorch? Or did you want to take another pass on it? Also, there's some submodule gunk on this commit, but I'll clean it before upstreaming. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for implementing this!
Ready to move upstream.
63c1b4c
to
dbe16f8
Compare
Additionally: - add support for calling functions that are not methods in the Python frontend - add an end-to-end test for the Python frontend - add a capture_stdout helper for checking that `print` actually works
Signed-off-by: Edward Z. Yang <[email protected]>
I know this works because I had to squelch a bunch of ASAN errors in multiprocessing. Signed-off-by: Edward Z. Yang <[email protected]>
PR is for discussion only
cc @rachtsingh