-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-basin and Stable Diffusion Tensor Flow weights #5
Comments
Hi @ogkalu2! The exact framework that the model runs in is not super important. For example, we have used our JAX code to align the weights of two PyTorch models in the past. The only important part is that you can load the weights into Python/JAX and that you have a correct In general writing down the Would be very cool to see if this works on StableDiffusion models! Do let me know if you get it working! |
Hi @samuela Thanks for responding. Quite a bit's happened since the last comment. I found a repo that had converted the code to pytorch. I also actually got down to writing the permutation spec for stable diffusion today. I think i'm on the right track but i'm not too sure. For instance, i'm not quite sure what the correct p_in and p_out values should be. Running it right now gives a couple different errors each time. Sometimes i get something like File "/content/drive/MyDrive/SD_rebasin/weight_matching.py", line 487, in weight_matching or "addmm_impl_cpu_" not implemented for 'Half' Can you take a quick look here and see what you think I might be doing wrong ? |
This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on. |
I'll try that thanks. I also get "RuntimeError: INDICES element is out of DATA bounds, id=256 axis_dim=256" |
I've tried a couple different tests now removing everything besides a few lines of layers that were explicitly labelled in the state dict to see if that was the issue and i could build from that. Example here https://imgur.com/a/IbAEfdP and https://imgur.com/a/o6SgMk8 But i still get those weird errors. |
@samuela I feel so close but i've hit another wall on the apply permutation line. It's giving me "Key Erorr: Betas" The good news is that i'm pretty sure i know what's happening here. Betas is the first key in Stable Diffusion's State Dict. It seems to be stuck trying to apply the permutation to Betas but betas wasn't defined in the permutation_spec list of layers or blocks to alter. More importantly, it's not the only key in the state dict i thought best left undisturbed. Is there any way i can get the apply permutation function to skip keys that it doesn't have altered values for ? |
Hmm, I'm not familiar with the stable diffusion architecture... is there a reason not to model permutations on the betas? You could always add them to you (Btw, since you're using the 3rd-party pytorch implementation some things may be different! That's a different codebase.) |
You can skip betas, it's not a layer of the actual architecture but a stored parameter used for sampling and generating timestep embeddings. Same goes for alphas_cumprod, sqrt_alphas, and a dozen others. UNet is the interesting part, and this resides in EDIT:
and these are the ones you can skip if you retain the same VAE and text encoder
|
Mostly 3 reasons
File "/content/drive/MyDrive/SD_rebasin/weight_matching.py", line 293, in weight_matching
For axes you mean the P_bgx and P_bgy values right ?. So None and None then ? I have a question on that too. There are a lot of layers. I'm unsure how to correctly label them all. For the layers with 2 P values, do i just keep going sequentially ? I reach P_bg50 or so that way. For the layers with only one , how would that work exactly ? It's a bit hard to tell when i need to go from say P_bg1 to P_bg2 and to P_bg3 and so on. The architecture is divided in 3 parts - The input blocks, the middle blocks and the output blocks. So I'm wondering, is it P_bg1, _bg2, P_bg3 for those set of blocks or something else ? |
Thanks. Lots of dreambooth repos train the text encoder also now so i won't skip them. I have some uncertainty on the type of certain layers. The layers that have Do you have any idea ? Are they norm, conv, dense or neither ? |
Keep in mind that the different Attention layer types might be either all conv or dense, but they are not just sequentially chained. |
Thanks for the response. It's helped a lot. Yes i think i'm going to skip the attention layers, at least for the first go around. Forgot to ask, what are the time_embed (i'm assuming dense now), model.out.0 and out.2 (norm and conv i think) and .op (i think conv ) ? |
unet time_embed: linear, silu, linear, so its dense It's all readily available in the implementation, so I suggest you read it yourself to get a better understanding. |
Just an update. At first I had trouble building up. Some layers would work, most wouldn't. But most irritatingly, it felt inconsistent on what would work without error and what wouldn't. Today I figured out the issue was the axis and the torch size of the layers. Not all layers can or should be connected by axis and the torch size help tell which ones can/should be. Anyway, I can get pretty much every later to permutate now. So I'll finish that and finally test this. |
@ogkalu2 So the idea is to go pytorch -> TF -> rebasin -> pytorch? This sounds huge btw, thanks for doing it. |
also, you probably saw this, but there's a PyTorch version, but I think you need to come up with a PermutationSpec: https://github.com/themrzmaster/git-re-basin-pytorch |
I'm pretty just using the pytorch implementation now. The one you linked, I already knew about it. Ended up using Jax for flattening and unflattening the parmeters but that's about it. No problem. It's my pleasure. Done with the unet. Working on the text encoder. There's no doubt it'll run now. Just if it merges as hoped. Fingers crossed for that. |
@ogkalu2 Do you have a repository for this where I could take a look? |
@lopho No. I wanted to finish things and see the results of a merged model before i uploaded anything to a repo. |
Although i did upload my first attempt here. A few things have changed to make it work, mostly the axes But i added a bias option for the conv and added the dense emb layers in the easyblock |
Hi @samuela would this run much faster on a gpu ? |
Yes, it should run quite a bit faster on a GPU since that will speed up the matrix multiples but the linear assignment problem solve still happens on the CPU, so I don't think the speedup you'd get would be anything too crazy... I've never tried running on CPU only |
Ah I see. If I don't specify device as cpu, i get File "/notebooks/weight_matching.py", line 798, in weight_matching |
@lopho @affableroots @samuela |
Hi @samuela How many iterations does the weight matching typically run ? I know max iterations is 100 but it doesn't usually go that high, i don't think ? |
It totally depends on the model and initialization. I've seen it take as few as 3 and as many as 50. It is guaranteed to terminate though, so don't worry it can't run forever! |
I keep OOMing on 32GB RAM, any tips on what I can delete when, or maybe running the merge in parts? |
Oh wow. I know it can't run forever but the 1st SD iteration took ~ 12 hours so I was curious. Ah well. What do the NewL - OldL values indicate exactly ? I see most of them are 0.0 |
Really ? Huh. I'm just running on vast right now. You can't really run in parts right now. As for what to delete, it's possible to skip some layers but i honestly don't know exactly what i can skip yet. The vae layers would be the first thing i'd remove but i don't know besides that. I'll look into that. The OOM errors seem odd though. Do you actually get those errors on your console/terminal or does your system freeze up or something ? |
Watching
EDIT: skipping the vae makes sense, that's a good idea. |
Oh i see. The test i have running has 2 dreambooth models pruned to 2GB. The bigger the size of the models, the higher the RAM usage. I didn't realize 4gb models were too much for 32 GB ram systems currently. The problem is the linear sum assignment. It can only run on the CPU |
@samuela Anyway i have a new problem now. So the perm spec runs fine and the parameters get updated fine. The previous line i wrote to save the model won't work. After defining the state dict(s) as state_a = model_a["state_dict"], i tried to save the model with but get hit with "state_dict": state_b(updated_params) |
Hi @samuela So a merge with apply_permutation(permutation_spec, final_permutation, mode_a state dict) just produces a model that is basically model a and a merge with apply_permutation(permutation_spec, final_permutation, mode_b state dict) just produces a model that is basically model b. Any idea what the issue might be ? |
Hi @ogkalu2, how are you measuring the difference between the permuted model and the original? Have you inspected the |
Hello, I wonder if there is an update on the automatic tracer in pytorch for the permutation spec? Thanks! |
Hi Samuel. Thank you for being willing to look at this. Basically, i'm trying to see if it is possible to merge stable diffusion models i've finetuned with dreambooth with your method.
The first hurdle of course is that your implementation is not yet compatible with pytorch as far as i know. But the pytorch weights can be successfully converted to Tensorflow weights. This has been done before as well. I don't mind doing this if i have to.
The second hurdle will be successfully using your implementation on the TF models. I'm not sure how feasible this all is or how i would use your code on the SD TensorFlow weights.
The text was updated successfully, but these errors were encountered: