Skip to content

Latest commit

 

History

History
200 lines (140 loc) · 9.88 KB

File metadata and controls

200 lines (140 loc) · 9.88 KB

Model Optimization Using Quantization-Aware Training (QAT)

FMS Model Optimizer supports quantization of models which will enable the utilization of reduced-precision numerical format and specialized hardware to accelerate inference performance (i.e., make "calling a model" faster).

Generally speaking, matrix multiplication (matmul) is the main operation in a neural network. The goal of quantization is to convert a floating-point (FP) matmul into an integer (INT) matmul, which runs much faster and requires lower energy consumption. A simplified example would be:

$$X@W \approx \lfloor \frac{X}{s_x} \rceil @ \lfloor \frac{W}{s_w} \rceil*s_xs_w$$

  • where $X$, $W$ are FP tensors whose elements are all within a certain range, e.g. $[-5.0, 5.0]$, $@$ is matmul operation, $\lfloor \rceil$ is rounding operation, scaling factor $s_x, s_w$ in this case is simply $5/127$.
  • On the right hand side, after scaling and rounding the tensors will only contain integers in the range of $[-127, 127]$, which can be stored as a 8-bit integer.
  • We may now use an INT8 matmul instead of a FP32 matmul to perform the task then multiply the scaling factors afterward.
  • Important The benefit from INT matmul should outweigh the overhead from scaling, rounding, and descaling. But rounding will inevitably introduce approximation errors. Luckily, we can mitigate the errors by taking these quantization related operations into account during the training process, hence the Quantization-aware training (QAT)!

In the following example, we will first create a fine-tuned FP16 model, and then quantize this model from FP16 to INT8 using QAT. Once the model is tuned and QAT'ed, you can observe the accuracy and the acceleration at inference time of the model.

Requirements

  • FMS Model Optimizer requirements
  • The inferencing step requires Nvidia GPUs with compute capability > 8.0 (A100 family or higher)
  • NVIDIA cutlass package (Need to clone the source, not pip install). Preferably place in user's home directory: cd ~ && git clone https://github.com/NVIDIA/cutlass.git
  • Ninja
  • PyTorch 2.3.1 (as newer version will cause issue for the custom CUDA kernel)

QuickStart

Note

This example is based on the HuggingFace Transformers Question answering example.

There are three main steps to try out the example as follows:

1. Fine-tune a model with 16-bit floating point (FP16) precision:

export CUDA_VISIBLE_DEVICES=0

python run_qa_no_trainer_qat.py \
  --model_name_or_path google-bert/bert-base-uncased \
  --dataset_name squad \
  --per_device_train_batch_size 12 \
  --learning_rate 3e-5 \
  --num_train_epochs 2 \
  --max_seq_length 384 \
  --doc_stride 128 \
  --output_dir ./fp16_ft_squad/ \
  --with_tracking \
  --report_to tensorboard \
  --attn_impl eager

Tip

The script can take up to 40 mins to run (on a single A100). By default, it is configured for detailed logging. You can disable the logging by removing the with_tracking and report_to flags in the script. This can reduce the runtime by around 20 mins.

2. Apply QAT on the fine-tuned model, which converts the precision data to 8-bit integer (INT8):

python run_qa_no_trainer_qat.py \
  --model_name_or_path ./fp16_ft_squad/ \
  --dataset_name squad \
  --per_device_train_batch_size 12 \
  --learning_rate 3e-5 \
  --num_train_epochs 2 \
  --max_seq_length 384 \
  --doc_stride 128 \
  --output_dir ./qat_on_fp16ft \
  --with_tracking \
  --report_to tensorboard \
  --attn_impl eager \
  --do_qat \
  --pact_a_lr 1e-3

Tip

The script can take up to 1.5 hours to run (on a single A100). Remove with_tracking and report_to flags can reduce the runtime to about 40 mins.

3. Compare the accuracy and inference speed of 16-bit floating point (FP16) and 8-bit integer (INT8) precision models:

export TOKENIZERS_PARALLELISM=false

python run_qa_no_trainer_qat.py \
  --model_name_or_path ./qat_on_fp16ft/ \
  --dataset_name squad \
  --per_device_train_batch_size 128 \
  --per_device_eval_batch_size 128 \
  --max_seq_length 384 \
  --doc_stride 128 \
  --attn_impl eager \
  --do_lowering

This script uses an "external kernel" instead of the torch.matmul kernel to perform real INT8 matmuls. This kernel is written for Nvidia's CUDA/CUTLASS library and is compiled once just ahead of the run. The compiled artifacts are usually stored in ~/.cache/torch_extensions/. Remove this folder if a fresh recompile of the kernel is needed.

Checkout Example Test Results to compare against your results.

Example Test Results

For comparison purposes, here are some of the results we found during testing when tested with PyTorch 2.3.1:

Note

Accuracy could vary ~ +-0.2 from run to run.

model batch size torch.compile accuracy(F1) inference speed (msec)
fp16 128 eager 88.21 (as fine-tuned) 126.38
128 Inductor 71.59
128 CUDAGRAPH 71.13
INT8 128 eager 88.33 329.45 1
128 Inductor 88.42 67.87 2
128 CUDAGRAPH -- -- 3

1 INT8 matmuls are ~2x faster than FP16 matmuls. However, INT8 models will have additional overhead compared to FP16 models. For example, converting FP tensors to INT before INT matmul.

2 Each of these additional quantization operations is relatively 'cheap', but the overhead of launching each job is not negligible. Using torch.compile can fuse the Ops and reduce the total number of jobs being launched.

3 CUDAGRAPH is the most effective way to minimize job launching overheads and can achieve ~2X end-to-end speed-up in this case. However, there seem to be bugs associated with this option at the moment. Further investigation is still on-going.

Code Walk-through

In this section, we will deep dive into what happens during the example steps.

There are three parts to the example:

1. Fine-tune a model with 16-bit floating point (FP16) precision

Fine-tunes a BERT model on the question answering dataset, SQuAD. This step is based on the HuggingFace Transformers Question answering example. It was modified to collect additional training information in case we would like to tweak the hyper-parameters later.

2. Apply Quantization using QAT

For INT8 quantization, we can achieve comparable accuracy with FP16 by using quantization-aware training (QAT) or post-training quantization (PTQ) techniques. In this example we use QAT.

In a nutshell, QAT simply quantizes the weight and activation tensors before matrix multiplications (matmul) so that quantization errors will be taken into account during the training/loss optimization process. The code below is an example of preparing a model for QAT quantization prior to fine tuning:

from fms_mo import qmodel_prep, qconfig_init

# Create a config dict using a default recipe and CLI args
# If same item exists in both, args take precedence over recipe.
qcfg = qconfig_init(recipe = 'qat_int8', args=args)

# Prepare a list of "ready-to-run" data for calibration
exam_inp = [next(iter(train_dataloader)) for _ in range(qcfg['qmodel_calibration']) ]

logger.info(f"--- Accuracy of {args.model_name_or_path} before QAT/PTQ")
squad_eval(model) # This is a fn modified from original script that checks accuracy

qmodel_prep(model, exam_inp, qcfg, optimizer, use_dynamo = True)

# (Continue to original fine-tuning script)
...

The resulting model is saved in the qat_on_fp16ft folder. Be aware that the weights are now different from the original FP16 checkpoint in Step 1, but not yet converted to real INT8!

3. Evaluate Inference Accuracy and Speed

Note

This step will compile an external kernel for INT matmul, which currently only works with PyTorch 2.3.1.

Here is an example code snippet used for evaluation:

from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
# ...

# Only need 1 batch (not a list) this time, will be used by `torch.compile` as well.
exam_inp = next(iter(train_dataloader))

qcfg = qconfig_init(recipe = 'qat_int8', args=args)
qcfg['qmodel_calibration'] = 0 # <----------- NOTE 1
qmodel_prep(model, exam_inp, qcfg, optimizer, use_dynamo = True,
            ckpt_reload=args.model_name_or_path) # <----------- NOTE 2

# ----------- NOTE 3
mod2swap = [n for n,m in model.named_modules() if isinstance(m, QLinear)]
for name in mod2swap:
    parent_name, module_name = _parent_name(name)
    parent_mod = model.get_submodule(parent_name)
    qmod = getattr(parent_mod, module_name)
    setattr(parent_mod, module_name, QLinearINT8Deploy.from_fms_mo(qmod))

# ...

with torch.no_grad():
    model = torch.compile(model) #, mode='reduce-overhead') # <----- NOTE 4
    model(**exam_inp)

# ...

return # Stop the run here, no further training loop

In this example:

  • By default, QAT will run calibration to initialize the quantization related parameters (with a small number of training data). At the end of QAT, these parameters are saved with the checkpoint, as we DO NOT want to run calibration at deployment stage. Hence, qcfg['qmodel_calibration'] = 0.
  • Quantization related parameters will not be automatically loaded by the HuggingFace method, as those are not part of the original BERT model. Hence calling qmodel_prep(..., ckpt_reload=[path to qat ckpt]).
  • By replacing QLinear layers with QLinearINT8Deploy, it will call the external kernel instead of torch.matmul.
  • torch.compile with reduce-overhead option will use CUDAGRAPH and achieve the most ideal speed-up. However, some models may not be fully compatible with this option.