Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add FP8Config #1442

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
MixedPrecisionConfig,
BitsAndBytesConfig,
SmoothQuantConfig,
FP8Config,
RtnConfig,
AwqConfig,
TeqConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
BitsAndBytesConfig,
MixedPrecisionConfig,
SmoothQuantConfig,
FP8Config,
RtnConfig,
AwqConfig,
TeqConfig,
Expand Down Expand Up @@ -71,6 +72,8 @@
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear
from neural_compressor.torch.quantization import quantize
from neural_compressor.torch.quantization import FP8Config as INCFP8Config
from threading import Thread
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig
Expand Down Expand Up @@ -355,6 +358,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
device_map = kwargs.get("device_map", "cpu")
use_cpu = True if device_map == torch.device("cpu") or device_map == "cpu" else False
use_xpu = True if device_map == torch.device("xpu") or device_map == "xpu" else False
use_xpu = True if device_map == torch.device("hpu") or device_map == "hpu" else False

config = kwargs.pop("config", None)
model_hub = kwargs.pop("model_hub", "huggingface")
Expand All @@ -374,20 +378,20 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

quantization_config = kwargs.pop("quantization_config", None)
if kwargs.get("use_llm_runtime", None) is not None:
use_neural_speed = kwargs.pop("use_llm_runtime", True) and not use_xpu
use_neural_speed = kwargs.pop("use_llm_runtime", True) and not use_xpu and not use_hpu
logger.warning(
"use_llm_runtime is deprecated in version 1.3.2, please use_neural_speed instead."
)
elif kwargs.get("use_neural_speed", None) is not None:
use_neural_speed = kwargs.pop("use_neural_speed", True) and not use_xpu
use_neural_speed = kwargs.pop("use_neural_speed", True) and not use_xpu and not use_hpu
else:
if hasattr(config, "model_type") == False:
logger.error(
"Can't get the model_type. Please check the correct model_type"
)
exit(0)

if config.model_type in cls.model_type_list and not use_xpu:
if config.model_type in cls.model_type_list and not use_xpu and not use_hpu:
if (
isinstance(quantization_config, GPTQConfig)
and config.model_type not in cls.model_type_list_for_gptq
Expand Down Expand Up @@ -446,6 +450,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
and is_bitsandbytes_available()
and not use_cpu
and not use_xpu
and not use_hpu
):
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
Expand Down Expand Up @@ -644,6 +649,68 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

model.save_pretrained = types.MethodType(save_low_bit, model)
logger.info("WeightOnlyQuant done.")
elif isinstance(quantization_config, FP8Config) and use_hpu:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
)
if quantization_config.approach == "dynamic":
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
model = quantize_dynamic(model, quantization_config.precision, inplace=True)
elif quantization_config.approach == "static":
qconfig = INCFP8Config(w_dtype=quantization_config.precision, act_dtype=quantization_config.precision, approach="static")
if quantization_config.skip_lm_head:
fp32_config = INCFP8Config(w_dtype="fp32", act_dtype="fp32")
qconfig.set_local("lm_head", fp32_config)

# calibration function
calib_func = quantization_config.calib_func
tokenizer = quantization_config.tokenizer
if calib_func is None:
if quantization_config.tokenizer is None:
logger.error(
"Please provide the tokenizer or provide calib_func directly,"
+ " the following is how to get tokenizer. \n"
+ " from transformer import AutoTokenizer \n"
+ " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) \n"
)
exit(0)

calib_dataset = quantization_config.calib_dataset
calib_shuffle = quantization_config.calib_shuffle
calib_iters = quantization_config.calib_iters
calib_padding = quantization_config.calib_padding
calib_len = quantization_config.calib_len

# dataset
from datasets import load_dataset
calib_dataset = load_dataset(calib_dataset, split="train").select(range(100))
if calib_shuffle:
calib_dataset = calib_dataset.shuffle(seed=42)
calib_data = []
for examples in calib_dataset:
calib_data.append(
tokenizer(
examples["text"],
return_tensors="pt",
max_length=calib_len,
padding="max_length",
truncation=True
)
)

def calib_func(model):
for i, calib_input in enumerate(calib_data):
if i >= calib_iters:
break
model(
input_ids=calib_input["input_ids"].to('hpu'),
attention_mask=calib_input["attention_mask"].to('hpu'),
)
calib_func = calib_func
model = quantize(model, qconfig, calib_func, inplace=True)
logger.info("FP8 Quantization done.")
elif isinstance(quantization_config, SmoothQuantConfig):
try:
import intel_extension_for_pytorch as ipex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
MixedPrecisionConfig,
BitsAndBytesConfig,
SmoothQuantConfig,
FP8Config,
SparsityConfig,
RtnConfig,
AwqConfig,
Expand Down
76 changes: 76 additions & 0 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,82 @@ def get_config_dict(
)


class FP8Config(ITREXQuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `auto-awq` library awq quantization relying on auto_awq backend.

Args:
precision (`str`, *optional*, defaults to fp8_e4m3):
The data type of weight and activation.
approach (`str`, *optional*, defaults to static):
The approach for quantization.
"""

def __init__(
self,
precision: str = "fp8_e4m3",
approach: str = "static",
**kwargs,
):
self.precision = precision
self.approach = approach
self.device = kwargs.get("device", "hpu")
self.calib_dataloader = kwargs.get("calib_dataloader", None)
self.calib_dataset = kwargs.get("calib_dataset", "NeelNanda/pile-10k")
self.calib_func = kwargs.get("calib_func", None)
self.calib_padding = kwargs.get("calib_padding", False)
self.calib_len = kwargs.get("calib_len", 64)
self.calib_shuffle = kwargs.get("calib_shuffle", True)
self.calib_iters = kwargs.get("calib_iters", 100)
self.skip_lm_head = kwargs.get("skip_lm_head", False)
self.tokenizer = kwargs.get("tokenizer", None)
self.post_init_fp8()

def post_init_fp8(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if self.precision is not None and self.precision not in ["fp8_e5m2", "fp8_e4m3"]:
raise ValueError("precision must be in ['fp8_e5m2', 'fp8_e4m3'].")
elif self.precision is None:
self.precision = "fp8_e4m3"

if self.approach is None:
self.approach = "static"
elif self.approach not in ["static", "dynamic"]:
raise ValueError(
f"Only support 'static' and 'dynamic' approach but found {self.approach}"
)

if self.device is not None and self.device not in ["hpu", torch.device("hpu")]:
raise ValueError(f"Only support hpu device but found {self.device}")
elif self.device is None:
self.device = "hpu"

def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.

Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict = self.to_dict()

# get the default config dict
default_config_dict = FP8Config().to_dict()

serializable_config_dict = {}

# only serialize values that differ from the default config
for key, value in config_dict.items():
if value != default_config_dict[key]:
serializable_config_dict[key] = value

return serializable_config_dict


class RtnConfig(ITREXQuantizationConfigMixin):
def __init__(
self,
Expand Down
Loading