diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py new file mode 100644 index 000000000..25b56c100 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pattern-Based Q/DQ Autotuning for ONNX Models. + +This package provides automated optimization of Quantize/Dequantize (Q/DQ) node placement +in ONNX computation graphs to minimize TensorRT inference latency. It uses pattern-based +region analysis to efficiently explore and optimize Q/DQ insertion strategies. + +**Key Features:** + +- **Automated Region Discovery**: Hierarchical decomposition of computation graphs into + LEAF and COMPOSITE regions with automatic pattern identification + +- **Pattern-Based Optimization**: Groups structurally-similar regions and optimizes them + together, making the process efficient and consistent + +- **TensorRT Performance Measurement**: Direct integration with TensorRT Python API for + accurate latency profiling of each Q/DQ configuration + +- **State Management**: Checkpoint/resume capability for long-running optimizations with + incremental state saving after each region + +- **Pattern Cache**: Warm-start optimization using learned schemes from previous runs, + enabling transfer learning across models + +**Core Components:** + +Autotuner Classes: + - QDQAutotuner: Main autotuner with automatic hierarchical region discovery + - QDQAutotunerBase: Base class for custom region identification strategies + +Region Management: + - Region: Hierarchical subgraph representation (nodes + children) + - RegionType: Enumeration (LEAF, COMPOSITE, ROOT) + - CombinedRegionSearch: Two-phase region discovery (partitioning + refinement) + - RegionPattern: Structural pattern analysis and matching for region grouping + +Q/DQ Insertion Points: + - InsertionScheme: Collection of Q/DQ insertion points for a region pattern + - NodeInputInsertionPoint: Q/DQ insertion at specific node inputs + - ChildRegionInputInsertionPoint: Q/DQ insertion at child region input boundaries + - RegionOutputInsertionPoint: Q/DQ insertion at region output boundaries + +Configuration & State: + - Config: Autotuning parameters (quant type, thresholds, verbosity) + - PatternCache: Top-performing schemes indexed by pattern (warm-start) + - PatternSchemes: Scheme collection and measurement results for a pattern + +Benchmarking: + - Benchmark: Abstract base class for model benchmarking + - TensorRTPyBenchmark: Benchmark using TensorRT Python API (recommended) + - TrtExecBenchmark: Benchmark using trtexec command-line tool (legacy) + +**Quick Start:** + + >>> from modelopt.onnx.quantization.autotune import QDQAutotuner, Config + >>> import onnx + >>> # Load model and initialize autotuner + >>> model = onnx.load("model.onnx") + >>> autotuner = QDQAutotuner(model) + >>> # Configure autotuning parameters + >>> config = Config(default_quant_type="int8") + >>> autotuner.initialize(config) + >>> # Generate and test Q/DQ schemes + >>> # (see workflows.region_pattern_autotuning_workflow for complete example) + +**Command-Line Interface:** + + The package can be run directly as a module: + + $ python -m modelopt.onnx.quantization.autotune --model model.onnx --output ./output + $ python -m modelopt.onnx.quantization.autotune --model model.onnx --quant-type fp8 + +**See Also:** + + - workflows.region_pattern_autotuning_workflow: Complete end-to-end optimization + - QDQAutotuner: Main autotuner class documentation + - RegionPattern: Pattern matching and signature computation +""" + +# Autotuner classes +from .autotuner import QDQAutotuner, QDQAutotunerBase + +# Benchmark classes +from .benchmark import Benchmark, TensorRTPyBenchmark, TrtExecBenchmark + +# Core data structures +from .common import ( + AutotunerError, + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionError, + RegionType, +) + +# Insertion points (from dedicated module) +from .insertion_points import ( + ChildRegionInputInsertionPoint, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, + ResolvedInsertionPoint, +) + +# Pattern analysis +from .region_pattern import RegionPattern + +# Region search +from .region_search import CombinedRegionSearch + +# Public API +__all__ = [ + # Exceptions + "AutotunerError", + "AutotunerNotInitializedError", + # Benchmark classes + "Benchmark", + "TensorRTPyBenchmark", + "TrtExecBenchmark", + # Configuration and state + "Config", + # Q/DQ insertion + "InsertionScheme", + "InvalidSchemeError", + "NodeInputInsertionPoint", + "ChildRegionInputInsertionPoint", + "RegionOutputInsertionPoint", + "ResolvedInsertionPoint", + # Main autotuner classes + "QDQAutotuner", + "QDQAutotunerBase", + # Region classes + "Region", + "RegionError", + "RegionPattern", + "RegionType", + "PatternCache", + "PatternSchemes", + "CombinedRegionSearch", +] diff --git a/modelopt/onnx/quantization/autotune/__main__.py b/modelopt/onnx/quantization/autotune/__main__.py new file mode 100644 index 000000000..127ee6ae6 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/__main__.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ONNX Q/DQ Autotuning Command-Line Interface. + +This module provides a command-line interface for automated Q/DQ (Quantize/Dequantize) +optimization of ONNX models. It uses pattern-based region analysis and TensorRT performance +measurement to find optimal Q/DQ insertion points that minimize inference latency. + +**Usage Examples:** + + # Basic usage - automatic region discovery and optimization + python -m modelopt.onnx.quantization.autotune --model model.onnx + + # INT8 vs FP8 quantization + python -m modelopt.onnx.quantization.autotune --model model.onnx --quant-type fp8 + + # Warm-start from pattern cache (transfer learning) + python -m modelopt.onnx.quantization.autotune \\ + --model model.onnx \\ + --pattern-cache ./output/pattern_cache.yaml + + # Import patterns from pre-quantized baseline model + python -m modelopt.onnx.quantization.autotune \\ + --model model.onnx \\ + --qdq-baseline quantized_baseline.onnx + + # Full example with all optimization options + python -m modelopt.onnx.quantization.autotune \\ + --model model.onnx \\ + --schemes-per-region 50 \\ + --pattern-cache pattern_cache.yaml \\ + --qdq-baseline baseline.onnx \\ + --output ./results \\ + --quant-type int8 \\ + --verbose + + # Use custom TensorRT plugins for model-specific operations + python -m modelopt.onnx.quantization.autotune \\ + --model model.onnx \\ + --plugin-libraries /path/to/plugin1.so /path/to/plugin2.so + +**Output Files:** + + output_dir/ + ├── autotuner_state.yaml # Checkpoint for resume capability + ├── baseline.onnx # Unquantized baseline model + ├── optimized_final.onnx # Final optimized model with Q/DQ + ├── logs/ # TensorRT build logs per scheme + │ ├── baseline.log + │ ├── region_*_scheme_*.log + │ └── final.log + └── region_models/ # Best model per region + └── region_*_level_*.onnx +""" + +import sys + +from modelopt.onnx.quantization.autotune.cli import get_autotune_parser, run_autotune + + +def main(): + """Command-line entry point for ONNX Q/DQ autotuning. + + Parses command-line arguments and executes the autotuning workflow. + + Returns: + Exit code from run_autotune (0 for success, non-zero for errors) + """ + parser = get_autotune_parser() + args = parser.parse_args() + + # Run autotuning + return run_autotune(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py new file mode 100644 index 000000000..c51fbf964 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -0,0 +1,2108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QDQ Autotuner - Automatic Q/DQ Insertion Optimization for ONNX Models. + +This module provides pattern-based automatic optimization of Quantize/Dequantize +(Q/DQ) node placement in ONNX computation graphs using iterative profiling and +performance measurement. + +**Core Functionality:** +- Identifies regions around compute-intensive operations (Conv, MatMul, Gemm, etc.) +- Generates and tests multiple Q/DQ insertion schemes per region pattern +- Measures performance and selects optimal configurations +- Applies best schemes to all regions matching each pattern + +**Pattern-Based Optimization:** +- Regions with identical structure share the same pattern signature +- Each pattern gets multiple InsertionScheme candidates tested +- Schemes use pattern-relative addressing (portable across matching regions) +- Best scheme per pattern applies to all regions with that structure + +**Typical Workflow:** +1. Initialize autotuner with ONNX model → regions discovered automatically +2. Measure baseline performance (optional but recommended) +3. For each region: generate schemes → export → measure → submit results +4. Export optimized model with best schemes applied + +**Classes:** +- QDQAutotuner: Default autotuner with automatic region discovery (use this) +- QDQAutotunerBase: Base class for custom region identification strategies +""" + +import copy +import logging +import os +import random +from collections import deque +from datetime import datetime, timezone + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import yaml + +from modelopt.onnx.quantization.autotune.common import ( + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionType, +) +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + merge_resolved_insertion_points, +) +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.autotune.region_search import CombinedRegionSearch +from modelopt.onnx.quantization.fp8 import int8_to_fp8 +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +# Module logger +logger = logging.getLogger(__name__) + + +class QDQAutotunerBase: + """Base class for pattern-based Q/DQ node insertion optimization in ONNX models. + + This base class provides core functionality for optimizing Quantize/Dequantize (Q/DQ) + node placement in ONNX models. It orchestrates scheme generation, performance profiling, + and model export, using pattern-based optimization where regions with identical structure + share insertion schemes. + + **Design:** + - Subclasses must populate `self.regions` with Region objects (e.g., via region search) + - Pattern-relative addressing: Schemes use indices relative to region structure + - Performance-driven selection: Measures and compares scheme latencies + - Best scheme per pattern: Optimal configuration applies to all matching regions + + **Key Attributes:** + - graph: ONNX GraphSurgeon representation of the model (clean copy) + - onnx_model: Original ONNX protobuf model + - regions: List of regions to optimize (populated by subclass) + - profiled_patterns: Patterns with tested schemes and performance results + - current_profile_region: Currently active region being tested + - current_profile_pattern_schemes: Currently active pattern schemes for the region + - config: Configuration for Q/DQ parameters and autotuning behavior + - pattern_cache: Pattern cache data for seeding schemes and tracking best results + + **Public API:** + - initialize(): Set up configuration and prepare for profiling + - set_profile_region(): Select region to profile and generate schemes + - generate(): Create new insertion scheme for current region + - export_onnx(): Export model with Q/DQ nodes (test scheme or best schemes) + - submit(): Record performance measurement for current scheme + - save_state(): Persist profiling results to file + - load_state(): Resume from previous session + + **Typical Workflow:** + 1. Subclass populates regions during/after initialization + 2. Measure baseline performance without Q/DQ + 3. For each region: set_profile_region() → generate() → export() → submit() + 4. Export final optimized model with best schemes + + **Note:** + Most users should use QDQAutotuner (subclass) which automatically searches for + regions based on common operations. Use QDQAutotunerBase directly only for + custom region identification strategies. + """ + + # ========================================================================= + # Initialization + # ========================================================================= + + def __init__(self, model: onnx.ModelProto | gs.Graph): + """Initialize the autotuner with an ONNX model. + + Creates a clean copy of the model graph and initializes internal state. + After construction, call initialize() to configure the autotuner, then + use a subclass strategy to populate regions (e.g., QDQAutotuner does this + automatically during initialize()). + + Args: + model: ONNX model (onnx.ModelProto) or graph (gs.Graph) to optimize. + A clean copy is created internally, leaving the original unchanged. + + Raises: + TypeError: If model is neither onnx.ModelProto nor gs.Graph + + Example: + >>> # Most users should use QDQAutotuner subclass + >>> autotuner = QDQAutotuner(model) + >>> autotuner.initialize() + """ + # Store ONNX model representation (needed for graph copying) + if isinstance(model, onnx.ModelProto): + self.onnx_model = model + elif isinstance(model, gs.Graph): + self.onnx_model = gs.export_onnx(model) + else: + raise TypeError(f"Expected onnx.ModelProto or gs.Graph, got {type(model)}") + + # Create clean graph copy (modifications won't affect original) + self.graph = self._copy_graph() + self.graph.tensor_users_map = get_tensor_consumer_node_indices(self.graph) + + # Region state (populated by subclass during/after initialize) + self.regions: list[Region] = [] + self.current_profile_region: Region | None = None + + # Pattern profiling state + self.profiled_patterns: list[PatternSchemes] = [] + self.current_profile_pattern_schemes: PatternSchemes | None = None + + # Current insertion scheme index (for generating new schemes) + self.current_insertion_scheme_index: int | None = None + + # Configuration (set properly in initialize()) + self.config = Config() + + # Session state + self.initialized = False + self.baseline_latency_ms: float | None = None + + # Pattern cache data (set in initialize()) + self.pattern_cache: PatternCache | None = None + + logger.debug(f"Initialized autotuner with model type: {type(model).__name__}") + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuning session with configuration and pattern cache. + + Prepares the autotuner for profiling by setting configuration parameters + and optionally loading pattern cache data. This base method resets all profiling + state and sets up the pattern cache storage. + + **Note:** This base class does NOT populate regions. Subclasses (e.g., QDQAutotuner) + should override this method to add region discovery after calling super().initialize(). + + After initialization, populate self.regions (in subclass), then use + set_profile_region() to begin testing schemes. + + Args: + config: Autotuning configuration parameters. If None, uses default Config(). + Controls Q/DQ parameters, performance thresholds, and scheme generation. + pattern_cache: Optional PatternCache object for seeding with known-good schemes. + If None, creates a new empty pattern cache for tracking best schemes. + If provided, uses existing schemes to warm-start optimization. + + Raises: + None (safe to call multiple times - will reset state each time) + + Example: + >>> # In subclass + >>> def initialize(self, config=None, pattern_cache=None): + >>> super().initialize(config, pattern_cache) + >>> # Add region discovery here + >>> self._search_regions() + """ + # Apply user configuration + if config is not None: + self.config = config + + # Set up pattern cache (for seeding schemes and tracking best results) + if pattern_cache is None: + # Create empty pattern cache with config settings + self.pattern_cache = PatternCache( + minimum_distance=self.config.pattern_cache_minimum_distance, + max_entries_per_pattern=self.config.pattern_cache_max_entries_per_pattern, + ) + else: + # Use provided pattern cache for warm-start + self.pattern_cache = pattern_cache + logger.debug( + f"Loaded pattern cache with {pattern_cache.num_patterns} patterns and " + f"{pattern_cache.total_schemes} schemes" + ) + + # Reset all profiling state (safe to call multiple times) + self.initialized = False + self.baseline_latency_ms = None + self.profiled_patterns.clear() + self.regions.clear() + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + + logger.info("Initializing autotuner") + logger.debug( + f"Configuration: q_scale={self.config.default_q_scale}, " + f"q_zero_point={self.config.default_q_zero_point}, quant_type={self.config.default_quant_type}" + ) + + # Mark as initialized + self.initialized = True + + def set_profile_region( + self, region: Region | None, commit: bool = True, per_region: bool = False + ) -> None: + """Set the target region for profiling and scheme generation. + + This method manages the profiling workflow: + 1. If commit=True: Saves current schemes to profiled_patterns + 2. Creates a RegionPattern from the new region's structure + 3. For pattern-based: tries to seed schemes from pattern cache if available + 4. Sets as current for generate() and submit() calls + + Pass region=None to clear the current profile target without setting a new one. + + **Workflow Pattern:** + - Call with commit=True (default) when moving between regions + - This commits previous region's results before starting new one + - Call with commit=False during initialization to avoid empty commits + + **Pattern Cache:** + - Automatically seeds from pattern cache if available for this pattern (pattern-based only) + - Remove profile result of seeded schemes (need profiling) + - If pattern already profiled, skips it + + Args: + region: The region to profile next (None to clear current target) + commit: If True, commit current schemes to profiled_patterns + before switching. Set to False during initialization. + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + + Example: + >>> # Pattern-based optimization (default) + >>> region = autotuner.regions[0] + >>> autotuner.set_profile_region(region) + >>> autotuner.generate() # Creates schemes for pattern + >>> # Per-region optimization + >>> region = autotuner.regions[1] + >>> autotuner.set_profile_region(region, per_region=True) + >>> autotuner.generate() # Creates schemes for this region only + """ + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + # Commit current pattern if requested + if commit: + if self.current_profile_pattern_schemes is not None: + num_schemes = len(self.current_profile_pattern_schemes.schemes) + best_scheme = self.current_profile_pattern_schemes.best_scheme + best_latency = best_scheme.latency_ms if best_scheme else float("inf") + + # Compute convergence metrics + samples_before_best, time_to_best = self._compute_convergence_metrics( + self.current_profile_pattern_schemes.schemes, best_scheme + ) + + logger.info( + f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" + ) + logger.debug( + f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" + ) + if samples_before_best is not None: + logger.debug(f"Convergence: best found at sample {samples_before_best}") + if time_to_best is not None: + logger.debug(f"Time to best: {time_to_best:.2f}s") + self.profiled_patterns.append(self.current_profile_pattern_schemes) + + if commit or region is None: + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + if region is None: + return + + # Validate region + if region not in self.regions: + raise ValueError(f"Region {region.id} not found in regions") + + # Create pattern for this region + region_pattern = RegionPattern.from_region(region, self.graph) + + # Check if pattern is already profiled - skip if so + if self._is_region_profiled(region): + logger.info(f"Skipping region {region.id} (pattern already profiled)") + logger.debug(f"Pattern signature: {region_pattern.signature}") + return + + # Try to seed from pattern cache + pattern_schemes = None + num_seeded = 0 + + if self.pattern_cache is not None: + cache_schemes = self.pattern_cache.get_pattern_schemes(region_pattern.signature) + + if cache_schemes is not None and len(cache_schemes.schemes) > 0: + # Create new PatternSchemes and seed it + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + + # Copy schemes from pattern cache and erase profile data + for cached_scheme in cache_schemes.schemes: + scheme_copy = copy.deepcopy(cached_scheme) + scheme_copy.latency_ms = float("inf") + scheme_copy.error = False + pattern_schemes.schemes.append(scheme_copy) + num_seeded += 1 + + logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") + else: + logger.debug("No pattern cache entries for this region") + + # Create empty PatternSchemes if not seeded from pattern cache + if pattern_schemes is None: + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + logger.debug("Initialized with empty scheme collection") + + # Set current region + self.current_profile_region = region + + # Set pattern schemes + self.current_profile_pattern_schemes = pattern_schemes + mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh" + logger.info( + f"Profiling region {region.id} [pattern mode, level {region.get_level()}, " + f"size {region.get_size()}, {mode_info}]" + ) + logger.debug(f"Pattern signature: {region_pattern.signature}") + + def generate(self) -> int: + """Generate a new Q/DQ insertion scheme for the current pattern or region. + + Creates a new InsertionScheme by mutating the top-performing schemes: + 1. Checks if there are any cached schemes (error=False, latency_ms=inf) + 2. If cached schemes exist, picks one to re-profile + 3. Otherwise, generates a new scheme by mutation + 4. Selects a random scheme from the top 10 performers + 5. Mutates it by adding/removing insertion points + 6. Ensures the new scheme is unique (different from existing schemes) + 7. Adds the scheme to current_profile_pattern_schemes + + The generated scheme includes both: + - node_inputs: Q/DQ at node inputs + - child_region_inputs: Q/DQ at child region boundaries (COMPOSITE only) + + After calling generate(), use export_onnx() to create a test model and + submit() to record its performance. + + Returns: + Index of the newly generated scheme in the active schemes collection, + or -1 if unable to generate a unique scheme after 100 attempts + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + InvalidSchemeError: If no region is currently set for profiling + (call set_profile_region() first) + + Example: + >>> autotuner.set_profile_region(region) + >>> # Generate and test multiple schemes + >>> for i in range(10): + >>> scheme_idx = autotuner.generate() + >>> if scheme_idx < 0: + >>> print("No more unique schemes") + >>> break + >>> autotuner.export_onnx(f"test_{i}.onnx") + >>> latency = benchmark(f"test_{i}.onnx") + >>> autotuner.submit(latency) + """ + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + # Determine which schemes collection is active (mutually exclusive) + if self.current_profile_pattern_schemes is not None: + schemes_collection = self.current_profile_pattern_schemes + else: + raise InvalidSchemeError( + "No pattern or region selected. Call set_profile_region() first." + ) + + pattern_schemes = schemes_collection + + # Check if there are any cached schemes (from pattern cache or previous runs) + cached_schemes = [ + (idx, scheme) + for idx, scheme in enumerate(pattern_schemes.schemes) + if not scheme.is_profiled + ] + + if cached_schemes: + # Re-profile a cached scheme + scheme_index, cached_scheme_data = cached_schemes[0] + + num_node_points = len(cached_scheme_data.node_inputs) + num_region_composite_points = len(cached_scheme_data.child_region_inputs) + num_region_output_points = len(cached_scheme_data.region_outputs) + total_points = num_node_points + num_region_composite_points + num_region_output_points + logger.info( + f"Scheme #{scheme_index + 1}: profiling cached scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Cached scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points ({len(cached_schemes)} cached schemes remaining)" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + # Generate a new scheme by mutation + # Collect known scheme hashes to avoid duplicates + known_schemes = {scheme.hash for scheme in pattern_schemes.schemes} + logger.debug(f"Generating new scheme ({len(pattern_schemes.schemes)} schemes exist)") + + max_attempts = getattr(self.config, "maximum_generation_attempts", 100) + + for attempts in range(max_attempts): + new_scheme = self._generate_next_insertion_sample() + + if new_scheme.hash not in known_schemes and not new_scheme.error: + # Found a unique, valid scheme + pattern_schemes.schemes.append(new_scheme) + scheme_index = len(pattern_schemes.schemes) - 1 + + num_node_points = len(new_scheme.node_inputs) + num_region_composite_points = len(new_scheme.child_region_inputs) + num_region_output_points = len(new_scheme.region_outputs) + total_points = ( + num_node_points + num_region_composite_points + num_region_output_points + ) + logger.info( + f"Scheme #{scheme_index + 1}: generated new scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points " + f"(hash: {new_scheme.hash[:16]}..., attempts: {attempts + 1})" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + # Failed to generate unique scheme + logger.warning(f"Could not generate unique scheme after {max_attempts} attempts") + return -1 + + def export_onnx( + self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False + ) -> bytes: + """Export ONNX model with Q/DQ nodes inserted according to tested schemes. + + This method creates a modified version of the model by: + 1. For each region, finding the matching pattern + 2. Applying the best scheme for profiled patterns + 3. Applying the current scheme for the active profile pattern + 4. Resolving pattern-relative insertion points to actual tensor names + 5. Inserting Q/DQ pairs at the resolved locations + 6. Converting to FP8 if needed (always creates INT8 first, then converts) + + **Scheme Selection Logic:** + - Profiled patterns: Uses best_scheme (lowest latency) + - Current profile pattern: Uses most recently generated scheme + - Unmatched patterns: No Q/DQ insertion + + **Parent-Child Coordination:** + - If a region's parent is profiled, skip inserting Q/DQ at region inputs + - Parent will handle boundary Q/DQ via CompositeRegionInsertionPoints + - Prevents duplicate Q/DQ at region boundaries + + Args: + output_path: Optional file path where the modified ONNX model will be saved. + If None, the model is not saved to disk and only bytes are returned. + insert_qdq: If True, insert Q/DQ nodes. If False, export unmodified model + (useful for baseline measurements) + + Returns: + bytes: Serialized ONNX model as bytes + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + + Example: + >>> # Export baseline (no Q/DQ) to file + >>> model_bytes = autotuner.export_onnx("baseline.onnx", insert_qdq=False) + >>> # Export with current test scheme to file + >>> autotuner.generate() + >>> model_bytes = autotuner.export_onnx("test.onnx", insert_qdq=True) + >>> # Export only to bytes without saving to file + >>> model_bytes = autotuner.export_onnx(None, insert_qdq=True) + >>> # Export final optimized model (all best schemes) + >>> model_bytes = autotuner.export_onnx("optimized.onnx", insert_qdq=True) + """ + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + output_desc = output_path if output_path is not None else "" + logger.debug( + f"Exporting model to {output_desc} (insert_qdq={insert_qdq}, " + f"regions={len(self.regions)}, profiled_patterns={len(self.profiled_patterns)})" + ) + + # Save original quant type for potential FP8 conversion + original_quant_type = self.config.default_quant_type + needs_fp8_conversion = insert_qdq and original_quant_type == "fp8" + + # Temporarily set quant type to int8 if FP8 is requested + if needs_fp8_conversion: + logger.debug("FP8 conversion: creating INT8 model first") + self.config.default_quant_type = "int8" + + # ===================================================================== + # Collect Q/DQ Insertion Points from Profiled Schemes + # ===================================================================== + # For each region, find the matching pattern and apply its best scheme. + # Pattern matching uses structural signatures, so the same pattern can + # apply to multiple regions with identical structure. + resolved_insertion_points = set() + + if insert_qdq: + logger.debug(f"Resolving Q/DQ insertion points from {len(self.regions)} regions") + matched_regions = 0 + + for region in self.regions: + # Create pattern signature for this region + pattern = RegionPattern.from_region(region, self.graph) + logger.debug(f"Region {region.id} (level {region.level})") + logger.debug(f" → Pattern signature: {pattern.signature}") + + current_scheme = None + for pattern_index, pattern_schemes in enumerate(self.profiled_patterns): + if pattern_schemes.pattern == pattern: + current_scheme = pattern_schemes.best_scheme + if current_scheme: + logger.debug( + f" → Matched profiled pattern #{pattern_index} " + f"(latency={current_scheme.latency_ms:.3f} ms)" + ) + else: + logger.debug( + f" → Matched profiled pattern #{pattern_index} but no valid schemes" + ) + break + + if current_scheme is None: + if ( + self.current_profile_pattern_schemes is None + or pattern != self.current_profile_pattern_schemes.pattern + ): + pass + elif best: + current_scheme = self.current_profile_pattern_schemes.best_scheme + else: + scheme_index = self.current_insertion_scheme_index + if scheme_index is None: + pass + else: + assert scheme_index < len( + self.current_profile_pattern_schemes.schemes + ), f"Invalid scheme index: {scheme_index}" + current_scheme = self.current_profile_pattern_schemes.schemes[ + scheme_index + ] + logger.debug(f" → Using current pattern scheme #{scheme_index}") + + if current_scheme is None and self.pattern_cache is not None: + pattern_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if pattern_schemes is not None: + schemes = pattern_schemes.schemes + if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: + current_scheme = schemes[0] + logger.debug(" → Using imported pattern from cache") + + # ------------------------------------------------------------- + # No matching pattern: skip this region + # ------------------------------------------------------------- + if current_scheme is None: + logger.debug(" → No scheme available, skipping") + continue + + # Remove these tensors if they were already added by profiled patterns/regions + # Current profile pattern has higher priority (more recent results) + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + assert full_insertion_scheme is not None + all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) + assert isinstance(all_region_ips, set) # matches returns set when scheme provided + excluded_tensors = all_region_ips - resolved_insertion_points + if excluded_tensors: + logger.debug( + f" → Excluded {len(excluded_tensors)} overlapping insertion points" + ) + + resolved_insertion_points.difference_update(all_region_ips) + + # ------------------------------------------------------------- + # Resolve pattern-relative insertion points to tensor names + # ------------------------------------------------------------- + # Pattern insertion points are relative (e.g., "node 2, input 0"). + # Resolve them to actual tensor names for this specific region. + new_ips = pattern.matches(region, self.graph, current_scheme) + assert isinstance(new_ips, set) # matches returns set when scheme provided + if new_ips: + resolved_insertion_points.update(new_ips) + matched_regions += 1 + logger.debug(f" → Added {len(new_ips)} insertion points") + + logger.debug( + f"Matched {matched_regions}/{len(self.regions)} regions, " + f"total {len(resolved_insertion_points)} unique insertion points" + ) + + # ===================================================================== + # Create Modified Graph with Q/DQ Nodes + # ===================================================================== + unique_tensors = len(resolved_insertion_points) + logger.debug(f"Inserting {unique_tensors} Q/DQ pairs into graph") + + # Create fresh graph copy (preserves original) + graph_copy = self._copy_graph() + + # Insert Q/DQ pairs at all collected tensor locations + if insert_qdq and resolved_insertion_points: + self._insert_qdq_at_tensors(graph_copy, resolved_insertion_points) + + # ===================================================================== + # Export to ONNX Format + # ===================================================================== + logger.debug("Serializing to ONNX format") + model = gs.export_onnx(graph_copy) + + # --------------------------------------------------------------------- + # Fix INT8 Zero-Point Initializers + # --------------------------------------------------------------------- + # ONNX requires INT8 zero_point to use int32_data field (4-byte aligned) + # instead of raw_data. This is a quirk of the ONNX format and required + # for correct INT8 and FP8 conversion. + if insert_qdq and resolved_insertion_points: + self._fix_zero_point_initializers(model) + + # --------------------------------------------------------------------- + # Convert INT8 to FP8 if Requested + # --------------------------------------------------------------------- + # FP8 quantization is a two-step process: + # 1. Create INT8 Q/DQ model (all tools understand INT8) + # 2. Convert INT8 to FP8 (specialized conversion utility) + # This approach ensures compatibility with ONNX tooling that may not + # natively support FP8 yet. + if needs_fp8_conversion: + logger.debug("Converting INT8 to FP8") + model = int8_to_fp8(model) + + # Restore original quantization type in config + self.config.default_quant_type = original_quant_type + + # Serialize to bytes + model_bytes = model.SerializeToString() + + # Save to file if output_path is provided + quant_type_str = "baseline" + if insert_qdq: + quant_type_str = f"{original_quant_type.upper()}" if needs_fp8_conversion else "INT8" + if output_path is not None: + onnx.save(model, output_path) + logger.info( + f"Exported {quant_type_str} model with {unique_tensors} Q/DQ pairs → {output_path}" + ) + else: + logger.info(f"Exported {quant_type_str} model with {unique_tensors} Q/DQ pairs") + + return model_bytes + + def submit(self, latency_ms: float, success: bool = True) -> None: + """Submit performance measurement for the most recently generated scheme. + + This method records the measured latency and manages the optimization state: + + **First Submission (Baseline):** + - Sets baseline_latency_ms for speedup calculations + - Does not modify any schemes + + **Subsequent Submissions:** + - Updates the most recently generated scheme's latency_ms + - Sets scheme's error flag based on success parameter + - Computes speedup relative to baseline (if successful) + - Sorts all schemes by latency (best schemes first) + - Logs results if config.verbose is True + + **Scheme Sorting:** + - Schemes are sorted by latency_ms (ascending) + - Unmeasured schemes (latency_ms = 0) go to the end + - This ensures best_scheme property returns optimal configuration + + **Optimization Mode:** + - Automatically detects whether pattern-based or per-region mode is active + - Commits to the appropriate collection (profiled_patterns) + - Mode is determined by set_profile_region(per_region=True/False) + + Args: + latency_ms: Measured latency in milliseconds (must be > 0) + success: Whether the measurement succeeded. If False, sets scheme.error=True, + logs a warning, and skips speedup calculation. + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + InvalidSchemeError: If no pattern or region is set, or no schemes have been generated + + Example: + >>> # Submit baseline + >>> autotuner.export_onnx("baseline.onnx", insert_qdq=False) + >>> autotuner.submit(benchmark("baseline.onnx")) # Sets baseline + >>> # Submit test measurements + >>> autotuner.set_profile_region(region) + >>> autotuner.generate() + >>> autotuner.export_onnx("test.onnx") + >>> latency = benchmark("test.onnx") + >>> autotuner.submit(latency) # Records to profiled_patterns + """ + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + # Handle baseline (first measurement establishes baseline) + if self.baseline_latency_ms is None: + self.baseline_latency_ms = latency_ms + logger.info(f"Baseline latency: {latency_ms:.3f} ms") + logger.debug("Baseline set for speedup calculations") + return + + # Determine which schemes collection is active (mutually exclusive) + if self.current_profile_pattern_schemes is not None: + schemes_collection = self.current_profile_pattern_schemes + else: + raise InvalidSchemeError( + "No pattern or region selected. Call set_profile_region() first." + ) + + # Check if there are schemes + if not schemes_collection.schemes: + raise InvalidSchemeError("No schemes available. Call generate() first.") + + pattern_schemes = schemes_collection + + # Update the scheme's latency + # Use current_insertion_scheme_index if set (handles both new and re-profiled schemes) + if ( + hasattr(self, "current_insertion_scheme_index") + and self.current_insertion_scheme_index is not None + ): + scheme_index = self.current_insertion_scheme_index + if scheme_index >= len(pattern_schemes.schemes): + raise InvalidSchemeError(f"Invalid scheme index: {scheme_index}") + scheme = pattern_schemes.schemes[scheme_index] + else: + # Fallback: use the last scheme (for backward compatibility) + scheme = pattern_schemes.schemes[-1] + scheme_index = len(pattern_schemes.schemes) - 1 + + scheme.latency_ms = latency_ms + scheme.error = not success + scheme.profile_timestamp = datetime.now(timezone.utc).isoformat() + # Display index is 1-based + display_index = scheme_index + 1 + + if not success: + logger.warning( + f"Scheme #{display_index}: measurement failed (latency={latency_ms:.3f} ms)" + ) + logger.debug("Marking scheme with error flag") + return + + # Compute speedup + speedup = self.baseline_latency_ms / latency_ms if latency_ms > 0 else 0.0 + + # Log results + logger.info(f"Scheme #{display_index}: {latency_ms:.3f} ms ({speedup:.2f}x speedup)") + logger.debug(f"Compared to baseline: {self.baseline_latency_ms:.3f} ms") + + # Sort schemes by latency (best first) + # Unmeasured schemes (latency_ms <= 0) go to the end + old_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + pattern_schemes.schemes.sort( + key=lambda s: s.latency_ms if s.latency_ms > 0 else float("inf") + ) + new_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + + if new_best < old_best: + new_speedup = self.baseline_latency_ms / new_best if new_best > 0 else 0.0 + logger.info(f" ★ New best: {new_best:.3f} ms ({new_speedup:.2f}x speedup)") + logger.debug(f"Previous best: {old_best:.3f} ms") + + # Update pattern cache with best schemes (only for pattern-based mode) + if self.current_profile_pattern_schemes is not None and self.pattern_cache is not None: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + logger.debug( + f"Pattern cache updated: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + # ========================================================================= + # State Management + # ========================================================================= + + def save_state(self, output_path: str) -> None: + """Save complete autotuner state to a YAML file for later reuse. + + Serializes all optimization results including: + - Baseline latency measurement + - All profiled patterns with their signatures + - All generated schemes with insertion points and latencies + - Configuration parameters + - Current profiling state + + Also saves pattern cache to a separate file with the suffix "_pattern_cache.yaml" + containing only the best schemes per pattern (if any patterns were profiled). + + The saved state can be loaded with load_state() to: + - Resume an interrupted optimization session + - Reuse results on a similar model + - Analyze optimization history + + **Note:** The state file contains pattern signatures and performance data, + but not the actual ONNX model or graph structure. + + Args: + output_path: File path where the YAML state file will be written. + Pattern cache will be saved to _pattern_cache.yaml + + Example: + >>> # Save after profiling some regions + >>> autotuner.save_state("checkpoint.yaml") + >>> # Creates: checkpoint.yaml and checkpoint_pattern_cache.yaml + >>> # Save final results + >>> autotuner.save_state("final_state.yaml") + >>> # Creates: final_state.yaml and final_state_pattern_cache.yaml + """ + # Save current_profile_pattern_schemes as pattern signature + current_pattern_sig = None + if self.current_profile_pattern_schemes is not None: + current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature + + state = { + "baseline_latency_ms": self.baseline_latency_ms, + "current_profile_pattern_schemes_signature": current_pattern_sig, + "config": { + "default_q_scale": self.config.default_q_scale, + "default_q_zero_point": self.config.default_q_zero_point, + "default_quant_type": self.config.default_quant_type, + "verbose": self.config.verbose, + }, + "patterns": [pattern_schemes.to_dict() for pattern_schemes in self.profiled_patterns], + } + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + num_patterns = len(self.profiled_patterns) + total_schemes = sum(len(p.schemes) for p in self.profiled_patterns) + logger.info( + f"Saved state → {output_path} ({num_patterns} patterns, {total_schemes} schemes)" + ) + logger.debug(f"State: baseline={self.baseline_latency_ms:.3f} ms") + + # Save pattern cache to separate file if it has patterns + if self.pattern_cache is not None and self.pattern_cache.num_patterns > 0: + # Generate pattern cache path: _pattern_cache.yaml + base_path, ext = os.path.splitext(output_path) + cache_path = f"{base_path}_pattern_cache{ext}" + + self.pattern_cache.save(cache_path) + logger.info(f"Saved pattern cache → {cache_path}") + logger.debug( + f"Cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def load_state(self, input_path: str) -> None: + """Load autotuner state from a previously saved YAML file. + + Restores optimization results from a previous session: + 1. Matches saved patterns to current model's patterns by signature + 2. Loads all schemes with their insertion points and latencies (including unmeasured ones) + 3. Restores baseline latency and configuration + + **Requirements:** + - Autotuner must be initialized first (model loaded, regions built) + - Saved patterns must match current model's structure + - Pattern matching is done by signature, not by index + + **Use Cases:** + - Resume interrupted optimization + - Apply previous results to similar model + - Start from checkpoint instead of scratch + + **Compatibility:** + - Skips patterns from saved state that don't match current model + - Warns about mismatched pattern sizes + - Backward compatible with older state file formats + + Args: + input_path: File path to the YAML state file to load + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + FileNotFoundError: If the input_path doesn't exist + + Example: + >>> # Load and resume from checkpoint + >>> autotuner = QDQAutotunerBase(model) + >>> autotuner.initialize() + >>> autotuner.load_state("checkpoint.yaml") + >>> # Continue profiling where you left off + >>> # Reuse results on similar model + >>> autotuner2 = QDQAutotunerBase(similar_model) + >>> autotuner2.initialize() + >>> autotuner2.load_state("final_state.yaml") + """ + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + with open(input_path) as f: + state = yaml.safe_load(f) + + # Load baseline latency + if state.get("baseline_latency_ms") is not None: + self.baseline_latency_ms = state["baseline_latency_ms"] + logger.debug(f"Baseline latency: {self.baseline_latency_ms:.3f} ms") + + # Load config (optional, merge with existing) + if "config" in state: + config_data = state["config"] + if "default_q_scale" in config_data: + self.config.default_q_scale = config_data["default_q_scale"] + if "default_q_zero_point" in config_data: + self.config.default_q_zero_point = config_data["default_q_zero_point"] + if "default_quant_type" in config_data: + self.config.default_quant_type = config_data["default_quant_type"] + if "verbose" in config_data: + self.config.verbose = config_data["verbose"] + logger.debug(f"Config merged: quant_type={self.config.default_quant_type}") + + # Load profiled patterns + if "patterns" in state: + num_loaded_patterns = 0 + num_loaded_schemes = 0 + + for pattern_data in state["patterns"]: + try: + pattern_schemes = PatternSchemes.from_dict(pattern_data) + + if pattern_schemes.schemes: # Only add if it has schemes + self.profiled_patterns.append(pattern_schemes) + num_loaded_patterns += 1 + num_loaded_schemes += len(pattern_schemes.schemes) + else: + logger.debug( + f"Skipped empty pattern {pattern_schemes.pattern_signature[:16]}..." + ) + + except Exception as e: # noqa: PERF203 + logger.warning(f"Failed to load pattern: {e}") + continue + + logger.info( + f"Loaded state from {input_path} ({num_loaded_patterns} patterns, " + f"{num_loaded_schemes} schemes)" + ) + + # Try to load pattern cache if it exists + base_path, ext = os.path.splitext(input_path) + cache_path = f"{base_path}_pattern_cache{ext}" + + if os.path.exists(cache_path): + try: + loaded_cache = PatternCache.load(cache_path) + if self.pattern_cache is None: + self.pattern_cache = loaded_cache + else: + # Merge with existing pattern cache + for pattern_schemes in loaded_cache.pattern_schemes: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + logger.info( + f"Loaded pattern cache from {cache_path} ({loaded_cache.num_patterns} patterns, " + f"{loaded_cache.total_schemes} schemes)" + ) + except Exception as e: + logger.warning(f"Failed to load pattern cache: {e}") + else: + logger.debug(f"No pattern cache file at {cache_path}") + + def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: + """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache. + + Analyzes the current model's regions against the provided quantized tensors + to extract Q/DQ insertion patterns. For each region, creates a pattern cache + entry that captures which insertion points correspond to the quantized tensors. + These cached patterns can then be used as seeds for future autotuning sessions. + + **Use Cases:** + - Import quantization strategy from an existing quantized model + - Seed pattern cache with known-good configurations + - Transfer quantization patterns across similar models + - Bootstrap autotuning with expert knowledge + + **Workflow:** + 1. Convert input to set for efficient lookup + 2. Iterate through all discovered regions (both LEAF and COMPOSITE) + 3. For each region, call pattern_cache.add_pattern_from_region() + 4. Pattern cache automatically handles deduplication and merging + + **Requirements:** + - Autotuner must be initialized first (regions must be discovered) + - Quantized tensors should correspond to actual tensor names in the graph + + Args: + quantized_tensors: Set or list of tensor names that are quantized + (i.e., tensors that have Q/DQ nodes applied to them) + + Raises: + AutotunerNotInitializedError: If initialize() hasn't been called + + Example: + >>> # Import from an existing quantized model + >>> import onnx + >>> # Load quantized model and extract quantized tensor names + >>> # (e.g., by finding first inputs of QuantizeLinear nodes) + >>> quantized_model = onnx.load("quantized_model.onnx") + >>> quantized_tensors = set() + >>> for node in quantized_model.graph.node: + ... if node.op_type == "QuantizeLinear": + ... quantized_tensors.add(node.input[0]) + >>> # Initialize autotuner on new model and import patterns + >>> autotuner = QDQAutotuner(new_model) + >>> autotuner.initialize() + >>> autotuner.import_insertion_points(quantized_tensors) + >>> # Pattern cache now contains learned insertion patterns + >>> print(f"Imported {autotuner.pattern_cache.num_patterns} patterns") + >>> autotuner.pattern_cache.save("imported_patterns.yaml") + """ + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + # Convert to set for efficient lookup + if isinstance(quantized_tensors, list): + quantized_tensors = set(quantized_tensors) + + logger.info(f"Importing insertion points from {len(quantized_tensors)} quantized tensors") + logger.debug(f"Processing {len(self.regions)} regions") + + if self.pattern_cache is None: + logger.warning("Pattern cache not initialized, skipping import") + return + + # Track statistics + patterns_before = self.pattern_cache.num_patterns + schemes_before = self.pattern_cache.total_schemes + + # Process all regions (both LEAF and COMPOSITE) + for region in self.regions: + self.pattern_cache.add_pattern_from_region(region, self.graph, quantized_tensors) + + # Log results + patterns_added = self.pattern_cache.num_patterns - patterns_before + schemes_added = self.pattern_cache.total_schemes - schemes_before + + logger.info( + f"Import complete: {patterns_added} patterns, {schemes_added} schemes added to cache" + ) + logger.debug( + f"Total cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + # ========================================================================= + # Private Helper Methods + # ========================================================================= + + def _compute_convergence_metrics( + self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None + ) -> tuple[int | None, float | None]: + """Compute convergence metrics for a collection of schemes. + + Analyzes when the best scheme was discovered during the profiling process + by sorting schemes by their profile timestamps and finding the position + of the best scheme. + + Args: + schemes: List of insertion schemes with profile timestamps + best_scheme: The best performing scheme (lowest latency) + + Returns: + Tuple of (samples_before_best, time_to_best) where: + - samples_before_best: Number of samples tested before finding best (0-based index) + - time_to_best: Time in seconds from first sample to best sample + Both values are None if metrics cannot be computed (e.g., missing timestamps) + """ + samples_before_best = None + time_to_best = None + + if not best_scheme or not best_scheme.profile_timestamp: + return samples_before_best, time_to_best + + # Get schemes with timestamps, sorted by timestamp + schemes_with_time = [s for s in schemes if s.profile_timestamp is not None] + + if not schemes_with_time: + return samples_before_best, time_to_best + + from datetime import datetime + + schemes_with_time.sort(key=lambda s: s.profile_timestamp or "") + + # Find position of best scheme in time-sorted list + try: + best_position = next( + i for i, s in enumerate(schemes_with_time) if s.hash == best_scheme.hash + ) + samples_before_best = best_position + + # Compute time difference + first_ts = schemes_with_time[0].profile_timestamp + best_ts = best_scheme.profile_timestamp + assert first_ts is not None and best_ts is not None + first_timestamp = datetime.fromisoformat(first_ts) + best_timestamp = datetime.fromisoformat(best_ts) + time_to_best = (best_timestamp - first_timestamp).total_seconds() + except (StopIteration, ValueError): + pass + + return samples_before_best, time_to_best + + def _is_region_profiled(self, region: Region) -> bool: + """Check if a region's pattern has already been fully profiled.""" + + def match_pattern(pattern: PatternSchemes, region: Region) -> bool: + """Check if a pattern matches a region.""" + if pattern.pattern is None or not pattern.pattern.matches(region, self.graph): + return False + return not any(not scheme.is_profiled for scheme in pattern.schemes) + + return any(match_pattern(pattern, region) for pattern in self.profiled_patterns) + + # --- Scheme Generation --- + + def _mutate_insertion_points( + self, base_points, all_points, point_type: str, max_mutations: int + ) -> list: + """Mutate a set of insertion points by adding, removing, or both. + + Args: + base_points: Set of tuples representing current insertion points + all_points: List of all possible insertion point objects + point_type: Type of insertion points (for logging) + max_mutations: Maximum number of mutations per operation + + Returns: + List of mutated insertion point objects + """ + current_points = set(base_points) + initial_count = len(current_points) + + # Randomly choose mutation type + mutation_type = random.choice(["add", "remove", "both"]) + + # Add points + if mutation_type in ["add", "both"] and len(current_points) < len(all_points): + # Get point keys based on type + all_keys: set[tuple] = set() + if point_type == "node input points": + all_keys = {(p.node_index, p.input_index) for p in all_points} + elif point_type == "region composite points": + all_keys = {(p.region_index, p.input_index) for p in all_points} + elif point_type == "region output points": + all_keys = {(p.region_index, p.node_index, p.output_index) for p in all_points} + + available_keys = all_keys - current_points + if available_keys: + max_add = min(max_mutations, len(available_keys)) + num_to_add = random.randint(1, max_add) + to_add = random.sample(list(available_keys), num_to_add) + current_points.update(to_add) + + # Remove points + if mutation_type in ["remove", "both"] and current_points: + max_remove = min(max_mutations, len(current_points)) + num_to_remove = random.randint(1, max_remove) if len(current_points) > 1 else 1 + num_to_remove = min(num_to_remove, len(current_points)) + to_remove = random.sample(list(current_points), num_to_remove) + for p in to_remove: + current_points.discard(p) + + logger.debug( + f"Mutated {point_type}: {initial_count} → {len(current_points)} ({mutation_type})" + ) + + # Convert back to insertion point objects based on type + if point_type == "node input points": + return [p for p in all_points if (p.node_index, p.input_index) in current_points] + elif point_type == "region composite points": + return [p for p in all_points if (p.region_index, p.input_index) in current_points] + elif point_type == "region output points": + return [ + p + for p in all_points + if (p.region_index, p.node_index, p.output_index) in current_points + ] + else: + return [] + + def _generate_next_insertion_sample(self) -> InsertionScheme: + """Generate a new insertion scheme by mutating top performers. + + This is the core scheme generation algorithm: + 1. Identifies top schemes by latency + 2. Randomly selects one as the base + 3. Mutates node input insertion points (add, remove, or both) + 4. Mutates region composite insertion points (child boundaries) + 5. Mutates region output insertion points + 6. Returns new unique scheme + + **Mutation Strategy:** + - Node input points: Add/remove 1-3 insertion points + - Region composite points: Add/remove 1-3 boundary points + - Region output points: Add/remove 1-3 output points + - Mutation type chosen randomly: 'add', 'remove', or 'both' + + **Baseline Case:** + If no schemes exist yet, returns an empty baseline scheme. + + Returns: + New InsertionScheme with mutated insertion points. + Returns empty scheme if no region is set or no candidates exist. + """ + # Validate current profile region is set + if self.current_profile_region is None: + return InsertionScheme() + + # Determine which schemes collection is active (mutually exclusive) + if self.current_profile_pattern_schemes is not None: + schemes_collection = self.current_profile_pattern_schemes + else: + return InsertionScheme() + + region = self.current_profile_region + pattern_schemes = schemes_collection + + # Get the pattern + pattern = None + if isinstance(schemes_collection, PatternSchemes): + pattern = schemes_collection.pattern + if pattern is None: + return InsertionScheme() + # Get all possible insertion points for this region + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + + logger.debug( + f"Available insertion points: {len(full_insertion_scheme.node_inputs)} node input, " + f"{len(full_insertion_scheme.child_region_inputs)} region composite, " + f"{len(full_insertion_scheme.region_outputs)} region output" + ) + + # Get top-performing schemes + top_percent = getattr(self.config, "top_percent_to_mutate", 0.1) + minimum_schemes = getattr(self.config, "minimum_schemes_to_mutate", 1) + + # Filter measured schemes + measured_schemes = [s for s in pattern_schemes.schemes if s.latency_ms > 0 and not s.error] + measured_schemes.sort(key=lambda s: s.latency_ms) + + num_top_schemes = max( + int(len(measured_schemes) * top_percent), min(minimum_schemes, len(measured_schemes)) + ) + top_schemes = measured_schemes[:num_top_schemes] + + # Return empty baseline if no schemes exist + if len(top_schemes) == 0: + logger.debug("No measured schemes yet, generating baseline (empty) scheme") + return InsertionScheme() + + # Select base scheme from top performers + base_scheme = random.choice(top_schemes) + total_base_points = ( + len(base_scheme.node_inputs) + + len(base_scheme.child_region_inputs) + + len(base_scheme.region_outputs) + ) + logger.debug( + f"Mutating from top {len(top_schemes)} schemes: " + f"selected base with {total_base_points} points (latency={base_scheme.latency_ms:.3f} ms)" + ) + + # Create new scheme + scheme = InsertionScheme() + + max_mutations = getattr(self.config, "maximum_mutations", 3) + + # Mutate node input insertion points + base_node_points = {(p.node_index, p.input_index) for p in base_scheme.node_inputs} + scheme.node_inputs = self._mutate_insertion_points( + base_node_points, full_insertion_scheme.node_inputs, "node input points", max_mutations + ) + + # Mutate region composite insertion points + base_region_composite_points = { + (p.region_index, p.input_index) for p in base_scheme.child_region_inputs + } + scheme.child_region_inputs = self._mutate_insertion_points( + base_region_composite_points, + full_insertion_scheme.child_region_inputs, + "region composite points", + max_mutations, + ) + + # Mutate region output insertion points + base_region_output_points = { + (p.region_index, p.node_index, p.output_index) for p in base_scheme.region_outputs + } + scheme.region_outputs = self._mutate_insertion_points( + base_region_output_points, + full_insertion_scheme.region_outputs, + "region output points", + max_mutations, + ) + + return scheme + + # --- Graph Manipulation --- + + def _copy_graph(self) -> gs.Graph: + """Create an independent copy of the computation graph. + + Exports the original model to ONNX and imports it back to create + a fresh graph instance. This ensures modifications don't affect + the original graph. + + Returns: + New gs.Graph instance with identical structure to the original + """ + new_graph = gs.import_onnx(self.onnx_model) + new_graph.toposort() + return new_graph + + def _get_quant_dtype(self, quant_type: str) -> np.dtype: + """Get numpy dtype for quantization type. + + Args: + quant_type: Quantization type string ("int8", "fp8") + + Returns: + Numpy dtype for the quantization type + + Note: + FP8 support requires numpy >= 2.0. If not available, falls back to a + compatible representation. + """ + # Handle FP8 with version check + if quant_type == "fp8": + try: + # Try to get FP8 dtype (numpy >= 2.0) + return np.dtype(np.float8_e4m3fn) + except (AttributeError, TypeError): + logger.warning( + "FP8 dtype not available (requires numpy >= 2.0), " + "using uint8 as placeholder. Note: This may not produce " + "correct results without proper FP8 support." + ) + return np.uint8 + + dtype_map = { + "int8": np.int8, + "uint8": np.uint8, + } + + if quant_type not in dtype_map: + logger.warning(f"Unknown quantization type '{quant_type}', defaulting to int8") + return np.int8 + + return dtype_map[quant_type] + + def _get_dq_output_dtype(self, dtype_str: str) -> np.dtype: + """Convert DQ dtype string to numpy dtype. + + Args: + dtype_str: Dtype string ("float16", "float32", "bfloat16") + + Returns: + Numpy dtype for the DQ output type + """ + dtype_map = { + "float16": np.float16, + "float32": np.float32, + } + + # Handle bfloat16 if available + if hasattr(np, "bfloat16"): + dtype_map["bfloat16"] = np.bfloat16 + + if dtype_str not in dtype_map: + logger.warning(f"Unknown DQ dtype '{dtype_str}', defaulting to float32") + return np.float32 + + return dtype_map[dtype_str] + + def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]: + """Build mapping from tensor names to tensor objects. + + Args: + graph: Graph to extract tensors from + + Returns: + Dictionary mapping tensor names to tensor objects (Variables or Constants) + """ + tensor_map = {} + + # Map node outputs (Variables) + for node in graph.nodes: + for output in node.outputs: + if hasattr(output, "name") and output.name: + tensor_map[output.name] = output + + # Map graph inputs (Variables) + for input_tensor in graph.inputs: + if hasattr(input_tensor, "name") and input_tensor.name: + tensor_map[input_tensor.name] = input_tensor + + # Map initializers/constants + for node in graph.nodes: + for input_tensor in node.inputs: + if ( + isinstance(input_tensor, gs.Constant) + and hasattr(input_tensor, "name") + and input_tensor.name + ): + tensor_map[input_tensor.name] = input_tensor + + return tensor_map + + def _get_tensor_metadata( + self, tensor: gs.Tensor, is_constant: bool + ) -> tuple[tuple | None, np.dtype]: + """Extract shape and dtype metadata from a tensor. + + Args: + tensor: Tensor to extract metadata from + is_constant: Whether the tensor is a Constant + + Returns: + Tuple of (shape, dtype) where shape may be None if unknown + """ + default_dtype = self._get_dq_output_dtype(self.config.default_dq_dtype) + if is_constant and hasattr(tensor, "values") and tensor.values is not None: + return tensor.values.shape, tensor.values.dtype + elif hasattr(tensor, "shape"): + dtype = ( + tensor.dtype + if hasattr(tensor, "dtype") and tensor.dtype is not None + else default_dtype + ) + return tensor.shape, dtype + else: + return None, default_dtype + + def _fix_zero_point_initializers(self, model: onnx.ModelProto) -> None: + """Fix INT8 zero_point initializers to use int32_data instead of raw_data. + + For INT8 tensors, ONNX stores the data in int32_data field with 4-byte alignment, + not in raw_data. This is needed because int8_to_fp8 expects zero_point.int32_data + to be populated. + + Args: + model: ONNX model to fix + """ + fixed_count = 0 + + for initializer in model.graph.initializer: + # Check if this is a zero_point tensor (q_zp_ or dq_zp_) + if ( + "_zp_" in initializer.name + and initializer.data_type == onnx.TensorProto.INT8 + and len(initializer.raw_data) > 0 + and len(initializer.int32_data) == 0 + ): + # Convert raw_data to int32_data (4-byte aligned) + np_array = onnx.numpy_helper.to_array(initializer) + # Store INT8 values in int32_data field (4-byte aligned) + int32_values = np_array.astype(np.int32).flatten().tolist() + + new_tensor = onnx.helper.make_tensor( + initializer.name, + onnx.TensorProto.INT8, + list(initializer.dims), + int32_values, # This populates int32_data instead of raw_data + ) + initializer.CopyFrom(new_tensor) + fixed_count += 1 + + if fixed_count > 0: + logger.debug(f"Fixed {fixed_count} zero_point initializers (int32_data format)") + + def _create_qdq_nodes( + self, + tensor_name: str, + qdq_input: gs.Tensor, + output_shape: tuple | None, + output_dtype: np.dtype, + quant_dtype: np.dtype, + quant_type: str, + q_scale: float, + ) -> tuple[gs.Node, gs.Node]: + """Create QuantizeLinear and DequantizeLinear node pair. + + Args: + tensor_name: Name of the tensor being quantized + qdq_input: Input tensor to the Q node + output_shape: Shape for Q/DQ outputs (may be None) + output_dtype: Dtype for DQ output (also used for scale dtype) + quant_dtype: Dtype for quantized values + quant_type: Quantization type string + q_scale: Quantization scale + + Returns: + Tuple of (q_node, dq_node) + """ + # Create unique names for Q/DQ nodes + q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_") + dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_") + + # Determine scale dtype from output_dtype (fp16/tf32/fp32) + # Scale should match the precision of the original I/O tensor + # Note: output_dtype can be either a numpy type class (np.float16) or dtype instance (dtype('float16')) + # Use np.dtype().name for consistent comparison + dtype_name = np.dtype(output_dtype).name + if dtype_name == "float16": + scale_dtype = np.float16 + elif dtype_name == "float32": + scale_dtype = np.float32 + elif dtype_name == "bfloat16" and hasattr(np, "bfloat16"): + scale_dtype = np.bfloat16 + else: + scale_dtype = np.float32 + + logger.debug( + f"Creating Q/DQ pair for '{tensor_name}' (scale_dtype={np.dtype(scale_dtype).name})" + ) + + # Build QuantizeLinear inputs: [input, scale, zero_point] + # Scale and zero_point must be proper ONNX initializers + q_scale_values = np.array([q_scale], dtype=scale_dtype) + q_zp_values = np.array([0], dtype=np.int8) + + q_inputs = [ + qdq_input, + gs.Constant(f"q_scale_{tensor_name}", values=q_scale_values), + gs.Constant(f"q_zp_{tensor_name}", values=q_zp_values), + ] + + q_node = gs.Node( + op="QuantizeLinear", + name=q_name, + inputs=q_inputs, + outputs=[ + gs.Variable(f"{tensor_name}_quantized", dtype=quant_dtype, shape=output_shape) + ], + ) + + # Build DequantizeLinear inputs: [quantized_input, scale, zero_point] + # Scale and zero_point must be proper ONNX initializers + dq_scale_values = np.array([q_scale], dtype=scale_dtype) + dq_zp_values = np.array([0], dtype=np.int8) + + dq_inputs = [ + q_node.outputs[0], + gs.Constant(f"dq_scale_{tensor_name}", values=dq_scale_values), + gs.Constant(f"dq_zp_{tensor_name}", values=dq_zp_values), + ] + + dq_node = gs.Node( + op="DequantizeLinear", + name=dq_name, + inputs=dq_inputs, + outputs=[ + gs.Variable(f"{tensor_name}_dequantized", dtype=output_dtype, shape=output_shape) + ], + ) + + return q_node, dq_node + + def _insert_qdq_at_tensors( + self, graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] + ) -> None: + """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations. + + This is the main entry point for Q/DQ insertion. It: + 1. Builds tensor map and tensor-to-users map for efficient lookup + 2. Processes each resolved insertion point to insert Q/DQ nodes + 3. Handles two insertion modes based on node_index + + **Insertion Modes:** + + 1. **Node-level insertion** (node_index is set, input_index is set): + - Inserts Q/DQ for a specific node's specific input connection + - Only rewires that one node-tensor connection + - Multiple Q/DQ pairs can be created for the same tensor at different nodes + - Naming: `{tensor_name}_n{node_index}_i{input_index}` + - Use case: Fine-grained control over quantization boundaries + + 2. **Tensor-level insertion** (node_index=None, input_index=None): + - Inserts one Q/DQ pair for the entire tensor + - Rewires ALL users of the tensor to use the same DQ output + - Only one Q/DQ pair created regardless of number of users + - Naming: `{tensor_name}_qdq` + - Use case: Quantize a tensor once when it feeds multiple nodes + + **Validation:** + - When node_index is set: input_index must also be set + - When node_index is None: input_index must be None + - All validations use assertions (failures indicate programming errors) + + **Handling for Constants:** + - Q/DQ nodes can be inserted directly on Constant tensors (weights, biases) + - No conversion needed since QuantizeLinear accepts Constant inputs + + **Quantization Parameters:** + - Uses config.default_quant_type for quantization type ("int8", "fp8") + - Uses config.default_q_scale for quantization scale + - Zero-point is always set to 0 (int8) for all quantization types + - Creates separate constants for each Q/DQ pair + + Args: + graph: Graph to modify in-place + resolved_insertion_points: Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ + """ + # Extract quantization parameters + q_scale = self.config.default_q_scale + quant_type = self.config.default_quant_type + quant_dtype = self._get_quant_dtype(quant_type) + + logger.debug(f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point=0") + + resolved_insertion_points = merge_resolved_insertion_points( + graph, resolved_insertion_points + ) + + # Build tensor name → tensor object mapping + tensor_map = self._build_tensor_map(graph) + tensor_users_map = get_tensor_consumer_node_indices(graph) + logger.debug( + f"Built tensor maps: {len(tensor_map)} tensors, {len(tensor_users_map)} with users" + ) + + # Process each resolved insertion point + for insertion_point in resolved_insertion_points: + tensor_name = insertion_point.tensor_name + node_index = insertion_point.node_index + input_index = insertion_point.input_index + + original_tensor = tensor_map[tensor_name] + # Validate input/output index + if node_index is not None: + assert node_index < len(graph.nodes), "Node index out of range" + target_node = graph.nodes[node_index] + assert input_index is not None, "Input index must be set when node index is set" + assert input_index < len(target_node.inputs), ( + f"Input index out of range for node {target_node.name}" + ) + original_tensor = target_node.inputs[input_index] + assert tensor_name == original_tensor.name, ( + f"Tensor name mismatch for node {target_node.name} input {input_index}" + ) + else: + assert tensor_name in tensor_map, f"Tensor {tensor_name} not found in tensor map" + assert input_index is None, "Input index must be None when node index is None" + + # Get node and tensor + is_constant = isinstance(original_tensor, gs.Constant) + + # Extract tensor metadata (shape, dtype) + output_shape, output_dtype = self._get_tensor_metadata(original_tensor, is_constant) + + # Create unique Q/DQ node pair for this specific insertion point + unique_suffix = "qdq" + if node_index is not None: + unique_suffix = f"n{node_index}_i{input_index}" + unique_tensor_name = f"{tensor_name}_{unique_suffix}" + + # Create Q/DQ node pair + q_node, dq_node = self._create_qdq_nodes( + unique_tensor_name, + original_tensor, + output_shape, + output_dtype, + quant_dtype, + quant_type, + q_scale, + ) + + # Add nodes to graph + graph.nodes.extend([q_node, dq_node]) + + # Rewire only the specific node-tensor connection + if node_index is not None: + # Insert QDQ between the producer and this specific input + target_node.inputs[input_index] = dq_node.outputs[0] + logger.debug( + f" Q/DQ inserted: tensor '{tensor_name}' → node #{node_index} " + f"({target_node.name}) input #{input_index}" + ) + else: + users = tensor_users_map[tensor_name] + for user_index in users: + user_node = graph.nodes[user_index] + for i, input_tensor in enumerate(user_node.inputs): + if hasattr(input_tensor, "name") and input_tensor.name == tensor_name: + user_node.inputs[i] = dq_node.outputs[0] + break + logger.debug(f" Q/DQ inserted: tensor '{tensor_name}' → {len(users)} users") + + # Cleanup and toposort + logger.debug("Running graph cleanup and topological sort") + try: + graph.cleanup().toposort() + logger.debug("Graph cleanup completed") + except Exception as e: + logger.warning(f"Graph cleanup failed: {e}") + logger.debug("Continuing anyway") + + +class QDQAutotuner(QDQAutotunerBase): + """Ready-to-use Q/DQ autotuner with automatic region discovery. + + This is the main class users should instantiate for Q/DQ optimization. + It extends QDQAutotunerBase by automatically searching for optimization + regions around compute-intensive operations during initialization. + + **Automatic Region Discovery:** + - Uses CombinedRegionSearch to identify regions automatically + - Focuses on Conv, MatMul, Gemm, and other compute-heavy operations + - Creates hierarchical region structure (COMPOSITE with LEAF children) + - Flattens hierarchy and prioritizes LEAF regions for profiling + + **Region Selection Strategy:** + The discovered regions are organized to optimize profiling efficiency: + 1. LEAF regions: Contain actual nodes, profiled first (most specific) + 2. Non-COMPOSITE regions: Profiled second (intermediate level) + 3. COMPOSITE regions: Skipped (only containers, no direct nodes) + + This ensures we test the most granular patterns first, which provides + better optimization opportunities and more reusable pattern cache entries. + + **Usage Pattern:** + ```python + # Load model + model = onnx.load("model.onnx") + + # Create autotuner (regions discovered automatically) + autotuner = QDQAutotuner(model) + + # Initialize with configuration + config = Config(default_quant_type="fp8") + autotuner.initialize(config) + + # Measure baseline (optional but recommended) + baseline_bytes = autotuner.export_onnx("baseline.onnx", insert_qdq=False) + baseline_latency = benchmark("baseline.onnx") + autotuner.submit(baseline_latency) + + # Profile regions + for region in autotuner.regions[:10]: # Top 10 regions + autotuner.set_profile_region(region) + + # Generate and test multiple schemes + for i in range(5): + scheme_idx = autotuner.generate() + if scheme_idx < 0: + break # No more unique schemes + + # Export and benchmark + test_bytes = autotuner.export_onnx(f"test_{i}.onnx") + latency = benchmark(f"test_{i}.onnx") + autotuner.submit(latency) + + # Export final optimized model + autotuner.export_onnx("optimized.onnx") + + # Save results for reuse + autotuner.save_state("results.yaml") + ``` + + **Key Differences from Base Class:** + - Automatic region discovery (no manual region specification needed) + - Hierarchical region structure flattened for efficient profiling + - LEAF regions prioritized (contain actual nodes to optimize) + - Ready to use out of the box (no custom region strategy needed) + + **Region Discovery Details:** + Uses a two-phase search strategy: + 1. Bottom-up partitioning: Groups nodes by divergence/convergence patterns + 2. Top-down refinement: Creates hierarchical structure within regions + + See CombinedRegionSearch documentation for algorithm details. + + Attributes: + regions: List of discovered regions, ordered by priority (LEAF first) + graph: ONNX computation graph + config: Configuration parameters + profiled_patterns: Results from profiled regions + + Example: + >>> # Simple usage + >>> autotuner = QDQAutotuner(model) + >>> autotuner.initialize() + >>> print(f"Found {len(autotuner.regions)} regions to optimize") + >>> # With custom config + >>> config = Config(default_quant_type="fp8") + >>> autotuner = QDQAutotuner(model) + >>> autotuner.initialize(config) + """ + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuner and discover optimization regions automatically. + + Extends base class initialization by automatically searching for regions + after configuration is set up. Regions are discovered using pattern-based + search around compute-intensive operations. + + **Automatic Steps:** + 1. Calls base class initialize (sets up config, pattern cache) + 2. Runs region search (discovers optimization targets) + 3. Flattens region hierarchy and prioritizes LEAF regions + 4. Reassigns region IDs for clean indexing + + After this method completes, self.regions contains all discovered regions + ready for profiling via set_profile_region(). + + Args: + config: Optional configuration for Q/DQ parameters and profiling behavior. + If None, uses default Config() settings. + pattern_cache: Optional pattern cache for warm-starting with known schemes. + If None, creates empty cache. + + Raises: + None (safe to call multiple times - resets state each time) + + Example: + >>> autotuner = QDQAutotuner(model) + >>> autotuner.initialize() + >>> print(f"Ready to profile {len(autotuner.regions)} regions") + >>> # With custom configuration + >>> config = Config(default_quant_type="fp8") + >>> autotuner.initialize(config) + """ + # Initialize base class (config, pattern cache, reset state) + super().initialize(config, pattern_cache) + + # Discover optimization regions automatically + self._search_regions() + + def _visit_region_recursively(self, region: Region) -> list[Region]: + """Recursively traverse region hierarchy and collect all regions. + + Performs depth-first traversal of the region tree starting from a given + region. Collects the root region and all descendant regions (children, + grandchildren, etc.) into a flat list. + + **Traversal Order:** + - Pre-order: Parent added before children + - Depth-first: Fully explores each branch before moving to next + + **Use Case:** + Used to flatten hierarchical region structure (COMPOSITE → LEAF) into + a single list for sequential profiling. This ensures all regions at all + levels are available for optimization. + + Args: + region: Root region to start traversal from + + Returns: + List of all regions in the subtree (including root), in pre-order DFS. + + Example: + >>> # Given hierarchy: COMPOSITE { LEAF{A, B}, COMPOSITE { LEAF{C} } } + >>> regions = _visit_region_recursively(composite_root) + >>> # Returns: [COMPOSITE_root, LEAF_AB, COMPOSITE_child, LEAF_C] + """ + # Start with the current region + regions = [region] + + # Recursively add all children and their descendants + for child in region.get_children(): + regions.extend(self._visit_region_recursively(child)) + + return regions + + def _reassign_region_ids(self, regions: list[Region]) -> None: + """Reassign sequential IDs to regions in breadth-first order. + + Traverses the region hierarchy (including children) and assigns new + sequential IDs starting from 0. This ensures clean, predictable region + numbering after region discovery and manipulation. + + **Traversal Strategy:** + - Breadth-first: Siblings get consecutive IDs before their children + - Sequential: IDs are 0, 1, 2, ... with no gaps + + **Why This Matters:** + - Clean logging: "Region 0", "Region 1", etc. + - Pattern cache: Region IDs appear in insertion point references + - Debugging: Predictable numbering aids in understanding results + + **Modifies:** + Updates region.id for each region in-place (including all descendants). + + Args: + regions: List of top-level regions (children will be processed too) + + Side Effects: + Modifies the .id attribute of all regions and their descendants + + Example: + >>> # Before: regions with IDs [5, 12, 8, ...] + >>> _reassign_region_ids(regions) + >>> # After: regions with IDs [0, 1, 2, ...] + """ + region_id = 0 + + # Use BFS to assign IDs level-by-level + queue = deque(regions) + + while queue: + region = queue.popleft() + + # Assign next sequential ID + region.id = region_id + region_id += 1 + + # Add children to queue for processing + queue.extend(region.get_children()) + + def _search_regions(self) -> None: + """Discover and organize optimization regions automatically. + + This is the core region discovery method that: + 1. Runs automatic region search to find optimization targets + 2. Flattens hierarchical structure into a list + 3. Prioritizes LEAF regions (contain actual nodes) + 4. Reassigns IDs for clean indexing + + **Search Strategy:** + Uses CombinedRegionSearch which performs: + - Phase 1: Bottom-up partitioning based on divergence/convergence + - Phase 2: Top-down refinement creating hierarchical structure + + **Region Organization:** + After discovery, regions are reorganized for optimal profiling: + ``` + Original: [COMPOSITE_1 { LEAF_A, LEAF_B }, COMPOSITE_2 { LEAF_C }] + + After flattening: [COMPOSITE_1, LEAF_A, LEAF_B, COMPOSITE_2, LEAF_C] + + After prioritization: [LEAF_A, LEAF_B, LEAF_C, COMPOSITE_1, COMPOSITE_2] + ``` + + **Why Prioritize LEAF Regions:** + - LEAF regions contain actual nodes (direct optimization targets) + - COMPOSITE regions are just containers (no direct nodes to optimize) + - Profiling LEAF first gives more specific, reusable patterns + - Pattern cache entries from LEAF regions apply to many models + + **Region Types:** + - LEAF: Contains graph nodes, profiled first (highest priority) + - COMPOSITE: Container for other regions, lower priority + - ROOT: Special container (typically not profiled directly) + + Side Effects: + Populates self.regions with discovered and organized regions + + Example: + >>> # Automatically called during initialize() + >>> _search_regions() + >>> # self.regions now contains all discovered regions + >>> leaf_count = sum(1 for r in self.regions if r.type == RegionType.LEAF) + >>> print(f"Discovered {leaf_count} LEAF regions for profiling") + """ + # ===================================================================== + # STEP 1: Run Automatic Region Discovery + # ===================================================================== + # Use CombinedRegionSearch to find regions around compute-intensive ops + # This creates a hierarchical structure: COMPOSITE → LEAF + logger.info("Discovering optimization regions") + search = CombinedRegionSearch( + self.graph, + maximum_sequence_region_size=self.config.maximum_sequence_region_size, + minimum_topdown_search_size=self.config.minimum_topdown_search_size, + ) + self.regions = search.search_regions() + + # Reassign IDs to top-level regions for clean indexing + self._reassign_region_ids(self.regions) + logger.debug(f"Found {len(self.regions)} top-level regions") + + # ===================================================================== + # STEP 2: Flatten Hierarchical Structure + # ===================================================================== + # Traverse the region tree and collect all regions at all levels + # This ensures we can profile both parent and child regions + all_regions = [] + for region in self.regions: + all_regions.extend(self._visit_region_recursively(region)) + + logger.debug(f"Flattened hierarchy to {len(all_regions)} total regions") + + # ===================================================================== + # STEP 3: Prioritize LEAF Regions + # ===================================================================== + # Organize regions to profile the most specific patterns first: + # 1. LEAF regions: Contain actual nodes, most specific patterns + # 2. Other non-COMPOSITE: Intermediate abstractions + # 3. COMPOSITE regions excluded: Just containers, no direct nodes + + # Extract LEAF regions (highest priority) + leaf_regions = [region for region in all_regions if region.type == RegionType.LEAF] + other_regions = [region for region in all_regions if region.type != RegionType.LEAF] + + # Combine: LEAF first, then others + # This ensures the most granular optimization targets are profiled first + all_regions = leaf_regions + other_regions + + # Update self.regions with prioritized list + self.regions = all_regions + + num_leaf = sum(1 for r in self.regions if r.type == RegionType.LEAF) + num_composite = sum(1 for r in self.regions if r.type == RegionType.COMPOSITE) + num_root = sum(1 for r in self.regions if r.type == RegionType.ROOT) + + logger.info( + f"Discovery complete: {len(self.regions)} regions " + f"({num_leaf} LEAF, {num_composite} COMPOSITE, {num_root} ROOT)" + ) + logger.debug("Regions prioritized: LEAF regions first for profiling") diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py new file mode 100644 index 000000000..4772b962b --- /dev/null +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -0,0 +1,936 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorRT Utilities and Benchmark Module. + +This module provides comprehensive TensorRT utilities including: +- Benchmark framework for measuring TensorRT engine performance +- Graph utilities for tensor analysis + +**Benchmark Classes:** +- Benchmark: Abstract base class defining the benchmarking interface +- TrtExecBenchmark: Uses trtexec command-line tool for benchmarking +- TensorRTPyBenchmark: Uses TensorRT Python API for direct engine profiling + +**Features:** +- Timing cache management for faster subsequent builds +- File path or raw bytes as model input +- Configurable warmup and timing iterations +- Custom TensorRT plugin library loading +- Automatic cleanup of temporary resources + +Example Usage: + # Using TrtExec for benchmarking + benchmark = TrtExecBenchmark(trtexec_args=['--fp16']) + latency = benchmark.run('model.onnx') + + # Using Python API + benchmark = TensorRTPyBenchmark() + latency = benchmark.run(model_bytes) + + # Using Python API with custom plugins + benchmark = TensorRTPyBenchmark( + plugin_libraries=['/path/to/plugin.so'] + ) + latency = benchmark.run('model_with_custom_ops.onnx') +""" + +import ctypes +import logging +import os +import re +import shutil +import subprocess # nosec B404 +import tempfile +import time +from abc import ABC, abstractmethod +from pathlib import Path + +import numpy as np + +# Optional dependencies - gracefully handle missing packages +try: + import tensorrt as trt + + TRT_AVAILABLE = True +except ImportError: + TRT_AVAILABLE = False + +try: + import pycuda.autoinit # noqa: F401 # Automatically initializes CUDA (side-effect import) + import pycuda.driver as cuda + + PYCUDA_AVAILABLE = True +except ImportError: + PYCUDA_AVAILABLE = False + +# Try to import from modelopt logging, fallback to standard logging +try: + from modelopt.onnx.logging_config import logger +except ImportError: + logger = logging.getLogger(__name__) + + +# ============================================================================= +# Benchmark Framework +# ============================================================================= + + +class Benchmark(ABC): + """Abstract base class for TensorRT model benchmarking. + + This class defines the interface that all benchmark implementations must follow. + It provides a consistent API for measuring inference latency of ONNX models + when converted to TensorRT engines. + + Attributes: + timing_cache_file: Path to the TensorRT timing cache file. + logger: Logger instance for this benchmark. + + Subclasses must implement: + run(): Execute the benchmark and return latency in milliseconds. + """ + + def __init__(self, timing_cache_file: str | None = None): + """Initialize the benchmark. + + Args: + timing_cache_file: Path to timing cache file to accelerate engine builds. + If None, uses '/tmp/trtexec_timing.cache' as default. + """ + self.timing_cache_file = timing_cache_file or "/tmp/trtexec_timing.cache" # nosec B108 + self.logger = logging.getLogger(__name__) + + @abstractmethod + def run(self, path_or_bytes: str | bytes, log_file: str | None = None) -> float: + """Run benchmark on the given ONNX model. + + Args: + path_or_bytes: Path to the ONNX model (str) or raw model data (bytes) + log_file: Optional path to save benchmark logs + + Returns: + Measured latency in milliseconds, or float("inf") on failure + """ + + def __call__(self, path_or_bytes: str | bytes, log_file: str | None = None) -> float: + """Convenience method to call benchmark as a function. + + Args: + path_or_bytes: Path to the ONNX model (str) or raw model data (bytes) + log_file: Optional path to save benchmark logs + + Returns: + Measured latency in milliseconds + """ + return self.run(path_or_bytes, log_file) + + +class TrtExecBenchmark(Benchmark): + """TensorRT benchmark using trtexec command-line tool. + + This implementation uses the trtexec binary to build engines and measure + inference latency. It is the most straightforward method and closely + mirrors standard TensorRT workflows. + + Features: + - Uses subprocess to call trtexec binary + - Supports all trtexec command-line arguments + - Custom TensorRT plugin library loading + - Automatic temporary directory management for engines + - Timing cache persistence across benchmarks + - Supports both file paths and raw bytes as input + + Attributes: + trtexec_path: Path to the trtexec binary. + trtexec_args: Additional command-line arguments for trtexec. + warmup_runs: Number of warmup iterations before timing. + timing_runs: Number of iterations for latency measurement. + timeout: Maximum time in seconds for trtexec execution. + plugin_libraries: List of paths to plugin libraries. + engine_dir: Directory for storing temporary engine files. + engine_path: Path to the engine file. + temp_model_path: Path for temporary ONNX model (when using bytes). + + Examples: + >>> # Basic usage with default trtexec + >>> benchmark = TrtExecBenchmark() + >>> latency = benchmark.run("model.onnx") + + >>> # Enable FP16 and custom workspace + >>> benchmark = TrtExecBenchmark(trtexec_args=["--fp16", "--workspace=4096"]) + >>> latency = benchmark.run(model_bytes) + + >>> # With plugin libraries + >>> benchmark = TrtExecBenchmark(plugin_libraries=["/path/to/custom_plugin.so"]) + >>> latency = benchmark.run("model_with_custom_ops.onnx") + + >>> # Full customization + >>> benchmark = TrtExecBenchmark( + ... trtexec_path="/usr/local/bin/trtexec", + ... trtexec_args=["--fp16", "--verbose"], + ... timing_cache_file="./cache.bin", + ... warmup_runs=10, + ... timing_runs=50, + ... timeout=600, + ... plugin_libraries=["/path/to/plugin.so"], + ... ) + """ + + def __init__( + self, + trtexec_path: str = "trtexec", + trtexec_args: list | None = None, + timing_cache_file: str | None = None, + warmup_runs: int = 5, + timing_runs: int = 10, + timeout: int = 300, + plugin_libraries: list[str] | None = None, + ): + """Initialize the trtexec benchmark. + + Args: + trtexec_path: Path to trtexec binary. Defaults to 'trtexec' which + looks for the binary in PATH. + trtexec_args: Additional command-line arguments to pass to trtexec. + These are appended after the standard arguments. + Example: ['--fp16', '--workspace=4096', '--verbose'] + timing_cache_file: Path to TensorRT timing cache file for faster + subsequent builds. Defaults to '/tmp/trtexec_timing.cache'. + warmup_runs: Number of warmup iterations before timing measurements. + timing_runs: Number of iterations for latency measurement. Results + are averaged across these runs. + timeout: Maximum time in seconds for trtexec execution before timeout. + plugin_libraries: List of paths to TensorRT plugin shared libraries (.so files). + These plugins will be loaded by trtexec during engine building. + If None, no custom plugins are loaded. + """ + super().__init__(timing_cache_file) + + # Store configuration + self.trtexec_path = trtexec_path + self.trtexec_args = trtexec_args or [] + self.warmup_runs = warmup_runs + self.timing_runs = timing_runs + self.timeout = timeout + self.plugin_libraries = plugin_libraries or [] + + # Create persistent temporary directory for engine and model files + # This directory persists for the lifetime of this benchmark object + self._temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_") + self.engine_dir = self._temp_dir + self.engine_path = os.path.join(self.engine_dir, "engine.trt") + self.temp_model_path = os.path.join(self.engine_dir, "temp_model.onnx") + self.logger.debug(f"Created temporary engine directory: {self.engine_dir}") + self.logger.debug(f"Temporary model path: {self.temp_model_path}") + + # Construct base trtexec command template + # The '--onnx' argument will be added dynamically in run() + self._base_cmd = [ + self.trtexec_path, + f"--avgRuns={self.timing_runs}", + f"--iterations={self.timing_runs}", + f"--warmUp={self.warmup_runs}", + "--stronglyTyped", # Enable strongly typed mode for Q/DQ ops + f"--saveEngine={self.engine_path}", + f"--timingCacheFile={self.timing_cache_file}", + ] + + # Add plugin libraries + for plugin_lib in self.plugin_libraries: + plugin_path = Path(plugin_lib).resolve() + if not plugin_path.exists(): + self.logger.warning(f"Plugin library not found: {plugin_path}") + else: + self._base_cmd.append(f"--staticPlugins={plugin_path}") + self.logger.debug(f"Added plugin library: {plugin_path}") + + # Append user-provided custom arguments + if self.trtexec_args: + self._base_cmd.extend(self.trtexec_args) + + self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}") + + def __del__(self): + """Cleanup temporary directory.""" + if hasattr(self, "_temp_dir"): + try: + shutil.rmtree(self._temp_dir, ignore_errors=True) + self.logger.debug(f"Cleaned up temporary directory: {self._temp_dir}") + except Exception as e: + self.logger.warning(f"Failed to cleanup temporary directory: {e}") + + def run( + self, + path_or_bytes: str | bytes, + log_file: str | None = None, + flush_timing_cache: bool = False, + ) -> float: + """Run benchmark using trtexec. + + Args: + path_or_bytes: Path to the ONNX model (str) or raw model data (bytes) + log_file: Optional path to save trtexec logs + + Returns: + Measured median latency in milliseconds + """ + cache_exists = os.path.exists(self.timing_cache_file) + if cache_exists: + self.logger.debug(f"Using existing timing cache: {self.timing_cache_file}") + else: + self.logger.debug(f"Will create timing cache: {self.timing_cache_file}") + + try: + # If bytes provided, write to temporary model path + if isinstance(path_or_bytes, bytes): + with open(self.temp_model_path, "wb") as f: + f.write(path_or_bytes) + model_path = self.temp_model_path + self.logger.debug(f"Wrote model bytes to temporary file: {model_path}") + else: + model_path = path_or_bytes + + # Build complete command from base template + cmd = [self._base_cmd[0], f"--onnx={model_path}", *self._base_cmd[1:]] + + self.logger.debug(f"Running: {' '.join(cmd)}") + + # Run trtexec and capture output + result = subprocess.run(cmd, capture_output=True, text=True, timeout=self.timeout) # nosec B603 + + # Save logs if requested + if log_file is not None: + try: + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + with open(log_path, "w") as f: + f.write(f"Command: {' '.join(cmd)}\n") + f.write(f"Return code: {result.returncode}\n") + f.write("=" * 80 + "\n") + f.write("STDOUT:\n") + f.write("=" * 80 + "\n") + f.write(result.stdout) + f.write("\n" + "=" * 80 + "\n") + f.write("STDERR:\n") + f.write("=" * 80 + "\n") + f.write(result.stderr) + self.logger.debug(f"Saved trtexec logs to: {log_file}") + except Exception as e: + self.logger.warning(f"Failed to save logs to {log_file}: {e}") + + if result.returncode != 0: + self.logger.error(f"trtexec failed with return code {result.returncode}") + self.logger.error(f"stderr: {result.stderr}") + return float("inf") + + # Parse output to extract latency + # trtexec outputs lines like: + # "[I] Latency: min = X ms, max = Y ms, mean = Z ms, median = W ms, ..." + output = result.stdout + + # Look for median latency in the main "[I] Latency:" line + pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms" + + match = re.search(pattern, output, re.IGNORECASE) + if match: + latency = float(match.group(1)) + self.logger.info(f"TrtExec benchmark (median): {latency:.2f} ms") + return latency + + self.logger.warning("Could not parse median latency from trtexec output") + self.logger.debug(f"trtexec stdout:\n{output}") + return float("inf") + + except subprocess.TimeoutExpired: + self.logger.error(f"trtexec timed out after {self.timeout} seconds") + return float("inf") + except FileNotFoundError: + self.logger.error(f"trtexec binary not found: {self.trtexec_path}") + self.logger.error("Please ensure TensorRT is installed and trtexec path is correct") + return float("inf") + except Exception as e: + self.logger.error(f"Benchmark failed: {e}") + return float("inf") + + +class TensorRTPyBenchmark(Benchmark): + """TensorRT benchmark using Python API with plugin support. + + This implementation directly uses the TensorRT Python API to build engines + and measure inference latency. It provides more control than trtexec and + can be faster for certain workflows as it avoids subprocess overhead. + + Features: + - Direct TensorRT Python API usage (no subprocess) + - Persistent Builder, Logger, and Runtime objects + - Custom TensorRT plugin library loading + - Automatic dynamic shape handling + - In-memory timing cache management + - CUDA memory management via PyCUDA + - Detailed latency statistics (min, max, mean, median) + + Requirements: + - tensorrt package + - pycuda package + - CUDA-capable GPU + + Attributes: + trt_logger: TensorRT Logger instance (persistent). + builder: TensorRT Builder instance (persistent). + runtime: TensorRT Runtime instance (persistent). + config: Builder configuration (recreated per run). + warmup_runs: Number of warmup iterations. + timing_runs: Number of timing iterations. + plugin_libraries: List of loaded plugin library paths. + _shape_configs: Dictionary storing custom shape configurations. + _plugin_registry: TensorRT PluginRegistry instance. + + Methods: + set_shapes(): Configure min/opt/max shapes for dynamic inputs. + run(): Execute the benchmark and return latency. + + Examples: + >>> # Basic usage + >>> benchmark = TensorRTPyBenchmark() + >>> latency = benchmark.run("model.onnx") + + >>> # With custom configuration + >>> benchmark = TensorRTPyBenchmark( + ... timing_cache_file="./cache.bin", warmup_runs=10, timing_runs=100 + ... ) + >>> latency = benchmark.run(model_bytes) + + >>> # With plugin libraries + >>> benchmark = TensorRTPyBenchmark( + ... plugin_libraries=["/path/to/custom_plugin.so", "/path/to/another_plugin.so"] + ... ) + >>> latency = benchmark.run("model_with_custom_ops.onnx") + + >>> # With custom dynamic shapes + >>> benchmark = TensorRTPyBenchmark() + >>> benchmark.set_shapes( + ... "input", + ... min_shape=[1, 3, 224, 224], + ... opt_shape=[4, 3, 224, 224], + ... max_shape=[8, 3, 224, 224], + ... ) + >>> latency = benchmark.run("model.onnx") + + >>> # Access detailed statistics from logs + >>> latency = benchmark.run("model.onnx", log_file="./benchmark.log") + + Note: + The Builder and Runtime are created once at initialization and + reused across multiple benchmark runs for efficiency. The config + is recreated for each run to ensure clean optimization profiles. + Plugin libraries are loaded once at initialization and remain + available for all subsequent engine builds. + """ + + def __init__( + self, + timing_cache_file: str | None = None, + warmup_runs: int = 5, + timing_runs: int = 20, + plugin_libraries: list[str] | None = None, + ): + """Initialize the TensorRT Python API benchmark. + + Creates persistent TensorRT objects (Logger, Builder, Runtime) and + loads the timing cache from disk if available. Optionally loads custom + TensorRT plugin libraries for models with custom operations. + + Args: + timing_cache_file: Path to TensorRT timing cache file. If None, + defaults to '/tmp/trtexec_timing.cache'. + warmup_runs: Number of warmup iterations before timing measurements. + timing_runs: Number of iterations for latency measurement. + plugin_libraries: List of paths to TensorRT plugin shared libraries (.so files). + These plugins will be loaded and registered for use during + engine building. If None, no custom plugins are loaded. + + Raises: + ImportError: If tensorrt or pycuda packages are not available. + FileNotFoundError: If a specified plugin library file does not exist. + RuntimeError: If plugin library loading fails. + + Example: + >>> # Without plugins + >>> benchmark = TensorRTPyBenchmark() + >>> # With custom plugins + >>> benchmark = TensorRTPyBenchmark( + ... plugin_libraries=[ + ... "/opt/tensorrt/lib/libnvinfer_plugin.so", + ... "/path/to/my_custom_plugin.so", + ... ] + ... ) + """ + super().__init__(timing_cache_file) + self.warmup_runs = warmup_runs + self.timing_runs = timing_runs + self.plugin_libraries = plugin_libraries or [] + + # Verify required dependencies + if not TRT_AVAILABLE: + raise ImportError("TensorRT Python API not available. Please install tensorrt package.") + if not PYCUDA_AVAILABLE: + raise ImportError("PyCUDA not available. Please install pycuda package.") + + self.trt_logger = trt.Logger(trt.Logger.WARNING) + self.builder = trt.Builder(self.trt_logger) + self.runtime = trt.Runtime(self.trt_logger) + # Load custom plugin libraries before initializing TensorRT plugins + self._loaded_plugin_handles = [] + if self.plugin_libraries: + self._load_plugin_libraries() + # Get plugin registry (must be done after loading plugin libraries) + # Initialize TensorRT's built-in plugins (CRITICAL!) + # This must be called after loading custom plugins and before using the plugin registry + trt.init_libnvinfer_plugins(self.trt_logger, "") + + self._plugin_registry = trt.get_plugin_registry() + + # Set network flag + self.network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + self.network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + + # Load timing cache from disk or create new one + self._timing_cache = None + self._load_timing_cache() + + # Storage for user-defined shape configurations + # Format: {input_name: (min_shape, opt_shape, max_shape)} + self._shape_configs = {} + + def _load_plugin_libraries(self): + """Load custom TensorRT plugin libraries from shared object files. + + This method loads plugin libraries using ctypes and initializes them + with the TensorRT plugin registry. Plugins must export the + initLibNvInferPlugins function to register their implementations. + + The loaded library handles are stored to prevent them from being + garbage collected during the benchmark's lifetime. + + Raises: + FileNotFoundError: If a plugin library file does not exist. + RuntimeError: If plugin initialization fails. + """ + for plugin_lib in self.plugin_libraries: + plugin_path = Path(plugin_lib).resolve() + + if not plugin_path.exists(): + raise FileNotFoundError(f"Plugin library not found: {plugin_path}") + + self.logger.info(f"Loading TensorRT plugin: {plugin_path}") + + try: + # Load the shared library using ctypes + # RTLD_LAZY: Resolve symbols as needed + # RTLD_GLOBAL: Make symbols available for subsequently loaded libraries + # Use os.RTLD_* constants (available on Unix) or default mode + if hasattr(os, "RTLD_LAZY") and hasattr(os, "RTLD_GLOBAL"): + plugin_handle = ctypes.CDLL( + str(plugin_path), mode=os.RTLD_LAZY | os.RTLD_GLOBAL + ) + else: + # Fallback for platforms without RTLD flags (e.g., Windows) + plugin_handle = ctypes.CDLL(str(plugin_path)) + + # Store handle to prevent garbage collection + self._loaded_plugin_handles.append(plugin_handle) + + # Try to initialize plugin with TensorRT registry + # Most TensorRT plugins export initLibNvInferPlugins function + if hasattr(plugin_handle, "initLibNvInferPlugins"): + init_func = plugin_handle.initLibNvInferPlugins + # Function signature: bool initLibNvInferPlugins(void* logger, const char* namespace) + init_func.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + init_func.restype = ctypes.c_bool + + # Initialize with the TensorRT logger and default namespace + success = init_func(None, b"") + if not success: + self.logger.warning( + f"Plugin initialization returned false for: {plugin_path}" + ) + else: + self.logger.info(f"Successfully initialized plugin: {plugin_path.name}") + else: + self.logger.info( + f"Plugin loaded (no initLibNvInferPlugins function): {plugin_path.name}" + ) + + except Exception as e: + raise RuntimeError(f"Failed to load plugin library {plugin_path}: {e}") from e + + def set_shapes(self, input_name: str, min_shape: list, opt_shape: list, max_shape: list): + """Set custom min/opt/max shapes for a dynamic input. + + This method allows you to specify custom shape ranges for dynamic inputs + (inputs with -1 dimensions). If not specified, the benchmark will use + default shapes (all -1 dimensions become 1). + + Args: + input_name: Name of the input tensor to configure. + min_shape: Minimum shape for this input. List of integers. + opt_shape: Optimal/default shape for this input. List of integers. + max_shape: Maximum shape for this input. List of integers. + + Example: + >>> benchmark = TensorRTPyBenchmark() + >>> # Set shapes for a dynamic batch input + >>> benchmark.set_shapes( + ... "input", + ... min_shape=[1, 3, 224, 224], + ... opt_shape=[4, 3, 224, 224], + ... max_shape=[8, 3, 224, 224], + ... ) + >>> latency = benchmark.run("model.onnx") + + Note: + - All three shapes must have the same number of dimensions + - For each dimension: min_shape[i] <= opt_shape[i] <= max_shape[i] + - Shapes are applied when the model is built during run() + """ + if len(min_shape) != len(opt_shape) or len(opt_shape) != len(max_shape): + raise ValueError("min_shape, opt_shape, and max_shape must have the same length") + + for i, (min_dim, opt_dim, max_dim) in enumerate(zip(min_shape, opt_shape, max_shape)): + if not (min_dim <= opt_dim <= max_dim): + raise ValueError( + f"Invalid shape range at dimension {i}: " + f"min={min_dim}, opt={opt_dim}, max={max_dim}. " + f"Must satisfy min <= opt <= max" + ) + + self._shape_configs[input_name] = (min_shape, opt_shape, max_shape) + self.logger.debug( + f"Set shapes for input '{input_name}': " + f"min={min_shape}, opt={opt_shape}, max={max_shape}" + ) + + def run( + self, + path_or_bytes: str | bytes, + log_file: str | None = None, + flush_timing_cache: bool = False, + ) -> float: + """Run benchmark using TensorRT Python API. + + Args: + path_or_bytes: Path to the ONNX model (str) or raw model data (bytes) + log_file: Optional path to save benchmark logs + + Returns: + Measured median latency in milliseconds + """ + # Initialize resource tracking variables + config = None + network = None + parser = None + serialized_engine = None + engine = None + context = None + inputs = [] + outputs = [] + stream = None + + try: + self.logger.debug("Creating TensorRT builder...") + config = self.builder.create_builder_config() + config.set_flag(trt.BuilderFlag.DIRECT_IO) + if not config.set_timing_cache(self._timing_cache, ignore_mismatch=True): + self.logger.warning("Failed to set timing cache to builder config") + network = self.builder.create_network(self.network_flags) + # Create network and parser using the shared builder and logger + parser = trt.OnnxParser(network, self.trt_logger) + + # Parse ONNX model + if isinstance(path_or_bytes, bytes): + self.logger.debug(f"Parsing ONNX model from bytes (size: {len(path_or_bytes)})") + model_data = path_or_bytes + else: + self.logger.debug(f"Parsing ONNX model: {path_or_bytes}") + with open(path_or_bytes, "rb") as f: + model_data = f.read() + + if not parser.parse(model_data): + self.logger.error("Failed to parse ONNX model") + for error_idx in range(parser.num_errors): + self.logger.error(f" {parser.get_error(error_idx)}") + return float("inf") + + # Handle dynamic shapes + has_dynamic_shapes = False + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + shape = input_tensor.shape + if any(dim == -1 for dim in shape): + has_dynamic_shapes = True + break + + if has_dynamic_shapes: + profile = self.builder.create_optimization_profile() + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_name = input_tensor.name + shape = list(input_tensor.shape) + + # Check if user provided custom shape configuration + if input_name in self._shape_configs: + # Use user-provided shapes + min_shape, opt_shape, max_shape = self._shape_configs[input_name] + self.logger.debug( + f"Using custom shapes for input '{input_name}': " + f"min={min_shape}, opt={opt_shape}, max={max_shape}" + ) + else: + # Use default: replace -1 with concrete values (1) + min_shape = [1 if dim == -1 else dim for dim in shape] + opt_shape = [1 if dim == -1 else dim for dim in shape] + max_shape = [1 if dim == -1 else dim for dim in shape] + self.logger.debug( + f"Using default shapes for input '{input_name}': {opt_shape}" + ) + + profile.set_shape(input_name, min_shape, opt_shape, max_shape) + + config.add_optimization_profile(profile) + + # Build engine + self.logger.debug("Building TensorRT engine...") + build_start = time.perf_counter() + serialized_engine = self.builder.build_serialized_network(network, config) + build_time = time.perf_counter() - build_start + + if serialized_engine is None: + self.logger.error("Failed to build TensorRT engine") + return float("inf") + + self.logger.debug(f"Engine built successfully in {build_time:.2f}s") + + # Save timing cache after successful build + if flush_timing_cache: + self._save_timing_cache() + + # Deserialize engine using shared runtime + engine = self.runtime.deserialize_cuda_engine(serialized_engine) + + if engine is None: + self.logger.error("Failed to deserialize engine") + return float("inf") + + # Create execution context + context = engine.create_execution_context() + + # Allocate buffers + inputs = [] + outputs = [] + + for i in range(engine.num_io_tensors): + tensor_name = engine.get_tensor_name(i) + dtype = trt.nptype(engine.get_tensor_dtype(tensor_name)) + shape = context.get_tensor_shape(tensor_name) + + # Allocate host and device buffers + size = trt.volume(shape) + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + + if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: + # Fill with random data for benchmark + np.copyto(host_mem, np.random.randn(size).astype(dtype)) + inputs.append({"host": host_mem, "device": device_mem, "name": tensor_name}) + else: + outputs.append({"host": host_mem, "device": device_mem, "name": tensor_name}) + + context.set_tensor_address(tensor_name, int(device_mem)) + + # Create CUDA stream + stream = cuda.Stream() + + # Warmup runs + self.logger.debug(f"Running {self.warmup_runs} warmup iterations...") + for _ in range(self.warmup_runs): + # Copy inputs to device + for inp in inputs: + cuda.memcpy_htod_async(inp["device"], inp["host"], stream) + + # Execute + context.execute_async_v3(stream_handle=stream.handle) + + # Copy outputs to host + for out in outputs: + cuda.memcpy_dtoh_async(out["host"], out["device"], stream) + + stream.synchronize() + + # Timing runs + self.logger.debug(f"Running {self.timing_runs} timing iterations...") + latencies = [] + + for _ in range(self.timing_runs): + # Copy inputs to device + for inp in inputs: + cuda.memcpy_htod_async(inp["device"], inp["host"], stream) + + # Execute with timing + stream.synchronize() + start = time.perf_counter() + context.execute_async_v3(stream_handle=stream.handle) + stream.synchronize() + end = time.perf_counter() + + latency_ms = (end - start) * 1000.0 + latencies.append(latency_ms) + + # Copy outputs to host + for out in outputs: + cuda.memcpy_dtoh_async(out["host"], out["device"], stream) + + # Compute statistics + latencies = np.array(latencies) + median_latency = float(np.median(latencies)) + mean_latency = float(np.mean(latencies)) + std_latency = float(np.std(latencies)) + min_latency = float(np.min(latencies)) + max_latency = float(np.max(latencies)) + + self.logger.info("TensorRT Python API benchmark:") + self.logger.info( + f" min={min_latency:.3f}ms, max={max_latency:.3f}ms, " + f"mean={mean_latency:.3f}ms, std={std_latency:.3f}ms, median={median_latency:.3f}ms" + ) + + # Save logs if requested + if log_file is not None: + try: + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + with open(log_path, "w") as f: + f.write("TensorRT Python API Benchmark\n") + if isinstance(path_or_bytes, bytes): + f.write(f"Model: \n") + else: + f.write(f"Model: {path_or_bytes}\n") + f.write(f"Build time: {build_time:.2f}s\n") + f.write(f"Warmup runs: {self.warmup_runs}\n") + f.write(f"Timing runs: {self.timing_runs}\n") + f.write("\n") + f.write("Latency Statistics:\n") + f.write(f" Min: {min_latency:.3f} ms\n") + f.write(f" Max: {max_latency:.3f} ms\n") + f.write(f" Mean: {mean_latency:.3f} ms\n") + f.write(f" Std: {std_latency:.3f} ms\n") + f.write(f" Median: {median_latency:.3f} ms\n") + f.write("\n") + f.write(f"All latencies: {latencies.tolist()}\n") + self.logger.debug(f"Saved benchmark logs to: {log_file}") + except Exception as e: + self.logger.warning(f"Failed to save logs to {log_file}: {e}") + return median_latency + except Exception as e: + self.logger.error(f"Benchmark failed: {e}", exc_info=True) + return float("inf") + finally: + # Cleanup resources (runs whether exception occurred or not) + try: + # Free device memory + for inp in inputs: + if "device" in inp: + inp["device"].free() + if "host" in inp: + del inp["host"] # Release page-locked host memory + for out in outputs: + if "device" in out: + out["device"].free() + if "host" in out: + del out["host"] # Release page-locked host memory + inputs.clear() + outputs.clear() + + # Explicitly delete TensorRT objects to free resources + if context is not None: + del context + # Clean up CUDA stream + if stream is not None: + del stream + if engine is not None: + del engine + if serialized_engine is not None: + del serialized_engine + if parser is not None: + del parser + if network is not None: + del network + if config is not None: + del config + except Exception as cleanup_error: + self.logger.warning(f"Error during cleanup: {cleanup_error}") + + def _load_timing_cache(self): + """Load timing cache from file or create a new one. + + Attempts to load the timing cache from disk. If the file exists and + can be loaded, deserializes it into a timing cache object. If loading + fails or the file doesn't exist, creates a new empty timing cache. + + The timing cache stores kernel timing data to accelerate subsequent + engine builds with similar configurations. + """ + config = self.builder.create_builder_config() + if os.path.exists(self.timing_cache_file): + try: + with open(self.timing_cache_file, "rb") as f: + timing_cache_data = f.read() + self._timing_cache = config.create_timing_cache(timing_cache_data) + self.logger.debug(f"Loaded timing cache from: {self.timing_cache_file}") + except Exception as e: + self.logger.warning(f"Failed to load timing cache: {e}") + self.logger.debug("Creating new timing cache") + self._timing_cache = None + + if self._timing_cache is None: + # Create new empty timing cache + self._timing_cache = config.create_timing_cache(b"") + self.logger.debug("Created new timing cache") + del config + + def _save_timing_cache(self): + """Save timing cache to file. + + Serializes the current timing cache and writes it to disk. This + preserves the accumulated kernel timing data for future benchmarks. + + Called after each successful engine build to incrementally update + the cache with new timing information. + """ + try: + if self._timing_cache is not None: + config = self.builder.create_builder_config() + output_cache = config.create_timing_cache(b"") + if self._timing_cache is None: + output_cache.combline(self._timing_cache, ignore_errors=True) + timing_cache_data = output_cache.serialize() + with open(self.timing_cache_file, "wb") as f: + f.write(timing_cache_data) + self.logger.debug(f"Saved timing cache to: {self.timing_cache_file}") + except Exception as e: + self.logger.warning(f"Failed to save timing cache: {e}") + finally: + del config diff --git a/modelopt/onnx/quantization/autotune/cli.py b/modelopt/onnx/quantization/autotune/cli.py new file mode 100644 index 000000000..c76740e29 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/cli.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLI argument parsing and execution for ONNX Q/DQ autotuning. + +This module provides: +- `get_autotune_parser`: Creates the argument parser for the CLI +- `run_autotune`: Executes the autotuning workflow with parsed arguments + +See `__main__.py` for usage examples and documentation. +""" + +import argparse +import logging +import sys +from pathlib import Path + +from modelopt.onnx.quantization.autotune.workflows import ( + init_benchmark_instance, + region_pattern_autotuning_workflow, +) + +logger = logging.getLogger(__name__) + +# Default values for CLI arguments +DEFAULT_OUTPUT_DIR = "./autotuner_output" +DEFAULT_NUM_SCHEMES = 30 +DEFAULT_QUANT_TYPE = "int8" +DEFAULT_DQ_DTYPE = "float32" +DEFAULT_TIMING_CACHE = "/tmp/trtexec_timing.cache" # nosec B108 +DEFAULT_WARMUP_RUNS = 5 +DEFAULT_TIMING_RUNS = 20 + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def validate_file_path(path: str | None, description: str) -> Path | None: + """Validate that a file path exists. + + Args: + path: Path string to validate (can be None) + description: Description of the file for error messages + + Returns: + Path object if valid, None if path is None + + Raises: + SystemExit: If path is provided but doesn't exist + """ + if path is None: + return None + + path_obj = Path(path) + if not path_obj.exists(): + logger.error(f"{description} not found: {path_obj}") + sys.exit(1) + + return path_obj + + +def log_benchmark_config(args): + """Log TensorRT benchmark configuration for transparency. + + Logs timing cache path, warmup/timing run counts, and any custom + plugin libraries that will be loaded. + + Args: + args: Parsed command-line arguments with benchmark configuration + """ + logger.info("Initializing TensorRT benchmark") + logger.info(f" Timing cache: {args.timing_cache}") + logger.info(f" Warmup runs: {args.warmup_runs}") + logger.info(f" Timing runs: {args.timing_runs}") + if args.plugin_libraries: + logger.info(f" Plugin libraries: {', '.join(args.plugin_libraries)}") + + +# ============================================================================= +# Command Handler +# ============================================================================= + + +def run_autotune(args) -> int: + """Execute the complete pattern-based Q/DQ autotuning workflow. + + This function orchestrates the entire optimization process: + 1. Validates input paths (model, baseline, output directory) + 2. Initializes TensorRT benchmark instance + 3. Runs pattern-based region autotuning workflow + 4. Handles interruptions gracefully with state preservation + + Args: + args: Parsed command-line arguments containing: + - model: Path to input ONNX model + - output: Output directory path + - num_schemes: Number of schemes to test per region + - pattern_cache_file: Optional pattern cache for warm-start + - state_file: Optional state file for resume capability + - quant_type: Quantization type (int8 or fp8) + - qdq_baseline: Optional baseline model for pattern import + - timing_cache, warmup_runs, timing_runs: TensorRT config + - verbose: Debug logging flag + + Returns: + Exit code: + - 0: Success + - 1: Autotuning failed (exception occurred) + - 130: Interrupted by user (Ctrl+C) + """ + # Validate input paths + model_path = validate_file_path(args.model, "Model file") + validate_file_path(args.qdq_baseline, "QDQ baseline model") + output_dir = Path(args.output) + + # Initialize TensorRT benchmark + log_benchmark_config(args) + init_benchmark_instance( + use_trtexec=args.use_trtexec, + plugin_libraries=args.plugin_libraries, + timing_cache_file=args.timing_cache, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + ) + + logger.info("Autotuning Mode: Pattern-Based") + + # Run autotuning workflow + try: + # Load node filter patterns from file if provided + node_filter_list = None + if args.node_filter_list: + filter_file = validate_file_path(args.node_filter_list, "Node filter list file") + if filter_file: + with open(filter_file) as f: + node_filter_list = [ + line.strip() + for line in f + if line.strip() and not line.strip().startswith("#") + ] + logger.info(f"Loaded {len(node_filter_list)} filter patterns from {filter_file}") + + region_pattern_autotuning_workflow( + model_path=str(model_path), + output_dir=output_dir, + num_schemes_per_region=args.num_schemes, + pattern_cache_file=args.pattern_cache_file, + state_file=args.state_file, + quant_type=args.quant_type, + default_dq_dtype=args.default_dq_dtype, + qdq_baseline_model=args.qdq_baseline, + node_filter_list=node_filter_list, + ) + + # Success message + logger.info("\n" + "=" * 70) + logger.info("✓ Autotuning completed successfully!") + logger.info(f"✓ Results: {output_dir}") + logger.info("=" * 70) + return 0 + + except KeyboardInterrupt: + logger.warning("\nInterrupted by user") + state_file = args.state_file or output_dir / "autotuner_state.yaml" + logger.info(f"Progress saved to: {state_file}") + return 130 + + except Exception as e: + logger.error(f"\nAutotuning failed: {e}", exc_info=args.verbose) + return 1 + + +# ============================================================================= +# Argument Parser +# ============================================================================= + + +def get_autotune_parser() -> argparse.ArgumentParser: + """Create and configure the command-line argument parser. + + Sets up argument groups for: + - Model and Output: Input model and output directory + - Autotuning Strategy: Scheme count, pattern cache, baseline import, state file + - Quantization: Data type selection (int8/fp8) + - TensorRT Benchmark: Timing cache, warmup/timing runs, plugins + - Logging: Verbose debug mode + + Returns: + Configured ArgumentParser instance with all CLI options + """ + parser = argparse.ArgumentParser( + prog="modelopt.onnx.quantization.autotune", + description="ONNX Q/DQ Autotuning with TensorRT", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic usage + python -m modelopt.onnx.quantization.autotune --model model.onnx + + # Import patterns from QDQ baseline model + python -m modelopt.onnx.quantization.autotune \\ + --model model.onnx --qdq-baseline baseline.onnx + + # Use pattern cache for warm-start + python -m modelopt.onnx.quantization.autotune --model model.onnx --pattern-cache cache.yaml + + # Full example with all options + python -m modelopt.onnx.quantization.autotune \\ + --model model.onnx --schemes-per-region 50 \\ + --pattern-cache cache.yaml --qdq-baseline baseline.onnx \\ + --quant-type int8 --verbose + """, + ) + + # Model and Output + io_group = parser.add_argument_group("Model and Output") + io_group.add_argument("--model", "-m", type=str, required=True, help="Path to ONNX model file") + io_group.add_argument( + "--output", + "-o", + type=str, + default=DEFAULT_OUTPUT_DIR, + help=f"Output directory for results (default: {DEFAULT_OUTPUT_DIR})", + ) + + # Autotuning Strategy + strategy_group = parser.add_argument_group("Autotuning Strategy") + strategy_group.add_argument( + "--schemes-per-region", + "-s", + type=int, + default=DEFAULT_NUM_SCHEMES, + dest="num_schemes", + help=f"Number of schemes to test per region (default: {DEFAULT_NUM_SCHEMES})", + ) + strategy_group.add_argument( + "--pattern-cache", + type=str, + default=None, + dest="pattern_cache_file", + help="Path to pattern cache YAML for warm-start (optional)", + ) + strategy_group.add_argument( + "--qdq-baseline", + type=str, + default=None, + help="Path to QDQ baseline ONNX model to import quantization patterns (optional)", + ) + strategy_group.add_argument( + "--state-file", + type=str, + default=None, + help="State file path for resume capability (default: /autotuner_state.yaml)", + ) + strategy_group.add_argument( + "--node-filter-list", + type=str, + default=None, + help="Path to a file containing wildcard patterns to filter ONNX nodes (one pattern per line). " + "Regions without any matching nodes are skipped during autotuning.", + ) + + # Quantization + quant_group = parser.add_argument_group("Quantization") + quant_group.add_argument( + "--quant-type", + type=str, + default=DEFAULT_QUANT_TYPE, + choices=["int8", "fp8"], + help=f"Quantization data type (default: {DEFAULT_QUANT_TYPE})", + ) + quant_group.add_argument( + "--default-dq-dtype", + type=str, + default=DEFAULT_DQ_DTYPE, + choices=["float16", "float32", "bfloat16"], + help="Default DQ output dtype if cannot be deduced (optional)", + ) + + # TensorRT Benchmark + trt_group = parser.add_argument_group("TensorRT Benchmark") + trt_group.add_argument( + "--use-trtexec", + action="store_true", + help="Use trtexec for benchmarking (default: False)", + default=False, + ) + trt_group.add_argument( + "--timing-cache", + type=str, + default=DEFAULT_TIMING_CACHE, + help=f"TensorRT timing cache file (default: {DEFAULT_TIMING_CACHE})", + ) + trt_group.add_argument( + "--warmup-runs", + type=int, + default=DEFAULT_WARMUP_RUNS, + help=f"Number of warmup runs (default: {DEFAULT_WARMUP_RUNS})", + ) + trt_group.add_argument( + "--timing-runs", + type=int, + default=DEFAULT_TIMING_RUNS, + help=f"Number of timing runs (default: {DEFAULT_TIMING_RUNS})", + ) + trt_group.add_argument( + "--plugin-libraries", + "--plugins", + type=str, + nargs="+", + default=None, + dest="plugin_libraries", + help="TensorRT plugin libraries (.so files) to load (optional, space-separated)", + ) + + # Logging + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose DEBUG logging") + + return parser \ No newline at end of file diff --git a/modelopt/onnx/quantization/autotune/workflows.py b/modelopt/onnx/quantization/autotune/workflows.py new file mode 100644 index 000000000..dae06bc19 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/workflows.py @@ -0,0 +1,562 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ONNX Q/DQ Autotuning Workflows. + +This module provides high-level workflow functions for automated Q/DQ (Quantization/Dequantization) +optimization of ONNX models using pattern-based region analysis and TensorRT performance measurement. + +**Core Capabilities:** + +1. **Automated Region Discovery**: Discovers hierarchical regions in the computation graph + - LEAF regions: Contain actual graph nodes + - COMPOSITE regions: Contain child regions with hierarchical structure + +2. **Pattern-Based Optimization**: Groups regions by structural pattern + - Regions with identical patterns share optimization schemes + - One optimization applies to all matching regions simultaneously + +3. **TensorRT Benchmarking**: Measures actual inference performance + - Builds TensorRT engines for each Q/DQ configuration + - Measures median latency across multiple runs + - Caches timing data for faster iteration + +4. **Incremental State Management**: Supports crash recovery and resume + - Saves state after each region profiling + - Resumes from last checkpoint automatically + - Preserves baseline and all measurements + +5. **Pattern Cache Warm-Start**: Leverages previous optimization results + - Loads known-good schemes from cache + - Reduces exploration time for similar models + - Transfers learned patterns across runs + +**Key Functions:** + +- **benchmark_onnx_model()**: Benchmark ONNX model inference latency using TensorRT +- **init_benchmark_instance()**: Initialize global TensorRT benchmark instance +- **region_pattern_autotuning_workflow()**: Complete end-to-end Q/DQ optimization workflow + +**Workflow Overview:** + +1. Initialize autotuner with automatic region discovery +2. Measure baseline performance (no Q/DQ) +3. For each region pattern: + - Generate Q/DQ insertion schemes + - Benchmark each scheme with TensorRT + - Select best scheme for pattern + - Apply to all regions with matching pattern +4. Export final optimized model + +**Performance Optimization:** + +- Pattern-based approach reduces redundant evaluation +- TensorRT timing cache speeds up engine builds +- Incremental state saves enable long-running optimizations +- Pattern cache enables cross-model learning +""" + +import fnmatch +import logging +from pathlib import Path + +import onnx + +from modelopt.onnx.quantization.autotune.autotuner import QDQAutotuner +from modelopt.onnx.quantization.autotune.benchmark import TensorRTPyBenchmark, TrtExecBenchmark +from modelopt.onnx.quantization.autotune.common import Config, PatternCache +from modelopt.onnx.quantization.qdq_utils import get_quantized_tensors + +logger = logging.getLogger(__name__) + +# Global benchmark instance - will be initialized with timing cache +_benchmark_instance = None + + +# ============================================================================= +# Benchmarking +# ============================================================================= + + +def benchmark_onnx_model( + model_path: str | bytes, log_file: str | None = None, flush_timing_cache: bool = False +) -> float: + """Benchmark ONNX model inference latency using TensorRT Python API. + + Uses the global TensorRTPyBenchmark instance to build a TensorRT engine + and measure inference latency. The benchmark instance persists across calls + for efficiency (reuses Builder, Runtime, Logger, and timing cache). + + **Process:** + 1. Loads ONNX model (from file path or bytes) + 2. Builds optimized TensorRT engine (uses timing cache for speed) + 3. Runs warmup iterations to stabilize performance + 4. Measures latency across multiple timing iterations + 5. Returns median latency + + Args: + model_path: Path to ONNX model file, or bytes containing serialized model protobuf + log_file: Optional path to save detailed TensorRT build and benchmark logs + (default: None, no logging) + flush_timing_cache: If True, flushes TensorRT timing cache before building engine. + Useful for periodic cache refresh (default: False) + + Returns: + Measured median inference latency in milliseconds. + Returns float('inf') on failure (invalid model, build error, etc.) + + Raises: + No exceptions raised - errors are caught and logged, returning float('inf') + + Note: + Requires _benchmark_instance to be initialized via init_benchmark_instance() + before calling this function. Otherwise returns float('inf'). + + Example: + >>> init_benchmark_instance("timing.cache", warmup_runs=5, timing_runs=20) + >>> latency = benchmark_onnx_model("model.onnx", log_file="build.log") + >>> print(f"Latency: {latency:.2f} ms") + """ + global _benchmark_instance + + if _benchmark_instance is None: + logger.error("Benchmark instance not initialized") + return float("inf") + + try: + # Run TensorRT benchmark + latency = _benchmark_instance.run( + model_path, log_file=log_file, flush_timing_cache=flush_timing_cache + ) + + if latency == float("inf"): + if isinstance(model_path, bytes): + logger.warning("Benchmark failed for model bytes") + else: + logger.warning(f"Benchmark failed: {model_path}") + return float("inf") + + logger.debug(f"Benchmark result: {latency:.2f} ms") + return latency + + except Exception as e: + logger.error(f"Benchmark error: {e}", exc_info=True) + return float("inf") + + +def init_benchmark_instance( + use_trtexec: bool = False, + plugin_libraries: list[str] | None = None, + timing_cache_file: str | None = None, + warmup_runs: int = 5, + timing_runs: int = 20, +): + """Initialize global TensorRT benchmark instance for model performance measurement. + + Creates and configures a TensorRTPyBenchmark instance that persists across + multiple benchmark_onnx_model() calls for efficiency. The instance reuses + TensorRT Builder, Runtime, Logger, and timing cache. + + **Benefits of Persistent Instance:** + - Avoids repeated initialization overhead + - Reuses timing cache across multiple models + - Maintains consistent benchmark configuration + + + Args: + use_trtexec: Whether to use trtexec for benchmarking. + plugin_libraries: List of paths to TensorRT plugin shared libraries (.so files). + These plugins will be loaded by trtexec or TensorRT Python API during engine building. + If None, no custom plugins are loaded. + timing_cache_file: Path to TensorRT timing cache file for faster engine builds. + If None, uses default "trtexec_timing.cache" (default: None) + warmup_runs: Number of warmup inference iterations before measurement. + Allows GPU to reach stable performance state (default: 5) + timing_runs: Number of timed inference iterations for latency measurement. + Higher values give more stable median (default: 20) + + Returns: + TensorRTPyBenchmark or TrtExecBenchmark instance if initialization succeeds, None on failure + + Example: + >>> # Initialize with default settings using TensorRT Python API + >>> benchmark = init_benchmark_instance(use_trtexec=False) + >>> if benchmark: + ... latency = benchmark_onnx_model("model.onnx") + ... print(f"Latency: {latency:.2f} ms") + + See Also: + benchmark_onnx_model(): Uses the initialized instance to benchmark models + """ + global _benchmark_instance + try: + if use_trtexec: + _benchmark_instance = TrtExecBenchmark( + timing_cache_file=timing_cache_file, + warmup_runs=warmup_runs, + timing_runs=timing_runs, + plugin_libraries=plugin_libraries, + ) + logger.info("Trtexec benchmark initialized") + else: + _benchmark_instance = TensorRTPyBenchmark( + timing_cache_file=timing_cache_file, + warmup_runs=warmup_runs, + timing_runs=timing_runs, + plugin_libraries=plugin_libraries, + ) + logger.info("TensorRT Python API benchmark initialized") + logger.debug( + f"Settings: warmup={warmup_runs}, timing={timing_runs}, " + f"cache={timing_cache_file or 'trtexec_timing.cache'}, plugin_libraries={plugin_libraries}" + ) + return _benchmark_instance + except Exception as e: + logger.error(f"TensorRT initialization failed: {e}", exc_info=True) + return None + + +def _region_matches_filter(region, graph, filter_patterns: list[str]) -> bool: + """Check if any node in the region matches any of the filter patterns. + + Args: + region: Region object to check + graph: ONNX graph (graphsurgeon) containing node information + filter_patterns: List of wildcard patterns to match against node names + + Returns: + True if at least one node in the region matches any pattern, False otherwise + """ + if not filter_patterns: + return True # No filter means all regions pass + + # Get all node indices in this region (including children) + node_indices = region.get_all_nodes_recursive() + + for node_idx in node_indices: + if node_idx < len(graph.nodes): + node_name = graph.nodes[node_idx].name + for pattern in filter_patterns: + if fnmatch.fnmatch(node_name, pattern): + return True + + return False + + +# ============================================================================= +# Autotuning Workflow +# ============================================================================= + + +def region_pattern_autotuning_workflow( + model_path: str, + output_dir: Path, + num_schemes_per_region: int = 30, + pattern_cache_file: str | None = None, + state_file: str | None = None, + quant_type: str = "int8", + default_dq_dtype: str = "float32", + qdq_baseline_model: str | None = None, + node_filter_list: list[str] | None = None, +) -> QDQAutotuner: + """Run automated Q/DQ (Quantization/Dequantization) optimization on an ONNX model. + + This workflow uses pattern-based region optimization to efficiently find optimal + Q/DQ insertion points. The key insight: regions with identical structural patterns + can share the same Q/DQ scheme. When a best scheme is found for a pattern, it + automatically applies to all regions matching that pattern, making optimization + both efficient and consistent. + + Automatically discovers regions, generates and tests Q/DQ insertion schemes, + and exports optimized model. Supports incremental state saving for crash recovery + and pattern cache-based warm-start. + + **Workflow Steps:** + 1. Load model and initialize autotuner with automatic hierarchical region discovery + 2. Resume from checkpoint if state file exists (crash recovery) + 3. Load pattern cache if provided (warm-start with known-good schemes) + 4. Import Q/DQ patterns from baseline model if provided (transfer learning) + 5. Measure baseline performance without Q/DQ insertions + 6. For each discovered region pattern: + a. Generate Q/DQ insertion schemes (pattern-relative) + b. Build TensorRT engine and measure latency for each scheme + c. Select best scheme for this pattern (applies to all matching regions) + d. Save checkpoint and intermediate model + 7. Export final optimized model with best Q/DQ scheme for each pattern + + **State Management (Crash Recovery):** + - Automatically saves checkpoint after profiling each region + - Resume from interruption by running same command (auto-detects state file) + - State file contains: + * Baseline latency measurement + * All profiled pattern schemes and their latencies + * Best scheme selection for each pattern + * Region discovery results and pattern assignments + - Enables long-running optimizations with fault tolerance + - Safe for cluster environments with preemption + + **Pattern Cache (Warm-Start Optimization):** + - Pattern cache files (YAML format) contain top-performing schemes indexed by pattern + - Stores results from previous optimization runs for reuse + - Used to seed scheme generation (warm-start vs cold-start) + - Benefits: + * Reduces exploration time by prioritizing known-good configurations + * Transfers learned schemes across similar models or model versions + * Accumulates knowledge across multiple optimization sessions + * Particularly effective for models with recurring pattern structures + - Cache is pattern-specific, not model-specific (generalizes well) + + **QDQ Baseline Model (Transfer Learning):** + - If provided, extracts Q/DQ insertion points from a pre-quantized model + - Identifies which tensors are quantized in the baseline model + - Maps these quantization points to region patterns in the current model + - Updates pattern cache with learned insertion strategies + - Enables warm-start from: + * Expert-tuned quantized models (manually optimized) + * Previous autotuning runs (transfer across model versions) + * Reference implementations (e.g., from framework exporters) + + Args: + model_path: Path to ONNX model file to optimize + output_dir: Directory for output files (state, logs, models). Created if doesn't exist. + num_schemes_per_region: Number of Q/DQ insertion schemes to test per region pattern. + Higher values explore more configurations but take longer (default: 30) + pattern_cache_file: Optional path to pattern cache YAML file containing known-good schemes + from previous runs. Enables warm-start optimization (default: None) + state_file: Optional path to state file for checkpoint/resume. If None, automatically + uses /autotuner_state.yaml (default: None) + quant_type: Quantization data type - "int8" for INT8 quantization (default), + "fp8" for FP8 quantization + qdq_baseline_model: Optional path to a pre-quantized ONNX model. If provided, + extracts Q/DQ insertion patterns and adds them to pattern cache + for warm-start (default: None) + + Returns: + Configured QDQAutotuner instance containing: + - All discovered regions and their patterns + - Profiled Q/DQ insertion schemes for each pattern + - Best scheme selections and performance measurements + - Complete optimization state (can be saved/loaded) + + The returned autotuner can be used for: + - Exporting optimized models with best Q/DQ schemes + - Analyzing per-pattern optimization results + - Further refinement or experimentation + - Pattern cache generation for future runs + + Example: + >>> # Initial run + >>> autotuner = region_pattern_autotuning_workflow("model.onnx", Path("./output")) + >>> # Resume from interruption + >>> autotuner = region_pattern_autotuning_workflow("model.onnx", Path("./output")) + >>> # With pattern cache warm-start + >>> autotuner = region_pattern_autotuning_workflow( + ... "model.onnx", Path("./output"), pattern_cache_file="./pattern_cache.yaml" + ... ) + >>> # With QDQ baseline model for pattern import + >>> autotuner = region_pattern_autotuning_workflow( + ... "model.onnx", Path("./output"), qdq_baseline_model="quantized_baseline.onnx" + ... ) + """ + # Setup directories + output_dir.mkdir(parents=True, exist_ok=True) + logs_dir = output_dir / "logs" + logs_dir.mkdir(exist_ok=True) + models_dir = output_dir / "region_models" + models_dir.mkdir(exist_ok=True) + + # Determine state file path + if state_file is None: + state_file = str(output_dir / "autotuner_state.yaml") + state_path = Path(state_file) + + # Load model + logger.info(f"Loading model: {model_path}") + model = onnx.load(model_path) + + # Load pattern cache if provided + pattern_cache = None + if pattern_cache_file: + pattern_cache_path = Path(pattern_cache_file) + if pattern_cache_path.exists(): + pattern_cache = PatternCache.load(str(pattern_cache_path)) + logger.info( + f"Loaded pattern cache: {pattern_cache.num_patterns} patterns, " + f"{pattern_cache.total_schemes} schemes" + ) + else: + logger.warning(f"Pattern cache not found: {pattern_cache_file}") + + # Initialize autotuner with config + logger.info( + f"Initializing autotuner (quant_type={quant_type}, default_dq_dtype={default_dq_dtype})" + ) + config = Config( + default_quant_type=quant_type, + default_dq_dtype=default_dq_dtype, + verbose=True, + ) + + autotuner = QDQAutotuner(model) + autotuner.initialize(config, pattern_cache) + + # Load previous state if exists (resume capability) + if state_path.exists(): + logger.info(f"Resuming from checkpoint: {state_path}") + autotuner.load_state(str(state_path)) + else: + logger.info("Starting new autotuning session") + + # Import quantization patterns from QDQ baseline model if provided + if qdq_baseline_model: + qdq_baseline_path = Path(qdq_baseline_model) + if qdq_baseline_path.exists(): + logger.info(f"Importing patterns from QDQ baseline: {qdq_baseline_model}") + qdq_model = onnx.load(str(qdq_baseline_path)) + + # Extract quantized tensors from baseline model + quantized_tensors = get_quantized_tensors(qdq_model) + logger.debug(f"Found {len(quantized_tensors)} quantized tensors in baseline") + + # Import insertion points into pattern cache + autotuner.import_insertion_points(quantized_tensors) + logger.info("Pattern import complete") + else: + logger.warning(f"QDQ baseline not found: {qdq_baseline_model}") + + # Get discovered regions + regions = autotuner.regions + logger.info(f"Ready to profile {len(regions)} regions") + + # Measure baseline (no Q/DQ) if not already measured + if autotuner.baseline_latency_ms is None: + logger.info("Measuring baseline (no Q/DQ)") + baseline_path = output_dir / "baseline.onnx" + autotuner.export_onnx(str(baseline_path), insert_qdq=False) + baseline_log = logs_dir / "baseline.log" + baseline_latency = benchmark_onnx_model(str(baseline_path), str(baseline_log)) + autotuner.submit(baseline_latency) + logger.info(f"Baseline: {baseline_latency:.2f} ms") + else: + baseline_latency = autotuner.baseline_latency_ms + logger.info(f"Using baseline from checkpoint: {baseline_latency:.2f} ms") + + # Profile regions + logger.info(f"Starting region profiling ({num_schemes_per_region} schemes per region)") + + iteration_count = 0 + + for region_idx, region in enumerate(regions): + logger.info( + f"Region {region_idx + 1}/{len(regions)} (ID={region.id}, level={region.get_level()})" + ) + + # Check if region matches node filter list + if node_filter_list and not _region_matches_filter( + region, autotuner.graph, node_filter_list + ): + logger.info(" Skipping (no nodes match filter patterns)") + continue + + # Set as current profile region + # Commit previous region's results (except for first region) + commit = region_idx > 0 + autotuner.set_profile_region(region, commit=commit) + + # Check if already profiled (from loaded state) + if autotuner.current_profile_pattern_schemes is None: + logger.info(" Skipping (already profiled)") + continue + + # Generate and test schemes for this region + schemes_tested = 0 + for scheme_num in range(num_schemes_per_region): + iteration_count += 1 + + # Generate new scheme + scheme_idx = autotuner.generate() + + if scheme_idx == -1: + logger.debug(f" Stopping at scheme {scheme_num + 1} (no more unique schemes)") + break + + schemes_tested += 1 + + # Export model with this scheme + best from previous regions + model_bytes = autotuner.export_onnx(None, insert_qdq=True) + + # Benchmark with TensorRT + test_log = logs_dir / f"region_{region.id}_scheme_{scheme_idx}.log" + flush_timing_cache = (iteration_count % 10) == 0 + latency = benchmark_onnx_model( + model_bytes, str(test_log), flush_timing_cache=flush_timing_cache + ) + + # Record result + autotuner.submit(latency, success=(latency != float("inf"))) + + # Display region summary + ps = autotuner.current_profile_pattern_schemes + if ps and ps.schemes: + best_scheme = ps.best_scheme + if best_scheme and best_scheme.latency_ms < float("inf") and baseline_latency > 0: + speedup = baseline_latency / best_scheme.latency_ms + logger.info( + f" Tested {schemes_tested} schemes: " + f"best {best_scheme.latency_ms:.2f} ms ({speedup:.3f}x speedup)" + ) + else: + logger.info(f" Tested {schemes_tested} schemes: no valid measurements") + else: + logger.info(f" Tested {schemes_tested} schemes") + + # Save best model for this region (before committing to next region) + region_model_path = models_dir / f"region_{region.id}_level_{region.get_level()}.onnx" + autotuner.export_onnx(str(region_model_path), insert_qdq=True, best=True) + logger.debug(f" Saved best model: {region_model_path.name}") + + # Save state after each region (incremental, crash recovery) + autotuner.save_state(str(state_path)) + logger.debug(" Checkpoint saved") + + # Commit final region + autotuner.set_profile_region(None, commit=True) + + # Export and measure final optimized model + logger.info("Exporting final optimized model") + final_model_path = output_dir / "optimized_final.onnx" + autotuner.export_onnx(str(final_model_path), insert_qdq=True) + final_log = logs_dir / "final.log" + final_latency = benchmark_onnx_model(str(final_model_path), str(final_log)) + + # Display results + if final_latency > 0 and final_latency != float("inf"): + speedup = baseline_latency / final_latency + logger.info( + f"Results: {baseline_latency:.2f} ms → {final_latency:.2f} ms ({speedup:.3f}x speedup)" + ) + else: + logger.info(f"Results: {baseline_latency:.2f} ms → failed (invalid measurement)") + + # Save final state + autotuner.save_state(str(state_path)) + + logger.info("Autotuning complete") + logger.info(f" Final model: {final_model_path}") + logger.info(f" State: {state_path}") + logger.debug(f" Logs: {logs_dir}") + logger.debug(f" Region models: {models_dir}") + + return autotuner diff --git a/tests/unit/onnx/quantization/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/test_autotuner.py new file mode 100644 index 000000000..411256a49 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_autotuner.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for QDQAutotuner class. + +Tests the main autotuner class public API. +Note: Full integration tests with TensorRT benchmarking should be in separate integration test files. +""" + +import os +import sys +import tempfile +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import onnx +import onnx_graphsurgeon as gs +from onnx import helper + +from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern +from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType + + +def create_simple_conv_model(): + """ + Create a simple ONNX model: Input -> Conv -> Relu -> Output. + + This is a minimal model for testing autotuner initialization. + """ + # Input + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + + # Output + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + # Conv node + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + + # Relu node + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node], + "simple_conv", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + return model + + +class TestQDQAutotuner(unittest.TestCase): + """Test QDQAutotuner functionality.""" + + @staticmethod + def _create_test_config(): + """ + Create a reasonable config for testing. + + Uses sensible defaults suitable for unit tests: + - verbose=False: Keep test output clean + - maximum_sequence_region_size=50: Allow larger test regions + - Other parameters: Match Config defaults for typical behavior + """ + return Config( + # Logging + verbose=False, + # Performance Requirements + # Quantization Parameters + default_q_scale=0.1, + default_q_zero_point=0, + default_quant_type="int8", + # Region Builder Settings + maximum_sequence_region_size=50, + minimum_topdown_search_size=10, + # Scheme Generation Settings + top_percent_to_mutate=0.1, + minimum_schemes_to_mutate=10, + maximum_mutations=3, + maximum_generation_attempts=100, + # Pattern Cache Settings + pattern_cache_minimum_distance=4, + pattern_cache_max_entries_per_pattern=32, + ) + + def test_creation_with_onnx_model(self): + """Test creating autotuner with ONNX ModelProto.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + assert autotuner is not None + assert autotuner.onnx_model is not None + assert autotuner.graph is not None + print("✓ QDQAutotuner creation with ONNX model") + + def test_creation_with_gs_graph(self): + """Test creating autotuner with GraphSurgeon graph.""" + model = create_simple_conv_model() + gs_graph = gs.import_onnx(model) + + autotuner = QDQAutotuner(gs_graph) + + assert autotuner is not None + assert autotuner.graph is not None + print("✓ QDQAutotuner creation with GS graph") + + def test_initialize_with_default_config(self): + """Test initialization with default test config.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + autotuner.initialize(config) + + # Should have provided config + assert autotuner.config is not None + assert autotuner.config.maximum_sequence_region_size == 50 + + # Should have discovered regions + assert len(autotuner.regions) > 0 + print("✓ QDQAutotuner initialize with default config") + + def test_initialize_with_config(self): + """Test initialization with custom config (different from default).""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + # Create custom config with different values + config = Config( + verbose=True, + default_q_scale=0.05, + default_q_zero_point=128, + default_quant_type="fp8", + maximum_sequence_region_size=20, + minimum_topdown_search_size=5, + top_percent_to_mutate=0.2, + minimum_schemes_to_mutate=5, + maximum_mutations=5, + maximum_generation_attempts=50, + pattern_cache_minimum_distance=2, + pattern_cache_max_entries_per_pattern=16, + ) + autotuner.initialize(config) + + # Should use provided custom config values + assert autotuner.config.verbose + assert autotuner.config.default_q_scale == 0.05 + assert autotuner.config.default_q_zero_point == 128 + assert autotuner.config.default_quant_type == "fp8" + assert autotuner.config.maximum_sequence_region_size == 20 + assert autotuner.config.minimum_topdown_search_size == 5 + assert autotuner.config.top_percent_to_mutate == 0.2 + assert autotuner.config.minimum_schemes_to_mutate == 5 + assert autotuner.config.maximum_mutations == 5 + assert autotuner.config.maximum_generation_attempts == 50 + assert autotuner.config.pattern_cache_minimum_distance == 2 + assert autotuner.config.pattern_cache_max_entries_per_pattern == 16 + print("✓ QDQAutotuner initialize with config") + + def test_initialize_with_pattern_cache(self): + """Test initialization with pattern cache.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + pattern_cache = PatternCache() + autotuner.initialize(config, pattern_cache=pattern_cache) + + assert autotuner.pattern_cache is not None + print("✓ QDQAutotuner initialize with pattern cache") + + def test_region_discovery(self): + """Test that regions are automatically discovered.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + autotuner.initialize(config) + + # Should discover at least one region + assert len(autotuner.regions) > 0 + + # Regions should be valid + for region in autotuner.regions: + assert region.get_id() is not None + assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] + + print("✓ QDQAutotuner region discovery") + + def test_export_baseline_model(self): + """Test exporting baseline model without Q/DQ.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + output_path = f.name + + try: + # Export baseline without Q/DQ insertion + autotuner.export_onnx(output_path, insert_qdq=False) + + # Verify file was created + assert os.path.exists(output_path) + + # Verify it's a valid ONNX model + exported_model = onnx.load(output_path) + assert exported_model is not None + print("✓ QDQAutotuner export baseline model") + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def test_set_profile_region(self): + """Test setting a region for profiling.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + # Should set current profile region + assert autotuner.current_profile_region == region + assert autotuner.current_profile_pattern_schemes is not None + print("✓ QDQAutotuner set_profile_region") + else: + self.skipTest("No regions discovered") + + def test_generate_scheme(self): + """Test generating an insertion scheme.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + # Generate a scheme + scheme_idx = autotuner.generate() + + # Should return a valid index (>= 0) or -1 if no more unique schemes + assert isinstance(scheme_idx, int) + print("✓ QDQAutotuner generate scheme") + else: + self.skipTest("No regions discovered") + + def test_submit_latency(self): + """Test submitting performance measurement.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Submit baseline latency + autotuner.submit(10.5) + + # Baseline should be recorded + assert autotuner.baseline_latency_ms == 10.5 + print("✓ QDQAutotuner submit latency") + + def test_save_and_load_state(self): + """Test saving and loading autotuner state.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Submit some results + autotuner.submit(10.5) # baseline + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + state_path = f.name + + try: + # Save state + autotuner.save_state(state_path) + assert os.path.exists(state_path) + + # Create new autotuner and load state + autotuner2 = QDQAutotuner(model) + config2 = self._create_test_config() + autotuner2.initialize(config2) + autotuner2.load_state(state_path) + + # Baseline should match + assert autotuner2.baseline_latency_ms == 10.5 + print("✓ QDQAutotuner save and load state") + finally: + if os.path.exists(state_path): + os.unlink(state_path) + + def test_regions_prioritization(self): + """Test that LEAF regions are prioritized.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Check that LEAF regions come before non-LEAF + leaf_indices = [ + i for i, r in enumerate(autotuner.regions) if r.get_type() == RegionType.LEAF + ] + non_leaf_indices = [ + i for i, r in enumerate(autotuner.regions) if r.get_type() != RegionType.LEAF + ] + + if leaf_indices and non_leaf_indices: + # All LEAF should come before non-LEAF + assert max(leaf_indices) < min(non_leaf_indices) + print("✓ QDQAutotuner LEAF region prioritization") + else: + print("✓ QDQAutotuner regions (not enough for prioritization test)") + + def test_profiled_patterns_tracking(self): + """Test that profiled patterns are tracked.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Submit baseline latency first + autotuner.submit(10.0) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + # Generate and submit a scheme + scheme_idx = autotuner.generate() + if scheme_idx >= 0: + autotuner.submit(12.0) + autotuner.set_profile_region(None, commit=True) + + # Pattern should be tracked + pattern_sig = RegionPattern.from_region(region, autotuner.graph).signature + profiled_patterns = [p.pattern.signature for p in autotuner.profiled_patterns] + assert pattern_sig in profiled_patterns + print("✓ QDQAutotuner profiled patterns tracking") + else: + print("✓ QDQAutotuner (no schemes to test tracking)") + else: + self.skipTest("No regions discovered") + + +def run_tests(): + """Run all QDQAutotuner tests.""" + print("=" * 70) + print("QDQAutotuner Test Suite") + print("=" * 70) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTests(loader.loadTestsFromTestCase(TestQDQAutotuner)) + + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + print("\n" + "=" * 70) + print("Test Summary") + print("=" * 70) + print(f"Tests run: {result.testsRun}") + print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + + if result.wasSuccessful(): + print("\n✓ All QDQAutotuner tests passed!") + return 0 + else: + print("\n✗ Some tests failed") + return 1 + + +if __name__ == "__main__": + sys.exit(run_tests()) diff --git a/tests/unit/onnx/quantization/autotune/test_config.py b/tests/unit/onnx/quantization/autotune/test_config.py new file mode 100644 index 000000000..db6b02aa3 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_config.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the Config class in the autotuner. + +Tests configuration parameter validation and defaults. +""" + +import os +import sys +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from modelopt.onnx.quantization.autotune.common import Config + + +class TestConfig(unittest.TestCase): + """Test Config class functionality.""" + + def test_default_values(self): + """Test that Config has correct default values.""" + config = Config() + + # Logging + assert not config.verbose + + # Performance thresholds + + # Q/DQ defaults + assert config.default_q_scale == 0.1 + assert config.default_q_zero_point == 0 + assert config.default_quant_type == "int8" + + # Region builder settings + assert config.maximum_sequence_region_size == 10 + assert config.minimum_topdown_search_size == 10 + + # Scheme generation parameters + assert config.top_percent_to_mutate == 0.1 + assert config.minimum_schemes_to_mutate == 10 + assert config.maximum_mutations == 3 + assert config.maximum_generation_attempts == 100 + + # Pattern cache parameters + assert config.pattern_cache_minimum_distance == 4 + assert config.pattern_cache_max_entries_per_pattern == 32 + + print("✓ Config default values are correct") + + def test_custom_values(self): + """Test creating Config with custom values.""" + config = Config( + verbose=True, + default_q_scale=0.05, + default_q_zero_point=128, + default_quant_type="fp8", + maximum_sequence_region_size=20, + ) + + assert config.verbose + assert config.default_q_scale == 0.05 + assert config.default_q_zero_point == 128 + assert config.default_quant_type == "fp8" + assert config.maximum_sequence_region_size == 20 + print("✓ Config custom values work correctly") + + def test_region_size_validation(self): + """Test that region size parameters are positive.""" + config = Config(maximum_sequence_region_size=50, minimum_topdown_search_size=5) + assert config.maximum_sequence_region_size > 0 + assert config.minimum_topdown_search_size > 0 + print("✓ Config region size validation") + + def test_genetic_algorithm_params(self): + """Test genetic algorithm parameters.""" + config = Config( + top_percent_to_mutate=0.2, + minimum_schemes_to_mutate=2, + maximum_mutations=5, + maximum_generation_attempts=50, + ) + + assert config.top_percent_to_mutate == 0.2 + assert config.minimum_schemes_to_mutate == 2 + assert config.maximum_mutations == 5 + assert config.maximum_generation_attempts == 50 + print("✓ Config genetic algorithm parameters") + + def test_pattern_cache_params(self): + """Test pattern cache parameters.""" + config = Config(pattern_cache_minimum_distance=3, pattern_cache_max_entries_per_pattern=10) + + assert config.pattern_cache_minimum_distance == 3 + assert config.pattern_cache_max_entries_per_pattern == 10 + print("✓ Config pattern cache parameters") + + +def run_tests(): + """Run all Config tests.""" + print("=" * 70) + print("Config Class Test Suite") + print("=" * 70) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTests(loader.loadTestsFromTestCase(TestConfig)) + + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + print("\n" + "=" * 70) + print("Test Summary") + print("=" * 70) + print(f"Tests run: {result.testsRun}") + print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + + if result.wasSuccessful(): + print("\n✓ All Config tests passed!") + return 0 + else: + print("\n✗ Some tests failed") + return 1 + + +if __name__ == "__main__": + sys.exit(run_tests())