-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux quantized with lora #10990
base: main
Are you sure you want to change the base?
Flux quantized with lora #10990
Changes from 1 commit
36eb48b
695ad14
bc912fc
f950380
67bc7c0
316d52f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import torch | ||
from huggingface_hub.utils import validate_hf_hub_args | ||
|
||
from ..quantizers.bitsandbytes import dequantize_bnb_weight | ||
from ..utils import ( | ||
USE_PEFT_BACKEND, | ||
deprecate, | ||
|
@@ -1905,7 +1906,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False): | |
|
||
for name, module in transformer.named_modules(): | ||
if isinstance(module, torch.nn.Linear) and name in module_names: | ||
module_weight = module.weight.data | ||
module_bias = module.bias.data if module.bias is not None else None | ||
bias = module_bias is not None | ||
|
||
|
@@ -1919,7 +1919,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False): | |
in_features, | ||
out_features, | ||
bias=bias, | ||
dtype=module_weight.dtype, | ||
) | ||
|
||
tmp_state_dict = {"weight": current_param_weight} | ||
|
@@ -1970,7 +1969,11 @@ def _maybe_expand_transformer_param_shape_or_error_( | |
is_peft_loaded = getattr(transformer, "peft_config", None) is not None | ||
for name, module in transformer.named_modules(): | ||
if isinstance(module, torch.nn.Linear): | ||
module_weight = module.weight.data | ||
module_weight = ( | ||
dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data | ||
if module.weight.__class__.__name__ == "Params4bit" | ||
else module.weight.data | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks good. But to be extra cautious, I would first determine if not is_bistandbytes_available():
raise ....
else:
.... Also I think the data device needs to be on a device where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the guard here |
||
module_bias = module.bias.data if module.bias is not None else None | ||
bias = module_bias is not None | ||
|
||
|
@@ -1994,7 +1997,7 @@ def _maybe_expand_transformer_param_shape_or_error_( | |
|
||
# TODO (sayakpaul): We still need to consider if the module we're expanding is | ||
# quantized and handle it accordingly if that is the case. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @sayakpaul should be ok to remove these TODO now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pleae go ahead and remove. |
||
module_out_features, module_in_features = module_weight.shape | ||
module_out_features, module_in_features = module_weight_shape | ||
debug_message = "" | ||
if in_features > module_in_features: | ||
debug_message += ( | ||
|
@@ -2018,17 +2021,13 @@ def _maybe_expand_transformer_param_shape_or_error_( | |
parent_module = transformer.get_submodule(parent_module_name) | ||
|
||
with torch.device("meta"): | ||
expanded_module = torch.nn.Linear( | ||
in_features, out_features, bias=bias, dtype=module_weight.dtype | ||
) | ||
expanded_module = torch.nn.Linear(in_features, out_features, bias=bias) | ||
# Only weights are expanded and biases are not. This is because only the input dimensions | ||
# are changed while the output dimensions remain the same. The shape of the weight tensor | ||
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which | ||
# explains the reason why only weights are expanded. | ||
new_weight = torch.zeros_like( | ||
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype | ||
) | ||
slices = tuple(slice(0, dim) for dim in module_weight.shape) | ||
new_weight = torch.zeros_like(expanded_module.weight.data) | ||
slices = tuple(slice(0, dim) for dim in module_weight_shape) | ||
new_weight[slices] = module_weight | ||
tmp_state_dict = {"weight": new_weight} | ||
if module_bias is not None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be guarded with
is_bitsandbytes_available()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard added, it is guarded internally to
quantizers.bitsandbytes
so would also be ok without