diff --git a/sab/clock_watch.py b/sab/clock_watch.py index 8efbd04..2b3bd10 100644 --- a/sab/clock_watch.py +++ b/sab/clock_watch.py @@ -97,11 +97,12 @@ def emit_clock_changes(): yield sm.value, mem.value, reason_txt class ThrottleMonitor: - def __init__(self, target_freq: int|None=None): + def __init__(self, target_freq: int|None=None, is_jetson: bool = False): self._throttle_detected = False self._target_freq = target_freq self._stop_thread = False self._thread = None + self._is_jetson = is_jetson def _check_for_throttling(self): # Get the generator @@ -160,6 +161,9 @@ def stop(self): # enable_persistence(False) # unlock_clocks() def __enter__(self): + if self._is_jetson: + return self + gpu_clock, mem_clock = get_max_clocks() enable_persistence(True) lock_clocks(gpu_clock, mem_clock) @@ -167,6 +171,9 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + if self._is_jetson: + return + self.stop() enable_persistence(False) unlock_clocks() @@ -258,4 +265,4 @@ def main(): f"SM={sm} MHz MEM={mem} MHz | {reason_txt}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/sab/models/utils.py b/sab/models/utils.py index 85d7d8c..76f4e2d 100644 --- a/sab/models/utils.py +++ b/sab/models/utils.py @@ -55,6 +55,7 @@ def __init__(self, max_images: int|None = None, graph_surgery_func: Callable[str, str]|None = None, max_dets: int = 100, + is_jetson: bool = False, ): self.onnx_path = onnx_path self.inference_class = inference_class @@ -64,6 +65,7 @@ def __init__(self, self.max_images = max_images self.graph_surgery_func = graph_surgery_func self.max_dets = max_dets + self.is_jetson = is_jetson def dump(self): return { @@ -100,7 +102,7 @@ def run_benchmark_on_artifact(artifact_request: ArtifactBenchmarkRequest, images if not os.path.exists(engine_path): print(f"Building engine for {artifact_request.onnx_path} and saving to {engine_path}...") - with ThrottleMonitor() as throttle_monitor: + with ThrottleMonitor(is_jetson=artifact_request.is_jetson) as throttle_monitor: build_engine(artifact_request.onnx_path, engine_path, use_fp16=artifact_request.needs_fp16) if throttle_monitor.did_throttle(): print("GPU throttled during engine build. This is expected and is a limitation of TensorRT.") @@ -127,7 +129,7 @@ def run_benchmark_on_artifact(artifact_request: ArtifactBenchmarkRequest, images else: print("CPU frequency stable during evaluation. Latency numbers should be reliable.") else: - with ThrottleMonitor() as throttle_monitor: + with ThrottleMonitor(is_jetson=artifact_request.is_jetson) as throttle_monitor: accuracy_stats = evaluate(inference, images_dir, annotations_file_path, inv_class_mapping, buffer_time=artifact_request.buffer_time, max_images=artifact_request.max_images, max_dets=artifact_request.max_dets) if throttle_monitor.did_throttle(): throttled = True @@ -230,4 +232,4 @@ def _fmt(x, width=6, prec=1): ar_s = _pct(stats, 9) ar_m = _pct(stats, 10) ar_l = _pct(stats, 11) - print(f"{model:30} {_fmt(ar1)} {_fmt(ar10)} {_fmt(armax_dets,13)} {_fmt(ar_s)} {_fmt(ar_m)} {_fmt(ar_l)}") \ No newline at end of file + print(f"{model:30} {_fmt(ar1)} {_fmt(ar10)} {_fmt(armax_dets,13)} {_fmt(ar_s)} {_fmt(ar_m)} {_fmt(ar_l)}") diff --git a/sab/trt_inference.py b/sab/trt_inference.py index 0e2e9ad..9c843bd 100644 --- a/sab/trt_inference.py +++ b/sab/trt_inference.py @@ -1,18 +1,42 @@ +import os import tensorrt as trt import torch import numpy as np +from pathlib import Path + from sab.profiler import CUDAProfiler +TIMING_CACHE_PATH = Path.home() / ".cache" / "tensorrt" / "timing.cache" + + +def _load_timing_cache(config): + if TIMING_CACHE_PATH.exists(): + print(f"Loading timing cache from {TIMING_CACHE_PATH}") + cache = config.create_timing_cache(TIMING_CACHE_PATH.read_bytes()) + else: + print("No existing timing cache found, creating new one") + cache = config.create_timing_cache(b"") + config.set_timing_cache(cache, ignore_mismatch=False) + return cache + + +def _save_timing_cache(cache): + TIMING_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True) + TIMING_CACHE_PATH.write_bytes(cache.serialize()) + print(f"Timing cache saved to {TIMING_CACHE_PATH}") + def build_engine(model_path, engine_path, use_fp16=False): logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) - + config = builder.create_builder_config() if use_fp16: config.set_flag(trt.BuilderFlag.FP16) + timing_cache = _load_timing_cache(config) + EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(EXPLICIT_BATCH) @@ -20,7 +44,7 @@ def build_engine(model_path, engine_path, use_fp16=False): with open(model_path, "rb") as f: model_data = f.read() - + if not parser.parse(model_data): print("Failed to parse ONNX model") for error in range(parser.num_errors): @@ -29,19 +53,19 @@ def build_engine(model_path, engine_path, use_fp16=False): # Create optimization profile to fix dynamic batch dimensions profile = builder.create_optimization_profile() - + # Handle dynamic input shapes - fix batch size to 1 for i in range(network.num_inputs): input_tensor = network.get_input(i) input_shape = input_tensor.shape print(f"Input {i} ({input_tensor.name}): {input_shape}") - + # Check if batch dimension is dynamic (typically -1) if input_shape[0] == -1: # Fix batch size to 1 fixed_shape = (1,) + tuple(input_shape[1:]) print(f" Setting fixed batch shape: {fixed_shape}") - + # Set min, optimal, and max shapes all to batch size 1 profile.set_shape(input_tensor.name, fixed_shape, fixed_shape, fixed_shape) @@ -50,13 +74,15 @@ def build_engine(model_path, engine_path, use_fp16=False): print(f"Building engine from {model_path} to {engine_path}") engine = builder.build_serialized_network(network, config) - + if engine is None: print("Failed to build engine") return None - + print(f"Engine built successfully") + _save_timing_cache(timing_cache) + with open(engine_path, "wb") as f: f.write(engine)