Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding torch.nn.ConvTranspose2d #3

Merged
merged 10 commits into from
Oct 31, 2023
Merged

adding torch.nn.ConvTranspose2d #3

merged 10 commits into from
Oct 31, 2023

Conversation

matthieutrs
Copy link
Contributor

Thanks for this very nice repo! I needed to convert models containing torch.nn.ConvTranspose2d but this is not completely straightforward as torch and lax do not perform the same transposed conv.

This PR is essentially a merging of jax-ml/jax#5772 to solve this issue.

Atm it relies on numpy; if this is a problem I think this could be avoided but I didn't have time to remove the dependency yet.

Copy link
Owner

@samuela samuela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this together @matthieutrs! My main concern is that there seems to be a bug when stride != 1. I tried poking on it briefly, but couldn't figure it out... any idea what might be necessary here?

Would an implementation with assert stride == 1 work for your usecase?

torch2jax/__init__.py Outdated Show resolved Hide resolved
torch2jax/__init__.py Outdated Show resolved Hide resolved
torch2jax/__init__.py Outdated Show resolved Hide resolved
torch2jax/__init__.py Outdated Show resolved Hide resolved
torch2jax/__init__.py Outdated Show resolved Hide resolved
.gitignore Outdated Show resolved Hide resolved
tests/test_all_the_things.py Outdated Show resolved Hide resolved
torch2jax/__init__.py Outdated Show resolved Hide resolved
@matthieutrs
Copy link
Contributor Author

Thanks a lot for the careful review!

  1. The dependency to numpy has been removed;
  2. The failling tests for some strides/kernel sizes was due to jax assuming a certain output_padding in torch.nn.ConvTranspose2d, I've added an assertion in the def of conv_transpose2d (see above);

Let me know if there are other things that need to be updated!

Copy link
Owner

@samuela samuela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Figured out a fix such that it works for basically all configuration options to torch.nn.ConvTranspose2d! The one exception being groups != 1, but I'm not worried about that for now.

Additionally,

  • Make the test case deterministic by skirting reliance on PyTorch randomness for weights
  • Decrease atol

torch2jax/__init__.py Outdated Show resolved Hide resolved
tests/test_all_the_things.py Outdated Show resolved Hide resolved
matthieutrs and others added 2 commits October 30, 2023 12:25
Co-authored-by: Samuel Ainsworth <[email protected]>
@matthieutrs
Copy link
Contributor Author

Thanks a lot! It indeed works on my architectures.

@samuela samuela merged commit 4ed99f5 into samuela:main Oct 31, 2023
1 check passed
@samuela
Copy link
Owner

samuela commented Oct 31, 2023

Thanks so much @matthieutrs! You get a special prize for being the first person to merge a PR on torch2jax! 🌟

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants