Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend <modelopt.torch.quantization.nn.modules.tensor_quantizer.register_quant_backend>`` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``.
- Add ``examples/llm_qad`` for QAD training with Megatron-LM.

**Bug Fixes**

- Synchronize MSE calibration amax across distributed groups (DP/EP/TP) to keep quantization parameters consistent.

**Deprecations**

- Deprecate ``num_query_groups`` parameter in Minitron pruning (``mcore_minitron``). You can use ModelOpt 0.40.0 or earlier instead if you need to prune it.
Expand Down
96 changes: 94 additions & 2 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
# TODO: create sync_bias_across_distributed_group
quantizer.sync_bias_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_bias_across_distributed_group(parallel_state.expert_model_parallel_group)

for name, module in model.named_modules():
if isinstance(module, QuantModule):
Expand Down Expand Up @@ -266,7 +267,98 @@ def quant_func(x, amax, quantizer=module):
# Step 4: Compute optimal amax and load it
finish_stats_collection(model, method="mse")

# TODO: Sync amax across distributed processes
if not distributed_sync:
return

def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp_ep(_q, parallel_state)
return
if getattr(quantizer, "_amax", None) is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
quantizer.sync_bias_across_distributed_group(parallel_state.data_parallel_group)
quantizer.sync_bias_across_distributed_group(parallel_state.expert_model_parallel_group)

for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)

def sync_quantizer_amax_across_tp(
quantizer: TensorQuantizer | SequentialQuantizer,
linear_name: str,
quantizer_type: str,
axes_for_sync: list,
parallel_state: ParallelState,
):
# Syncing amax across TP for sequential quantizer
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
# Syncing amax across TP for sequential quantizer
sync_quantizer_amax_across_tp(
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
)
return
# sync is not needed for block quantization
if quantizer.block_sizes is not None:
if hasattr(quantizer, "_padding"):
warnings.warn(
f"Found block-quantized padded {quantizer_type} for {linear_name}, amax will"
" not be synced correctly."
)
# Skip amax sync for INT4 / W4A8 block quantization
# Sync amax for NVFP4 (dynamic per-block, static per-tensor quantized scale)
if getattr(quantizer.block_sizes, "type", None) == "dynamic":
return

if quantizer.axis in axes_for_sync and quantizer.amax is not None:
quantizer.sync_amax_across_distributed_group(parallel_state.tensor_parallel_group)

for name, module in model.named_modules():
if getattr(module, "_parallel_state", None) is None:
continue

if is_quantized_column_parallel_linear(module):
sync_quantizer_amax_across_tp(
module.input_quantizer,
name,
"input_quantizer",
axes_for_sync=[None, -1],
parallel_state=module.parallel_state,
)

sync_quantizer_amax_across_tp(
module.weight_quantizer,
name,
"weight_quantizer",
axes_for_sync=[None, -1],
parallel_state=module.parallel_state,
)

if is_quantized_row_parallel_linear(module):
sync_quantizer_amax_across_tp(
module.input_quantizer,
name,
"input_quantizer",
axes_for_sync=[None],
parallel_state=module.parallel_state,
)

sync_quantizer_amax_across_tp(
module.weight_quantizer,
name,
"weight_quantizer",
axes_for_sync=[None, 0],
parallel_state=module.parallel_state,
)

for name, module in model.named_modules():
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()


def enable_stats_collection(model: nn.Module):
Expand Down
48 changes: 48 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,54 @@ def sync_amax_across_distributed_group(self, parallel_group: DistributedProcessG
"if happening during modelopt restore."
)

def sync_bias_across_distributed_group(self, parallel_group: DistributedProcessGroup):
"""Synchronize the bias across all ranks in the given group."""
if not parallel_group.is_initialized():
return
if self.bias_calibrator is None or self.bias_type != "static":
return

bias = self.bias_calibrator.compute_bias()
if bias is None:
return

try:
if self.bias_method == "mean":
cnt = float(getattr(self.bias_calibrator, "_cnt", 0))
bias_sum = bias.float() * cnt
cnt_tensor = torch.tensor(cnt, device=bias_sum.device, dtype=bias_sum.dtype)
dist.all_reduce(bias_sum, op=dist.ReduceOp.SUM, group=parallel_group.group)
dist.all_reduce(cnt_tensor, op=dist.ReduceOp.SUM, group=parallel_group.group)
if cnt_tensor.item() > 0:
bias_avg = (bias_sum / cnt_tensor).to(bias.dtype)
else:
bias_avg = bias
self.bias_value = bias_avg
self.bias_calibrator._calib_bias = bias_avg.detach().clone()
self.bias_calibrator._cnt = int(cnt_tensor.item())
elif self.bias_method == "max_min":
calib_max = getattr(self.bias_calibrator, "_calib_max", None)
calib_min = getattr(self.bias_calibrator, "_calib_min", None)
if calib_max is None:
calib_max = torch.full_like(bias, -float("inf"))
if calib_min is None:
calib_min = torch.full_like(bias, float("inf"))
dist.all_reduce(calib_max, op=dist.ReduceOp.MAX, group=parallel_group.group)
dist.all_reduce(calib_min, op=dist.ReduceOp.MIN, group=parallel_group.group)
bias_val = ((calib_max + calib_min) / 2).to(bias.dtype)
self.bias_value = bias_val
self.bias_calibrator._calib_max = calib_max.detach().clone()
self.bias_calibrator._calib_min = calib_min.detach().clone()
self.bias_calibrator._calib_bias = bias_val.detach().clone()
else:
warnings.warn(f"Unsupported bias method: {self.bias_method}; skipping bias sync.")
except RuntimeError as e:
warnings.warn(
f"Failed to synchronize bias: {e}, probably because the tensor is on a device which is not"
"supported by the current distributed backend. This warning can be ignored"
"if happening during modelopt restore."
)

@contextlib.contextmanager
def disable_pre_quant_scale(self):
"""Context manager to turn off pre_quant_scale inside this quantizer."""
Expand Down
205 changes: 205 additions & 0 deletions tests/gpu/torch/quantization/test_mse_calibrate_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from functools import partial

import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from _test_utils.torch.distributed.utils import get_device_counts, spawn_multiprocess_job

import modelopt.torch.quantization as mtq


def _test_mse_calibrate_sync(distributed_sync: bool, rank: int, size: int) -> None:
model = nn.Sequential(nn.Linear(16, 16), nn.ReLU(), nn.Linear(16, 16)).cuda()

config = copy.deepcopy(mtq.INT8_DEFAULT_CFG)
config["algorithm"] = {
"method": "mse",
"num_steps": 16,
"start_multiplier": 0.001,
"stop_multiplier": 4.0,
"distributed_sync": distributed_sync,
}

def forward_loop(model):
torch.manual_seed(1234 + rank)
scale = 1.0 if rank == 0 else 100.0
for _ in range(4):
model(torch.randn(64, 16, device="cuda") * scale)

model = mtq.quantize(model, config, forward_loop)

target = next(module for module in model.modules() if hasattr(module, "input_quantizer"))
amax_val = target.input_quantizer.amax.detach().float().max()

gathered = [torch.zeros_like(amax_val) for _ in range(size)]
dist.all_gather(gathered, amax_val)

if size < 2 or rank != 0:
return

values = torch.stack(gathered)
if distributed_sync:
assert torch.allclose(values, values[0], rtol=0, atol=0), (
"Expected amax values to be synchronized across ranks, but got "
f"{values.tolist()}"
)
else:
assert (values.max() - values.min()) > 10.0, (
"Expected amax values to differ across ranks when sync is disabled, but got "
f"{values.tolist()}"
)


def _test_mse_calibrate_bias_sync(distributed_sync: bool, rank: int, size: int) -> None:
for bias_method in ["mean", "max_min"]:
model = nn.Sequential(nn.Linear(16, 16), nn.ReLU(), nn.Linear(16, 16)).cuda()

config = copy.deepcopy(mtq.INT8_DEFAULT_CFG)
config["quant_cfg"]["*input_quantizer"]["bias"] = {
0: None,
"type": "static",
"method": bias_method,
}
config["algorithm"] = {
"method": "mse",
"num_steps": 16,
"start_multiplier": 0.001,
"stop_multiplier": 4.0,
"distributed_sync": distributed_sync,
}

def forward_loop(model):
torch.manual_seed(4321 + rank)
offset = 0.0 if rank == 0 else 10.0
for _ in range(4):
model(torch.randn(64, 16, device="cuda") * 0.1 + offset)

model = mtq.quantize(model, config, forward_loop)

target = next(module for module in model.modules() if hasattr(module, "input_quantizer"))
bias_val = target.input_quantizer.bias_value.detach().float().mean()

gathered = [torch.zeros_like(bias_val) for _ in range(size)]
dist.all_gather(gathered, bias_val)

if size < 2 or rank != 0:
continue

values = torch.stack(gathered)
if distributed_sync:
assert torch.allclose(values, values[0], rtol=0, atol=0), (
f"Expected bias values to be synchronized across ranks for {bias_method}, but got "
f"{values.tolist()}"
)
else:
assert (values.max() - values.min()) > 5.0, (
f"Expected bias values to differ across ranks for {bias_method} when sync is disabled, "
f"but got {values.tolist()}"
)


def _test_max_calibrate_bias_sync(distributed_sync: bool, rank: int, size: int) -> None:
for bias_method in ["mean", "max_min"]:
model = nn.Sequential(nn.Linear(16, 16), nn.ReLU(), nn.Linear(16, 16)).cuda()

config = copy.deepcopy(mtq.INT8_DEFAULT_CFG)
config["quant_cfg"]["*input_quantizer"]["bias"] = {
0: None,
"type": "static",
"method": bias_method,
}
config["algorithm"] = {"method": "max", "distributed_sync": distributed_sync}

def forward_loop(model):
torch.manual_seed(9876 + rank)
offset = 0.0 if rank == 0 else 10.0
for _ in range(4):
model(torch.randn(64, 16, device="cuda") * 0.1 + offset)

model = mtq.quantize(model, config, forward_loop)

target = next(module for module in model.modules() if hasattr(module, "input_quantizer"))
bias_val = target.input_quantizer.bias_value.detach().float().mean()

gathered = [torch.zeros_like(bias_val) for _ in range(size)]
dist.all_gather(gathered, bias_val)

if size < 2 or rank != 0:
continue

values = torch.stack(gathered)
if distributed_sync:
assert torch.allclose(values, values[0], rtol=0, atol=0), (
f"Expected bias values to be synchronized across ranks for {bias_method}, but got "
f"{values.tolist()}"
)
else:
assert (values.max() - values.min()) > 5.0, (
f"Expected bias values to differ across ranks for {bias_method} when sync is disabled, "
f"but got {values.tolist()}"
)


@pytest.mark.parametrize("device_count", get_device_counts())
def test_mse_calibrate_with_sync(device_count):
spawn_multiprocess_job(
size=device_count, job=partial(_test_mse_calibrate_sync, True), backend="nccl"
)


@pytest.mark.parametrize("device_count", get_device_counts())
def test_mse_calibrate_without_sync(device_count):
if device_count < 2:
pytest.skip("need 2 GPUs")
spawn_multiprocess_job(
size=device_count, job=partial(_test_mse_calibrate_sync, False), backend="nccl"
)


@pytest.mark.parametrize("device_count", get_device_counts())
def test_mse_calibrate_bias_with_sync(device_count):
spawn_multiprocess_job(
size=device_count, job=partial(_test_mse_calibrate_bias_sync, True), backend="nccl"
)


@pytest.mark.parametrize("device_count", get_device_counts())
def test_mse_calibrate_bias_without_sync(device_count):
if device_count < 2:
pytest.skip("need 2 GPUs")
spawn_multiprocess_job(
size=device_count, job=partial(_test_mse_calibrate_bias_sync, False), backend="nccl"
)


@pytest.mark.parametrize("device_count", get_device_counts())
def test_max_calibrate_bias_with_sync(device_count):
spawn_multiprocess_job(
size=device_count, job=partial(_test_max_calibrate_bias_sync, True), backend="nccl"
)


@pytest.mark.parametrize("device_count", get_device_counts())
def test_max_calibrate_bias_without_sync(device_count):
if device_count < 2:
pytest.skip("need 2 GPUs")
spawn_multiprocess_job(
size=device_count, job=partial(_test_max_calibrate_bias_sync, False), backend="nccl"
)