Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions sab/clock_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ def emit_clock_changes():
yield sm.value, mem.value, reason_txt

class ThrottleMonitor:
def __init__(self, target_freq: int|None=None):
def __init__(self, target_freq: int|None=None, is_jetson: bool = False):
self._throttle_detected = False
self._target_freq = target_freq
self._stop_thread = False
self._thread = None
self._is_jetson = is_jetson

def _check_for_throttling(self):
# Get the generator
Expand Down Expand Up @@ -160,13 +161,19 @@ def stop(self):
# enable_persistence(False)
# unlock_clocks()
def __enter__(self):
if self._is_jetson:
return self

gpu_clock, mem_clock = get_max_clocks()
enable_persistence(True)
lock_clocks(gpu_clock, mem_clock)
self.monitor_throttling(gpu_clock)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self._is_jetson:
return

self.stop()
enable_persistence(False)
unlock_clocks()
Expand Down Expand Up @@ -258,4 +265,4 @@ def main():
f"SM={sm} MHz MEM={mem} MHz | {reason_txt}")

if __name__ == "__main__":
main()
main()
8 changes: 5 additions & 3 deletions sab/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self,
max_images: int|None = None,
graph_surgery_func: Callable[str, str]|None = None,
max_dets: int = 100,
is_jetson: bool = False,
):
self.onnx_path = onnx_path
self.inference_class = inference_class
Expand All @@ -64,6 +65,7 @@ def __init__(self,
self.max_images = max_images
self.graph_surgery_func = graph_surgery_func
self.max_dets = max_dets
self.is_jetson = is_jetson

def dump(self):
return {
Expand Down Expand Up @@ -100,7 +102,7 @@ def run_benchmark_on_artifact(artifact_request: ArtifactBenchmarkRequest, images

if not os.path.exists(engine_path):
print(f"Building engine for {artifact_request.onnx_path} and saving to {engine_path}...")
with ThrottleMonitor() as throttle_monitor:
with ThrottleMonitor(is_jetson=artifact_request.is_jetson) as throttle_monitor:
build_engine(artifact_request.onnx_path, engine_path, use_fp16=artifact_request.needs_fp16)
if throttle_monitor.did_throttle():
print("GPU throttled during engine build. This is expected and is a limitation of TensorRT.")
Expand All @@ -127,7 +129,7 @@ def run_benchmark_on_artifact(artifact_request: ArtifactBenchmarkRequest, images
else:
print("CPU frequency stable during evaluation. Latency numbers should be reliable.")
else:
with ThrottleMonitor() as throttle_monitor:
with ThrottleMonitor(is_jetson=artifact_request.is_jetson) as throttle_monitor:
accuracy_stats = evaluate(inference, images_dir, annotations_file_path, inv_class_mapping, buffer_time=artifact_request.buffer_time, max_images=artifact_request.max_images, max_dets=artifact_request.max_dets)
if throttle_monitor.did_throttle():
throttled = True
Expand Down Expand Up @@ -230,4 +232,4 @@ def _fmt(x, width=6, prec=1):
ar_s = _pct(stats, 9)
ar_m = _pct(stats, 10)
ar_l = _pct(stats, 11)
print(f"{model:30} {_fmt(ar1)} {_fmt(ar10)} {_fmt(armax_dets,13)} {_fmt(ar_s)} {_fmt(ar_m)} {_fmt(ar_l)}")
print(f"{model:30} {_fmt(ar1)} {_fmt(ar10)} {_fmt(armax_dets,13)} {_fmt(ar_s)} {_fmt(ar_m)} {_fmt(ar_l)}")
40 changes: 33 additions & 7 deletions sab/trt_inference.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,50 @@
import os
import tensorrt as trt
import torch
import numpy as np

from pathlib import Path

from sab.profiler import CUDAProfiler

TIMING_CACHE_PATH = Path.home() / ".cache" / "tensorrt" / "timing.cache"


def _load_timing_cache(config):
if TIMING_CACHE_PATH.exists():
print(f"Loading timing cache from {TIMING_CACHE_PATH}")
cache = config.create_timing_cache(TIMING_CACHE_PATH.read_bytes())
else:
print("No existing timing cache found, creating new one")
cache = config.create_timing_cache(b"")
config.set_timing_cache(cache, ignore_mismatch=False)
return cache


def _save_timing_cache(cache):
TIMING_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
TIMING_CACHE_PATH.write_bytes(cache.serialize())
print(f"Timing cache saved to {TIMING_CACHE_PATH}")


def build_engine(model_path, engine_path, use_fp16=False):
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)

config = builder.create_builder_config()
if use_fp16:
config.set_flag(trt.BuilderFlag.FP16)

timing_cache = _load_timing_cache(config)

EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)

parser = trt.OnnxParser(network, logger)

with open(model_path, "rb") as f:
model_data = f.read()

if not parser.parse(model_data):
print("Failed to parse ONNX model")
for error in range(parser.num_errors):
Expand All @@ -29,19 +53,19 @@ def build_engine(model_path, engine_path, use_fp16=False):

# Create optimization profile to fix dynamic batch dimensions
profile = builder.create_optimization_profile()

# Handle dynamic input shapes - fix batch size to 1
for i in range(network.num_inputs):
input_tensor = network.get_input(i)
input_shape = input_tensor.shape
print(f"Input {i} ({input_tensor.name}): {input_shape}")

# Check if batch dimension is dynamic (typically -1)
if input_shape[0] == -1:
# Fix batch size to 1
fixed_shape = (1,) + tuple(input_shape[1:])
print(f" Setting fixed batch shape: {fixed_shape}")

# Set min, optimal, and max shapes all to batch size 1
profile.set_shape(input_tensor.name, fixed_shape, fixed_shape, fixed_shape)

Expand All @@ -50,13 +74,15 @@ def build_engine(model_path, engine_path, use_fp16=False):

print(f"Building engine from {model_path} to {engine_path}")
engine = builder.build_serialized_network(network, config)

if engine is None:
print("Failed to build engine")
return None

print(f"Engine built successfully")

_save_timing_cache(timing_cache)

with open(engine_path, "wb") as f:
f.write(engine)

Expand Down