From a1b630352871e8842e4f4b5607d261bdd444ee27 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 10 Jun 2026 00:23:46 +0000 Subject: [PATCH 1/5] Add 2:4 sparsity -> INT8 SmoothQuant PTQ -> ONNX -> TensorRT example New example examples/sparse_quant_trt/pipeline.py: an end-to-end flow for Qwen2.5-1.5B-Instruct that applies optional 2:4 structured sparsity and INT8 W8A8 SmoothQuant PTQ (with optional QAT), exports to ONNX, builds a strongly-typed TensorRT engine, validates structured-sparse INT8 kernel selection, and runs greedy text inference. Tested in nvcr.io/nvidia/pytorch:26.01-py3 (PyTorch 2.10.0a0+a36e1d39eb, ONNX 1.18.0, TensorRT 10.14.1.48, CUDA 13.1, ModelOpt 0.45.0rc0, transformers 5.9.0). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/sparse_quant_trt/pipeline.py | 852 ++++++++++++++++++++++++++ 1 file changed, 852 insertions(+) create mode 100644 examples/sparse_quant_trt/pipeline.py diff --git a/examples/sparse_quant_trt/pipeline.py b/examples/sparse_quant_trt/pipeline.py new file mode 100644 index 0000000000..a282f3b585 --- /dev/null +++ b/examples/sparse_quant_trt/pipeline.py @@ -0,0 +1,852 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end pipeline for Qwen2.5-1.5B-Instruct: + + [2:4 weight sparsity] -> INT8 W8A8 SmoothQuant PTQ -> [QAT] + -> finalize -> torch->ONNX export (opset 20) + -> TensorRT engine build (trtexec, structured sparsity enabled) + -> validate sparse INT8 kernels -> real inference (text-in -> text-out). + +Sparsity (--sparsity) and QAT (--qat) are OPTIONAL and OFF by default; the default run is +plain INT8 W8A8 SmoothQuant, which preserves accuracy (coherent generation). + +WARNING: one-shot 2:4 magnitude sparsity zeros half the weights and causes SEVERE accuracy +degradation by itself -- the model produces gibberish until recovered with QAT/SAT fine-tuning. +So --sparsity is only useful together with --qat (and realistically a longer recovery run than +this smoke-level QAT). When sparsity is on, ordering is the proven one +(examples/llm_sparsity/weight_sparsity): sparsify FIRST, then quantize so SmoothQuant calibrates +on the sparse weights, then QAT. + +Tested with (Docker container + library versions/commits): + - Docker container: nvcr.io/nvidia/pytorch:26.01-py3 + - PyTorch: 2.10.0a0+a36e1d39eb (git commit a36e1d39eb) + - ONNX: 1.18.0 + - TensorRT: 10.14.1.48 (trtexec v101401) + - CUDA: 13.1 + - NVIDIA ModelOpt: 0.45.0rc0 (this repository, installed editable) + - transformers: 5.9.0 (supported range >=4.56,<5.10) + - accelerate: 1.13.0 + - Model: Qwen/Qwen2.5-1.5B-Instruct + - GPU used: NVIDIA RTX 6000 Ada Generation (sm_89) + +ModelOpt is installed editable in the container; transformers/accelerate and the ONNX-export +helper deps (onnxruntime, onnx-graphsurgeon, onnxconverter-common, onnxslim) are pip-installed +without disturbing the container's torch 2.10 / onnx 1.18. +""" + +import argparse +import contextlib +import json +import os +import re +import subprocess +import time + +import torch +import torch.nn as nn + +# ----------------------------------------------------------------------------- helpers +BANNER = "=" * 92 + + +def log(msg: str) -> None: + print(f"\n{BANNER}\n# {msg}\n{BANNER}", flush=True) + + +def linear_weight_zero_fraction(model: nn.Module, needle: str = "q_proj"): + """Return zero-fraction of the first decoder Linear matching `needle` (effective weight).""" + import modelopt.torch.sparsity as mts + + for name, mod in model.named_modules(): + if ( + needle in name + and hasattr(mod, "weight") + and isinstance(getattr(mod, "weight"), torch.Tensor) + ): + w = mod.weight + mask = getattr(mod, "_weight_mask", None) + if isinstance(mod, getattr(mts, "SparseModule", ())) and mask is not None: + w = w * mask + return name, float((w == 0).float().mean().item()), tuple(w.shape) + return None, None, None + + +def count_enabled_weight_quantizers(model: nn.Module) -> int: + n = 0 + for _, mod in model.named_modules(): + q = getattr(mod, "weight_quantizer", None) + if q is not None and getattr(q, "is_enabled", False): + n += 1 + return n + + +# ----------------------------------------------------------------------------- stages +def load_model(model_dir: str): + from transformers import AutoModelForCausalLM, AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_dir) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + # fp32 for the torch stages: the post-training-sparsified (un-recovered) model has large + # intermediate activations that overflow to NaN in fp16 during SmoothQuant calibration. + # fp32 has full dynamic range -> stable calibration/QAT. ONNX export later converts weights + # to fp16 via weights_dtype="fp16" (matches the proven timm INT8 example). Eager attention + # keeps the graph plain matmul+softmax so it exports cleanly to ONNX (no flash/sdpa op). + model = ( + AutoModelForCausalLM.from_pretrained( + model_dir, dtype=torch.float32, attn_implementation="eager" + ) + .to("cuda") + .eval() + ) + model.config.use_cache = False + return model, tok + + +def build_calib_forward_loop(model, tok, num_samples, calib_seq, dataset_name, batch_size): + """SmoothQuant calibration forward_loop, mirroring examples/llm_ptq/hf_ptq.py: + + calib_dataloader = get_dataset_dataloader(dataset_name, tokenizer, batch_size, + num_samples, max_sample_length=calib_seq, device) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + + Supports cnn_dailymail / nemotron-* datasets (see get_supported_datasets()). Falls back to + synthetic prompts only if the dataset can't be fetched. + """ + from modelopt.torch.utils.dataset_utils import ( + create_forward_loop, + get_dataset_dataloader, + get_supported_datasets, + ) + + try: + supported = get_supported_datasets() + if dataset_name not in supported: + print( + f"[calib] '{dataset_name}' not in supported {supported}; using cnn_dailymail", + flush=True, + ) + dataset_name = "cnn_dailymail" + calib_dataloader = get_dataset_dataloader( + dataset_name=dataset_name, + tokenizer=tok, + batch_size=batch_size, + num_samples=num_samples, + max_sample_length=calib_seq, + device="cuda", + ) + forward_loop = create_forward_loop(dataloader=calib_dataloader) + # materialize once so a fetch failure surfaces here (and we can fall back) + n_batches = len(calib_dataloader) + print( + f"[calib] dataset={dataset_name} samples={num_samples} seq={calib_seq} " + f"batch_size={batch_size} -> {n_batches} batches", + flush=True, + ) + return forward_loop, dataset_name + except Exception as e: # pragma: no cover - network/dataset fallback + print( + f"[calib] dataset '{dataset_name}' unavailable ({e}); using synthetic prompts", + flush=True, + ) + prompts = _synthetic_prompts(num_samples) + enc = tok( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=calib_seq, + ) + ids = enc["input_ids"] + + def forward_loop(m): + with torch.no_grad(): + for i in range(0, ids.shape[0], batch_size): + m(input_ids=ids[i : i + batch_size].to("cuda")) + + return forward_loop, "synthetic" + + +def _synthetic_prompts(n): + base = [ + "The history of artificial intelligence spans many decades of research.", + "In quantum mechanics, the wave function describes the state of a system.", + "Climate change is driven by greenhouse gas emissions from human activity.", + "The recipe calls for flour, sugar, eggs, and a pinch of salt.", + "Neural networks learn representations from large amounts of data.", + "The stock market reacted sharply to the central bank's announcement.", + "Photosynthesis converts sunlight into chemical energy in plants.", + "She traveled across the country to visit the ancient ruins.", + ] + out = [] + while len(out) < n: + for i, b in enumerate(base): + out.append(f"{b} This is calibration example number {len(out)}, paragraph {i}.") + if len(out) >= n: + break + return out + + +def apply_sparsity(model): + import modelopt.torch.sparsity as mts + + out = mts.sparsify(model, "sparse_magnitude", config=None) + model = out[0] if isinstance(out, tuple) else out + name, zf, shape = linear_weight_zero_fraction(model, "q_proj") + print(f"[sparsity] sample {name} effective zero-frac={zf:.4f} shape={shape}", flush=True) + assert zf is not None and 0.45 <= zf <= 0.55, f"2:4 sparsity not applied (zero-frac={zf})" + return model + + +def apply_ptq(model, forward_loop, quant_output=False): + import copy + + import modelopt.torch.quantization as mtq + + cfg = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG) + if quant_output: + # Enable INT8 output quantizers so every GEMM has an INT8-out epilogue. Without this, + # each Linear's consumer (attention / SiLU / residual+layernorm) is unquantized, so the + # GEMM dequantizes to fp32 -> the dense-favorable epilogue. This tests whether the fp32 + # output (not the K=1536 shape) is what keeps TRT from choosing sparse INT8 kernels. + cfg["quant_cfg"].append( + {"quantizer_name": "*output_quantizer", "cfg": {"num_bits": 8, "axis": None}} + ) + print("[ptq] INT8 output quantizers ENABLED (GEMMs INT8-out)", flush=True) + model = mtq.quantize(model, cfg, forward_loop=forward_loop) + mtq.print_quant_summary(model) + nwq = count_enabled_weight_quantizers(model) + print(f"[ptq] enabled weight quantizers: {nwq}", flush=True) + assert nwq > 0, "no weight quantizers enabled after PTQ" + return model + + +def run_qat(model, tok, steps, seq_len, lr): + """Minimal EXAMPLE QAT (quantization-aware training) loop -- NOT production training. + + This is a smoke-level placeholder: a few steps of next-token cross-entropy on a handful of + cnn_dailymail samples, just to exercise the QAT code path (the quantizer amax values stay + frozen and ModelOpt's fake-quant + the 2:4 sparsity mask are applied on every forward). + + To actually recover accuracy -- which is REQUIRED after 2:4 sparsification -- replace the body + below with YOUR OWN dataset and training pipeline: your task-appropriate training/instruction + data, loss, optimizer/scheduler, batch size, sequence length, and a realistic number of + steps/epochs (and ideally a held-out eval). The surrounding pipeline (mtq.quantize before this, + finalize/export after) stays the same; only this training loop is meant to be swapped out. + Distributed training (DDP/FSDP) and HF Trainer also work -- see examples/llm_qat and + examples/llm_sparsity/weight_sparsity/finetune.py for fuller references. + """ + from modelopt.torch.utils.dataset_utils import get_dataset_dataloader + + try: + dl = get_dataset_dataloader( + dataset_name="cnn_dailymail", + tokenizer=tok, + batch_size=1, + num_samples=max(steps, 8), + max_sample_length=seq_len, + include_labels=True, + device="cuda", + ) + batches = list(dl) + except Exception as e: # pragma: no cover + print(f"[qat] dataset unavailable ({e}); using synthetic", flush=True) + enc = tok( + _synthetic_prompts(max(steps, 8)), + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=seq_len, + ) + batches = [ + {"input_ids": enc["input_ids"][i : i + 1]} for i in range(enc["input_ids"].shape[0]) + ] + + model.train() + opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=lr) + done = 0 + for batch in batches: + if done >= steps: + break + ids = batch["input_ids"].to("cuda") + attn = batch.get("attention_mask") + kwargs = {"input_ids": ids, "labels": ids} + if attn is not None: + kwargs["attention_mask"] = attn.to("cuda") + out = model(**kwargs) + loss = out.loss + if not torch.isfinite(loss): + print(f"[qat] step {done}: non-finite loss, skipping", flush=True) + opt.zero_grad(set_to_none=True) + continue + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + opt.step() + opt.zero_grad(set_to_none=True) + done += 1 + if done % 5 == 0 or done == 1: + print(f"[qat] step {done}/{steps} loss={loss.item():.4f}", flush=True) + model.eval() + print(f"[qat] completed {done} steps", flush=True) + return model + + +def finalize_for_export(model, expect_sparse=True): + """Prepare the sparse+quant model for ONNX export. + + We deliberately do NOT call ``mts.export``: ModelOpt's mode system unwinds modes LIFO, so + with ``quantize`` applied on top of ``sparse_magnitude`` the sparse export is blocked until + quantize is exported first — and exporting quantize would strip the fake-quant QDQ we need. + For the ONNX-QDQ -> TensorRT path we keep both modes live: the SparseModule's dynamic + ``weight*mask`` getter feeds the 2:4 zeros into the torch.onnx trace, and the TensorQuantizer + emits QuantizeLinear/DequantizeLinear. As insurance, fold the mask into the underlying weight + parameter so the structural zeros are guaranteed to land in the ONNX initializers even if the + exporter reads the raw parameter rather than the dynamic getter. + """ + import modelopt.torch.sparsity as mts + + baked = 0 + with torch.no_grad(): + for _, mod in model.named_modules(): + if isinstance(mod, getattr(mts, "SparseModule", ())): + mask = getattr(mod, "_weight_mask", None) + if mask is None: + continue + # underlying parameter lives in _parameters; the public ``mod.weight`` getter + # returns weight*mask. Write zeros straight into the stored parameter. + raw = mod._parameters.get("weight", None) + if raw is not None: + raw.data.mul_(mask.to(raw.dtype)) + baked += 1 + print(f"[finalize] folded 2:4 mask into {baked} underlying weight params", flush=True) + name, zf, shape = linear_weight_zero_fraction(model, "q_proj") + print(f"[finalize] effective {name} weight zero-frac={zf:.4f} shape={shape}", flush=True) + if expect_sparse: + assert zf is not None and 0.45 <= zf <= 0.55, ( + f"2:4 sparsity not present pre-export (zf={zf})" + ) + nwq = count_enabled_weight_quantizers(model) + print(f"[finalize] enabled weight quantizers (INT8 QDQ source): {nwq}", flush=True) + assert nwq > 0, "no INT8 quantizers present before ONNX export" + return model + + +class PrefillExportWrapper(nn.Module): + """input_ids -> logits, no KV cache / attention_mask, for a clean prefill ONNX graph.""" + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_ids): + return self.model(input_ids=input_ids, use_cache=False, return_dict=True).logits + + +@contextlib.contextmanager +def patch_transformers_for_export(): + """Make a single unpadded prefill graph export cleanly to ONNX/TensorRT. + + transformers' ``find_packed_sequence_indices`` skips its ``return None`` early-exit while + tracing (``is_tracing()`` is True), so it bakes a packed-sequence ``diff``+``cumsum``-on-bool + into the graph: ``torch.diff`` has no ONNX symbolic, and TRT's CumSum rejects bool inputs. + For one contiguous sequence there is genuinely no packing, so we force it to return None. + We also decompose torch.diff as a belt-and-suspenders safeguard. + """ + import transformers.masking_utils as mu + + orig_fps = mu.find_packed_sequence_indices + orig_diff = torch.diff + + def _diff(input, n=1, dim=-1, prepend=None, append=None): + parts = [] + if prepend is not None: + parts.append(prepend) + parts.append(input) + if append is not None: + parts.append(append) + x = torch.cat(parts, dim=dim) if len(parts) > 1 else input + for _ in range(n): + hi = [slice(None)] * x.dim() + lo = [slice(None)] * x.dim() + hi[dim] = slice(1, None) + lo[dim] = slice(None, -1) + x = x[tuple(hi)] - x[tuple(lo)] + return x + + mu.find_packed_sequence_indices = lambda position_ids: None + torch.diff = _diff + try: + yield + finally: + mu.find_packed_sequence_indices = orig_fps + torch.diff = orig_diff + + +def export_onnx(model, out_dir, model_name, seq_len, vocab_size, weights_dtype="fp16"): + import glob + + from modelopt.torch._deploy.utils.torch_onnx import OnnxBytes, get_onnx_bytes_and_metadata + + # The output weight dtype is set by the torch MODEL's dtype, not by the helper's weights_dtype + # arg. For fp16 we cast the model and export NATIVELY (torch emits a self-consistent fp16 graph, + # with explicit Casts where RMSNorm upcasts to fp32); fp32 keeps the graph fp32. Either way the + # graph is self-consistently typed, so TensorRT --stronglyTyped can parse it. + if weights_dtype == "fp16": + model = model.half() + # Always pass weights_dtype="fp32" to the helper. Inside get_onnx_bytes_and_metadata that arg + # ONLY gates onnxconverter's convert_float_to_float16 pass -- which we must NOT run: it + # block-lists Div and leaves fp32 islands around RMSNorm that strongly-typed REJECTS. So "fp32" + # here means "helper, don't do your own fp16 conversion" (the model.half() above already did it). + helper_dtype = "fp32" + wrapper = PrefillExportWrapper(model).eval() + dummy = torch.randint(0, vocab_size, (1, seq_len), dtype=torch.int64, device="cuda") + os.makedirs(out_dir, exist_ok=True) + # onnx 1.18 resolves the external-data `location` against CWD and refuses to overwrite an + # existing .onnx_data, so clear any stale artifacts from a previous run first. + for stale in glob.glob(os.path.join(out_dir, f"{model_name}.onnx*")): + os.remove(stale) + print( + f"[onnx] exporting weights_dtype={weights_dtype} " + f"(native {'fp16' if weights_dtype == 'fp16' else 'fp32'} graph for strongly-typed)", + flush=True, + ) + with patch_transformers_for_export(): + onnx_bytes, _ = get_onnx_bytes_and_metadata( + model=wrapper, + dummy_input=(dummy,), + model_name=model_name, + onnx_opset=20, + dynamo_export=False, # TorchScript path (dynamo avoided for quantized graphs on torch 2.10) + weights_dtype=helper_dtype, + dynamic_axes={"input_ids": {0: "batch", 1: "seq"}}, + ) + OnnxBytes.from_bytes(onnx_bytes).write_to_disk(out_dir, clean_dir=False) + onnx_path = os.path.join(out_dir, f"{model_name}.onnx") + assert os.path.isfile(onnx_path), f"ONNX not written to {onnx_path} (dir={os.listdir(out_dir)})" + print(f"[onnx] wrote {onnx_path} ({os.path.getsize(onnx_path) / 1e6:.1f} MB)", flush=True) + return onnx_path + + +def inspect_onnx(onnx_path, expect_sparse=True): + import onnx + from onnx import numpy_helper + + m = onnx.load(onnx_path, load_external_data=True) + op_counts = {} + for n in m.graph.node: + op_counts[n.op_type] = op_counts.get(n.op_type, 0) + 1 + q = op_counts.get("QuantizeLinear", 0) + dq = op_counts.get("DequantizeLinear", 0) + print( + f"[onnx] QuantizeLinear={q} DequantizeLinear={dq} MatMul={op_counts.get('MatMul', 0)} " + f"Gemm={op_counts.get('Gemm', 0)} Conv={op_counts.get('Conv', 0)} CumSum={op_counts.get('CumSum', 0)}", + flush=True, + ) + assert q + dq > 0, "no QDQ nodes in ONNX -> not an INT8 graph" + + # Verify the 2:4 sparsity zeros reached the ONNX weights. The quantized Linear weights are + # stored as 2D tensors feeding QuantizeLinear (as initializers or Constant-node values). Count + # how many large 2D weight tensors are ~50% zero (the 2:4 sparsified Linears) vs dense ones + # (embedding / lm_head are intentionally NOT sparsified). + def _tensors(): + yield from m.graph.initializer + for node in m.graph.node: + if node.op_type == "Constant": + for a in node.attribute: + if a.name == "value" and a.t.dims: + yield a.t + + sparse_2to4 = dense = 0 + example = None + for t in _tensors(): + if len(t.dims) == 2 and min(t.dims) >= 256: + arr = numpy_helper.to_array(t) + zf = float((arr == 0).mean()) + if 0.45 <= zf <= 0.55: + sparse_2to4 += 1 + if example is None: + example = (t.name, tuple(t.dims), zf) + elif zf < 0.05: + dense += 1 + ex = f" e.g. {example[0]} dims={example[1]} zf={example[2]:.4f}" if example else "" + print( + f"[onnx] 2D weights: {sparse_2to4} are 2:4-sparse (~50% zero), {dense} dense.{ex}", + flush=True, + ) + if expect_sparse: + assert sparse_2to4 >= 50, ( + f"2:4 zeros did not reach ONNX weights (only {sparse_2to4} sparse)" + ) + return q, dq + + +def build_trt( + onnx_path, + engine_path, + log_path, + layer_info_path, + seq_len, + timeout_s, + strongly_typed=True, + sparsity="enable", +): + trtexec = "/usr/src/tensorrt/bin/trtexec" + if not os.path.isfile(trtexec): + trtexec = ( + subprocess.run( + ["bash", "-lc", "command -v trtexec"], capture_output=True, text=True + ).stdout.strip() + or "trtexec" + ) + prec = ["--stronglyTyped"] if strongly_typed else ["--int8", "--fp16"] + cmd = [ + trtexec, + f"--onnx={onnx_path}", + *prec, + f"--sparsity={sparsity}", + f"--saveEngine={engine_path}", + "--minShapes=input_ids:1x1", + f"--optShapes=input_ids:1x{seq_len}", + f"--maxShapes=input_ids:1x{max(seq_len, 512)}", + "--builderOptimizationLevel=4", + "--profilingVerbosity=detailed", + f"--exportLayerInfo={layer_info_path}", + "--verbose", + ] + print(f"[trt] {' '.join(cmd)}", flush=True) + with open(log_path, "w") as f: + proc = subprocess.run(cmd, stdout=f, stderr=subprocess.STDOUT, timeout=timeout_s) + print(f"[trt] trtexec returncode={proc.returncode}; log -> {log_path}", flush=True) + return proc.returncode + + +def validate_sparse(log_path, layer_info_path): + with open(log_path, errors="ignore") as f: + text = f.read() + spars_lines = [ln for ln in text.splitlines() if "(Sparsity)" in ln] + print("[validate] trtexec sparsity report lines:", flush=True) + for ln in spars_lines: + print(" " + ln.split("] ")[-1].strip(), flush=True) + found = chose = 0 + for ln in spars_lines: + mf = re.search(r"Found (\d+) layer", ln) + mc = re.search(r"Chose (\d+) layer", ln) + # TRT emits several (Sparsity) lines (per myelin foreign node); take the max, not the last. + if mf: + found = max(found, int(mf.group(1))) + if mc: + chose = max(chose, int(mc.group(1))) + # secondary signal: sparse kernel markers in layer-info tactic metadata + sparse_tactics = 0 + try: + info = json.load(open(layer_info_path)) + for layer in info.get("Layers", []): + blob = json.dumps(layer).lower() + if any(k in blob for k in ("spars", "_2_1", "_4_2", "sp_mma", "spmma")): + sparse_tactics += 1 + except Exception as e: + print(f"[validate] layer_info parse skipped: {e}", flush=True) + print( + f"[validate] eligible(Found)={found} chosen(Chose)={chose} " + f"layer_info_sparse_markers={sparse_tactics}", + flush=True, + ) + return found, chose, sparse_tactics + + +_TRT_TO_TORCH = None + + +def _trt_dtype_map(): + import tensorrt as trt + + return { + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.int32: torch.int32, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, + } + + +def run_inference(engine_path, model_dir, prompt, max_new_tokens): + """Real text-in -> text-out through the built engine (greedy decode). + + The engine is a PREFILL graph (input_ids -> logits, no KV cache), so we generate greedily by + re-running the full prefill each step (fine for a short demo). Mirrors generate.py. Also serves + as the engine sanity check: it deserializes, runs, and must produce finite logits. + """ + import tensorrt as trt + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_dir) + # instruct model -> render the chat template to text, then tokenize (robust across versions) + text = tok.apply_chat_template( + [{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False + ) + ids = list(tok(text)["input_ids"]) + n_prompt = len(ids) + + logger = trt.Logger(trt.Logger.WARNING) + with open(engine_path, "rb") as f: + engine = trt.Runtime(logger).deserialize_cuda_engine(f.read()) + assert engine is not None, "failed to deserialize engine" + ctx = engine.create_execution_context() + dmap = _trt_dtype_map() + names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)] + in_name = next(n for n in names if engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT) + out_name = next(n for n in names if engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT) + in_dtype = dmap[engine.get_tensor_dtype(in_name)] + out_dtype = dmap[engine.get_tensor_dtype(out_name)] + max_seq = engine.get_tensor_profile_shape(in_name, 0)[2][1] # max seq from the build profile + stream = torch.cuda.Stream() + eos = {tok.eos_token_id} if tok.eos_token_id is not None else set() + print( + f"[infer] prompt={prompt!r} (prompt_tokens={n_prompt}, engine max_seq={max_seq})", + flush=True, + ) + + for step in range(max_new_tokens): + cur_len = len(ids) + if cur_len >= max_seq: + print(f"[infer] reached engine max_seq={max_seq}; stopping", flush=True) + break + inp = torch.tensor([ids], dtype=in_dtype, device="cuda") + ctx.set_input_shape(in_name, (1, cur_len)) + ctx.set_tensor_address(in_name, inp.data_ptr()) + out = torch.empty(tuple(ctx.get_tensor_shape(out_name)), dtype=out_dtype, device="cuda") + ctx.set_tensor_address(out_name, out.data_ptr()) + ok = ctx.execute_async_v3(stream.cuda_stream) + stream.synchronize() + assert ok, "execute_async_v3 returned False" + last = out[0, cur_len - 1].float() + if step == 0: + assert torch.isfinite(last).all(), "engine produced non-finite logits" + ids.append(int(last.argmax())) + if ids[-1] in eos: + break + + gen = tok.decode(ids[n_prompt:], skip_special_tokens=True) + print(f"[infer] generated={gen!r}", flush=True) + print("[infer] inference OK", flush=True) + return gen + + +# ----------------------------------------------------------------------------- main +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--model-dir", default="/models/Qwen2.5-1.5B-Instruct") + ap.add_argument("--out-dir", default="/workspace/out") + ap.add_argument("--model-name", default="qwen2_5_1_5b_int8_sparse") + # calibration (mirrors examples/llm_ptq/hf_ptq.py defaults: 1024 samples, calib_seq 512) + ap.add_argument( + "--calib-samples", + type=int, + default=1024, + help="number of calibration samples (hf_ptq.py default 1024)", + ) + ap.add_argument( + "--calib-seq", + type=int, + default=512, + help="max sequence length for calibration samples (hf_ptq.py default 512)", + ) + ap.add_argument( + "--calib-dataset", + default="cnn_dailymail", + help="calibration dataset, e.g. cnn_dailymail or a nemotron-* dataset", + ) + ap.add_argument("--calib-batch-size", type=int, default=4, help="calibration batch size") + # 2:4 sparsity is optional and OFF by default. WARNING: one-shot 2:4 magnitude pruning zeros + # half the weights and causes SEVERE accuracy degradation on its own; recovering quality + # requires QAT/SAT fine-tuning (--qat). Without recovery the model produces gibberish. + ap.add_argument( + "--sparsity", + action="store_true", + help="apply 2:4 structured sparsity before PTQ. WARNING: severe accuracy " + "degradation -- requires QAT (--qat) to recover quality.", + ) + # QAT is optional. The built-in loop is only a minimal EXAMPLE -- integrate your own dataset + # and training pipeline (see run_qat docstring) for real accuracy recovery. + ap.add_argument( + "--qat", + action="store_true", + help="run the EXAMPLE QAT fine-tune after PTQ (smoke-level placeholder; " + "swap in your own dataset + training loop for real recovery)", + ) + ap.add_argument("--qat-steps", type=int, default=20) + ap.add_argument("--qat-lr", type=float, default=1e-5) + ap.add_argument("--seq-len", type=int, default=128) + ap.add_argument("--trt-timeout", type=int, default=1800) + ap.add_argument( + "--weights-dtype", + choices=["fp32", "fp16"], + default="fp16", + help="ONNX/engine weight dtype. fp16 exports a NATIVE fp16 graph; fp32 keeps " + "it fp32. Both are self-consistently typed and build with TRT " + "--stronglyTyped (INT8 GEMMs via QDQ either way).", + ) + ap.add_argument( + "--reuse-onnx", + action="store_true", + help="skip torch stages 1-7 and build TRT from an already-exported ONNX", + ) + # final-stage real inference (text-in -> text-out) through the built engine + ap.add_argument( + "--prompt", + default="What is the capital of France? Answer in one word.", + help="prompt for the final real-inference step", + ) + ap.add_argument( + "--max-new-tokens", + type=int, + default=32, + help="number of tokens to greedily generate in the final inference step", + ) + args = ap.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + t0 = time.time() + print( + f"torch={torch.__version__} cuda_cc={torch.cuda.get_device_capability()} " + f"device={torch.cuda.get_device_name()}", + flush=True, + ) + + onnx_path = os.path.join(args.out_dir, f"{args.model_name}.onnx") + if args.reuse_onnx: + log("REUSE: skipping stages 1-7, building TRT from existing ONNX") + assert os.path.isfile(onnx_path), f"--reuse-onnx but {onnx_path} not found" + return _build_validate_infer(args, onnx_path, t0) + + log("STAGE 1/8: load Qwen2.5-1.5B-Instruct (fp32, eager attention)") + model, tok = load_model(args.model_dir) + vocab = model.config.vocab_size + print(f"vocab_size={vocab} layers={model.config.num_hidden_layers}", flush=True) + + log("STAGE 2/8: build SmoothQuant calibration forward_loop (hf_ptq.py style)") + forward_loop, src = build_calib_forward_loop( + model, tok, args.calib_samples, args.calib_seq, args.calib_dataset, args.calib_batch_size + ) + print(f"calibration source: {src}", flush=True) + + sparsity_on = args.sparsity + if sparsity_on: + log("STAGE 3/8: apply 2:4 structured sparsity (sparse_magnitude)") + if not args.qat: + print( + "[warn] 2:4 sparsity causes SEVERE accuracy degradation without recovery; " + "the model will likely produce gibberish. Add --qat (or do real SAT/QAT) to " + "recover quality.", + flush=True, + ) + model = apply_sparsity(model) + else: + log("STAGE 3/8: sparsity SKIPPED (default; pass --sparsity to enable -- INT8-only run)") + + # Enable INT8 output quantizers when sparsity is on: an INT8-out epilogue is what lets + # TensorRT actually CHOOSE the structured-sparse INT8 GEMM kernels (a dequant-to-fp32 epilogue + # makes dense faster). For the dense INT8-only path we leave outputs in fp16/fp32 for accuracy. + log( + "STAGE 4/8: INT8 W8A8 SmoothQuant PTQ" + + (" (+INT8 output quantizers)" if sparsity_on else "") + ) + model = apply_ptq(model, forward_loop, quant_output=sparsity_on) + + if args.qat: + log(f"STAGE 5/8: EXAMPLE QAT fine-tune ({args.qat_steps} steps, lr={args.qat_lr})") + print( + "[qat] NOTE: this is a minimal example loop -- integrate your own dataset and " + "training pipeline here for real accuracy recovery (see run_qat docstring).", + flush=True, + ) + model = run_qat(model, tok, args.qat_steps, args.seq_len, args.qat_lr) + else: + log("STAGE 5/8: QAT skipped (pass --qat to enable)") + + log("STAGE 6/8: finalize model for export (fold 2:4 zeros, keep INT8 QDQ)") + model = finalize_for_export(model, expect_sparse=sparsity_on) + + log("STAGE 7/8: export to ONNX (INT8 QDQ, opset 20)") + onnx_path = export_onnx( + model, args.out_dir, args.model_name, args.seq_len, vocab, weights_dtype=args.weights_dtype + ) + inspect_onnx(onnx_path, expect_sparse=sparsity_on) + del model + torch.cuda.empty_cache() + + return _build_validate_infer(args, onnx_path, t0) + + +def _build_validate_infer(args, onnx_path, t0): + # trtexec --sparsity follows the model: enable sparse tactics only when 2:4 sparsity was + # applied (--sparsity), otherwise disable (the weights are dense, so there's nothing to gain). + trt_sparsity = "enable" if args.sparsity else "disable" + log(f"STAGE 8/8: TensorRT engine build (--sparsity={trt_sparsity}) + validate + infer") + engine_path = os.path.join(args.out_dir, f"{args.model_name}.engine") + build_log = os.path.join(args.out_dir, "trtexec_build.log") + layer_info = os.path.join(args.out_dir, "layer_info.json") + # Always build a strongly-typed engine. Both --weights-dtype fp16 (native fp16 export) and + # fp32 produce a self-consistently typed ONNX that --stronglyTyped can parse, so no + # --int8/--fp16 fallback is needed. + mode = "stronglyTyped" + rc = build_trt( + onnx_path, + engine_path, + build_log, + layer_info, + args.seq_len, + args.trt_timeout, + strongly_typed=True, + sparsity=trt_sparsity, + ) + assert rc == 0, ( + f"strongly-typed trtexec build failed (rc={rc}); see {build_log}. " + f"The ONNX must be self-consistently typed for --stronglyTyped." + ) + print(f"[trt] engine built OK in mode: {mode}", flush=True) + + found, chose, markers = validate_sparse(build_log, layer_info) + + run_inference(engine_path, args.model_dir, args.prompt, args.max_new_tokens) + + log("PIPELINE SUMMARY") + print(f"ONNX: {onnx_path}", flush=True) + print(f"engine: {engine_path} (build mode: {mode})", flush=True) + print( + f"sparse INT8 kernels -> eligible(Found)={found} chosen(Chose)={chose} markers={markers}", + flush=True, + ) + if chose > 0: + print("RESULT: PASS - TensorRT selected structured-sparse INT8 kernels.", flush=True) + elif found > 0: + print( + "RESULT: PASS(weak) - 2:4 pattern detected & eligible; dense tactic was faster " + "for some/all GEMMs.", + flush=True, + ) + else: + print( + "RESULT: WARN - no sparse-eligible layers reported; check ONNX weight zero-fraction.", + flush=True, + ) + print(f"total wall time: {time.time() - t0:.1f}s", flush=True) + + +if __name__ == "__main__": + main() From 7d9d804871c33f9d940661aab9a667dd7900d33c Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 10 Jun 2026 00:30:56 +0000 Subject: [PATCH 2/5] Add README for the sparse + INT8 -> ONNX -> TensorRT example Documents examples/sparse_quant_trt: pipeline overview, tested Docker container and library versions, setup, usage and key flags, per-stage description, when TensorRT selects structured-sparse INT8 kernels (Found vs Chose), and the 2:4-sparsity accuracy caveat (requires QAT/SAT recovery). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/sparse_quant_trt/README.md | 137 ++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 examples/sparse_quant_trt/README.md diff --git a/examples/sparse_quant_trt/README.md b/examples/sparse_quant_trt/README.md new file mode 100644 index 0000000000..d727fcb949 --- /dev/null +++ b/examples/sparse_quant_trt/README.md @@ -0,0 +1,137 @@ + + +# Sparse + INT8 → ONNX → TensorRT example + +End-to-end example that takes a Hugging Face LLM (default +[`Qwen/Qwen2.5-1.5B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct)) +all the way to a running TensorRT engine: + +```text +[2:4 weight sparsity] -> INT8 W8A8 SmoothQuant PTQ -> [QAT] + -> finalize -> torch -> ONNX export (opset 20) + -> TensorRT engine build (trtexec, --stronglyTyped) + -> validate structured-sparse INT8 kernels -> real text generation +``` + +A single script, [`pipeline.py`](pipeline.py), drives the whole flow. + +Both **2:4 structured sparsity** (`--sparsity`) and **QAT** (`--qat`) are **optional and OFF by +default** — the default run is plain INT8 W8A8 SmoothQuant, which preserves accuracy and produces +coherent generations. + +## Tested environment + +Developed and tested inside the `nvcr.io/nvidia/pytorch:26.01-py3` Docker container: + +| Component | Version / commit | +| --- | --- | +| Docker container | `nvcr.io/nvidia/pytorch:26.01-py3` | +| PyTorch | `2.10.0a0+a36e1d39eb` (git `a36e1d39eb`) | +| ONNX | `1.18.0` | +| TensorRT | `10.14.1.48` (trtexec `v101401`) | +| CUDA | `13.1` | +| NVIDIA ModelOpt | `0.45.0rc0` (this repository, installed editable) | +| transformers | `5.9.0` (supported range `>=4.56,<5.10`) | +| accelerate | `1.13.0` | +| GPU | NVIDIA RTX 6000 Ada Generation (sm_89) | + +## Setup + +Run inside the container above. ModelOpt is installed editable; `transformers`/`accelerate` and the +ONNX-export helper dependencies are added without disturbing the container's `torch 2.10` / `onnx 1.18`: + +```bash +# editable ModelOpt (the container ships a pip constraint pinning an older version, so clear it) +PIP_CONSTRAINT= pip install -e . --no-deps +pip install --upgrade-strategy only-if-needed "transformers>=4.56,<5.10" "accelerate>=1.0.0" \ + omegaconf "pulp<4.0" "pydantic>=2.0" rich safetensors regex scipy +# ONNX-export helper deps used by modelopt.torch._deploy (do NOT upgrade onnx off 1.18) +pip install --upgrade-strategy only-if-needed onnxruntime onnx-graphsurgeon \ + "onnxconverter-common~=1.16.0" "onnxslim>=0.1.76" +``` + +## Usage + +```bash +# Default: INT8 W8A8 SmoothQuant -> ONNX -> strongly-typed TensorRT -> text generation +python pipeline.py + +# 2:4 sparsity + INT8 (TensorRT selects structured-sparse INT8 kernels). +# Add --qat to (start to) recover the accuracy that sparsity costs. +python pipeline.py --sparsity [--qat] + +# Iterate on the TensorRT build/inference only, reusing an already-exported ONNX +python pipeline.py --reuse-onnx +``` + +### Key options + +| Flag | Default | Description | +| --- | --- | --- | +| `--model-dir` | `/models/Qwen2.5-1.5B-Instruct` | HF model directory | +| `--sparsity` | off | Apply 2:4 structured sparsity before PTQ (auto-enables INT8 output quantizers) | +| `--qat` | off | Run the **example** QAT fine-tune after PTQ | +| `--weights-dtype` | `fp16` | ONNX/engine weight dtype (`fp16` or `fp32`); both build strongly-typed | +| `--calib-dataset` | `cnn_dailymail` | Calibration dataset (`cnn_dailymail` or a `nemotron-*` dataset) | +| `--calib-samples` / `--calib-seq` | `1024` / `512` | Calibration size / sequence length (mirrors `examples/llm_ptq/hf_ptq.py`) | +| `--seq-len` | `128` | Representative sequence length for ONNX export + trtexec optimization profile | +| `--prompt` / `--max-new-tokens` | *"What is the capital of France? ..."* / `32` | Final real-inference prompt and length | +| `--reuse-onnx` | off | Skip the torch stages and build TensorRT from an existing ONNX | + +Run `python pipeline.py --help` for the full list. + +## What each stage does + +1. **Load** the model in fp32 with eager attention (fp32 keeps SmoothQuant calibration of a sparse + model numerically stable; eager attention exports as plain matmul+softmax). +2. **Calibration loop** — `get_dataset_dataloader(...) -> create_forward_loop(...)`, mirroring + `examples/llm_ptq/hf_ptq.py`. +3. **2:4 sparsity** *(optional)* — `mts.sparsify(model, "sparse_magnitude")`. +4. **INT8 W8A8 SmoothQuant PTQ** — `mtq.quantize(model, mtq.INT8_SMOOTHQUANT_CFG, forward_loop)`. + With `--sparsity`, INT8 **output** quantizers are also enabled (see below). +5. **QAT** *(optional)* — a minimal **example** loop; replace it with your own dataset/training + pipeline for real recovery. +6. **Finalize + ONNX export** — opset 20, `dynamo=False`; for `fp16` the model is cast and exported + natively so the graph is self-consistently typed. +7. **TensorRT build** — `trtexec --stronglyTyped --sparsity={enable|disable}`. +8. **Validate + generate** — parse the trtexec sparsity report and run greedy text generation + through the engine. + +## When are structured-sparse INT8 kernels actually used? + +TensorRT's structured-sparse path validates the 2:4 pattern and reports, in the verbose build log: + +```text +(Sparsity) Found N layer(s) eligible to use sparse tactics: ... +(Sparsity) Chose M layer(s) using sparse tactics: ... +``` + +- **`Found`** = the weights pass the 2:4 pattern check (every 4 elements along the reduction axis + have ≥2 zeros). +- **`Chose`** = TensorRT actually selected a sparse INT8 kernel (its tactic timer found it fastest). + +A sparse INT8 GEMM is chosen only when its epilogue keeps data in **INT8** — which is why +`--sparsity` auto-enables INT8 output quantizers (otherwise each GEMM dequantizes to fp16/fp32 and +the dense kernel wins). With this, on the default model TensorRT reports roughly +`Found 196 / Chose 140` (sparse kernels selected for the q/o/gate/up/down projections across all +layers; the small k/v projections stay dense). The script prints a `PASS` when `Chose > 0`. + +## ⚠️ Accuracy note + +One-shot 2:4 **magnitude** sparsity zeros half the weights and causes **severe** accuracy +degradation on its own — the model produces gibberish until recovered with **QAT/SAT +fine-tuning**. The built-in `--qat` loop is a smoke-level **example**; recovering 50% sparsity +requires a real training run on a representative dataset. Use `--sparsity` to demonstrate the +sparse-kernel path; use the default (INT8-only) run for accuracy-preserving generation. + +## Outputs + +Written to `--out-dir` (default `/workspace/out`): + +- `.onnx` (+ `.onnx_data` external weights) — INT8 QDQ graph +- `.engine` — the TensorRT engine +- `trtexec_build.log`, `layer_info.json` — build log and per-layer info used for sparse-kernel + validation From 148102bd2169ea2d39fb37969055aee904cb578f Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 10 Jun 2026 00:49:29 +0000 Subject: [PATCH 3/5] Add FP16-baseline performance comparison to the sparse/INT8 TensorRT example Add --compare-baseline: build an FP16 (unquantized, dense) TensorRT engine from the same model and profile it against the optimized engine with trtexec, using the same profiling parameters as modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py (--warmUp / --avgRuns / --iterations / --noDataTransfers / --useCudaGraph / --useSpinWait). Reports throughput (qps) and median GPU-compute/latency for both engines plus the optimized engine's speedup. Documents the flag and method in the README. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/sparse_quant_trt/README.md | 21 +++++ examples/sparse_quant_trt/pipeline.py | 129 +++++++++++++++++++++++++- 2 files changed, 148 insertions(+), 2 deletions(-) diff --git a/examples/sparse_quant_trt/README.md b/examples/sparse_quant_trt/README.md index d727fcb949..0d1a112537 100644 --- a/examples/sparse_quant_trt/README.md +++ b/examples/sparse_quant_trt/README.md @@ -65,6 +65,9 @@ python pipeline.py --sparsity [--qat] # Iterate on the TensorRT build/inference only, reusing an already-exported ONNX python pipeline.py --reuse-onnx + +# Also build an FP16 (unquantized, dense) baseline engine and compare performance +python pipeline.py --compare-baseline ``` ### Key options @@ -80,9 +83,27 @@ python pipeline.py --reuse-onnx | `--seq-len` | `128` | Representative sequence length for ONNX export + trtexec optimization profile | | `--prompt` / `--max-new-tokens` | *"What is the capital of France? ..."* / `32` | Final real-inference prompt and length | | `--reuse-onnx` | off | Skip the torch stages and build TensorRT from an existing ONNX | +| `--compare-baseline` | off | Also build an FP16 (unquantized, dense) engine and profile both with trtexec | +| `--profiling-runs` | `1` | trtexec profiling runs for `--compare-baseline` (each run = 500 inferences) | Run `python pipeline.py --help` for the full list. +## Performance comparison (`--compare-baseline`) + +With `--compare-baseline`, the script additionally builds an **FP16 (unquantized, dense)** engine +from the same model and profiles both engines with `trtexec`, using the same profiling parameters +as ModelOpt's `modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py`: + +```text +trtexec --loadEngine= --shapes=input_ids:1x \ + --warmUp=500 --avgRuns=500 --iterations=500* \ + --noDataTransfers --useCudaGraph --useSpinWait +``` + +It parses `Throughput` (qps) and the median `GPU Compute Time` / `Latency` from the trtexec output +and prints a side-by-side table plus the optimized engine's throughput/latency speedup over the +FP16 baseline. Both engines are profiled at the same fixed shape for an apples-to-apples comparison. + ## What each stage does 1. **Load** the model in fp32 with eager attention (fp32 keeps SmoothQuant calibration of a sparse diff --git a/examples/sparse_quant_trt/pipeline.py b/examples/sparse_quant_trt/pipeline.py index a282f3b585..c27447c8b1 100644 --- a/examples/sparse_quant_trt/pipeline.py +++ b/examples/sparse_quant_trt/pipeline.py @@ -718,6 +718,19 @@ def main(): default=32, help="number of tokens to greedily generate in the final inference step", ) + # performance comparison vs an FP16 baseline engine + ap.add_argument( + "--compare-baseline", + action="store_true", + help="also build an FP16 (unquantized, dense) engine and profile both with trtexec, " + "reporting the optimized engine's throughput/latency speedup", + ) + ap.add_argument( + "--profiling-runs", + type=int, + default=1, + help="trtexec profiling runs for --compare-baseline (each run = 500 inferences)", + ) args = ap.parse_args() os.makedirs(args.out_dir, exist_ok=True) @@ -732,7 +745,10 @@ def main(): if args.reuse_onnx: log("REUSE: skipping stages 1-7, building TRT from existing ONNX") assert os.path.isfile(onnx_path), f"--reuse-onnx but {onnx_path} not found" - return _build_validate_infer(args, onnx_path, t0) + opt_engine = _build_validate_infer(args, onnx_path, t0) + if args.compare_baseline: + compare_engines(args, opt_engine) + return log("STAGE 1/8: load Qwen2.5-1.5B-Instruct (fp32, eager attention)") model, tok = load_model(args.model_dir) @@ -790,7 +806,9 @@ def main(): del model torch.cuda.empty_cache() - return _build_validate_infer(args, onnx_path, t0) + opt_engine = _build_validate_infer(args, onnx_path, t0) + if args.compare_baseline: + compare_engines(args, opt_engine) def _build_validate_infer(args, onnx_path, t0): @@ -846,6 +864,113 @@ def _build_validate_infer(args, onnx_path, t0): flush=True, ) print(f"total wall time: {time.time() - t0:.1f}s", flush=True) + return engine_path + + +def profile_engine(engine_path, seq_len, profiling_runs=1, timeout_s=900): + """Profile a built TensorRT engine with trtexec and return its throughput/latency. + + Uses the same profiling parameters as ModelOpt's + modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py (_get_profiling_params): + --warmUp=500 --avgRuns=500 --iterations=500*runs --noDataTransfers --useCudaGraph --useSpinWait + (WARMUP_TIME_MS=500, DEFAULT_NUM_INFERENCE_PER_RUN=500). The engine is loaded with --loadEngine + and profiled at a fixed shape so the baseline and optimized engines are compared apples-to-apples. + """ + trtexec = "/usr/src/tensorrt/bin/trtexec" + if not os.path.isfile(trtexec): + trtexec = ( + subprocess.run( + ["bash", "-lc", "command -v trtexec"], capture_output=True, text=True + ).stdout.strip() + or "trtexec" + ) + warmup_ms, n_inf = 500, 500 # engine_builder.py: WARMUP_TIME_MS, DEFAULT_NUM_INFERENCE_PER_RUN + cmd = [ + trtexec, + f"--loadEngine={engine_path}", + f"--shapes=input_ids:1x{seq_len}", + f"--warmUp={warmup_ms}", + f"--avgRuns={n_inf}", + f"--iterations={profiling_runs * n_inf}", + "--noDataTransfers", + "--useCudaGraph", + "--useSpinWait", + ] + print(f"[profile] {' '.join(cmd)}", flush=True) + out = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout_s).stdout + + def _grab(pat): + m = re.search(pat, out) + return float(m.group(1)) if m else None + + return { + "throughput_qps": _grab(r"Throughput:\s*([\d.]+)\s*qps"), + "gpu_compute_ms": _grab(r"GPU Compute Time:.*?median\s*=\s*([\d.]+)\s*ms"), + "latency_ms": _grab(r"Latency:.*?median\s*=\s*([\d.]+)\s*ms"), + } + + +def build_fp16_baseline(args): + """Build an FP16 (unquantized, dense) TensorRT engine from the same model, as a perf baseline.""" + log("BASELINE: build FP16 unquantized/dense TensorRT engine for comparison") + model, _ = load_model(args.model_dir) + vocab = model.config.vocab_size + name = f"{args.model_name}_fp16_baseline" + # export_onnx with weights_dtype="fp16" casts to fp16 and exports natively; the model has no + # quantizers, so the ONNX is a plain fp16 graph (no QDQ). Build it strongly-typed, sparsity off. + onnx_path = export_onnx(model, args.out_dir, name, args.seq_len, vocab, weights_dtype="fp16") + del model + torch.cuda.empty_cache() + engine_path = os.path.join(args.out_dir, f"{name}.engine") + rc = build_trt( + onnx_path, + engine_path, + os.path.join(args.out_dir, f"{name}_build.log"), + os.path.join(args.out_dir, f"{name}_layer_info.json"), + args.seq_len, + args.trt_timeout, + strongly_typed=True, + sparsity="disable", + ) + assert rc == 0, f"FP16 baseline engine build failed (rc={rc})" + print(f"[baseline] FP16 engine built OK: {engine_path}", flush=True) + return engine_path + + +def compare_engines(args, optimized_engine_path): + """Build an FP16 baseline engine and profile it against the optimized engine with trtexec.""" + baseline_engine = build_fp16_baseline(args) + log(f"COMPARE: profiling FP16 baseline vs optimized engine (trtexec, shape 1x{args.seq_len})") + base = profile_engine(baseline_engine, args.seq_len, args.profiling_runs) + opt = profile_engine(optimized_engine_path, args.seq_len, args.profiling_runs) + + def _fmt(v): + return f"{v:.3f}" if isinstance(v, float) else "n/a" + + log("PERFORMANCE COMPARISON (FP16 baseline vs optimized)") + print(f" {'metric':<26}{'fp16-baseline':>16}{'optimized':>16}", flush=True) + print( + f" {'throughput (qps)':<26}{_fmt(base['throughput_qps']):>16}{_fmt(opt['throughput_qps']):>16}", + flush=True, + ) + print( + f" {'GPU compute median (ms)':<26}{_fmt(base['gpu_compute_ms']):>16}{_fmt(opt['gpu_compute_ms']):>16}", + flush=True, + ) + print( + f" {'latency median (ms)':<26}{_fmt(base['latency_ms']):>16}{_fmt(opt['latency_ms']):>16}", + flush=True, + ) + if base["gpu_compute_ms"] and opt["gpu_compute_ms"]: + print( + f" => optimized GPU-compute speedup: {base['gpu_compute_ms'] / opt['gpu_compute_ms']:.2f}x", + flush=True, + ) + if base["throughput_qps"] and opt["throughput_qps"]: + print( + f" => optimized throughput gain: {opt['throughput_qps'] / base['throughput_qps']:.2f}x", + flush=True, + ) if __name__ == "__main__": From 1b52de8dad99209b81982f86ff58fa7625b191c2 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 10 Jun 2026 01:11:03 +0000 Subject: [PATCH 4/5] Quantize attention math by default (q/k/v_bmm + softmax) INT8-quantize the attention BMMs and softmax in addition to the linear projections by default (--no-quant-attention reverts to a linears-only graph). Update the script docstring and the example README accordingly. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/sparse_quant_trt/README.md | 5 ++++- examples/sparse_quant_trt/pipeline.py | 28 ++++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/examples/sparse_quant_trt/README.md b/examples/sparse_quant_trt/README.md index 0d1a112537..9ebecb98d5 100644 --- a/examples/sparse_quant_trt/README.md +++ b/examples/sparse_quant_trt/README.md @@ -76,6 +76,7 @@ python pipeline.py --compare-baseline | --- | --- | --- | | `--model-dir` | `/models/Qwen2.5-1.5B-Instruct` | HF model directory | | `--sparsity` | off | Apply 2:4 structured sparsity before PTQ (auto-enables INT8 output quantizers) | +| `--quant-attention` / `--no-quant-attention` | on | INT8-quantize the attention math (q/k/v_bmm + softmax) in addition to the linear projections | | `--qat` | off | Run the **example** QAT fine-tune after PTQ | | `--weights-dtype` | `fp16` | ONNX/engine weight dtype (`fp16` or `fp32`); both build strongly-typed | | `--calib-dataset` | `cnn_dailymail` | Calibration dataset (`cnn_dailymail` or a `nemotron-*` dataset) | @@ -112,7 +113,9 @@ FP16 baseline. Both engines are profiled at the same fixed shape for an apples-t `examples/llm_ptq/hf_ptq.py`. 3. **2:4 sparsity** *(optional)* — `mts.sparsify(model, "sparse_magnitude")`. 4. **INT8 W8A8 SmoothQuant PTQ** — `mtq.quantize(model, mtq.INT8_SMOOTHQUANT_CFG, forward_loop)`. - With `--sparsity`, INT8 **output** quantizers are also enabled (see below). + By default the **attention math** (`q/k/v_bmm` + `softmax`) is INT8-quantized too, not just the + linear projections (`--no-quant-attention` for linears-only). With `--sparsity`, INT8 **output** + quantizers are also enabled (see below). 5. **QAT** *(optional)* — a minimal **example** loop; replace it with your own dataset/training pipeline for real recovery. 6. **Finalize + ONNX export** — opset 20, `dynamo=False`; for `fp16` the model is cast and exported diff --git a/examples/sparse_quant_trt/pipeline.py b/examples/sparse_quant_trt/pipeline.py index c27447c8b1..86f955489d 100644 --- a/examples/sparse_quant_trt/pipeline.py +++ b/examples/sparse_quant_trt/pipeline.py @@ -22,7 +22,9 @@ -> validate sparse INT8 kernels -> real inference (text-in -> text-out). Sparsity (--sparsity) and QAT (--qat) are OPTIONAL and OFF by default; the default run is -plain INT8 W8A8 SmoothQuant, which preserves accuracy (coherent generation). +INT8 W8A8 SmoothQuant, which preserves accuracy (coherent generation). The attention math +(q/k/v_bmm + softmax) is INT8-quantized by default in addition to the linear projections; +pass --no-quant-attention for a linears-only graph. WARNING: one-shot 2:4 magnitude sparsity zeros half the weights and causes SEVERE accuracy degradation by itself -- the model produces gibberish until recovered with QAT/SAT fine-tuning. @@ -212,7 +214,7 @@ def apply_sparsity(model): return model -def apply_ptq(model, forward_loop, quant_output=False): +def apply_ptq(model, forward_loop, quant_output=False, quant_attention=False): import copy import modelopt.torch.quantization as mtq @@ -227,6 +229,17 @@ def apply_ptq(model, forward_loop, quant_output=False): {"quantizer_name": "*output_quantizer", "cfg": {"num_bits": 8, "axis": None}} ) print("[ptq] INT8 output quantizers ENABLED (GEMMs INT8-out)", flush=True) + if quant_attention: + # Also quantize the attention math: the BMM quantizers (q/k/v_bmm) feed Q*K^T and P*V in + # INT8, and the softmax_quantizer quantizes the attention probabilities. These are inserted + # by ModelOpt's _QuantAttention plugin but disabled by INT8_SMOOTHQUANT_CFG's default. + cfg["quant_cfg"].append( + {"quantizer_name": "*_bmm_quantizer", "cfg": {"num_bits": 8, "axis": None}} + ) + cfg["quant_cfg"].append( + {"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": 8, "axis": None}} + ) + print("[ptq] INT8 attention quantizers ENABLED (q/k/v_bmm + softmax)", flush=True) model = mtq.quantize(model, cfg, forward_loop=forward_loop) mtq.print_quant_summary(model) nwq = count_enabled_weight_quantizers(model) @@ -718,6 +731,13 @@ def main(): default=32, help="number of tokens to greedily generate in the final inference step", ) + ap.add_argument( + "--quant-attention", + action=argparse.BooleanOptionalAction, + default=True, + help="INT8-quantize the attention math (q/k/v_bmm + softmax) in addition to the linear " + "projections (default: on; pass --no-quant-attention for linears-only)", + ) # performance comparison vs an FP16 baseline engine ap.add_argument( "--compare-baseline", @@ -782,7 +802,9 @@ def main(): "STAGE 4/8: INT8 W8A8 SmoothQuant PTQ" + (" (+INT8 output quantizers)" if sparsity_on else "") ) - model = apply_ptq(model, forward_loop, quant_output=sparsity_on) + model = apply_ptq( + model, forward_loop, quant_output=sparsity_on, quant_attention=args.quant_attention + ) if args.qat: log(f"STAGE 5/8: EXAMPLE QAT fine-tune ({args.qat_steps} steps, lr={args.qat_lr})") From dff1aa1a2011ceba02c60688046f218bb0e6544e Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 10 Jun 2026 01:13:52 +0000 Subject: [PATCH 5/5] Always quantize attention math; drop the --quant-attention flag INT8-quantize the attention BMMs and softmax (q/k/v_bmm + softmax) unconditionally as part of the default INT8 path, and remove the optional --quant-attention / --no-quant-attention flag. Update the docstring and README accordingly. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- examples/sparse_quant_trt/README.md | 6 ++-- examples/sparse_quant_trt/pipeline.py | 40 ++++++++++----------------- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/examples/sparse_quant_trt/README.md b/examples/sparse_quant_trt/README.md index 9ebecb98d5..2c93bd8b16 100644 --- a/examples/sparse_quant_trt/README.md +++ b/examples/sparse_quant_trt/README.md @@ -76,7 +76,6 @@ python pipeline.py --compare-baseline | --- | --- | --- | | `--model-dir` | `/models/Qwen2.5-1.5B-Instruct` | HF model directory | | `--sparsity` | off | Apply 2:4 structured sparsity before PTQ (auto-enables INT8 output quantizers) | -| `--quant-attention` / `--no-quant-attention` | on | INT8-quantize the attention math (q/k/v_bmm + softmax) in addition to the linear projections | | `--qat` | off | Run the **example** QAT fine-tune after PTQ | | `--weights-dtype` | `fp16` | ONNX/engine weight dtype (`fp16` or `fp32`); both build strongly-typed | | `--calib-dataset` | `cnn_dailymail` | Calibration dataset (`cnn_dailymail` or a `nemotron-*` dataset) | @@ -113,9 +112,8 @@ FP16 baseline. Both engines are profiled at the same fixed shape for an apples-t `examples/llm_ptq/hf_ptq.py`. 3. **2:4 sparsity** *(optional)* — `mts.sparsify(model, "sparse_magnitude")`. 4. **INT8 W8A8 SmoothQuant PTQ** — `mtq.quantize(model, mtq.INT8_SMOOTHQUANT_CFG, forward_loop)`. - By default the **attention math** (`q/k/v_bmm` + `softmax`) is INT8-quantized too, not just the - linear projections (`--no-quant-attention` for linears-only). With `--sparsity`, INT8 **output** - quantizers are also enabled (see below). + The **attention math** (`q/k/v_bmm` + `softmax`) is INT8-quantized too, not just the linear + projections. With `--sparsity`, INT8 **output** quantizers are also enabled (see below). 5. **QAT** *(optional)* — a minimal **example** loop; replace it with your own dataset/training pipeline for real recovery. 6. **Finalize + ONNX export** — opset 20, `dynamo=False`; for `fp16` the model is cast and exported diff --git a/examples/sparse_quant_trt/pipeline.py b/examples/sparse_quant_trt/pipeline.py index 86f955489d..fb9fd54450 100644 --- a/examples/sparse_quant_trt/pipeline.py +++ b/examples/sparse_quant_trt/pipeline.py @@ -22,9 +22,8 @@ -> validate sparse INT8 kernels -> real inference (text-in -> text-out). Sparsity (--sparsity) and QAT (--qat) are OPTIONAL and OFF by default; the default run is -INT8 W8A8 SmoothQuant, which preserves accuracy (coherent generation). The attention math -(q/k/v_bmm + softmax) is INT8-quantized by default in addition to the linear projections; -pass --no-quant-attention for a linears-only graph. +INT8 W8A8 SmoothQuant, which preserves accuracy (coherent generation). INT8 quantization covers +both the linear projections AND the attention math (q/k/v_bmm + softmax). WARNING: one-shot 2:4 magnitude sparsity zeros half the weights and causes SEVERE accuracy degradation by itself -- the model produces gibberish until recovered with QAT/SAT fine-tuning. @@ -214,12 +213,23 @@ def apply_sparsity(model): return model -def apply_ptq(model, forward_loop, quant_output=False, quant_attention=False): +def apply_ptq(model, forward_loop, quant_output=False): import copy import modelopt.torch.quantization as mtq cfg = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG) + # Always INT8-quantize the attention math in addition to the linear projections: the BMM + # quantizers (q/k/v_bmm) feed Q*K^T and P*V in INT8 and the softmax_quantizer quantizes the + # attention probabilities. ModelOpt's _QuantAttention plugin inserts these quantizers but + # leaves them disabled in INT8_SMOOTHQUANT_CFG; enable them here. + cfg["quant_cfg"].append( + {"quantizer_name": "*_bmm_quantizer", "cfg": {"num_bits": 8, "axis": None}} + ) + cfg["quant_cfg"].append( + {"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": 8, "axis": None}} + ) + print("[ptq] INT8 attention quantizers ENABLED (q/k/v_bmm + softmax)", flush=True) if quant_output: # Enable INT8 output quantizers so every GEMM has an INT8-out epilogue. Without this, # each Linear's consumer (attention / SiLU / residual+layernorm) is unquantized, so the @@ -229,17 +239,6 @@ def apply_ptq(model, forward_loop, quant_output=False, quant_attention=False): {"quantizer_name": "*output_quantizer", "cfg": {"num_bits": 8, "axis": None}} ) print("[ptq] INT8 output quantizers ENABLED (GEMMs INT8-out)", flush=True) - if quant_attention: - # Also quantize the attention math: the BMM quantizers (q/k/v_bmm) feed Q*K^T and P*V in - # INT8, and the softmax_quantizer quantizes the attention probabilities. These are inserted - # by ModelOpt's _QuantAttention plugin but disabled by INT8_SMOOTHQUANT_CFG's default. - cfg["quant_cfg"].append( - {"quantizer_name": "*_bmm_quantizer", "cfg": {"num_bits": 8, "axis": None}} - ) - cfg["quant_cfg"].append( - {"quantizer_name": "*softmax_quantizer", "cfg": {"num_bits": 8, "axis": None}} - ) - print("[ptq] INT8 attention quantizers ENABLED (q/k/v_bmm + softmax)", flush=True) model = mtq.quantize(model, cfg, forward_loop=forward_loop) mtq.print_quant_summary(model) nwq = count_enabled_weight_quantizers(model) @@ -731,13 +730,6 @@ def main(): default=32, help="number of tokens to greedily generate in the final inference step", ) - ap.add_argument( - "--quant-attention", - action=argparse.BooleanOptionalAction, - default=True, - help="INT8-quantize the attention math (q/k/v_bmm + softmax) in addition to the linear " - "projections (default: on; pass --no-quant-attention for linears-only)", - ) # performance comparison vs an FP16 baseline engine ap.add_argument( "--compare-baseline", @@ -802,9 +794,7 @@ def main(): "STAGE 4/8: INT8 W8A8 SmoothQuant PTQ" + (" (+INT8 output quantizers)" if sparsity_on else "") ) - model = apply_ptq( - model, forward_loop, quant_output=sparsity_on, quant_attention=args.quant_attention - ) + model = apply_ptq(model, forward_loop, quant_output=sparsity_on) if args.qat: log(f"STAGE 5/8: EXAMPLE QAT fine-tune ({args.qat_steps} steps, lr={args.qat_lr})")