diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ad0d4acdfac..342d8f4c1aa 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ Changelog - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. - Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``. +- Add support for ONNX Q/DQ node placement for DLA via the new flag ``--target_dla``. 0.45 (2026-06-xx) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index 4671b99139c..1cbcaf857c0 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -300,6 +300,14 @@ def get_parser() -> argparse.ArgumentParser: "if certain operations require a higher version." ), ) + argparser.add_argument( + "--target_dla", + action="store_true", + help=( + "If set, enables Q/DQ nodes to be placed in all tensors for optimal DLA deployment. This only has " + "effect in INT8 quantization. Note that this may cause accuracy degradation, proceed with caution." + ), + ) argparser.add_argument( "--autotune", nargs="?", @@ -494,6 +502,7 @@ def main(): calibrate_per_node=args.calibrate_per_node, direct_io_types=args.direct_io_types, opset=args.opset, + target_dla=args.target_dla, autotune=autotune_enabled, autotune_output_dir=args.autotune_output_dir, autotune_num_schemes_per_region=args.autotune_schemes_per_region, diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index 5b1ad5efe4e..170446dab5a 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -163,7 +163,7 @@ def quantize( return onnx_model enable_gemv_detection_for_trt = kwargs.get("enable_gemv_detection_for_trt", True) - if enable_gemv_detection_for_trt and not autotune: + if enable_gemv_detection_for_trt and not (autotune or kwargs.get("target_dla", False)): # Either of m or n in matmul is 1, this matmul cannot utilize TensorCores. # The perf of adding Q/DQ layers is not good in TRT. Thus, in this case, # do not add Q/DQ layers to this matmul. @@ -183,11 +183,13 @@ def quantize( # Collect node names to exclude from quantization nodes_to_exclude = find_nodes_to_exclude(graph, nodes_to_exclude, op_types_to_exclude) # type: ignore[arg-type] - if not autotune: + if not (autotune or kwargs.get("target_dla", False)): nodes_to_exclude.extend(find_nodes_from_convs_to_exclude(graph, quantize_mode="int8")) # Change the default configuration of ORT quantization op_types_to_quantize = op_types_to_quantize or [] + if kwargs.get("target_dla", False) and not op_types_to_quantize: + op_types_to_quantize = list({node.op_type for node in onnx_model.graph.node}) if op_types_to_quantize: op_types_to_quantize.extend(custom_ops_to_quantize) op_types = {node.op for node in graph.nodes} diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 3484140a57c..6bbc3299129 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -360,6 +360,7 @@ def quantize( input_shapes_profile: Sequence[dict[str, str]] | None = None, direct_io_types: bool = False, opset: int | None = None, + target_dla: bool = False, autotune: bool = False, autotune_output_dir: str | None = None, autotune_num_schemes_per_region: int = 50, @@ -498,6 +499,9 @@ def quantize( Target ONNX opset version for the quantized model. If None, uses required minimum opset (19 for int8/fp8, 21 for int4, 23 for nvfp4). If the specified opset is lower than the required minimum, a warning will be issued and the opset will be upgraded to the required minimum. + target_dla: + If True, enable Q/DQ nodes to be placed in all tensors for optimal DLA deployment. This only has + effect in INT8 quantization. Note that this may cause accuracy degradation, proceed with caution. autotune: If True, detect optimal Q/DQ node placements according to the TensorRT version and platform available. If False, use the default pattern-based quantization approach. @@ -621,16 +625,17 @@ def quantize( # MatMuls in MHA pattern. # (3) else when quantize_mode == "fp8", if head_size > 256 or head_size <= 8 # or mha doesn't meet fp8 fMHA v2 pattern, don't add Q/DQ layers to MatMuls in MHA pattern. - nodes_to_exclude = find_nodes_from_mha_to_exclude( - onnx_path, - use_external_data_format, - nodes_to_exclude, - disable_mha_qdq, - quantize_mode, - intermediate_generated_files, - calibration_data_reader, - calibration_eps, - ) + if not (target_dla and quantize_mode == "int8"): + nodes_to_exclude = find_nodes_from_mha_to_exclude( + onnx_path, + use_external_data_format, + nodes_to_exclude, + disable_mha_qdq, + quantize_mode, + intermediate_generated_files, + calibration_data_reader, + calibration_eps, + ) if calibrate_per_node and not calibration_shapes: calibration_shapes = get_input_shapes(onnx_path) @@ -665,6 +670,7 @@ def quantize( kwargs["no_quantize_inputs"] = no_quantize_inputs kwargs["op_types_needing_output_quant"] = op_types_needing_output_quant + kwargs["target_dla"] = target_dla quantize_func = quantize_int8 if quantize_mode == "int8" else quantize_fp8 onnx_model = quantize_func( onnx_path=onnx_path, diff --git a/tests/_test_utils/onnx/lib_test_models.py b/tests/_test_utils/onnx/lib_test_models.py index 2c34a350852..1b669317e8a 100644 --- a/tests/_test_utils/onnx/lib_test_models.py +++ b/tests/_test_utils/onnx/lib_test_models.py @@ -1126,3 +1126,221 @@ def build_conv_layernorm_model(): onnx.checker.check_model(model_inferred) return model_inferred + + +def build_small_grouped_conv_model(): + """Build a model with grouped (depthwise) Convs with kernel 1x1 and 2x2. + + Topology: + Conv(256->128, 3x3) -> Relu -> Resize(2x nearest) -+-> DWConv(128, 2x2) -------------------> Mul -> output + | ^ + | | + +-> DWConv(128, 1x1) -> DWConv(128, 2x2) -+ + """ + channels = 128 + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(1, 256, 36, 52)] + output_shapes = [(1, channels, 72, 104)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + nodes = [ + helper.make_node( + op_type="Conv", + inputs=["input_0", "conv1_weights", "conv1_bias"], + outputs=["conv1_out"], + name="conv1", + dilations=[1, 1], + group=1, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[1, 1], + ), + helper.make_node( + op_type="Relu", + inputs=["conv1_out"], + outputs=["relu1_out"], + name="relu1", + ), + helper.make_node( + op_type="Resize", + inputs=["relu1_out", "resize1_roi", "resize1_scales"], + outputs=["resize1_out"], + name="resize1", + coordinate_transformation_mode="half_pixel", + mode="nearest", + nearest_mode="round_prefer_ceil", + ), + # Main upsample path: depthwise 2x2 conv + helper.make_node( + op_type="Conv", + inputs=["resize1_out", "dw_conv1_weights"], + outputs=["dw_conv1_out"], + name="dw_conv1", + dilations=[1, 1], + group=channels, + kernel_shape=[2, 2], + pads=[0, 0, 1, 1], + strides=[1, 1], + ), + # Scaling path: 1x1 depthwise conv with zero weights + bias + helper.make_node( + op_type="Conv", + inputs=["resize1_out", "dw_conv2_weights", "dw_conv2_bias"], + outputs=["dw_conv2_out"], + name="dw_conv2", + dilations=[1, 1], + group=channels, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + ), + # Scaling path continued: depthwise 2x2 conv on the bias-only output + helper.make_node( + op_type="Conv", + inputs=["dw_conv2_out", "dw_conv3_weights"], + outputs=["dw_conv3_out"], + name="dw_conv3", + dilations=[1, 1], + group=channels, + kernel_shape=[2, 2], + pads=[0, 0, 1, 1], + strides=[1, 1], + ), + helper.make_node( + op_type="Mul", + inputs=["dw_conv1_out", "dw_conv3_out"], + outputs=["output_0"], + name="mul1", + ), + ] + + rng = np.random.default_rng(0) + initializers = [ + helper.make_tensor( + "conv1_weights", + onnx.TensorProto.FLOAT, + [channels, 256, 3, 3], + rng.standard_normal(channels * 256 * 3 * 3).astype(np.float32).tolist(), + ), + helper.make_tensor( + "conv1_bias", + onnx.TensorProto.FLOAT, + [channels], + rng.standard_normal(channels).astype(np.float32).tolist(), + ), + helper.make_tensor( + "resize1_roi", + onnx.TensorProto.FLOAT, + [0], + [], + ), + helper.make_tensor( + "resize1_scales", + onnx.TensorProto.FLOAT, + [4], + [1.0, 1.0, 2.0, 2.0], + ), + helper.make_tensor( + "dw_conv1_weights", + onnx.TensorProto.FLOAT, + [channels, 1, 2, 2], + rng.standard_normal(channels * 1 * 2 * 2).astype(np.float32).tolist(), + ), + helper.make_tensor( + "dw_conv2_weights", + onnx.TensorProto.FLOAT, + [channels, 1, 1, 1], + np.zeros(channels).astype(np.float32).tolist(), + ), + helper.make_tensor( + "dw_conv2_bias", + onnx.TensorProto.FLOAT, + [channels], + rng.standard_normal(channels).astype(np.float32).tolist(), + ), + helper.make_tensor( + "dw_conv3_weights", + onnx.TensorProto.FLOAT, + [channels, 1, 2, 2], + rng.standard_normal(channels * 1 * 2 * 2).astype(np.float32).tolist(), + ), + ] + + graph = helper.make_graph( + nodes, "small_grouped_conv", inputs, outputs, initializer=initializers + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 10 + + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred + + +def build_matmul_1xn_model(): + """A minimal MatMul-1xN (GEMV) graph. + + The MatMul has m=1 (2D input), making it a GEMV that is excluded from int8 + quantization by the GEMV detection in int8.py (line 168) when target_dla=False. + The exclusion is bypassed when target_dla=True. + + Using a 2D input [m, k] so the 2D output [m, n] satisfies the + ``len(dims) < 3 and any(dim == 1)`` branch in _exclude_matmuls_by_shape_inference + (the weight is a Constant initializer, not a Variable, so the dims[-2] == 1 + branch does not apply). + + Topology: + input [1, 32] -> MatMul([32, 64]) -> output [1, 64] + """ + m, k, n = 1, 32, 64 + input_names = ["input_0"] + output_names = ["output_0"] + input_shapes = [(m, k)] + output_shapes = [(m, n)] + + inputs = [ + helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape) + for input_name, input_shape in zip(input_names, input_shapes) + ] + outputs = [ + helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape) + for output_name, output_shape in zip(output_names, output_shapes) + ] + + nodes = [ + helper.make_node( + op_type="MatMul", + inputs=["input_0", "matmul_weights"], + outputs=["output_0"], + name="matmul1", + ), + ] + + rng = np.random.default_rng(0) + initializers = [ + helper.make_tensor( + "matmul_weights", + onnx.TensorProto.FLOAT, + [k, n], + rng.standard_normal(k * n).astype(np.float32).tolist(), + ), + ] + + graph = helper.make_graph(nodes, "matmul_1xn", inputs, outputs, initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 10 + + model_inferred = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model_inferred) + + return model_inferred diff --git a/tests/unit/onnx/quantization/test_qdq_rules_int8.py b/tests/unit/onnx/quantization/test_qdq_rules_int8.py index 5c4648c70fa..61ce89324b5 100644 --- a/tests/unit/onnx/quantization/test_qdq_rules_int8.py +++ b/tests/unit/onnx/quantization/test_qdq_rules_int8.py @@ -25,9 +25,11 @@ build_conv_isinf_model, build_conv_layernorm_model, build_convtranspose_conv_residual_model, + build_matmul_1xn_model, build_r1a_model, build_resnet_block, build_resnet_block_with_downsample, + build_small_grouped_conv_model, export_as_onnx, ) @@ -282,3 +284,54 @@ def test_conv_layernorm_quantization(tmp_path): f"LayerNorm activation input should come from DequantizeLinear, " f"but comes from {producer.op}. Conv->LayerNorm output quantization is missing!" ) + + +@pytest.mark.parametrize("target_dla", [False, True]) +def test_target_dla_conv(tmp_path, target_dla): + model = build_small_grouped_conv_model() + onnx_path = os.path.join(tmp_path, "model.onnx") + onnx.save(model, onnx_path) + + quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16", target_dla=target_dla) + + # Output model should be produced in the same tmp_path + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check quantized nodes + conv_nodes = [n for n in graph.nodes if "Conv" in n.op] + mul_nodes = [n for n in graph.nodes if "Mul" in n.op] + if target_dla: + # Check that all Convs and Mul nodes are quantized + assert assert_nodes_are_quantized(conv_nodes) + assert assert_nodes_are_quantized(mul_nodes) + else: + # Check that only the 1st Conv is quantized + assert assert_nodes_are_quantized([conv_nodes[0]]) + assert assert_nodes_are_not_quantized(mul_nodes) + + +@pytest.mark.parametrize("target_dla", [False, True]) +def test_target_dla_matmul(tmp_path, target_dla): + model = build_matmul_1xn_model() + onnx_path = os.path.join(tmp_path, "model.onnx") + onnx.save(model, onnx_path) + + quantize(onnx_path, quantize_mode="int8", high_precision_dtype="fp16", target_dla=target_dla) + + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + assert os.path.isfile(output_onnx_path) + + graph = gs.import_onnx(onnx.load(output_onnx_path)) + matmul_nodes = [n for n in graph.nodes if n.op == "MatMul"] + if target_dla: + # Check that MatMul is quantized + assert assert_nodes_are_quantized(matmul_nodes) + else: + # GEMV detection excludes the MatMul (m=1) from quantization. + assert assert_nodes_are_not_quantized(matmul_nodes)