Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
_replace_with_custom_fn_if_matches_filter,
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
dequantize_affine,
)
from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -1123,6 +1125,27 @@ def test_dynamic_quant(self):
# self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear))


class TestStaticQuant(unittest.TestCase):
def test_static_quant(self):
M, K, N = 8, 16, 8
x = torch.randn(M, K)
m = nn.Sequential(nn.Linear(K, N))
block_size = [M, K] # per-tensor quantization
scale, _ = choose_qparams_affine(
x,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=torch.int8,
)

y_ref = m(x)
quantize_(m, Int8StaticActivationInt8WeightConfig(scale))
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
self.assertGreater(sqnr, 40.0)


class TestWeightOnlyInt8Quant(unittest.TestCase):
def test_weight_only_quant(self):
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
Expand Down
133 changes: 98 additions & 35 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
)
from torchao.prototype.smoothquant.core import SmoothQuantStep
from torchao.quantization import quantize_
from torchao.quantization.linear_activation_scale import (
WeightTensorWithLinearActivationScaleMetadata,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
)
from torchao.quantization.utils import (
compute_error as SQNR,
)


Expand All @@ -34,16 +41,19 @@ def example_inputs(
dtype=torch.bfloat16,
device="cuda",
):
return [
torch.randn(
1,
sequence_length,
self.linear1.in_features,
dtype=dtype,
device=device,
)
for j in range(batch_size)
]
# For SmoothQuant tests, we intentionally insert some outliers to input features
x = torch.randn(
batch_size,
sequence_length,
self.linear1.in_features,
dtype=dtype,
device=device,
)
n_outliers = max(1, int(x.size(-1) * 0.1))
# Randomly select outlier features
outlier_indices = torch.randperm(x.size(-1))[:n_outliers]
x[:, :, outlier_indices] *= 10.0
return (x,)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -52,7 +62,9 @@ def forward(self, x):
return x


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


@unittest.skipIf(torch.version.hip is not None, "Skipping tests in ROCm")
class TestSmoothQuant(unittest.TestCase):
"""SmoothQuant tests using only supported quantization configs."""
Expand All @@ -72,37 +84,27 @@ def setUpClass(cls):
# TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
],
)
@common_utils.parametrize("device", ["cpu", "cuda"])
@common_utils.parametrize("device", device_list)
@common_utils.parametrize("input_dtype", [torch.bfloat16])
def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
def test_smoothquant_dynamic_act_accuracy(
self, alpha, base_config, device, input_dtype
):
"""Test if SmoothQuant achieves lower loss than basic quantization."""
in_features = 64
out_features = 128

# Note: This is sanity check. For real run, consider Transformer model to reproduce.
X = torch.randn(16, in_features, dtype=input_dtype, device=device)
W = torch.randn(out_features, in_features, dtype=input_dtype, device=device)

# Create linear layer
linear = (
torch.nn.Linear(in_features, out_features, bias=False)
.to(device)
.to(input_dtype)
)
with torch.no_grad():
linear.weight.copy_(W)
m = ToyLinearModel().eval().to(device).to(input_dtype)
x = m.example_inputs(batch_size=16, dtype=input_dtype, device=device)

# Reference output
out_ref = linear(X)
out_ref = m(*x)

# Step 1. Basic quantization
basic_model = deepcopy(linear)
basic_model = deepcopy(m)
quantize_(basic_model, base_config)
out_basic = basic_model(X)
out_basic = basic_model(*x)
loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item()

# SmoothQuant quantization
model = deepcopy(linear)
# Step 2. SmoothQuant
model = deepcopy(m)
config = SmoothQuantConfig(
base_config=base_config,
step=SmoothQuantStep.PREPARE,
Expand All @@ -111,23 +113,83 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
quantize_(model, config)

# Perform calibration with test data
model(X)
model(*x)

# Step 2. SmoothQuant
config.step = SmoothQuantStep.CONVERT
quantize_(model, config)
assert isinstance(
model.linear1.weight, WeightTensorWithLinearActivationScaleMetadata
)
assert isinstance(
model.linear2.weight, WeightTensorWithLinearActivationScaleMetadata
)

out_smoothquant = model(X)
out_smoothquant = model(*x)
loss_smoothquant = torch.nn.functional.mse_loss(out_smoothquant, out_ref).item()

assert loss_smoothquant < loss_base, (
f"SmoothQuant loss ({loss_smoothquant:.6f}) should not be higher than basic loss ({loss_base:.6f})"
)

@common_utils.parametrize("alpha", [0.5, 0.25])
@common_utils.parametrize("device", device_list)
@common_utils.parametrize("input_dtype", [torch.bfloat16])
def test_smoothquant_static_act_accuracy(self, alpha, device, input_dtype):
"""Test if SmoothQuant with static quantization achieves lower loss than basic quantization."""
m = ToyLinearModel().eval().to(device).to(input_dtype)
x = m.example_inputs(batch_size=16, dtype=input_dtype, device=device)

# Output without quantization
out_ref = m(*x)

# Step 1. Reference with alpha=0
m_ref = deepcopy(m)
base_config = Int8StaticActivationInt8WeightConfig()
config = SmoothQuantConfig(
base_config=base_config,
step=SmoothQuantStep.PREPARE,
alpha=0.0,
)
with torch.no_grad():
quantize_(m_ref, config)
m_ref(*x) # calibration
config.step = SmoothQuantStep.CONVERT
quantize_(m_ref, config)
out_base = m_ref(*x)
loss_base = torch.nn.functional.mse_loss(out_base, out_ref).item()

# Step 2. SmoothQuant quantization
base_config = Int8StaticActivationInt8WeightConfig()
config = SmoothQuantConfig(
base_config=base_config,
step=SmoothQuantStep.PREPARE,
alpha=alpha,
)
with torch.no_grad():
quantize_(m, config)
m(*x) # calibration
config.step = SmoothQuantStep.CONVERT
quantize_(m, config)
out_sq = m(*x)
assert isinstance(
m.linear1.weight, WeightTensorWithLinearActivationScaleMetadata
)
assert isinstance(
m.linear2.weight, WeightTensorWithLinearActivationScaleMetadata
)
loss_smoothquant = torch.nn.functional.mse_loss(out_sq, out_ref).item()

assert loss_smoothquant < loss_base, (
f"SmoothQuant loss ({loss_smoothquant:.6f}) should not be higher than basic loss ({loss_base:.6f})"
)
# Make sure the result is reasonable
self.assertGreater(SQNR(out_ref, out_sq), 20.0)

@common_utils.parametrize(
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8StaticActivationInt8WeightConfig(),
# TODO: Check more quantization APIs
],
)
Expand Down Expand Up @@ -167,6 +229,7 @@ def test_observer_insertion(self, base_config):
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8StaticActivationInt8WeightConfig(),
# TODO: Check more quantization APIs
],
)
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/smoothquant/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ for data in calibration_dataset:
quant_config.step = SmoothQuantStep.CONVERT
quantize_(model, quant_config)
```
For static quantization of activation, use `Int8StaticActivationInt8WeightConfig` instead of `Int8DynamicActivationInt8WeightConfig`. Generally, static quantization produces better througput at the cost of accuracy (higher perplexity).

## Benchmarks

Expand Down
10 changes: 9 additions & 1 deletion torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from torchao.quantization.quant_api import (
_QUANTIZE_CONFIG_HANDLER,
Int8StaticActivationInt8WeightConfig,
_linear_extra_repr,
)
from torchao.quantization.transform_module import (
Expand Down Expand Up @@ -96,7 +97,12 @@ def _smooth_quant_transform(
raise ValueError(f"Unexpected step: {step}")

# Compute smoothed weight parameters
smoothing_factor = observed_linear.obs.calculate_qparams()
act_quant_min, act_quant_max = None, None
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
act_quant_min, act_quant_max = -127, 127
smoothing_factor, act_scale = observed_linear.obs.calculate_qparams(
act_quant_min, act_quant_max
)
weight = observed_linear.weight * smoothing_factor

# Create new linear layer
Expand All @@ -111,6 +117,8 @@ def _smooth_quant_transform(
linear.bias = observed_linear.bias

# Quantize weights
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
base_config = Int8StaticActivationInt8WeightConfig(act_scale)
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)]
dummy_mod = DummyModule(weight)
quant_mod = base_config_handler(dummy_mod, base_config)
Expand Down
10 changes: 7 additions & 3 deletions torchao/prototype/smoothquant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def forward(self, input: torch.Tensor):
self.inputs.append(input.to("cpu"))
return input

def calculate_qparams(self):
def calculate_qparams(self, act_quant_min=None, act_quant_max=None):
assert self.inputs and len(self.inputs) > 0, (
"calibrate observer first by running model on exemplar data"
)
Expand All @@ -54,15 +54,19 @@ def calculate_qparams(self):
# Calculate per-channel max values
x_abs_max = torch.max(torch.abs(acc), dim=0)[0]
w_abs_max = torch.max(torch.abs(self.weight), dim=0)[0]
act_scale = None
if act_quant_min is not None and act_quant_max is not None:
x_abs_max_t = acc.abs().max()
act_scale = x_abs_max_t / (act_quant_max - act_quant_min) / 2

# Calculate smoothing factor
if self.alpha is None:
return torch.ones_like(x_abs_max)
return torch.ones_like(x_abs_max), act_scale

eps = torch.finfo(torch.float32).eps
return torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)
), act_scale


class SmoothQuantObservedLinear(torch.nn.Linear):
Expand Down
Loading
Loading