-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding torch.nn.ConvTranspose2d #3
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this together @matthieutrs! My main concern is that there seems to be a bug when stride != 1
. I tried poking on it briefly, but couldn't figure it out... any idea what might be necessary here?
Would an implementation with assert stride == 1
work for your usecase?
Co-authored-by: Samuel Ainsworth <[email protected]>
Co-authored-by: Samuel Ainsworth <[email protected]>
Co-authored-by: Samuel Ainsworth <[email protected]>
Co-authored-by: Samuel Ainsworth <[email protected]>
Thanks a lot for the careful review!
Let me know if there are other things that need to be updated! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Figured out a fix such that it works for basically all configuration options to torch.nn.ConvTranspose2d
! The one exception being groups != 1
, but I'm not worried about that for now.
Additionally,
- Make the test case deterministic by skirting reliance on PyTorch randomness for weights
- Decrease
atol
Co-authored-by: Samuel Ainsworth <[email protected]>
Co-authored-by: Samuel Ainsworth <[email protected]>
Thanks a lot! It indeed works on my architectures. |
Thanks so much @matthieutrs! You get a special prize for being the first person to merge a PR on torch2jax! 🌟 |
Thanks for this very nice repo! I needed to convert models containing
torch.nn.ConvTranspose2d
but this is not completely straightforward as torch and lax do not perform the same transposed conv.This PR is essentially a merging of jax-ml/jax#5772 to solve this issue.
Atm it relies on numpy; if this is a problem I think this could be avoided but I didn't have time to remove the dependency yet.