Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux quantized with lora #10990

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from huggingface_hub.utils import validate_hf_hub_args

from ..quantizers.bitsandbytes import dequantize_bnb_weight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be guarded with is_bitsandbytes_available()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard added, it is guarded internally to quantizers.bitsandbytes so would also be ok without

from ..utils import (
USE_PEFT_BACKEND,
deprecate,
Expand Down Expand Up @@ -1905,7 +1906,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):

for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear) and name in module_names:
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None

Expand All @@ -1919,7 +1919,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
in_features,
out_features,
bias=bias,
dtype=module_weight.dtype,
)

tmp_state_dict = {"weight": current_param_weight}
Expand Down Expand Up @@ -1970,7 +1969,11 @@ def _maybe_expand_transformer_param_shape_or_error_(
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
module_weight = (
dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data
if module.weight.__class__.__name__ == "Params4bit"
else module.weight.data
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. But to be extra cautious, I would first determine if module.weight.__class__.__name__ == "Params4bit". Then do

if not is_bistandbytes_available():
    raise ....
else:
   ....

Also I think the data device needs to be on a device where bitsanbdytes is known to be supported. Otherwise, its won't be able to dequantize.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the guard here

module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None

Expand All @@ -1994,7 +1997,7 @@ def _maybe_expand_transformer_param_shape_or_error_(

# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @sayakpaul should be ok to remove these TODO now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pleae go ahead and remove.

module_out_features, module_in_features = module_weight.shape
module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
Expand All @@ -2018,17 +2021,13 @@ def _maybe_expand_transformer_param_shape_or_error_(
parent_module = transformer.get_submodule(parent_module_name)

with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
)
expanded_module = torch.nn.Linear(in_features, out_features, bias=bias)
# Only weights are expanded and biases are not. This is because only the input dimensions
# are changed while the output dimensions remain the same. The shape of the weight tensor
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
# explains the reason why only weights are expanded.
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
slices = tuple(slice(0, dim) for dim in module_weight.shape)
new_weight = torch.zeros_like(expanded_module.weight.data)
slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
Expand Down
Loading