Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchao weights only compability #34355

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,50 @@ def validate_environment(self, *args, **kwargs):
)
else:
self.offload = True
if self.pre_quantized:
safe_globals = []
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we do import torchao, I think we should get everything here (classes etc. being added to safeglobals)? otherwise we'd need to fix torchao

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using torchao 0.5.0 and it's not working on my side. I can try with the latest tomorrow !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it's not expected I think, I think it should be fixed in torchao side, I feel 0.5 should have this functionality already actually. if you can have a standalone repro that will be very helpful for us. I remember I have tested in https://huggingface.co/docs/transformers/main/en/quantization/torchao

Copy link
Member Author

@SunMarc SunMarc Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we ran into this issue with @MekkCyber on the example you shared in the docs.
Here's a the reproducer, let us know if you also have this issue :

from transformers import TorchAoConfig, AutoTokenizer, AutoModelForCausalLM
import torch

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
    quantization_config=quant_config,
)
output_dir = "llama3-8b-int4wo-128"
quantized_model.save_pretrained(output_dir, safe_serialization=False)

loaded_quantized_model = AutoModelForCausalLM.from_pretrained(output_dir, device_map="cuda:0")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK will test and report back

if self.quantization_config.quant_type == "int4_weight_only":
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
TensorCoreTiledAQTLayout,
TensorCoreTiledLayoutType,
ZeroPointDomain,
)

safe_globals += [
AffineQuantizedTensor,
TensorCoreTiledAQTLayout,
TensorCoreTiledLayoutType,
ZeroPointDomain,
]
elif self.quantization_config.quant_type == "int8_weight_only":
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
PlainAQTLayout,
PlainLayoutType,
ZeroPointDomain,
)

safe_globals += [PlainAQTLayout, AffineQuantizedTensor, PlainLayoutType, ZeroPointDomain]
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
elif self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight":
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
PlainAQTLayout,
PlainLayoutType,
ZeroPointDomain,
)
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.quantization.quant_api import _int8_symm_per_token_reduced_range_quant

safe_globals += [
LinearActivationQuantizedTensor,
AffineQuantizedTensor,
PlainAQTLayout,
PlainLayoutType,
ZeroPointDomain,
_int8_symm_per_token_reduced_range_quant,
]
torch.serialization.add_safe_globals(safe_globals)

def update_torch_dtype(self, torch_dtype):
if self.quantization_config.quant_type == "int4_weight_only":
Expand All @@ -85,6 +129,10 @@ def update_torch_dtype(self, torch_dtype):
"Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning."
)
torch_dtype = torch.bfloat16
if self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight":
if torch_dtype is None:
# we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
torch_dtype = torch.float32
return torch_dtype

def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
Expand Down Expand Up @@ -172,6 +220,12 @@ def is_serializable(self, safe_serialization=None):
)
if not _is_torchao_serializable:
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
if self.offload and self.quantization_config.modules_to_not_convert is None:
logger.warning(
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
)
return False
return _is_torchao_serializable

@property
Expand Down
56 changes: 56 additions & 0 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import gc
import tempfile
import unittest

from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
Expand Down Expand Up @@ -209,5 +210,60 @@ def test_int4wo_offload(self):
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)


@require_torch_gpu
@require_torchao
class TorchAoSerializationTest(unittest.TestCase):
input_text = "What are we having for dinner?"
max_new_tokens = 10
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_config = TorchAoConfig("int4_weight_only", group_size=32)

# called only once for all test in this class
@classmethod
def setUpClass(cls):
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
torch_dtype=torch.bfloat16,
device_map="cuda:0",
quantization_config=cls.quant_config,
)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()

def test_original_model_expected_output(self):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)

self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_serialization_weight_only(self):
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, torch_dtype=torch.bfloat16, device_map="cuda:0"
)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
# TODO: investigate why we don't have the same output as the original model
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)


class TorchAoSerializationW8A8Test(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight")
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SunMarc marked this conversation as resolved.
Show resolved Hide resolved


class TorchAoSerializationW8Test(TorchAoSerializationTest):
quant_config = TorchAoConfig("int8_weight_only")
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"


if __name__ == "__main__":
unittest.main()
Loading