-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch compatibility #1
Comments
no concrete plans yet, but it ought not to be all that hard! might be a fun idea to live code/stream on twitch at some point... i certainly won't have time for this until at least a month after the ICLR deadline. if someone can beat me to it, that would be cool as well! |
I converted part of the code to PyTorch. Will push to a repo and share here |
https://github.com/themrzmaster/git-re-basin-pytorch Still working on somethings, but it is there :) |
How can this be applied to arbitrary pytorch models? |
@PythonNut has been working on a PyTorch tracer that will enable you to automatically get |
There's an attempt at merging for Stable Diffusion here (just linking it for newcomers). Note: we haven't gotten it to work on GPU yet, and it OOMs on 32GB RAM, currently. |
Did this work out in the end? |
This is really @PythonNut's thing, but IIUC I don't think it's in a state suitable to open sourcing yet. That said, it shouldn't be too hard to hack together something that works with the symbolic execution stuff provided in https://pytorch.org/docs/stable/fx.html. |
I have an independent implementation of the matching algorithm in PyTorch with some optimizations over our original version. This is definitely something I could release but the existence of themrzmaster/git-re-basin-pytorch makes this less urgent. For the tracer, I have code that basically works, but right now the utility is somewhat limited. It should be able to reproduce all of the permutation specs used in the paper, but there are not that many new networks that work out of the box. Currently, I assume that any given tensor axis can be represented by a single permutation group, but there are several ways this assumption might not hold:
We've worked out what the matching algorithm would look like in all of these cases. The main missing piece is a structure for the permutation spec which can support all of these cases. If you have a particular use-case in mind, we'd love to hear about it! |
Thank you! I'd say my use-case was quite speculative, so I was hoping for a quick way of testing it out without investing the time into writing correct specs for my model. |
When implemented in PyTorch, linear_sum_assignment in script does not run on the GPU, which causes a speed bottleneck. I have done a PyTorch implementation of linear_sum_assignment. Referenced jax implementation |
Very interesting work! Are there plans to make weights matching applicable to PyTorch models?
The text was updated successfully, but these errors were encountered: