-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] support loading Flux Control LoRAs with bitsandbytes
quantization
#10588
Comments
bitsandbytes
quantization
Honestly, I don't have a good suggestion how to tackle this. Probably it's best to ask the bnb devs what the recommended way is. In PEFT, we sometimes encounter similar situations. Not with the shapes changing, but we may have to create a new weight because we want to merge the LoRA weights into the base weight. This has always been a very brittle part of the code, often undocumented, no matter the quantization scheme is used. If I had a wish, I'd like an API like so: quantized_layer = ...
quantized_weight = quantized_layer.weight
new_data = ...
new_weight = quantized_weight.new(new_data) # <= new data but all other properties stay the same Alas, AFAIK there is no such API. |
Could you point me to a reference in the peft codebase for this? IIRC, it's through a dispatching system, right? Cc: @matthewdouglas in case you have more ideas. |
It's not really a comparable situation, but I mean places like this: weight = self.get_base_layer().weight # <= `self.get_base_layer()` is the bnb layer
kwargs = weight.__dict__
...
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device) https://github.com/huggingface/peft/blob/93d80465a5dd63cda22e0ec1103dad35b7bc35c6/src/peft/tuners/lora/bnb.py#L359-L360 I think it's clear that this can be very brittle, e.g. when new arguments are introduced. Or just recently, I noticed it breaking with
We have such a system for LoRA, but it's only responsible for deciding which type of LoRA layer to use to wrap the base layer. |
Sorry for not making it clear. I think it has overlaps.
So, in our case, it would be deciding which type of |
I'd suggest not starting with a dispatching mechanism to keep things simple. In PEFT, we just got to a point where it became unwieldy, as we had a giant if...else in PEFT for each combination of quantization and layer type (linear, conv2d, etc.). As long as it's not that bad in diffusers, I'd suggest keeping it as is. |
Hi @sayakpaul , @BenjaminBossan , Thanks for all the work on this! I was wondering if there’s an estimated timeline for supporting Flux Control LoRAs with bitsandbytes quantization. Also, if I want to use it now, are there any specific changes I could make to the current code to get it working? Any guidance would be really helpful. Appreciate your time! |
@chaewon-huh thanks for your patience. I plan to start working on this very soon (apologies for the delay).
The first point to address is using the appropriate layer when doing the expansion here: diffusers/src/diffusers/loaders/lora_pipeline.py Line 2020 in 9f5ad1d
And then we have to consider dequantizing the params when expanding the LoRA state dicts here: diffusers/src/diffusers/loaders/lora_pipeline.py Lines 2094 to 2103 in 9f5ad1d
|
I’ve been investigating different approaches to enable Flux Control LoRA + bitsandbytes (4-bit) quantization without fully modifying Fusing LoRA weights first and then quantizing:
Are any of these alternatives known to work reliably, or is modifying the bitsandbytes layer creation/expansion logic in I’d be grateful for any guidance or suggestions you might have. |
Well, fusion is definitely an option and if that works for you, please go ahead as it should already work out of the box. But if you wanted to change the LoRA scale that won't be possible :/ |
@chaewon-huh could you give #10990 a test? |
#10578 fixed loading LoRAs into 4bit quantized models for Flux.
#10576 added a test to ensure Flux LoRAs can be loaded when 8bit
bitsandbytes
quantization is applied.We still need to support all of this for Flux Control LoRAs as we do quite a bit of expansion gymnastics as well as new layer assignments to make it all work.
Some stuff I wanted to discuss before attempting a PR:
(when I say quantization always assumed quantization from
bitsandbytes
for this thread)diffusers/src/diffusers/loaders/lora_pipeline.py
Line 2020 in c944f06
nn.Linear
. It needs to configured based on what quantization scheme we're using (4bit/8bit).Same goes for:
diffusers/src/diffusers/loaders/lora_pipeline.py
Line 1917 in c944f06
@BenjaminBossan I wanted to pick your brains here to have a robust design for approaching the solution. Suggestions?
The text was updated successfully, but these errors were encountered: