Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow compressed-tensors quantized model to be trained #34520

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,19 @@ def _process_model_before_weight_loading(self, model, **kwargs):
ct_quantization_config = self.compressor.quantization_config
apply_quantization_config(model, ct_quantization_config, run_compressed=True)

def _process_model_after_weight_loading(self, model, **kwargs):
def _process_model_after_weight_loading(self, model, **kwargs) -> None:
pass

@property
def is_trainable(self):
return False
def is_trainable(self) -> bool:
"""Models quantized using compressed tensors can be finetuned"""
return True

def is_serializable(self, safe_serialization=None):
return False
@property
def is_qat_trainable(self) -> bool:
"""Loaded Models can carry out quantization aware training"""
return True

def is_serializable(self, safe_serialization=None) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""
return True
5 changes: 5 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,13 +521,18 @@ def __init__(
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)

_is_model_quantized_and_trainable = getattr(model, "hf_quantizer", None) is not None and getattr(
model, "hf_quantizer", False
)

# Filter out quantized + compiled models
if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
raise ValueError(
"You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
)

# At this stage the model is already loaded
# if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_trainable:
if _is_quantized_and_base_model and not _is_peft_model(model):
raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
Expand Down
22 changes: 16 additions & 6 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,8 @@ class CompressedTensorsConfig(QuantizationConfigMixin):
do not override, should be compressed-tensors
"""

QUANTIZATION_NAME = "compressed-tensors"

def __init__(
self,
config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, # noqa: F821
Expand Down Expand Up @@ -1150,6 +1152,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
Returns:
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
"""

if "quantization_config" in config_dict:
config_dict = dict(
sparsity_config=config_dict.get("sparsity_config"),
Expand All @@ -1160,16 +1163,23 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):

def to_dict(self) -> Dict[str, Any]:
"""
Quantization config to be added to config.json

Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
quantization_config = self.quantization_config.dict() if self.quantization_config is not None else None
sparsity_config = self.sparsity_config.dict() if self.sparsity_config is not None else None
quantization_config = {}
if self.quantization_config is not None:
quantization_config = self.quantization_config.dict()
else:
quantization_config["quant_method"] = self.QUANTIZATION_NAME

return {
"quantization_config": quantization_config,
"sparsity_config": sparsity_config,
}
if self.sparsity_config is not None:
quantization_config["sparsity_config"] = self.sparsity_config.dict()
else:
quantization_config["sparsity_config"] = {}

return quantization_config

def to_diff_dict(self) -> Dict[str, Any]:
"""
Expand Down