Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch compatibility #1

Open
superkirill opened this issue Sep 17, 2022 · 11 comments
Open

PyTorch compatibility #1

superkirill opened this issue Sep 17, 2022 · 11 comments

Comments

@superkirill
Copy link

Very interesting work! Are there plans to make weights matching applicable to PyTorch models?

@samuela
Copy link
Owner

samuela commented Sep 17, 2022

no concrete plans yet, but it ought not to be all that hard! might be a fun idea to live code/stream on twitch at some point... i certainly won't have time for this until at least a month after the ICLR deadline. if someone can beat me to it, that would be cool as well!

@themrzmaster
Copy link

I converted part of the code to PyTorch. Will push to a repo and share here

@themrzmaster
Copy link

themrzmaster commented Sep 19, 2022

https://github.com/themrzmaster/git-re-basin-pytorch

Still working on somethings, but it is there :)

@affableroots
Copy link

How can this be applied to arbitrary pytorch models?

@samuela
Copy link
Owner

samuela commented Oct 28, 2022

@PythonNut has been working on a PyTorch tracer that will enable you to automatically get PermutationSpecs between most model architectures for free... but that's still a work in progress atm

@affableroots
Copy link

There's an attempt at merging for Stable Diffusion here (just linking it for newcomers).

Note: we haven't gotten it to work on GPU yet, and it OOMs on 32GB RAM, currently.

@markdjwilliams
Copy link

@PythonNut has been working on a PyTorch tracer that will enable you to automatically get PermutationSpecs between most model architectures for free... but that's still a work in progress atm

Did this work out in the end?

@samuela
Copy link
Owner

samuela commented Feb 15, 2023

This is really @PythonNut's thing, but IIUC I don't think it's in a state suitable to open sourcing yet. That said, it shouldn't be too hard to hack together something that works with the symbolic execution stuff provided in https://pytorch.org/docs/stable/fx.html.

@PythonNut
Copy link
Collaborator

@markdjwilliams

Did this work out in the end?

I have an independent implementation of the matching algorithm in PyTorch with some optimizations over our original version. This is definitely something I could release but the existence of themrzmaster/git-re-basin-pytorch makes this less urgent.

For the tracer, I have code that basically works, but right now the utility is somewhat limited. It should be able to reproduce all of the permutation specs used in the paper, but there are not that many new networks that work out of the box. Currently, I assume that any given tensor axis can be represented by a single permutation group, but there are several ways this assumption might not hold:

  1. Tensor axis which is a concatenation of multiple permutation groups (e.g. DenseNet)
  2. Tensor axis which has a hierarchical permutation structure (e.g. MultiHeadAttention)
  3. Tensor axis which is a "tensor product" of permutation groups (e.g. VGGs on ImageNet [CIFAR is fine])
  4. Probably more!

We've worked out what the matching algorithm would look like in all of these cases. The main missing piece is a structure for the permutation spec which can support all of these cases.

If you have a particular use-case in mind, we'd love to hear about it!

@markdjwilliams
Copy link

Thank you! I'd say my use-case was quite speculative, so I was hoping for a quick way of testing it out without investing the time into writing correct specs for my model.

@MasanoriYamada
Copy link

When implemented in PyTorch, linear_sum_assignment in script does not run on the GPU, which causes a speed bottleneck.
For example, in STE, linear_sum_assignment exists in the learning loop, which causes frequent memory transfers between the CPU and GPU, sacrificing speed.

I have done a PyTorch implementation of linear_sum_assignment.
It accepts torch.Tensor on GPU and returns result similar to Scipy.
However, I am having trouble with the torch version because it runs slower than Scipy for some reason.
If anyone has tried similar problems, please advise.
https://gist.github.com/MasanoriYamada/72405515264749df02ba392f16810e12

Referenced jax implementation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants