-
-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More cudnn ops #178
Comments
You can already do conv2d backprop with the existing methods (See https://github.com/coreylowman/dfdx/blob/main/src/tensor_ops/conv2d/cudnn_kernel.rs#L91). Conv1d - I'm not sure this exists in cudnn? Not sure about flash attention existence either (or at least it didn't when I last checked cudnn) But yes open to any contributions here! |
Ah great that the conv2d backward step is already there, we'll add it to candle. For the conv1d, there is some support for Nd convolution I think, e.g. in the cudarc ffi so hopefully having a safe api around this would enable 1d convolutions. For flash attention, I meant this fused flash attn fprop, though I've actually never used it. |
Oh sweet I missed the Nd convolution, nice! Should be able to add that. If I'm understanding the flash attn thing, it seems like that is something detected at runtime if you're using the cudnn graph API? |
Pooling would be great to have as well. It seems like it could have a similar api as conv. |
@coreylowman would you take a PR for adding nd pooling? |
@kckeiks of course, any and all prs welcome |
Thanks for this amazing crate, it's been instrumental to candle. We've recently added a feature to use the cudnn conv2d which sped things up a lot compared to our handcrafted kernel, and would like to have cudnn support for more ops. Mostly interested in:
Are there any plans to add these to the cudnn safe api? If not would you be ok with people making PR to add it?
The text was updated successfully, but these errors were encountered: