Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] support loading Flux Control LoRAs with bitsandbytes quantization #10588

Open
sayakpaul opened this issue Jan 15, 2025 · 10 comments
Open
Assignees

Comments

@sayakpaul
Copy link
Member

#10578 fixed loading LoRAs into 4bit quantized models for Flux.

#10576 added a test to ensure Flux LoRAs can be loaded when 8bit bitsandbytes quantization is applied.

We still need to support all of this for Flux Control LoRAs as we do quite a bit of expansion gymnastics as well as new layer assignments to make it all work.

Some stuff I wanted to discuss before attempting a PR:

(when I say quantization always assumed quantization from bitsandbytes for this thread)

expanded_module = torch.nn.Linear(
is responsible for initializing an expanded module. This is perfectly fine for non-quantized scenarios but for quantization we cannot be using nn.Linear. It needs to configured based on what quantization scheme we're using (4bit/8bit).

Same goes for:

original_module = torch.nn.Linear(

@BenjaminBossan I wanted to pick your brains here to have a robust design for approaching the solution. Suggestions?

@sayakpaul sayakpaul self-assigned this Jan 15, 2025
@sayakpaul sayakpaul changed the title [LoRA] support loading Flux Control LoRAs with bitsandbytes` quantization [LoRA] support loading Flux Control LoRAs with bitsandbytes quantization Jan 15, 2025
@BenjaminBossan
Copy link
Member

Honestly, I don't have a good suggestion how to tackle this. Probably it's best to ask the bnb devs what the recommended way is.

In PEFT, we sometimes encounter similar situations. Not with the shapes changing, but we may have to create a new weight because we want to merge the LoRA weights into the base weight. This has always been a very brittle part of the code, often undocumented, no matter the quantization scheme is used.

If I had a wish, I'd like an API like so:

quantized_layer = ...
quantized_weight = quantized_layer.weight
new_data = ...
new_weight = quantized_weight.new(new_data)  # <= new data but all other properties stay the same

Alas, AFAIK there is no such API.

@sayakpaul
Copy link
Member Author

This has always been a very brittle part of the code, often undocumented, no matter the quantization scheme is used.

Could you point me to a reference in the peft codebase for this? IIRC, it's through a dispatching system, right?

Cc: @matthewdouglas in case you have more ideas.

@BenjaminBossan
Copy link
Member

Could you point me to a reference in the peft codebase for this?

It's not really a comparable situation, but I mean places like this:

                weight = self.get_base_layer().weight  # <= `self.get_base_layer()` is the bnb layer
                kwargs = weight.__dict__
...
                kwargs["requires_grad"] = False
                kwargs.pop("data", None)
                self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)

https://github.com/huggingface/peft/blob/93d80465a5dd63cda22e0ec1103dad35b7bc35c6/src/peft/tuners/lora/bnb.py#L359-L360
https://github.com/huggingface/peft/blob/93d80465a5dd63cda22e0ec1103dad35b7bc35c6/src/peft/tuners/lora/bnb.py#L387-L389

I think it's clear that this can be very brittle, e.g. when new arguments are introduced. Or just recently, I noticed it breaking with torch.compile because it adds spurious attributes. For other quantization techniques, it's pretty much the same picture.

IIRC, it's through a dispatching system, right?

We have such a system for LoRA, but it's only responsible for deciding which type of LoRA layer to use to wrap the base layer.

@sayakpaul
Copy link
Member Author

It's not really a comparable situation, but I mean places like this:

Sorry for not making it clear. I think it has overlaps.

We have such a system for LoRA, but it's only responsible for deciding which type of LoRA layer to use to wrap the base layer.

So, in our case, it would be deciding which type of Linear layer to use based on the quantization scheme. I think we'd start with bitsandbytes first.

@BenjaminBossan
Copy link
Member

So, in our case, it would be deciding which type of Linear layer to use based on the quantization scheme. I think we'd start with bitsandbytes first.

I'd suggest not starting with a dispatching mechanism to keep things simple. In PEFT, we just got to a point where it became unwieldy, as we had a giant if...else in PEFT for each combination of quantization and layer type (linear, conv2d, etc.). As long as it's not that bad in diffusers, I'd suggest keeping it as is.

@chaewon-huh
Copy link

Hi @sayakpaul , @BenjaminBossan ,

Thanks for all the work on this! I was wondering if there’s an estimated timeline for supporting Flux Control LoRAs with bitsandbytes quantization.

Also, if I want to use it now, are there any specific changes I could make to the current code to get it working? Any guidance would be really helpful.

Appreciate your time!

@sayakpaul
Copy link
Member Author

@chaewon-huh thanks for your patience. I plan to start working on this very soon (apologies for the delay).

Also, if I want to use it now, are there any specific changes I could make to the current code to get it working? Any guidance would be really helpful.

The first point to address is using the appropriate layer when doing the expansion here:

expanded_module = torch.nn.Linear(

And then we have to consider dequantizing the params when expanding the LoRA state dicts here:

if base_module_shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_module_shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)

@chaewon-huh
Copy link

@sayakpaul

I’ve been investigating different approaches to enable Flux Control LoRA + bitsandbytes (4-bit) quantization without fully modifying lora_pipeline.py in the ways you described. Specifically:

Fusing LoRA weights first and then quantizing:

  • One idea is to load the model in float/half precision, perform a full LoRA fusion, and only then apply bitsandbytes quantization to the final fused model.
  • Would that be a viable workaround to avoid the shape expansion logic in _maybe_expand_transformer_param_shape_or_error_ and _maybe_expand_lora_state_dict?

Are any of these alternatives known to work reliably, or is modifying the bitsandbytes layer creation/expansion logic in lora_pipeline.py still the recommended path?

I’d be grateful for any guidance or suggestions you might have.

@sayakpaul
Copy link
Member Author

Well, fusion is definitely an option and if that works for you, please go ahead as it should already work out of the box.

But if you wanted to change the LoRA scale that won't be possible :/

@sayakpaul
Copy link
Member Author

@chaewon-huh could you give #10990 a test?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants