diff --git a/development/stream_interface/rfdetr_nano_seg_trt_workflow.py b/development/stream_interface/rfdetr_nano_seg_trt_workflow.py new file mode 100644 index 0000000000..9c213a8639 --- /dev/null +++ b/development/stream_interface/rfdetr_nano_seg_trt_workflow.py @@ -0,0 +1,116 @@ +"""Minimal benchmark: RF-DETR instance segmentation through inference-models, +run via InferencePipeline on a single video source. + +Workflow has exactly one block — the segmentation model. No annotators, no +buffer strategies, no rate limiting. + +The `--backend` flag (trt | onnx | torch) is parsed before importing +`inference` and pins the auto-loader by setting +`DISABLED_INFERENCE_MODELS_BACKENDS` to every backend except the chosen one, +so the benchmark numbers correspond unambiguously to a single execution path. + +Defaults: rfdetr-seg-nano @ confidence 0.4 on the native TRT backend. +""" +import argparse +import os + +_ALL_BACKENDS = { + "torch", + "torch-script", + "onnx", + "trt", + "hugging-face", + "ultralytics", + "mediapipe", + "custom", +} +def _select_backend_from_argv() -> str: + pre = argparse.ArgumentParser(add_help=False) + pre.add_argument("--backend", choices=("trt", "onnx", "torch"), default="trt") + args, _ = pre.parse_known_args() + return args.backend + + +_BACKEND = _select_backend_from_argv() +os.environ.setdefault( + "ONNXRUNTIME_EXECUTION_PROVIDERS", + "[TensorrtExecutionProvider,CUDAExecutionProvider,CPUExecutionProvider]", +) +os.environ["DISABLED_INFERENCE_MODELS_BACKENDS"] = ",".join( + sorted(_ALL_BACKENDS - {_BACKEND}) +) + +from time import perf_counter + +from inference import InferencePipeline + + +def build_workflow(model_id: str, confidence: float) -> dict: + return { + "version": "1.0", + "inputs": [{"type": "WorkflowImage", "name": "image"}], + "steps": [ + { + "type": "roboflow_core/roboflow_instance_segmentation_model@v3", + "name": "segmentation", + "images": "$inputs.image", + "model_id": model_id, + "confidence_mode": "custom", + "custom_confidence": confidence, + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "predictions", + "selector": "$steps.segmentation.predictions", + }, + ], + } + +FRAME_COUNT = 0 +START_TIME = None +PROGRESS_EVERY = 50 + + +def sink(predictions, _video_frames) -> None: + global FRAME_COUNT, START_TIME + del _video_frames + if not isinstance(predictions, list): + predictions = [predictions] + FRAME_COUNT += sum(p is not None for p in predictions) + if START_TIME is None: + START_TIME = perf_counter() + if FRAME_COUNT % PROGRESS_EVERY == 0: + fps = FRAME_COUNT / (perf_counter() - START_TIME) + print(f"[progress] frames={FRAME_COUNT} fps={fps:.2f}", flush=True) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--video_reference", required=True) + parser.add_argument("--model_id", default="rfdetr-seg-nano") + parser.add_argument("--confidence", type=float, default=0.4) + parser.add_argument( + "--backend", + choices=("trt", "onnx", "torch"), + default="trt", + help="inference-models backend (consumed pre-import via env var).", + ) + args = parser.parse_args() + + pipeline = InferencePipeline.init_with_workflow( + video_reference=args.video_reference, + workflow_specification=build_workflow(args.model_id, args.confidence), + on_prediction=sink, + ) + pipeline.start() + pipeline.join() + + elapsed = perf_counter() - START_TIME if START_TIME else 0.0 + fps = FRAME_COUNT / elapsed if elapsed > 0 else 0.0 + print(f"frames={FRAME_COUNT} elapsed={elapsed:.2f}s fps={fps:.2f}") + + +if __name__ == "__main__": + main() diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 08f5f59d70..c32c44c0fe 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -1,5 +1,6 @@ import base64 import io +import os from io import BytesIO from time import perf_counter from typing import Any, List, Optional, Tuple, Union @@ -8,6 +9,26 @@ import torch from PIL import Image, ImageDraw, ImageFont + +# Cache of pinned host buffers for async DtoH, keyed by (name, dtype). +# Pinned memory lets torch's copy_(non_blocking=True) actually run async. +# We grow on first use and reuse thereafter; buffer is sliced to the +# current n_survivors for each copy. +_PINNED_HOST_BUFFERS: dict = {} + + +def _get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor: + shape = tuple(int(s) for s in shape) + key = (name, dtype) + buf = _PINNED_HOST_BUFFERS.get(key) + if buf is not None: + # Reuse if the cached buffer is at least as large in every dim. + if all(buf.shape[i] >= shape[i] for i in range(len(shape))): + return buf[tuple(slice(0, s) for s in shape)] + buf = torch.empty(shape, dtype=dtype, pin_memory=True) + _PINNED_HOST_BUFFERS[key] = buf + return buf + from inference.core.entities.requests import ( ClassificationInferenceRequest, InferenceRequest, @@ -308,17 +329,51 @@ def postprocess( detections_list = self._model.post_process( predictions, preprocess_return_metadata, **mapped_kwargs ) + gpu_fastpath = os.getenv("RFDETR_GPU_POSTPROCESS", "true").lower() in ("true", "1") responses: List[InstanceSegmentationInferenceResponse] = [] for preproc_metadata, det in zip(preprocess_return_metadata, detections_list): H = preproc_metadata.original_size.height W = preproc_metadata.original_size.width - xyxy = det.xyxy.detach().cpu().numpy() - confs = det.confidence.detach().cpu().numpy() - masks = det.mask.detach().cpu().numpy() + # Fast path: triton_rfdetr_fullpost returns padded buffers plus + # the atomic counter. One .cpu() for the combined (N, 6) int32 + # scalar buffer, one .cpu() for the mask buffer, one .item() to + # get n_survivors (the counter DtoH doubles as the sync for the + # in-flight .cpu() calls because they're on the same stream). + # Fast path: single .cpu() of the combined (n, 6) int32 buffer + # plus one .cpu() for the mask buffer. Use pinned host buffers + + # non_blocking=True so both DtoH transfers pipeline on the copy + # engine, then sync once at the end. + combined_gpu = getattr(det, "_combined_gpu", None) + if combined_gpu is not None and det.mask.is_cuda: + mask_gpu = det.mask + combined_host = _get_pinned_buffer( + "combined", combined_gpu.shape, combined_gpu.dtype + ) + mask_host = _get_pinned_buffer( + "mask", mask_gpu.shape, mask_gpu.dtype + ) + combined_host.copy_(combined_gpu, non_blocking=True) + mask_host.copy_(mask_gpu, non_blocking=True) + torch.cuda.current_stream(combined_gpu.device).synchronize() + combined_cpu = combined_host.numpy() + xyxy = combined_cpu[:, :4] + confs = combined_cpu[:, 4].view(np.float32) + class_ids = combined_cpu[:, 5] + masks = mask_host.numpy() + elif combined_gpu is not None: + combined_cpu = combined_gpu.detach().cpu().numpy() + xyxy = combined_cpu[:, :4] + confs = combined_cpu[:, 4].view(np.float32) + class_ids = combined_cpu[:, 5] + masks = det.mask.detach().cpu().numpy() + else: + xyxy = det.xyxy.detach().cpu().numpy() + confs = det.confidence.detach().cpu().numpy() + class_ids = det.class_id.detach().cpu().numpy() + masks = det.mask.detach().cpu().numpy() polys = masks2poly(masks) - class_ids = det.class_id.detach().cpu().numpy() predictions: List[InstanceSegmentationPrediction] = [] diff --git a/inference/models/rfdetr/rfdetr.py b/inference/models/rfdetr/rfdetr.py index 400e7574cb..11cd3dbe21 100644 --- a/inference/models/rfdetr/rfdetr.py +++ b/inference/models/rfdetr/rfdetr.py @@ -47,6 +47,7 @@ run_session_via_iobinding, ) from inference.core.utils.postprocess import mask2poly +from inference.core.utils.environment import str2bool from inference.core.utils.preprocess import letterbox_image if USE_PYTORCH_FOR_PREPROCESSING: @@ -54,6 +55,29 @@ CUDA_IS_AVAILABLE = torch.cuda.is_available() +RFDETR_USE_TRITON_PREPROC = str2bool(os.getenv("RFDETR_USE_TRITON_PREPROC", False)) + +if RFDETR_USE_TRITON_PREPROC: + try: + import torch + + from inference.models.rfdetr.triton_preprocess import ( + TRITON_AVAILABLE, + triton_preprocess_rfdetr, + ) + + _TRITON_PREPROC_READY = TRITON_AVAILABLE and torch.cuda.is_available() + except Exception as _triton_err: # pragma: no cover + logger.warning( + "RFDETR_USE_TRITON_PREPROC=1 but triton preprocess unavailable: %s", + _triton_err, + ) + _TRITON_PREPROC_READY = False + triton_preprocess_rfdetr = None +else: + _TRITON_PREPROC_READY = False + triton_preprocess_rfdetr = None + ROBOFLOW_BACKGROUND_CLASS = "background_class83422" @@ -73,6 +97,25 @@ class RFDETRObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel): preprocess_means = [0.485, 0.456, 0.406] preprocess_stds = [0.229, 0.224, 0.225] + _preproc_cache_initialized: bool = False + + def _ensure_preproc_cache(self, device) -> None: + if self._preproc_cache_initialized: + return + means = torch.tensor(self.preprocess_means, device=device, dtype=torch.float32) + stds = torch.tensor(self.preprocess_stds, device=device, dtype=torch.float32) + self._preproc_scale_gpu = (1.0 / (255.0 * stds)).view(1, 3, 1, 1).contiguous() + self._preproc_offset_gpu = (-means / stds).view(1, 3, 1, 1).contiguous() + self._input_buffer = torch.empty( + (1, 3, self.img_size_h, self.img_size_w), + dtype=torch.float32, + device=device, + ) + self._preproc_bgr2rgb_idx = torch.tensor( + [2, 1, 0], device=device, dtype=torch.long + ) + self._preproc_cache_initialized = True + @property def weights_file(self) -> str: """Gets the weights file for the RFDETR model. @@ -82,6 +125,38 @@ def weights_file(self) -> str: """ return "weights.onnx" + def _try_triton_preprocess( + self, image: Any + ) -> Optional[Tuple["torch.Tensor", Tuple[int, int]]]: + if isinstance(image, np.ndarray): + src = image + elif isinstance(image, InferenceRequestImage): + try: + src, _ = load_image(image, disable_preproc_auto_orient=True) + except Exception: + return None + if not isinstance(src, np.ndarray): + return None + else: + return None + + if src.dtype != np.uint8 or src.ndim != 3 or src.shape[2] != 3: + return None + if self._needs_nonsquare_preproc: + return None + + orig_h, orig_w = src.shape[0], src.shape[1] + src_gpu = torch.from_numpy(np.ascontiguousarray(src)).cuda(non_blocking=True) + out = triton_preprocess_rfdetr( + src_gpu, + target_h=self.img_size_h, + target_w=self.img_size_w, + means=tuple(self.preprocess_means), + stds=tuple(self.preprocess_stds), + pad_color=114, + ) + return out, (orig_h, orig_w) + def preproc_image( self, image: Union[Any, InferenceRequestImage], @@ -103,6 +178,15 @@ def preproc_image( Returns: Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size. """ + if ( + _TRITON_PREPROC_READY + and USE_PYTORCH_FOR_PREPROCESSING + and self.resize_method in ("Fit (grey edges) in", "Stretch to") + ): + triton_out = self._try_triton_preprocess(image) + if triton_out is not None: + return triton_out + if isinstance(image, Image.Image) and USE_PYTORCH_FOR_PREPROCESSING: if CUDA_IS_AVAILABLE: np_image = torch.from_numpy(np.asarray(image, copy=False)).cuda() @@ -135,15 +219,12 @@ def preproc_image( ) preprocessed_image = preprocessed_image.float() - preprocessed_image /= 255.0 - - means = torch.tensor( - self.preprocess_means, device=preprocessed_image.device - ).view(3, 1, 1) - stds = torch.tensor( - self.preprocess_stds, device=preprocessed_image.device - ).view(3, 1, 1) - preprocessed_image = (preprocessed_image - means) / stds + self._ensure_preproc_cache(preprocessed_image.device) + preprocessed_image = torch.addcmul( + self._preproc_offset_gpu, + preprocessed_image, + self._preproc_scale_gpu, + ) else: preprocessed_image = preprocessed_image.astype(np.float32) preprocessed_image /= 255.0 @@ -224,14 +305,22 @@ def preproc_image( if isinstance(resized, np.ndarray): resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) else: - resized = resized[:, [2, 1, 0], :, :] + resized = resized.index_select(1, self._preproc_bgr2rgb_idx) if isinstance(resized, np.ndarray): img_in = np.transpose(resized, (2, 0, 1)) img_in = img_in.astype(np.float32) img_in = np.expand_dims(img_in, axis=0) elif USE_PYTORCH_FOR_PREPROCESSING: - img_in = resized.float() + if ( + not self.batching_enabled + and self._preproc_cache_initialized + and resized.shape == self._input_buffer.shape + ): + self._input_buffer.copy_(resized, non_blocking=True) + img_in = self._input_buffer + else: + img_in = resized else: raise ValueError( f"Received an image of unknown type, {type(resized)}; " diff --git a/inference/models/rfdetr/triton_preprocess.py b/inference/models/rfdetr/triton_preprocess.py new file mode 100644 index 0000000000..a544b23409 --- /dev/null +++ b/inference/models/rfdetr/triton_preprocess.py @@ -0,0 +1,272 @@ +"""Fused Triton preprocessing kernel for RF-DETR. + +Combines letterbox resize (bilinear) + BGR->RGB + normalize + CHW/NCHW +layout in a single CUDA kernel launch. + +Reference (torch/numpy) pipeline this replaces: + from_numpy(uint8 HWC BGR).cuda() + -> permute(HWC -> CHW) + -> .contiguous().float() / 255 + -> subtract means / divide stds + -> interpolate (bilinear resize keeping aspect ratio) + -> pad with grey 114 (letterbox) + -> BGR -> RGB channel swap + -> unsqueeze(0) + +Each of those torch ops launches at least one CUDA kernel; fusing them +eliminates that overhead for small images where per-launch cost dominates. +""" + +from typing import Optional, Tuple + +import torch + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: # pragma: no cover - optional dep + triton = None + tl = None + TRITON_AVAILABLE = False + + +if TRITON_AVAILABLE: + + @triton.jit + def _rfdetr_preprocess_kernel( + src_ptr, + dst_ptr, + src_h, + src_w, + src_stride_h, + src_stride_w, + target_h, + target_w, + scale, + pad_x, + pad_y, + mean_r, + mean_g, + mean_b, + std_r, + std_g, + std_b, + pad_r_norm, + pad_g_norm, + pad_b_norm, + dst_stride_c, + dst_stride_h, + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, + ): + """One kernel per (tile_y, tile_x) over the target image; writes all + 3 channels in a single pass. + + Src layout: uint8 HWC BGR, strides are (src_stride_h, src_stride_w, 1). + Dst layout: fp32 (1, 3, target_h, target_w) contiguous CHW RGB. + """ + pid_y = tl.program_id(0) + pid_x = tl.program_id(1) + + offs_y = pid_y * BLOCK_H + tl.arange(0, BLOCK_H) + offs_x = pid_x * BLOCK_W + tl.arange(0, BLOCK_W) + mask_y = offs_y < target_h + mask_x = offs_x < target_w + mask = mask_y[:, None] & mask_x[None, :] + + # Inverse letterbox: map output pixel -> source pixel coordinate. + src_y_f = (offs_y.to(tl.float32) + 0.5 - pad_y) / scale - 0.5 + src_x_f = (offs_x.to(tl.float32) + 0.5 - pad_x) / scale - 0.5 + + src_y_f_2d = src_y_f[:, None] + src_x_f_2d = src_x_f[None, :] + + y0 = tl.floor(src_y_f_2d).to(tl.int32) + x0 = tl.floor(src_x_f_2d).to(tl.int32) + y1 = y0 + 1 + x1 = x0 + 1 + + dy = src_y_f_2d - y0.to(tl.float32) + dx = src_x_f_2d - x0.to(tl.float32) + + # Clamp to source bounds for the gather; we mask the fully-out-of-bounds + # tiles at the end (pad region). + y0c = tl.maximum(tl.minimum(y0, src_h - 1), 0) + y1c = tl.maximum(tl.minimum(y1, src_h - 1), 0) + x0c = tl.maximum(tl.minimum(x0, src_w - 1), 0) + x1c = tl.maximum(tl.minimum(x1, src_w - 1), 0) + + # Output-pixel-in-pad-region iff the *center* maps outside the resized + # image footprint. Using src_y_f_2d (post-center shift) is fine since + # we compare against [-0.5, src_h - 0.5]. + in_bounds = ( + (src_y_f_2d >= -0.5) + & (src_y_f_2d <= src_h.to(tl.float32) - 0.5) + & (src_x_f_2d >= -0.5) + & (src_x_f_2d <= src_w.to(tl.float32) - 0.5) + ) + + # Gather 4 corners for all three channels. Source is HWC BGR + # (channel 0 = B, channel 1 = G, channel 2 = R); output CHW RGB. + # Triton doesn't support nested function defs, so we inline the + # gather for each channel. + base_00 = y0c * src_stride_h + x0c * src_stride_w + base_01 = y0c * src_stride_h + x1c * src_stride_w + base_10 = y1c * src_stride_h + x0c * src_stride_w + base_11 = y1c * src_stride_h + x1c * src_stride_w + + w_tl = (1.0 - dy) * (1.0 - dx) + w_tr = (1.0 - dy) * dx + w_bl = dy * (1.0 - dx) + w_br = dy * dx + + # Channel 0 (B) + p00_b = tl.load(src_ptr + base_00 + 0, mask=mask, other=0).to(tl.float32) + p01_b = tl.load(src_ptr + base_01 + 0, mask=mask, other=0).to(tl.float32) + p10_b = tl.load(src_ptr + base_10 + 0, mask=mask, other=0).to(tl.float32) + p11_b = tl.load(src_ptr + base_11 + 0, mask=mask, other=0).to(tl.float32) + b_val = p00_b * w_tl + p01_b * w_tr + p10_b * w_bl + p11_b * w_br + + # Channel 1 (G) + p00_g = tl.load(src_ptr + base_00 + 1, mask=mask, other=0).to(tl.float32) + p01_g = tl.load(src_ptr + base_01 + 1, mask=mask, other=0).to(tl.float32) + p10_g = tl.load(src_ptr + base_10 + 1, mask=mask, other=0).to(tl.float32) + p11_g = tl.load(src_ptr + base_11 + 1, mask=mask, other=0).to(tl.float32) + g_val = p00_g * w_tl + p01_g * w_tr + p10_g * w_bl + p11_g * w_br + + # Channel 2 (R) + p00_r = tl.load(src_ptr + base_00 + 2, mask=mask, other=0).to(tl.float32) + p01_r = tl.load(src_ptr + base_01 + 2, mask=mask, other=0).to(tl.float32) + p10_r = tl.load(src_ptr + base_10 + 2, mask=mask, other=0).to(tl.float32) + p11_r = tl.load(src_ptr + base_11 + 2, mask=mask, other=0).to(tl.float32) + r_val = p00_r * w_tl + p01_r * w_tr + p10_r * w_bl + p11_r * w_br + + r_norm = (r_val / 255.0 - mean_r) / std_r + g_norm = (g_val / 255.0 - mean_g) / std_g + b_norm = (b_val / 255.0 - mean_b) / std_b + + r_out = tl.where(in_bounds, r_norm, pad_r_norm) + g_out = tl.where(in_bounds, g_norm, pad_g_norm) + b_out = tl.where(in_bounds, b_norm, pad_b_norm) + + out_row_offsets = offs_y[:, None] * dst_stride_h + offs_x[None, :] + tl.store(dst_ptr + 0 * dst_stride_c + out_row_offsets, r_out, mask=mask) + tl.store(dst_ptr + 1 * dst_stride_c + out_row_offsets, g_out, mask=mask) + tl.store(dst_ptr + 2 * dst_stride_c + out_row_offsets, b_out, mask=mask) + + +def triton_preprocess_rfdetr( + src: torch.Tensor, + target_h: int, + target_w: int, + means: Tuple[float, float, float] = (0.485, 0.456, 0.406), + stds: Tuple[float, float, float] = (0.229, 0.224, 0.225), + pad_color: int = 114, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Fused preprocess: uint8 HWC BGR -> fp32 (1,3,H,W) CHW RGB normalized + with letterbox. + + Args: + src: uint8 tensor of shape (H, W, 3) on CUDA, BGR channel order + (as produced by cv2.imread / VideoFrame.image). + target_h, target_w: output spatial dims. + means, stds: imagenet normalization in RGB order. + pad_color: uint8 value used in each channel for letterbox padding + (applied *before* normalization). Default 114 matches RF-DETR + "Fit (grey edges) in". + out: optional preallocated fp32 tensor of shape (1, 3, target_h, + target_w) to write into. + + Returns: + fp32 tensor of shape (1, 3, target_h, target_w) on the same device + as src. + """ + if not TRITON_AVAILABLE: + raise RuntimeError( + "triton is not installed; cannot run triton_preprocess_rfdetr. " + "Install the optional 'triton-preproc' extra." + ) + if not src.is_cuda: + raise ValueError( + f"triton_preprocess_rfdetr requires a CUDA tensor, got device={src.device}" + ) + if src.dtype != torch.uint8: + raise ValueError( + f"triton_preprocess_rfdetr expects uint8 input, got dtype={src.dtype}" + ) + if src.ndim != 3 or src.shape[2] != 3: + raise ValueError( + f"triton_preprocess_rfdetr expects HWC 3-channel input, got shape={tuple(src.shape)}" + ) + + src = src.contiguous() + src_h, src_w = int(src.shape[0]), int(src.shape[1]) + # HWC strides: contiguous => (W*3, 3, 1). We pass (row_stride, col_stride) + # in elements (uint8). + src_stride_h = int(src.stride(0)) + src_stride_w = int(src.stride(1)) + + scale = min(target_h / src_h, target_w / src_w) + scaled_h = int(src_h * scale) + scaled_w = int(src_w * scale) + pad_x = (target_w - scaled_w) / 2.0 + pad_y = (target_h - scaled_h) / 2.0 + + if out is None: + out = torch.empty( + (1, 3, target_h, target_w), dtype=torch.float32, device=src.device + ) + else: + if tuple(out.shape) != (1, 3, target_h, target_w): + raise ValueError( + f"out has shape {tuple(out.shape)}, expected (1, 3, {target_h}, {target_w})" + ) + if out.dtype != torch.float32: + raise ValueError(f"out must be float32, got {out.dtype}") + if not out.is_cuda or out.device != src.device: + raise ValueError("out must be a CUDA tensor on the same device as src") + + # (1,3,H,W) contiguous => per-channel plane stride is H*W, row stride is W. + dst_stride_c = target_h * target_w + dst_stride_h = target_w + + pad_norm_r = (pad_color / 255.0 - means[0]) / stds[0] + pad_norm_g = (pad_color / 255.0 - means[1]) / stds[1] + pad_norm_b = (pad_color / 255.0 - means[2]) / stds[2] + + BLOCK_H = 16 + BLOCK_W = 16 + grid = ( + (target_h + BLOCK_H - 1) // BLOCK_H, + (target_w + BLOCK_W - 1) // BLOCK_W, + ) + _rfdetr_preprocess_kernel[grid]( + src, + out, + src_h, + src_w, + src_stride_h, + src_stride_w, + target_h, + target_w, + float(scale), + float(pad_x), + float(pad_y), + float(means[0]), + float(means[1]), + float(means[2]), + float(stds[0]), + float(stds[1]), + float(stds[2]), + float(pad_norm_r), + float(pad_norm_g), + float(pad_norm_b), + dst_stride_c, + dst_stride_h, + BLOCK_H=BLOCK_H, + BLOCK_W=BLOCK_W, + ) + return out diff --git a/inference_models/inference_models/models/common/trt.py b/inference_models/inference_models/models/common/trt.py index 2d62408930..bbca52b1c9 100644 --- a/inference_models/inference_models/models/common/trt.py +++ b/inference_models/inference_models/models/common/trt.py @@ -78,6 +78,10 @@ class TRTCudaGraphState: input_buffer: torch.Tensor output_buffers: List[torch.Tensor] execution_context: trt.IExecutionContext + # Consumers of output_buffers can record into this event; the next graph + # replay will wait on it to avoid overwriting the previous frame's outputs. + # None until first consumer records an event. + consumer_done_event: Optional["torch.cuda.Event"] = None class TRTCudaGraphCache: @@ -571,7 +575,17 @@ def infer_from_trt_engine( outputs=outputs, trt_cuda_graph_cache=trt_cuda_graph_cache, ) - stream.synchronize() + # Graph-replay path: record the event on the graph's own stream so + # downstream can wait on the correct stream (the graph captures on a + # dedicated stream, not `stream`). Non-graph path: record on `stream`. + if hasattr(results[0], "_trt_graph_state"): + graph_state = results[0]._trt_graph_state + produce_event = torch.cuda.Event() + produce_event.record(graph_state.cuda_stream) + else: + produce_event = torch.cuda.Event() + produce_event.record(stream) + results[0]._trt_produce_event = produce_event return results @@ -692,12 +706,20 @@ def _execute_trt_engine( if cache_key not in trt_cuda_graph_cache: LOGGER.debug("Capturing CUDA graph for shape %s", input_shape) + # If the caller hands us a stable pre-allocated buffer and + # promises to keep writing in-place to it, we can bake its + # address directly into the graph and skip the per-frame DtoD + # copy. Marker attribute set by the model's pre_process path. + use_external = getattr( + pre_processed_images, "_trt_reuse_as_input_buffer", False + ) results, trt_cuda_graph = _capture_cuda_graph( pre_processed_images=pre_processed_images, engine=engine, device=device, input_name=input_name, outputs=outputs, + use_pre_processed_images_as_input_buffer=bool(use_external), ) trt_cuda_graph_cache[cache_key] = trt_cuda_graph return results @@ -705,11 +727,24 @@ def _execute_trt_engine( else: trt_cuda_graph_state = trt_cuda_graph_cache[cache_key] stream = trt_cuda_graph_state.cuda_stream + # Before re-running the graph, make sure any consumer of the + # previous frame's outputs is done. Post-process callers that read + # the output buffers directly must record into consumer_done_event. + consumer_done = trt_cuda_graph_state.consumer_done_event + if consumer_done is not None: + consumer_done.wait(stream) with torch.cuda.stream(stream): - trt_cuda_graph_state.input_buffer.copy_(pre_processed_images) + # Skip the DtoD copy if the caller is writing in-place to + # the graph's own input buffer (input_buffer is the same + # torch.Tensor object as pre_processed_images). + if trt_cuda_graph_state.input_buffer.data_ptr() != pre_processed_images.data_ptr(): + trt_cuda_graph_state.input_buffer.copy_(pre_processed_images) trt_cuda_graph_state.cuda_graph.replay() - results = [buf.clone() for buf in trt_cuda_graph_state.output_buffers] - stream.synchronize() + # Return graph-owned output buffers directly — no clone. Attach + # the graph state as a side-channel so consumers can record into + # consumer_done_event without plumbing the cache through. + results = list(trt_cuda_graph_state.output_buffers) + results[0]._trt_graph_state = trt_cuda_graph_state return results else: @@ -752,14 +787,21 @@ def _capture_cuda_graph( device: torch.device, input_name: str, outputs: List[str], + use_pre_processed_images_as_input_buffer: bool = False, ) -> Tuple[List[torch.Tensor], TRTCudaGraphState]: # Each CUDA graph needs its own execution context. Sharing a single context # across graphs for different input shapes causes TRT to reallocate internal # workspace buffers, invalidating GPU addresses baked into earlier graphs. graph_context = engine.create_execution_context() - input_buffer = torch.empty_like(pre_processed_images, device=device) - input_buffer.copy_(pre_processed_images) + if use_pre_processed_images_as_input_buffer: + # Zero-copy: TRT reads directly from the caller's buffer. The caller + # must keep this exact tensor alive for the lifetime of the graph + # and re-write its contents in-place before every replay. + input_buffer = pre_processed_images + else: + input_buffer = torch.empty_like(pre_processed_images, device=device) + input_buffer.copy_(pre_processed_images) status = graph_context.set_input_shape( input_name, tuple(pre_processed_images.shape) diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index 2b0e62067a..11591060fb 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -1,9 +1,46 @@ +import os from typing import List, Optional, Tuple, Union import torch from torchvision.transforms import functional from inference_models import InstanceDetections + +_RFDETR_TRITON_POSTPROC = os.getenv("RFDETR_TRITON_POSTPROC", "false").lower() in ( + "true", + "1", +) +if _RFDETR_TRITON_POSTPROC: + try: + from inference_models.models.rfdetr.triton_postprocess import ( + TRITON_AVAILABLE as _TRITON_POSTPROC_AVAILABLE, + triton_rfdetr_conf_filter, + ) + _TRITON_POSTPROC_READY = _TRITON_POSTPROC_AVAILABLE and torch.cuda.is_available() + except Exception: + _TRITON_POSTPROC_READY = False + triton_rfdetr_conf_filter = None +else: + _TRITON_POSTPROC_READY = False + triton_rfdetr_conf_filter = None + +_RFDETR_TRITON_FULLPOSTPROC = os.getenv("RFDETR_TRITON_FULLPOSTPROC", "false").lower() in ( + "true", + "1", +) +if _RFDETR_TRITON_FULLPOSTPROC: + try: + from inference_models.models.rfdetr.triton_fullpostproc import ( + TRITON_AVAILABLE as _TRITON_FULLPOST_AVAILABLE, + triton_rfdetr_fullpost, + ) + _TRITON_FULLPOST_READY = _TRITON_FULLPOST_AVAILABLE and torch.cuda.is_available() + except Exception: + _TRITON_FULLPOST_READY = False + triton_rfdetr_fullpost = None +else: + _TRITON_FULLPOST_READY = False + triton_rfdetr_fullpost = None from inference_models.entities import ImageDimensions from inference_models.errors import CorruptedModelPackageError from inference_models.models.common.roboflow.model_packages import ( @@ -51,35 +88,88 @@ def post_process_instance_segmentation_results( num_classes: int, classes_re_mapping: Optional[ClassesReMapping], ) -> List[InstanceDetections]: - logits_sigmoid = torch.nn.functional.sigmoid(logits) results = [] device = bboxes.device - if isinstance(threshold, torch.Tensor): - threshold = threshold.to(device=device, dtype=logits_sigmoid.dtype) - for image_bboxes, image_logits, image_masks, image_meta in zip( - bboxes, logits_sigmoid, masks, pre_processing_meta + # Try the full-fusion fast path first (batch=1, no static crop, + # no nonsquare-intermediate resize). Matches rfdetr-seg-nano default. + if ( + _TRITON_FULLPOST_READY + and bboxes.is_cuda + and bboxes.shape[0] == 1 + and len(pre_processing_meta) == 1 + and pre_processing_meta[0].nonsquare_intermediate_size is None + and pre_processing_meta[0].static_crop_offset.offset_x == 0 + and pre_processing_meta[0].static_crop_offset.offset_y == 0 + and classes_re_mapping is not None ): - confidence, top_classes = image_logits.max(dim=1) - if classes_re_mapping is not None: - remapping_mask = torch.isin( - top_classes, classes_re_mapping.remaining_class_ids + meta = pre_processing_meta[0] + thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) + combined, mask_bin, mask_any = triton_rfdetr_fullpost( + bboxes=bboxes, + logits=logits, + masks=masks, + threshold=thr_arg, + num_classes=num_classes, + class_mapping=classes_re_mapping.class_mapping, + inference_size_wh=(meta.inference_size.width, meta.inference_size.height), + pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), + scale_wh=(meta.scale_width, meta.scale_height), + orig_size_wh=(meta.original_size.width, meta.original_size.height), + ) + # `combined` packs [xyxy|conf_i32|cls_id] as (n, 6) int32. Attach as + # side-channel so the adapter can do ONE .cpu() instead of three. + detections = InstanceDetections( + xyxy=combined[:, :4], + confidence=combined[:, 4], # int32 bits; adapter bitcasts to fp32 + class_id=combined[:, 5], + mask=mask_bin, + ) + detections.__dict__["_combined_gpu"] = combined + results.append(detections) + return results + use_triton = _TRITON_POSTPROC_READY and bboxes.is_cuda + if isinstance(threshold, torch.Tensor): + threshold_dtype = logits.dtype if use_triton else torch.float32 + threshold = threshold.to(device=device, dtype=threshold_dtype) + if use_triton: + iterator = zip(bboxes, logits, masks, pre_processing_meta) + else: + logits_sigmoid = torch.nn.functional.sigmoid(logits) + if isinstance(threshold, torch.Tensor): + threshold = threshold.to(device=device, dtype=logits_sigmoid.dtype) + iterator = zip(bboxes, logits_sigmoid, masks, pre_processing_meta) + cmap = classes_re_mapping.class_mapping if classes_re_mapping is not None else None + for image_bboxes, image_logits, image_masks, image_meta in iterator: + if use_triton: + confidence, top_classes, keep = triton_rfdetr_conf_filter( + image_logits, threshold, num_classes, class_mapping=cmap ) - top_classes = classes_re_mapping.class_mapping[top_classes[remapping_mask]] - confidence = confidence[remapping_mask] - image_bboxes = image_bboxes[remapping_mask] - image_masks = image_masks[remapping_mask] + confidence = confidence[keep] + top_classes = top_classes[keep].long() + selected_boxes = image_bboxes[keep] + selected_masks = image_masks[keep] else: - # drop DETR no-object rows - named = top_classes < num_classes - confidence = confidence[named] - top_classes = top_classes[named] - image_bboxes = image_bboxes[named] - image_masks = image_masks[named] - confidence_mask = confidence > (threshold[top_classes.long()] if isinstance(threshold, torch.Tensor) else threshold) - confidence = confidence[confidence_mask] - top_classes = top_classes[confidence_mask] - selected_boxes = image_bboxes[confidence_mask] - selected_masks = image_masks[confidence_mask] + confidence, top_classes = image_logits.max(dim=1) + if classes_re_mapping is not None: + remapping_mask = torch.isin( + top_classes, classes_re_mapping.remaining_class_ids + ) + top_classes = classes_re_mapping.class_mapping[top_classes[remapping_mask]] + confidence = confidence[remapping_mask] + image_bboxes = image_bboxes[remapping_mask] + image_masks = image_masks[remapping_mask] + else: + # drop DETR no-object rows + named = top_classes < num_classes + confidence = confidence[named] + top_classes = top_classes[named] + image_bboxes = image_bboxes[named] + image_masks = image_masks[named] + confidence_mask = confidence > (threshold[top_classes.long()] if isinstance(threshold, torch.Tensor) else threshold) + confidence = confidence[confidence_mask] + top_classes = top_classes[confidence_mask] + selected_boxes = image_bboxes[confidence_mask] + selected_masks = image_masks[confidence_mask] confidence, sorted_indices = torch.sort(confidence, descending=True) top_classes = top_classes[sorted_indices] selected_boxes = selected_boxes[sorted_indices] diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py index 7affae1235..334fd0a186 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py @@ -1,3 +1,4 @@ +import os import threading from typing import List, Optional, Tuple, Union @@ -48,9 +49,31 @@ post_process_instance_segmentation_results, ) from inference_models.models.rfdetr.pre_processing import pre_process_network_input +from inference_models.entities import ImageDimensions as _ImageDimensions +from inference_models.models.common.roboflow.model_packages import ( + StaticCropOffset as _StaticCropOffset, +) from inference_models.models.common.roboflow.post_processing import ConfidenceFilter from inference_models.weights_providers.entities import RecommendedParameters +_RFDETR_USE_TRITON_PREPROC = os.getenv("RFDETR_USE_TRITON_PREPROC", "false").lower() in ( + "true", + "1", +) +if _RFDETR_USE_TRITON_PREPROC: + try: + from inference_models.models.rfdetr.triton_preprocess import ( + TRITON_AVAILABLE as _TRITON_AVAILABLE, + triton_preprocess_rfdetr_stretch, + ) + _TRITON_READY = _TRITON_AVAILABLE and torch.cuda.is_available() + except Exception: # pragma: no cover + _TRITON_READY = False + triton_preprocess_rfdetr_stretch = None +else: + _TRITON_READY = False + triton_preprocess_rfdetr_stretch = None + try: import tensorrt as trt except ImportError as import_error: @@ -230,6 +253,14 @@ def pre_process( pre_processing_overrides: Optional[PreProcessingOverrides] = None, **kwargs, ) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]: + fast = self._try_fast_preprocess( + images=images, + input_color_format=input_color_format, + image_size=image_size, + pre_processing_overrides=pre_processing_overrides, + ) + if fast is not None: + return fast with torch.cuda.stream(self._pre_process_stream): pre_processed_images, pre_processing_meta = pre_process_network_input( images=images, @@ -243,6 +274,140 @@ def pre_process( self._pre_process_stream.synchronize() return pre_processed_images, pre_processing_meta + def _try_fast_preprocess( + self, + images, + input_color_format, + image_size, + pre_processing_overrides, + ): + if not _TRITON_READY: + return None + if image_size is not None: + return None + ipp = self._inference_config.image_pre_processing + if ( + ipp.static_crop is not None + and ipp.static_crop.enabled + or ipp.contrast is not None + and ipp.contrast.enabled + or ipp.grayscale is not None + and ipp.grayscale.enabled + ): + return None + ni = self._inference_config.network_input + if ni.resize_mode != ResizeMode.STRETCH_TO: + return None + if ni.input_channels != 3: + return None + if ni.dataset_version_resize_dimensions is not None: + return None + if ni.scaling_factor not in (None, 255): + return None + if ni.normalization is None: + return None + means, stds = ni.normalization + # Only handle numpy HWC BGR uint8 (the common hot path). + if isinstance(images, list): + if len(images) != 1: + return None + candidate = images[0] + else: + candidate = images + if not isinstance(candidate, np.ndarray): + return None + if ( + candidate.dtype != np.uint8 + or candidate.ndim != 3 + or candidate.shape[2] != 3 + ): + return None + images = candidate + # Color: if caller says RGB, skip; we do BGR->model_color_mode. + from inference_models.models.common.roboflow.model_packages import ColorMode + caller_mode = ColorMode(input_color_format) if input_color_format is not None else ColorMode.BGR + if caller_mode != ColorMode.BGR or ni.color_mode != ColorMode.RGB: + return None + + target_h = ni.training_input_size.height + target_w = ni.training_input_size.width + orig_h, orig_w = images.shape[0], images.shape[1] + + if not getattr(self, "_fast_buffer_initialized", False): + self._fast_input_buffer = torch.empty( + (1, 3, target_h, target_w), + dtype=torch.float32, + device=self._device, + ) + # Marker: tells the TRT CUDA-graph capture path to use this + # tensor as the graph's own input buffer, eliminating the + # per-frame DtoD copy from our preproc output into the graph's + # internal buffer. Our preproc always writes in-place here. + self._fast_input_buffer._trt_reuse_as_input_buffer = True + self._fast_means = tuple(means) + self._fast_stds = tuple(stds) + # Pinned host buffer for the raw BGR frame — lets us do a + # truly async HtoD into src_gpu. Grown lazily below if the + # frame size changes. + self._fast_src_host_pinned = None + self._fast_src_gpu = None + self._fast_buffer_initialized = True + + # Reuse a pinned host staging buffer so torch.Tensor.copy_ with + # non_blocking=True actually runs async. Without pinning, + # non_blocking is silently promoted to a sync copy. + src_shape = images.shape + src_nbytes = images.nbytes + pinned = self._fast_src_host_pinned + if ( + pinned is None + or pinned.numel() * pinned.element_size() < src_nbytes + or tuple(pinned.shape) != src_shape + ): + pinned = torch.empty(src_shape, dtype=torch.uint8, pin_memory=True) + self._fast_src_host_pinned = pinned + self._fast_src_gpu = torch.empty( + src_shape, dtype=torch.uint8, device=self._device + ) + # Copy the numpy BGR frame into pinned host memory (fast CPU memcpy), + # then async DtoH->GPU while the Triton launch happens on CPU side. + pinned_np = pinned.numpy() + np.copyto(pinned_np, images, casting="no") + src_gpu = self._fast_src_gpu + with torch.cuda.stream(self._pre_process_stream): + src_gpu.copy_(pinned, non_blocking=True) + triton_preprocess_rfdetr_stretch( + src_gpu, + target_h=target_h, + target_w=target_w, + means=self._fast_means, + stds=self._fast_stds, + out=self._fast_input_buffer, + ) + # Record an event so the inference stream can wait on preproc + # completion without blocking the CPU. + self._fast_preproc_event = torch.cuda.Event() + self._fast_preproc_event.record(self._pre_process_stream) + self._fast_input_buffer.record_stream(self._pre_process_stream) + + size_after = _ImageDimensions(height=orig_h, width=orig_w) + target = _ImageDimensions(height=target_h, width=target_w) + metadata = PreProcessingMetadata( + pad_left=0, + pad_top=0, + pad_right=0, + pad_bottom=0, + original_size=size_after, + size_after_pre_processing=size_after, + inference_size=target, + scale_width=target_w / orig_w, + scale_height=target_h / orig_h, + static_crop_offset=_StaticCropOffset( + offset_x=0, offset_y=0, crop_width=orig_w, crop_height=orig_h + ), + ) + return self._fast_input_buffer, [metadata] + def forward( self, pre_processed_images: torch.Tensor, @@ -250,6 +415,10 @@ def forward( **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None + ev = getattr(self, "_fast_preproc_event", None) + if ev is not None: + ev.wait(self._inference_stream) + self._fast_preproc_event = None with self._lock: with use_cuda_context(context=self._cuda_context): detections, labels, masks = infer_from_trt_engine( @@ -277,7 +446,14 @@ def post_process( recommended_parameters=self.recommended_parameters, default_confidence=INFERENCE_MODELS_RFDETR_DEFAULT_CONFIDENCE, ) + # Wait on the TRT-stream "produce" event so our post_process stream + # can start reading the (graph-owned) output buffers as soon as the + # engine finishes, without a CPU-side synchronize(). + produce_event = getattr(model_results[0], "_trt_produce_event", None) + graph_state = getattr(model_results[0], "_trt_graph_state", None) with torch.cuda.stream(self._post_process_stream): + if produce_event is not None: + produce_event.wait(self._post_process_stream) for result_element in model_results: result_element.record_stream(self._post_process_stream) bboxes, logits, masks = model_results @@ -290,7 +466,14 @@ def post_process( num_classes=len(self.class_names), classes_re_mapping=self._classes_re_mapping, ) - self._post_process_stream.synchronize() + # Record "consumer done" so the next TRT replay can wait on it + # before overwriting the graph-owned output buffers. + if graph_state is not None: + ev = graph_state.consumer_done_event + if ev is None: + ev = torch.cuda.Event() + graph_state.consumer_done_event = ev + ev.record(self._post_process_stream) return results @property diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py new file mode 100644 index 0000000000..cc54097714 --- /dev/null +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -0,0 +1,467 @@ +"""Experimental: full RF-DETR instance-segmentation post-processing in Triton. + +Fuses the entire post-TRT chain into two Triton kernels of fixed grid: + + _rfdetr_fullpost_filter_kernel (grid = num_queries) + sigmoid(logits) -> argmax-over-classes -> class remap -> conf > threshold + -> xywh -> xyxy -> multiply by inference_size -> subtract padding + -> divide by scale -> clip to orig image bounds + Emits padded fixed-shape outputs (num_queries rows). Rows that don't pass + the filter get `keep=False`; downstream consumers skip them. + + _rfdetr_fullpost_mask_kernel (grid = num_queries * tile_y * tile_x) + Inverse-letterbox bilinear upsample masks 78x78 -> orig_h x orig_w, + threshold > 0, emit as uint8. Skips work when keep[q] is False. + +Design notes +------------ +* Everything is fixed-shape in/out. No compaction, no sort, no variable + grid. Downstream Python handles selection by `keep`. +* This replaces the torch.sort + gather + align_instance_segmentation_results + chain for the common case of no static_crop, STRETCH_TO resize, class + remapping available. Falls back to the existing path otherwise. +""" +from typing import Optional, Tuple + +import torch + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: # pragma: no cover + triton = None + tl = None + TRITON_AVAILABLE = False + + +if TRITON_AVAILABLE: + + @triton.jit + def _rfdetr_fullpost_filter_kernel( + logits_ptr, # (num_queries, num_classes_total) fp16/fp32 + bboxes_ptr, # (num_queries, 4) fp32, normalized cxcywh + threshold_ptr, # scalar or (num_remapped,) fp32 + class_map_ptr, # (num_classes_total,) int32; -1 means drop + # SINGLE combined int32 output buffer. Layout: + # offset 0..(num_queries*6): per-survivor [x1,y1,x2,y2,conf_i32,cls] records + # (stored at stride 6 per slot; host reinterprets conf slot as fp32) + # Written by this kernel compactly via atomic_add(counter). + combined_out_ptr, # (num_queries, 6) int32 + survivor_idx_out_ptr,# (num_queries,) int32 — original query id of each survivor + mask_any_out_ptr, # (num_queries,) int32 — zeroed at compact slot; mask kernel atomic_maxes up to 1 + counter_ptr, # (1,) int32 — atomic counter; host reads to get n_survivors + # static scalars + num_queries, + num_classes_total, + inference_w, + inference_h, + pad_left, + pad_top, + inv_scale_w, + inv_scale_h, + orig_w, + orig_h, + logits_stride_q, + bboxes_stride_q, + PER_CLASS: tl.constexpr, + HAS_REMAPPING: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + pid = tl.program_id(0) + if pid >= num_queries: + return + offs_c = tl.arange(0, BLOCK_C) + mask_c = offs_c < num_classes_total + + logits_row = tl.load( + logits_ptr + pid * logits_stride_q + offs_c, + mask=mask_c, + other=-float("inf"), + ) + max_val = tl.max(logits_row, axis=0) + BIG = 1 << 30 + is_max = logits_row == max_val + idx_or_big = tl.where(is_max & mask_c, offs_c, BIG) + raw_c = tl.min(idx_or_big, axis=0) + + if HAS_REMAPPING: + top_c = tl.load(class_map_ptr + raw_c) + valid = top_c >= 0 + else: + top_c = raw_c + valid = raw_c < num_classes_total + + abs_max = tl.abs(max_val) + z = tl.exp(-abs_max) + sig_pos = 1.0 / (1.0 + z) + sig_neg = z / (1.0 + z) + conf = tl.where(max_val >= 0.0, sig_pos, sig_neg) + + if PER_CLASS: + safe_c = tl.where(valid, top_c, 0) + thr = tl.load(threshold_ptr + safe_c) + else: + thr = tl.load(threshold_ptr) + keep = valid & (conf > thr) + + # Early exit for filtered queries — don't compute boxes, don't + # reserve a slot. Cheap (rest of kernel is ~8 FLOPS + 4 stores). + if not keep: + return + + # Match the non-Triton path's exact FP32 evaluation order, so + # sub-pixel results round the same way: + # x_min_pct = cx_pct - 0.5 * w_pct (subtract percents first) + # x_min = x_min_pct * W (then scale to inference coords) + # x_min = x_min - pad_left + # x_min = x_min / scale_w (baseline uses div, not mul-by-inv) + cx_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 0) + cy_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 1) + w_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 2) + h_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 3) + + x1_pct = cx_pct - 0.5 * w_pct + y1_pct = cy_pct - 0.5 * h_pct + x2_pct = cx_pct + 0.5 * w_pct + y2_pct = cy_pct + 0.5 * h_pct + + x1 = x1_pct * inference_w + y1 = y1_pct * inference_h + x2 = x2_pct * inference_w + y2 = y2_pct * inference_h + + x1 = x1 - pad_left + y1 = y1 - pad_top + x2 = x2 - pad_left + y2 = y2 - pad_top + + # Baseline uses tensor division (computes 1/scale once per tensor). + # inv_scale_* was computed on the host as 1.0/sw — matches. + x1 = x1 * inv_scale_w + y1 = y1 * inv_scale_h + x2 = x2 * inv_scale_w + y2 = y2 * inv_scale_h + + x1 = tl.maximum(tl.minimum(x1, orig_w), 0.0) + y1 = tl.maximum(tl.minimum(y1, orig_h), 0.0) + x2 = tl.maximum(tl.minimum(x2, orig_w), 0.0) + y2 = tl.maximum(tl.minimum(y2, orig_h), 0.0) + + # Banker's rounding (half-to-even) matches torch.round().int(). For + # typical RF-DETR bbox values (far from half-integer boundaries), + # round-half-up would give identical results. We keep banker's for + # bit-parity with the non-Triton path on the rare half-integer tie. + x1_r = tl.floor(x1 + 0.5) + y1_r = tl.floor(y1 + 0.5) + x2_r = tl.floor(x2 + 0.5) + y2_r = tl.floor(y2 + 0.5) + x1_i = x1_r.to(tl.int32) + y1_i = y1_r.to(tl.int32) + x2_i = x2_r.to(tl.int32) + y2_i = y2_r.to(tl.int32) + x1_i = tl.where(((x1_r - x1) == 0.5) & ((x1_i & 1) != 0), x1_i - 1, x1_i) + y1_i = tl.where(((y1_r - y1) == 0.5) & ((y1_i & 1) != 0), y1_i - 1, y1_i) + x2_i = tl.where(((x2_r - x2) == 0.5) & ((x2_i & 1) != 0), x2_i - 1, x2_i) + y2_i = tl.where(((y2_r - y2) == 0.5) & ((y2_i & 1) != 0), y2_i - 1, y2_i) + + # Reserve a compact slot via atomic-add. Order is non-deterministic + # across survivors but downstream doesn't require query-order. + slot = tl.atomic_add(counter_ptr, 1) + + # Reinterpret conf (fp32) as int32 bits so we can write the whole + # record with int32 stores. Host side views the same memory as + # int32 and extracts conf via numpy view(np.float32). + conf_bits = conf.to(tl.float32, bitcast=False) # no-op, already fp32 + conf_i32 = conf_bits.to(tl.int32, bitcast=True) + base = slot * 6 + tl.store(combined_out_ptr + base + 0, x1_i) + tl.store(combined_out_ptr + base + 1, y1_i) + tl.store(combined_out_ptr + base + 2, x2_i) + tl.store(combined_out_ptr + base + 3, y2_i) + tl.store(combined_out_ptr + base + 4, conf_i32) + tl.store(combined_out_ptr + base + 5, top_c) + tl.store(survivor_idx_out_ptr + slot, pid.to(tl.int32)) + # Initialize mask_any[slot] to 0. The mask kernel's tile-level + # atomic_max raises it to 1 if any pixel passes threshold. + tl.store(mask_any_out_ptr + slot, 0) + + + @triton.jit + def _rfdetr_fullpost_mask_kernel_compact( + masks_ptr, # (num_queries, mask_h, mask_w) fp32 + survivor_idx_ptr, # (n_survivors,) int32 — indices into num_queries + counter_ptr, # (1,) int32 — n_survivors; used for GPU-side early exit + out_ptr, # (n_survivors, orig_h, orig_w) uint8 — compact binary mask + mask_any_ptr, # (n_survivors,) int32 — 1 if any pixel survives threshold, 0 else + mask_h, + mask_w, + orig_h, + orig_w, + # Scale from orig -> mask coords (covers the whole mask span, + # since STRETCH_TO has no letterbox-crop at mask resolution). + mask_scale_y, # mask_h / orig_h + mask_scale_x, + masks_stride_q, + masks_stride_h, + out_stride_s, + out_stride_h, + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, + ): + s = tl.program_id(0) # survivor index in [0, num_queries) — over-launched + tile_y = tl.program_id(1) + tile_x = tl.program_id(2) + + # GPU-side early exit: filter kernel atomic-added into counter; any + # program with s >= counter has no corresponding survivor. + n_survivors = tl.load(counter_ptr) + if s >= n_survivors: + return + + # Look up the source query slot for the bilinear gather. + q = tl.load(survivor_idx_ptr + s) + + offs_y = tile_y * BLOCK_H + tl.arange(0, BLOCK_H) + offs_x = tile_x * BLOCK_W + tl.arange(0, BLOCK_W) + mask_yy = offs_y < orig_h + mask_xx = offs_x < orig_w + m_outbox = mask_yy[:, None] & mask_xx[None, :] + + # Inverse map orig pixel -> source mask coord, pixel-center bilinear. + src_y_f = (offs_y.to(tl.float32) + 0.5) * mask_scale_y - 0.5 + src_x_f = (offs_x.to(tl.float32) + 0.5) * mask_scale_x - 0.5 + src_y_2d = src_y_f[:, None] + src_x_2d = src_x_f[None, :] + + y0 = tl.floor(src_y_2d).to(tl.int32) + x0 = tl.floor(src_x_2d).to(tl.int32) + y1 = y0 + 1 + x1 = x0 + 1 + dy = src_y_2d - y0.to(tl.float32) + dx = src_x_2d - x0.to(tl.float32) + + y0c = tl.maximum(tl.minimum(y0, mask_h - 1), 0) + y1c = tl.maximum(tl.minimum(y1, mask_h - 1), 0) + x0c = tl.maximum(tl.minimum(x0, mask_w - 1), 0) + x1c = tl.maximum(tl.minimum(x1, mask_w - 1), 0) + + base = q * masks_stride_q + + p00 = tl.load(masks_ptr + base + y0c * masks_stride_h + x0c, mask=m_outbox, other=0.0) + p01 = tl.load(masks_ptr + base + y0c * masks_stride_h + x1c, mask=m_outbox, other=0.0) + p10 = tl.load(masks_ptr + base + y1c * masks_stride_h + x0c, mask=m_outbox, other=0.0) + p11 = tl.load(masks_ptr + base + y1c * masks_stride_h + x1c, mask=m_outbox, other=0.0) + + w_tl = (1.0 - dy) * (1.0 - dx) + w_tr = (1.0 - dy) * dx + w_bl = dy * (1.0 - dx) + w_br = dy * dx + val = p00 * w_tl + p01 * w_tr + p10 * w_bl + p11 * w_br + bin_val = (val > 0.0).to(tl.int8) + + out_offsets = offs_y[:, None] * out_stride_h + offs_x[None, :] + # Write to compact row s (not q). + tl.store(out_ptr + s * out_stride_s + out_offsets, bin_val, mask=m_outbox) + + # Tile-level reduction of any-true within the bool tile, then a single + # atomic-max into mask_any[s]. Saves a separate torch.any reduction + # downstream. Atomic max preserves the 0/1 semantic across tiles. + tile_any = tl.max(bin_val.to(tl.int32), axis=0) + tile_any2 = tl.max(tile_any, axis=0) + tl.atomic_max(mask_any_ptr + s, tile_any2) + + +def _next_pow2(n: int) -> int: + p = 1 + while p < n: + p <<= 1 + return p + + +# Cache small supporting tensors so we don't incur HtoD for them per frame. +_THRESHOLD_CACHE: dict = {} +_EMPTY_INT32 = torch.empty((1,), dtype=torch.int32) +_MASK_BIN_BUFFER_CACHE: dict = {} + + +def _get_mask_bin_buffer( + capacity: int, orig_h: int, orig_w: int, device: torch.device +) -> torch.Tensor: + """Return a reusable (capacity, orig_h, orig_w) uint8 mask buffer. + + Avoids a per-frame torch.empty kernel for the biggest allocation in the + post-process path (capacity * H * W bytes — ~10 MB at 100*240*426). + We return the full buffer; the caller views [:n_survivors]. Rows beyond + n_survivors may contain stale data from prior frames; the caller must + only read the slice it sizes via the atomic counter. + """ + key = (capacity, orig_h, orig_w, device) + buf = _MASK_BIN_BUFFER_CACHE.get(key) + if buf is None: + buf = torch.empty( + (capacity, orig_h, orig_w), dtype=torch.uint8, device=device + ) + _MASK_BIN_BUFFER_CACHE[key] = buf + return buf + + +def _prepare_threshold(threshold, device: torch.device, num_classes: int): + """Return (threshold_tensor_on_device, per_class_flag), caching the tensor + form of scalar thresholds so we don't ship a 4-byte HtoD every frame.""" + if isinstance(threshold, torch.Tensor): + t = threshold + if t.dtype != torch.float32 or t.device != device or not t.is_contiguous(): + t = t.to(dtype=torch.float32, device=device).contiguous() + return t, True + key = (float(threshold), device) + cached = _THRESHOLD_CACHE.get(key) + if cached is None: + cached = torch.tensor([float(threshold)], dtype=torch.float32, device=device) + _THRESHOLD_CACHE[key] = cached + return cached, False + + +def triton_rfdetr_fullpost( + bboxes: torch.Tensor, # (B=1, num_queries, 4) fp32 normalized cxcywh + logits: torch.Tensor, # (B=1, num_queries, num_classes_total) fp32/fp16 + masks: torch.Tensor, # (B=1, num_queries, mask_h, mask_w) fp32 + threshold: "torch.Tensor | float", + num_classes: int, + class_mapping: Optional[torch.Tensor], + inference_size_wh: Tuple[int, int], # (W, H) of the inference image + pad_ltrb: Tuple[int, int, int, int], # (left, top, right, bottom) in inference coords + scale_wh: Tuple[float, float], # (scale_w, scale_h) = eff_w/orig_w, eff_h/orig_h + orig_size_wh: Tuple[int, int], # (W, H) of the original image +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Full RF-DETR post-process fused into two Triton launches. Returns + compact outputs sliced to n_survivors. + + Returns: + combined: (n_survivors, 6) int32 — per-survivor record + [x1, y1, x2, y2, conf_as_i32_bits, class_id]. + mask_bin: (n_survivors, orig_h, orig_w) uint8. + mask_any: (n_survivors,) int32 — 1 if any pixel passes threshold. + """ + assert TRITON_AVAILABLE, "triton not available" + assert bboxes.is_cuda and logits.is_cuda and masks.is_cuda + assert bboxes.shape[0] == 1 and logits.shape[0] == 1 and masks.shape[0] == 1, "batch=1 only" + + device = bboxes.device + num_queries, num_classes_total = logits.shape[1], logits.shape[2] + _, _, mask_h, mask_w = masks.shape + + # Flatten batch dim — these views are contiguous when batch=1 and the + # tensor came straight from TRT engine outputs, so .contiguous() is a + # no-op in the common case. Still call it to be defensive; torch skips + # the kernel launch when the view is already contiguous. + logits_2d = logits[0] if logits[0].is_contiguous() else logits[0].contiguous() + bboxes_2d = bboxes[0] if bboxes[0].is_contiguous() else bboxes[0].contiguous() + masks_3d = masks[0] if masks[0].is_contiguous() else masks[0].contiguous() + + # Single combined int32 output buffer for all per-survivor scalar fields + # (xyxy [4] + conf_as_i32 + class_id = 6 int32 per slot). Reduces 3-4 + # separate small .cpu() transfers at the adapter level to one. + combined = torch.empty((num_queries, 6), dtype=torch.int32, device=device) + survivor_idx = torch.empty((num_queries,), dtype=torch.int32, device=device) + mask_any = torch.empty((num_queries,), dtype=torch.int32, device=device) + # Atomic counter — must be zeroed each call since the filter kernel + # atomic_adds into it. + counter = torch.zeros((1,), dtype=torch.int32, device=device) + + thr_tensor, per_class = _prepare_threshold(threshold, device, num_classes) + + if class_mapping is not None: + has_remap = True + cmap = class_mapping if ( + class_mapping.dtype == torch.int32 + and class_mapping.device == device + and class_mapping.is_contiguous() + ) else class_mapping.to(dtype=torch.int32, device=device).contiguous() + else: + has_remap = False + cmap = _EMPTY_INT32.to(device, non_blocking=True) + + inf_w, inf_h = inference_size_wh + pad_l, pad_t, _, _ = pad_ltrb + sw, sh = scale_wh + orig_w, orig_h = orig_size_wh + + BLOCK_C = max(32, _next_pow2(num_classes_total)) + _rfdetr_fullpost_filter_kernel[(num_queries,)]( + logits_2d, + bboxes_2d, + thr_tensor, + cmap, + combined, + survivor_idx, + mask_any, + counter, + num_queries, + num_classes_total, + int(inf_w), + int(inf_h), + int(pad_l), + int(pad_t), + float(1.0 / sw), + float(1.0 / sh), + int(orig_w), + int(orig_h), + logits_2d.stride(0), + bboxes_2d.stride(0), + PER_CLASS=1 if per_class else 0, + HAS_REMAPPING=1 if has_remap else 0, + BLOCK_C=BLOCK_C, + ) + + # Launch the mask kernel with max grid (num_queries). Each program + # checks counter[0] on GPU and early-exits if its s index is out of + # range. This lets us skip a CPU-blocking counter.item() between the + # two kernel launches — both get queued to the stream immediately. + mask_bin_full = _get_mask_bin_buffer(num_queries, orig_h, orig_w, device) + + BLOCK_H = 16 + BLOCK_W = 16 + grid = ( + num_queries, + (orig_h + BLOCK_H - 1) // BLOCK_H, + (orig_w + BLOCK_W - 1) // BLOCK_W, + ) + _rfdetr_fullpost_mask_kernel_compact[grid]( + masks_3d, + survivor_idx, + counter, + mask_bin_full, + mask_any, + int(mask_h), + int(mask_w), + int(orig_h), + int(orig_w), + float(mask_h / orig_h), + float(mask_w / orig_w), + masks_3d.stride(0), + masks_3d.stride(1), + mask_bin_full.stride(0), + mask_bin_full.stride(1), + BLOCK_H=BLOCK_H, + BLOCK_W=BLOCK_W, + ) + + # Read counter (syncs the postproc stream to current stream context). + # This must happen inside the postproc stream context that the caller + # has set up, otherwise .cpu()s issued on the default stream elsewhere + # will race with the Triton writes on this stream. + n_survivors = int(counter.item()) + + if n_survivors == 0: + empty_combined = torch.empty((0, 6), dtype=torch.int32, device=device) + empty_mask = torch.empty((0, orig_h, orig_w), dtype=torch.uint8, device=device) + empty_any = torch.empty((0,), dtype=torch.int32, device=device) + return empty_combined, empty_mask, empty_any + + return ( + combined[:n_survivors], # (n, 6) int32 + mask_bin_full[:n_survivors], # (n, H, W) uint8 + mask_any[:n_survivors], # (n,) int32 + ) diff --git a/inference_models/inference_models/models/rfdetr/triton_postprocess.py b/inference_models/inference_models/models/rfdetr/triton_postprocess.py new file mode 100644 index 0000000000..064f8bb4fe --- /dev/null +++ b/inference_models/inference_models/models/rfdetr/triton_postprocess.py @@ -0,0 +1,153 @@ +"""Fused Triton kernel for RF-DETR post-processing first stage. + +Replaces the sequence + sigmoid(logits) -> argmax(class) -> named-filter -> confidence-threshold +with a single kernel launch. The remaining ops (sort, gather by index, bbox +denorm, mask alignment) stay in torch. + +Per-query: logits row has `num_classes_total` entries (num_classes + optional +"no-object" slot at the end). The kernel computes, for each query: + conf[q] = max_c sigmoid(logits[q, c]) + top_cls[q] = argmax_c logits[q, c] + keep[q] = (top_cls[q] < num_classes) & (conf[q] > threshold[top_cls[q]]) +""" +from typing import Optional, Tuple + +import torch + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: # pragma: no cover + triton = None + tl = None + TRITON_AVAILABLE = False + + +if TRITON_AVAILABLE: + + @triton.jit + def _rfdetr_conf_filter_kernel( + logits_ptr, + threshold_ptr, + scalar_threshold, + class_map_ptr, # (num_classes_total,) int, maps raw class -> remapped id; -1 = drop + conf_out_ptr, + top_class_out_ptr, + keep_out_ptr, + num_queries, + num_classes, + num_classes_total, + logits_stride_q, + PER_CLASS: tl.constexpr, + HAS_REMAPPING: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + pid = tl.program_id(0) + if pid >= num_queries: + return + offs_c = tl.arange(0, BLOCK_C) + mask_c = offs_c < num_classes_total + logits_row = tl.load( + logits_ptr + pid * logits_stride_q + offs_c, + mask=mask_c, + other=-float("inf"), + ) + max_val = tl.max(logits_row, axis=0) + BIG = 1 << 30 + is_max = logits_row == max_val + idx_or_big = tl.where(is_max & mask_c, offs_c, BIG) + raw_c = tl.min(idx_or_big, axis=0) + if HAS_REMAPPING: + top_c = tl.load(class_map_ptr + raw_c) + valid = top_c >= 0 + else: + top_c = raw_c + valid = raw_c < num_classes + abs_max = tl.abs(max_val) + z = tl.exp(-abs_max) + sig_pos = 1.0 / (1.0 + z) + sig_neg = z / (1.0 + z) + conf = tl.where(max_val >= 0.0, sig_pos, sig_neg) + if PER_CLASS: + safe_c = tl.where(valid, top_c, 0) + thr = tl.load(threshold_ptr + safe_c) + else: + thr = scalar_threshold + keep = valid & (conf > thr) + tl.store(conf_out_ptr + pid, conf) + tl.store(top_class_out_ptr + pid, top_c) + tl.store(keep_out_ptr + pid, keep.to(tl.int8)) + + +def _next_pow2(n: int) -> int: + p = 1 + while p < n: + p <<= 1 + return p + + +def triton_rfdetr_conf_filter( + logits: torch.Tensor, + threshold: "torch.Tensor | float", + num_classes: int, + class_mapping: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused sigmoid + argmax + named/remap-filter + confidence-threshold. + + Args: + logits: (num_queries, num_classes_total) CUDA tensor for a single image. + threshold: scalar float or (num_classes,) CUDA tensor (per remapped id). + num_classes: number of named classes (when no remapping: drop rows + where argmax >= num_classes). + class_mapping: optional (num_classes_total,) int32/int64 CUDA tensor: + raw_class -> remapped_id, or -1 to drop. + + Returns: + (conf, top_class, keep) each of length num_queries. + """ + assert logits.is_cuda and logits.ndim == 2 + num_queries, num_classes_total = logits.shape + logits_c = logits.contiguous() + + device = logits.device + conf = torch.empty((num_queries,), dtype=torch.float32, device=device) + top_c = torch.empty((num_queries,), dtype=torch.int32, device=device) + keep = torch.empty((num_queries,), dtype=torch.int8, device=device) + + if isinstance(threshold, torch.Tensor): + per_class = True + thr_tensor = threshold.contiguous() + scalar_thr = 0.0 + else: + per_class = False + thr_tensor = torch.empty((1,), dtype=torch.float32, device=device) + scalar_thr = float(threshold) + + if class_mapping is not None: + has_remap = True + cmap = class_mapping.to(dtype=torch.int32, device=device).contiguous() + else: + has_remap = False + cmap = torch.empty((1,), dtype=torch.int32, device=device) + + BLOCK_C = max(32, _next_pow2(num_classes_total)) + _rfdetr_conf_filter_kernel[(num_queries,)]( + logits_c, + thr_tensor, + scalar_thr, + cmap, + conf, + top_c, + keep, + num_queries, + num_classes, + num_classes_total, + logits_c.stride(0), + PER_CLASS=1 if per_class else 0, + HAS_REMAPPING=1 if has_remap else 0, + BLOCK_C=BLOCK_C, + ) + return conf, top_c, keep.bool() diff --git a/inference_models/inference_models/models/rfdetr/triton_preprocess.py b/inference_models/inference_models/models/rfdetr/triton_preprocess.py new file mode 100644 index 0000000000..e64af78fd9 --- /dev/null +++ b/inference_models/inference_models/models/rfdetr/triton_preprocess.py @@ -0,0 +1,204 @@ +"""Fused Triton preprocessing kernel for RF-DETR seg (stretch-to resize). + +Replaces: cv2.resize(BGR) -> torch.from_numpy.to(cuda) -> unsqueeze -> permute +-> BGR->RGB fancy index -> /255 -> normalize. Eight+ CUDA launches on a +tiny 312x312 tensor collapse into one Triton launch. +""" +from typing import Optional, Tuple + +import torch + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: # pragma: no cover + triton = None + tl = None + TRITON_AVAILABLE = False + + +if TRITON_AVAILABLE: + + @triton.jit + def _rfdetr_stretch_preprocess_kernel( + src_ptr, + dst_ptr, + src_h, + src_w, + src_stride_h, + src_stride_w, + target_h, + target_w, + scale_y, + scale_x, + inv_std_r_255, + inv_std_g_255, + inv_std_b_255, + offset_r, + offset_g, + offset_b, + dst_stride_c, + dst_stride_h, + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, + ): + """Stretch-to resize + BGR->RGB + normalize, fused. + + Src: uint8 HWC BGR. + Dst: fp32 (1, 3, target_h, target_w) CHW RGB, normalized to + (pixel / 255 - mean) / std. + + We pre-fold the division: inv_std_c_255 = 1 / (255 * std_c), and + pre-fold the subtraction: offset_c = -mean_c / std_c (done on host). + So the per-pixel math becomes: pixel * inv_std_c_255 + offset_c. + """ + pid_y = tl.program_id(0) + pid_x = tl.program_id(1) + + offs_y = pid_y * BLOCK_H + tl.arange(0, BLOCK_H) + offs_x = pid_x * BLOCK_W + tl.arange(0, BLOCK_W) + mask_y = offs_y < target_h + mask_x = offs_x < target_w + mask = mask_y[:, None] & mask_x[None, :] + + # Pixel-center bilinear sampling for parity with cv2.resize. + src_y_f = (offs_y.to(tl.float32) + 0.5) * scale_y - 0.5 + src_x_f = (offs_x.to(tl.float32) + 0.5) * scale_x - 0.5 + + src_y_f_2d = src_y_f[:, None] + src_x_f_2d = src_x_f[None, :] + + y0 = tl.floor(src_y_f_2d).to(tl.int32) + x0 = tl.floor(src_x_f_2d).to(tl.int32) + y1 = y0 + 1 + x1 = x0 + 1 + + dy = src_y_f_2d - y0.to(tl.float32) + dx = src_x_f_2d - x0.to(tl.float32) + + y0c = tl.maximum(tl.minimum(y0, src_h - 1), 0) + y1c = tl.maximum(tl.minimum(y1, src_h - 1), 0) + x0c = tl.maximum(tl.minimum(x0, src_w - 1), 0) + x1c = tl.maximum(tl.minimum(x1, src_w - 1), 0) + + base_00 = y0c * src_stride_h + x0c * src_stride_w + base_01 = y0c * src_stride_h + x1c * src_stride_w + base_10 = y1c * src_stride_h + x0c * src_stride_w + base_11 = y1c * src_stride_h + x1c * src_stride_w + + w_tl = (1.0 - dy) * (1.0 - dx) + w_tr = (1.0 - dy) * dx + w_bl = dy * (1.0 - dx) + w_br = dy * dx + + # BGR source: channel 0=B, 1=G, 2=R. Output order is RGB. + p00_b = tl.load(src_ptr + base_00 + 0, mask=mask, other=0).to(tl.float32) + p01_b = tl.load(src_ptr + base_01 + 0, mask=mask, other=0).to(tl.float32) + p10_b = tl.load(src_ptr + base_10 + 0, mask=mask, other=0).to(tl.float32) + p11_b = tl.load(src_ptr + base_11 + 0, mask=mask, other=0).to(tl.float32) + b_val = p00_b * w_tl + p01_b * w_tr + p10_b * w_bl + p11_b * w_br + + p00_g = tl.load(src_ptr + base_00 + 1, mask=mask, other=0).to(tl.float32) + p01_g = tl.load(src_ptr + base_01 + 1, mask=mask, other=0).to(tl.float32) + p10_g = tl.load(src_ptr + base_10 + 1, mask=mask, other=0).to(tl.float32) + p11_g = tl.load(src_ptr + base_11 + 1, mask=mask, other=0).to(tl.float32) + g_val = p00_g * w_tl + p01_g * w_tr + p10_g * w_bl + p11_g * w_br + + p00_r = tl.load(src_ptr + base_00 + 2, mask=mask, other=0).to(tl.float32) + p01_r = tl.load(src_ptr + base_01 + 2, mask=mask, other=0).to(tl.float32) + p10_r = tl.load(src_ptr + base_10 + 2, mask=mask, other=0).to(tl.float32) + p11_r = tl.load(src_ptr + base_11 + 2, mask=mask, other=0).to(tl.float32) + r_val = p00_r * w_tl + p01_r * w_tr + p10_r * w_bl + p11_r * w_br + + # Math: (pixel/255 - mean) / std == pixel * (1/(255*std)) + (-mean/std). + # inv_std_c_255 and offset_c are computed on the host and passed in. + r_out = r_val * inv_std_r_255 + offset_r + g_out = g_val * inv_std_g_255 + offset_g + b_out = b_val * inv_std_b_255 + offset_b + + out_row_offsets = offs_y[:, None] * dst_stride_h + offs_x[None, :] + tl.store(dst_ptr + 0 * dst_stride_c + out_row_offsets, r_out, mask=mask) + tl.store(dst_ptr + 1 * dst_stride_c + out_row_offsets, g_out, mask=mask) + tl.store(dst_ptr + 2 * dst_stride_c + out_row_offsets, b_out, mask=mask) + + +def triton_preprocess_rfdetr_stretch( + src: torch.Tensor, + target_h: int, + target_w: int, + means: Tuple[float, float, float] = (0.485, 0.456, 0.406), + stds: Tuple[float, float, float] = (0.229, 0.224, 0.225), + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Fused preprocess: uint8 HWC BGR -> fp32 (1,3,target_h,target_w) CHW RGB + normalized (stretch-to resize, no padding).""" + if not TRITON_AVAILABLE: + raise RuntimeError("triton is not installed") + if not src.is_cuda: + raise ValueError(f"expected CUDA tensor, got device={src.device}") + if src.dtype != torch.uint8: + raise ValueError(f"expected uint8, got {src.dtype}") + if src.ndim != 3 or src.shape[2] != 3: + raise ValueError(f"expected HWC 3-channel, got shape={tuple(src.shape)}") + + src = src.contiguous() + src_h, src_w = int(src.shape[0]), int(src.shape[1]) + src_stride_h = int(src.stride(0)) + src_stride_w = int(src.stride(1)) + + scale_y = src_h / target_h + scale_x = src_w / target_w + + if out is None: + out = torch.empty( + (1, 3, target_h, target_w), dtype=torch.float32, device=src.device + ) + else: + if tuple(out.shape) != (1, 3, target_h, target_w): + raise ValueError( + f"out has shape {tuple(out.shape)}, expected (1, 3, {target_h}, {target_w})" + ) + if out.dtype != torch.float32 or not out.is_cuda: + raise ValueError("out must be fp32 CUDA tensor") + + dst_stride_c = target_h * target_w + dst_stride_h = target_w + + inv_std_r_255 = 1.0 / (255.0 * stds[0]) + inv_std_g_255 = 1.0 / (255.0 * stds[1]) + inv_std_b_255 = 1.0 / (255.0 * stds[2]) + offset_r = -means[0] / stds[0] + offset_g = -means[1] / stds[1] + offset_b = -means[2] / stds[2] + + BLOCK_H = 16 + BLOCK_W = 16 + grid = ( + (target_h + BLOCK_H - 1) // BLOCK_H, + (target_w + BLOCK_W - 1) // BLOCK_W, + ) + _rfdetr_stretch_preprocess_kernel[grid]( + src, + out, + src_h, + src_w, + src_stride_h, + src_stride_w, + target_h, + target_w, + float(scale_y), + float(scale_x), + float(inv_std_r_255), + float(inv_std_g_255), + float(inv_std_b_255), + float(offset_r), + float(offset_g), + float(offset_b), + dst_stride_c, + dst_stride_h, + BLOCK_H=BLOCK_H, + BLOCK_W=BLOCK_W, + ) + return out diff --git a/jetson_parallel_streams_analysis.md b/jetson_parallel_streams_analysis.md new file mode 100644 index 0000000000..d87dc281a5 --- /dev/null +++ b/jetson_parallel_streams_analysis.md @@ -0,0 +1,156 @@ +# Roboflow Inference on Jetson — Static Analysis of Parallel-Stream Throughput + +## Engine Configuration (RESOLVED) + +The `trt_config.json` shipped alongside the `rfdetr-seg-nano` engine is available in-repo as a test asset: + +- `/home/ubuntu/inference/inference_models/tests/integration_tests/models/assets/models/rfdetr-seg-nano-t4-trt/unpacked/trt_config.json`: + `{"static_batch_size": 1, "dynamic_batch_size_min": null, "dynamic_batch_size_opt": null, "dynamic_batch_size_max": null}` +- `inference_config.json` in the same directory: `training_input_size: 312×312`, stretch resize, ImageNet normalization, `dynamic_spatial_size_supported: false`. +- Engine file size: ~187 MB. + +**What this means.** The shipped engine is **hard-coded to batch size 1**. The cross-stream batching path in `trt.py:607-675` cannot engage — with `max_batch_size=1`, any N-frame batch is split into N serial single-frame forwards in a Python loop (`trt.py:644-675`), each holding the `threading.Lock()` at `rfdetr_instance_segmentation_trt.py:253`. This confirms the static-analysis hypothesis: Opportunity #1 is real and blocked by engine config, not code. + +**Target hardware from the engine blob: Tesla T4 (sm_75, Turing).** Confirmed by dumping the TRT engine inspector JSON on `rfdetr-seg-nano-t4-trt/unpacked/engine.plan` — every kernel tactic in the engine is prefixed `sm75_xmma_*` (e.g., `sm75_xmma_gemm_f16f16_f16f16_f16_...`). `hardware_compatibility_level` is `NONE`, meaning the engine is locked to sm_75 and will not load on Orin NX (sm_87, Ampere). The `-t4-trt` filename suffix matches: this package is T4-only. + +**Implication for Orin NX deployment.** There is **no on-device rebuild path invoked at runtime** — `RUNS_ON_JETSON` appears in the codebase only at `inference/core/env.py:701` and `video_source.py:184` (for V4L2 capture-backend selection). The TRT compile entry points in `inference_models/development/compile_rfdetr.py` and `compilation/engine_builder.py` are developer scripts, not runtime hooks. Therefore, the Orin NX deployment must be pulling a **separate, Orin-built `.plan`** that is not present in this repo checkout. The `trt_config.json` we can see (`static_batch_size: 1`) is almost certainly the same config used for the Orin build, but this cannot be confirmed without reading the Orin model cache directly. + +**Other caveats:** +- Training input size shown is 312×312, not 560×560 as in `development/compile_rfdetr.py:22`. The compile script is a different variant (base, not nano). Nano-seg genuinely runs at 312×312 — so inference is on a substantially smaller tensor than the 560×560 assumption in earlier notes. +- This is a test asset snapshot (2026-04-21). The production Orin package could carry a different config. + +## Target Hardware + +**Jetson Orin NX** (confirmed by Brad). Implications for the analysis below: +- Single NVDEC engine — easily handles 4× 1080p30 H.264/H.265 on paper, so NVDEC is a real lever (§1). +- 2 DLA cores available (unlike Xavier NX's 2, Nano's 0), but DLA value for RF-DETR seg is still gated by the transformer head (§8). +- 8 GB or 16 GB unified LPDDR5, shared with CPU — makes zero-copy / mapped-memory particularly attractive and makes large-batch engines memory-constrained (§3, §6). +- 6 or 8 Cortex-A78AE cores — CPU decode + CPU preprocessing + Python GIL is a harder ceiling here than on AGX Orin (§1, §2, §7). +- Current observed per-stream FPS and target FPS: **TBD — to be established by runtime measurement.** + +## Executive Summary + +- **Current architecture (one paragraph).** `InferencePipeline` spawns one decode thread per `VideoSource` using OpenCV `cv2.VideoCapture` (CPU decode). A multiplexer collects frames from N sources into a single batch list and hands them to a single inference thread, which invokes a compiled Workflow. The one model block (`RFDetrForInstanceSegmentationTRT`) holds a single TensorRT engine, a single `IExecutionContext`, and a `threading.Lock()` that serializes every forward pass. Pre-process runs on one CUDA stream, the forward runs on a second, post-process on a third, each followed by a `synchronize()`. With N parallel streams, only the decode stage is actually parallel; model execution is strictly sequential. +- **Top 3 throughput opportunities, ranked.** + 1. **Batched forward across streams.** High-confidence. N streams arrive as a list but batch size 1 engines are the default observed path; moving from N sequential forwards to one batched forward is the biggest potential lever. (`_infer_from_trt_engine_with_batch_size_boundaries`, `trt.py:588-604`). + 2. **CPU video decode + CPU-side preprocessing.** High-confidence from code, size pending runtime evidence. Decode is pure `cv2.VideoCapture` (`video_source.py:140-142`); the numpy preprocessing path does `cv2.cvtColor`/`cv2.resize` on CPU before H2D copy (`pre_processing.py:973,1052-1053`). On Jetson this competes with the Python GIL and inference. + 3. **Per-stage `stream.synchronize()` scaffolding.** Medium-confidence hypothesis. Each of pre/forward/post ends with a blocking sync (`rfdetr_instance_segmentation_trt.py:243, 293` plus the lock at 253); event-based chaining could overlap stages across frames. +- **Negatives worth flagging early:** No NVDEC, GStreamer, DeepStream, VPI, or DLA code paths anywhere in the repo (see §1, §8, §10). INT8 flag is supported in builder but no calibration code is committed (§3). + +## Current Implementation Map + +Tracing one frame from a single stream: + +1. A `VideoSource` is constructed around `CV2VideoFrameProducer` (`inference/core/interfaces/camera/video_source.py:136-142`). On Jetson camera devices it uses `cv2.CAP_V4L2`; otherwise the default FFmpeg backend. No GStreamer/NVDEC pipeline string is ever passed. +2. Each source runs a dedicated decode thread started at `video_source.py:635` (`_consume_video`), pushing decoded CPU `ndarray` frames into a per-source `frames_buffer` Queue. +3. `InferencePipeline._generate_frames()` (`inference_pipeline.py:1031-1044`) delegates to `multiplex_videos()`, which calls `VideoSourcesManager.retrieve_frames_from_sources()` (`camera/utils.py:143-175`). That method loops over sources serially, appending whatever frame is ready within `batch_collection_timeout`. +4. The inference thread (`inference_pipeline.py:914-922`) runs `predictions = self._on_video_frame(video_frames)` on the batch, then pushes `(predictions, video_frames)` onto a dispatch queue. +5. `_on_video_frame` is a partial wrapping `WorkflowRunner.run_workflow` (`model_handlers/workflows.py:10-60`), which wraps each frame as `{"type": "numpy_object", "value": ndarray, ...}` and calls `execution_engine.run(...)`. +6. The Workflow is compiled once at pipeline init (`inference_pipeline.py` init_with_workflow path; compilation cached at `workflows/execution_engine/v1/compiler/core.py:60-69`). The single-block Workflow executes the RF-DETR seg block, which calls `model.infer([...])`. +7. Inside `RFDetrForInstanceSegmentationTRT`: `pre_process` (line 225) → `forward` (246) → `post_process` (268). Forward acquires `self._lock` and pushes the primary CUDA context (line 253-254). Under the hood, `_execute_trt_engine` either replays a cached CUDA graph (`trt.py:706-713`) or sets dynamic input shape and runs `execute_async_v3` (`trt.py:716-740`). + +**What changes with N parallel streams.** Decode remains N-way parallel (N OS threads, CPU). The multiplexer serializes retrieval into one list. The inference thread processes that list as one workflow invocation; the SIMD-batch workflow path (`workflows/execution_engine/v1/executor/core.py:317-378`) passes the full list to the block's single `run()` call. The block, however, receives a list of numpy arrays, and preprocessing loops over them one-by-one on CPU (`pre_processing.py:843-865`) before concatenating tensors. The TRT forward is thus the only place where batching across streams could pay off, and only if the engine was built with `dynamic_batch_size_max ≥ N`. + +## Findings by Category + +### 1. Stream ingestion and decode path + +- **What the code does:** All decoding goes through `cv2.VideoCapture(video)` or `cv2.VideoCapture(video, cv2.CAP_V4L2)` (`video_source.py:140-142`). There is no GStreamer pipeline string, no PyAV, no DeepStream, no Jetson Multimedia API usage anywhere in `inference/` or `inference_models/`. `retrieve()` returns a host `ndarray` (`video_source.py:150-151`). +- **Why it matters:** On Jetson, FFmpeg-backed `VideoCapture` does CPU H.264/H.265 decode by default, bypassing NVDEC. For N 30fps streams this is N separate CPU decoders fighting the GIL alongside inference. +- **Confidence:** verified from code for the decoder path; "FFmpeg CPU decode" is the default runtime behavior of OpenCV unless a GStreamer pipeline is passed, which it never is here. +- **Estimated opportunity:** medium-to-large — NVDEC on Orin can decode many 1080p30 streams at negligible CPU cost. + +### 2. Preprocessing pipeline + +- **What the code does:** The numpy path (`pre_process_numpy_images_list`, `pre_processing.py:843-865`) calls `pre_process_numpy_image` per image in a Python `for` loop, concatenating tensors at the end. Within each iteration, grayscale/contrast transforms and resize use `cv2.cvtColor` (`:973`) and `cv2.resize` (`:1052`, `:1104`, `:1254`) on the CPU. Only after the resize does the code do `torch.from_numpy(...).to(target_device)` (`:1053, :1105, :1269`) — one H2D copy per image. +- **Why it matters:** On N streams this is N CPU resizes plus N H2D copies per multiplex cycle. For RF-DETR nano seg at 560×560 input, this is non-trivial but small per-frame; it becomes the bottleneck when decode+preprocess time exceeds the available Python runtime budget between forwards. +- **Confidence:** verified from code. +- **Estimated opportunity:** medium — batching resize on GPU (e.g., via torchvision, VPI, or a single concat-then-resize) would remove the per-image Python overhead. + +### 3. TensorRT engine instantiation and batching (CRITICAL) + +- **Single context, lock-serialized.** `engine.create_execution_context()` is called once in `from_pretrained` (`rfdetr_instance_segmentation_trt.py:159`), stored as `self._execution_context`. Every `forward()` acquires `self._lock = threading.Lock()` (set at `:216`) before invoking the engine (`:253-265`). This lock is the choke point for parallel streams — even if the workflow hands the block a list of N frames, the forward is still a single lock-protected call. +- **Dynamic shapes and batching are supported but splitting/padding happens automatically.** `_infer_from_trt_engine_with_batch_size_boundaries` (`trt.py:607-675`) pads a batch up to `min_batch_size` with zeros (`:619-631`) or splits anything larger than `max_batch_size` into serial sub-batches in a Python for loop (`:644-675`). The engine build profile reads `dynamic_batch_size_min/opt/max` from `trt_config.json` (`model_packages.py:146-196`). +- **CUDA graphs, when enabled, require fixed shapes.** A separate `graph_context = engine.create_execution_context()` is created per captured shape (`trt.py:759`, comment at `:756-758`). The replay path (`:706-713`) is the fast path and avoids the per-call `set_input_shape`/`set_tensor_address` work at `:716-722`. +- **Precision.** Builder supports FP32/FP16/INT8 (`engine_builder.py:80-87`). The RF-DETR compile script at `inference_models/development/compile_rfdetr.py:22` requests FP16/FP32 for 560×560 inputs, with 15 GB workspace (`:11`). No committed INT8 calibration code. +- **Why it matters:** With `max_batch_size ≥ N`, N streams could be one forward. But if the `.plan` shipped to the Jetson was built with `static_batch_size=1` or `max=1`, the code at `trt.py:644-675` degrades to a Python loop of single-frame forwards. The practical throughput for parallel streams therefore depends entirely on what `trt_config.json` ships inside the RF-DETR seg model package — which we cannot see from this repo alone. +- **Confidence:** single-context + global lock is verified from code; actual batch capabilities of the shipped `.plan` are a runtime artifact we cannot read. +- **Estimated opportunity:** large — if the engine is batch-1, rebuilding with a dynamic profile unlocks the single biggest lever. + +### 4. CUDA stream usage and async execution + +- **What the code does:** Three CUDA streams per model instance — one inference stream on the class (`rfdetr_instance_segmentation_trt.py:217`), plus per-thread pre- and post-process streams via `threading.local()` (`:218`, `:297-310`). The forward uses `execute_async_v3(stream_handle=stream.cuda_stream)` (`trt.py:740`). However, each stage ends with an explicit `stream.synchronize()` (`:243`, `:293`, `trt.py:712`). +- **Why it matters:** The streams are real, but the synchronize barriers between stages mean pre-process finishes → sync → forward → sync → post → sync, serially. Within a single forward call, the async execution is in principle overlapping with any prior post-process work on a different stream; in practice the `synchronize()` at end of each stage blocks until completion and the downstream stage waits. +- **Confidence:** verified from code for the sync points; whether they are required for correctness or just defensive is a code-reading judgment — the post-process sync at `:293` is plausibly redundant because consumers of the returned `InstanceDetections` will cause their own syncs. +- **Estimated opportunity:** small-to-medium — cross-frame pipelining would require a redesign of the model handler; the per-frame overhead is likely tens of milliseconds total but won't multiply with N. + +### 5. Postprocessing + +- **What the code does:** RF-DETR seg post-processing runs entirely on GPU tensors (`inference_models/models/rfdetr/common.py:45-129`): sigmoid, top-k, gather on masks, cxcywh→xyxy. `align_instance_segmentation_results` is GPU-based. No `.cpu()`, `.numpy()`, or `.item()` calls inside the post_process method at `rfdetr_instance_segmentation_trt.py:268-294`. Final `.cpu()` transfers happen at the workflow boundary, not inside the model. +- **Why it matters:** Post-processing is already well-structured. It is not a major target. +- **Confidence:** verified from code. +- **Estimated opportunity:** small. + +### 6. Memory management + +- **What the code does:** CUDA graphs reuse `input_buffer`/`output_buffers` allocated at capture time (`trt.py:761, 779-789`), cached in an LRU keyed by input shape (`TRTCudaGraphCache`, `trt.py:83-272`, default size 8). The non-graph path, however, allocates a fresh output tensor every forward (`trt.py:732-736`: `result = torch.empty(tuple(output_tensor_shape), ...)`). No `cudaHostAlloc` with mapped flag, no `cudaMallocManaged`, no pinned-memory pool — the only pinning comes from PyTorch's default allocator. +- **Why it matters:** Jetson's physically unified memory means CPU→GPU copies are logical, not over PCIe, but Torch still does a staged copy by default. Zero-copy of the decoded frame into a GPU-addressable buffer would cut a full frame-sized H2D per stream per frame. This repo has no such path. +- **Confidence:** verified from code for absence of zero-copy. +- **Estimated opportunity:** medium on Jetson; same change on discrete GPU would be small. + +### 7. Threading / process model in InferencePipeline + +- **What the code does:** N decode threads (one per `VideoSource`) → shared queues → one inference thread (`inference_pipeline.py:906-922`) → one dispatch thread. Workflow block execution can use a `ThreadPoolExecutor` but for a single-block workflow there is nothing to parallelize. +- **Why it matters:** This topology is actually good for batching — the multiplexer already produces a per-iteration list of frames, one per active source. The missed opportunity is downstream: the model handler serializes that list through one lock. +- **Confidence:** verified from code. +- **Estimated opportunity:** small in isolation (the topology is already set up to feed a batch); the value comes from pairing with batched inference (§3). + +### 8. DLA usage + +- **What the code does:** Nothing. `grep` for `DLA`, `useDLACore`, `setDeviceType`, `kDLA` across `inference_models/` returns no matches (confirmed by the TRT exploration). The only builder config beyond precision is `hardware_compatibility_level = trt.HardwareCompatibilityLevel.SAME_COMPUTE_CAPABILITY` (`engine_builder.py:91`). +- **Why it matters:** RF-DETR's transformer blocks almost certainly cannot target DLA, but the CNN backbone could if split. No such path exists. +- **Confidence:** verified from code (absence). +- **Estimated opportunity:** unknown — likely small for seg because of the transformer head, and would require engine-graph surgery. + +### 9. Workflow / single-block overhead + +- **What the code does:** Workflow compilation is cached (`workflows/execution_engine/v1/compiler/core.py:60-69`). Per frame, however, `assemble_runtime_parameters` pydantic-validates inputs (`runtime_input_assembler.py:16-42`), fresh `ExecutionCache`/`DynamicBatchesManager`/`BranchingManager` instances are created in `ExecutionDataManager.init()` (`execution_data_manager/manager.py:48-66`), and step-input assembly traverses the compound input graph. For a single-block SIMD workflow the block is called once per batch with the full list, so the overhead is amortized across all N frames in the batch. +- **Why it matters:** Python orchestration is per-*batch* not per-*frame*, which is a relief at higher stream counts. But per-batch it is still measurable. +- **Confidence:** verified from code. +- **Estimated opportunity:** small — would only show up at very high N or very fast inference (where Python overhead approaches forward time). + +### 10. Jetson-specific code paths + +- **What the code does:** Jetson detection exists (`inference/core/devices/utils.py:68-86`, reading `/proc/device-tree/serial-number`) and is used for device-id reporting and the V4L2 capture-backend choice (`video_source.py:183-188`). The adaptive buffer-filling strategy is the default for streams (`video_source.py:932` per agent report). No `nvpmodel`, `jetson_clocks`, JetPack version check, VPI library usage, or DeepStream integration anywhere. +- **Why it matters:** There is essentially no Jetson-specific optimization beyond "use V4L2 for USB cameras." RF-DETR on Jetson runs the same code path as on desktop CUDA. +- **Confidence:** verified from code. +- **Estimated opportunity:** medium — Jetson-specific hooks (NVDEC, VPI, mapped-memory frames) are untouched ground. + +## Top Opportunities, Ranked + +1. **Batched TRT forward across streams.** Feed the multiplexer's N-frame list as a single batched tensor to one `execute_async_v3`. *Impact (static-analysis estimate): 1.5–2x on N=4 if the engine supports it.* **Runtime evidence needed:** the `trt_config.json` shipped in the RF-DETR seg model package (look at `static_batch_size` and `dynamic_batch_size_max`), and a `trtexec --loadEngine --shapes` run at batch 1 vs 4 to confirm that amortized kernel time scales sublinearly. +2. **GPU video decode via NVDEC / GStreamer pipeline string into `VideoCapture`.** Replace `cv2.VideoCapture(video)` at `video_source.py:140-142` with a GStreamer appsink pipeline that uses `nvv4l2decoder`/`nvvidconv` on Jetson. *Impact: 10–30% CPU headroom at N=4, 30fps 1080p; also removes GIL contention with the inference thread.* **Runtime evidence:** `tegrastats` during a 4-stream run showing CPU saturation; Nsight Systems trace showing decode threads blocking inference. +3. **GPU-side batched preprocessing (skip the per-image Python loop).** The current `pre_process_numpy_images_list` loop (`pre_processing.py:843-865`) with its per-image `cv2.resize` → `from_numpy` → `.to(device)` chain is a serial CPU gauntlet. *Impact: 1.1–1.3x at N=4.* **Runtime evidence:** Nsight kernel timeline showing gaps between H2D copies; CPU sampling profiler (py-spy) showing time in `cv2.resize` and `to_numpy`. +4. **Eliminate mid-pipeline `stream.synchronize()` barriers.** The three stage-ending syncs (`rfdetr_instance_segmentation_trt.py:243, 293`, `trt.py:712`) plus the global lock mean no overlap across frames. *Impact: 5–20% at N=4, depending on how close pre/post times are to forward time.* **Runtime evidence:** Nsight Systems trace showing gaps between kernels on the inference stream. +5. **Zero-copy or pinned-memory input buffer on Jetson.** The non-CUDA-graph path allocates output tensors every forward (`trt.py:732-736`); the input also transits a staged copy. *Impact: 5–15% at N=4.* **Runtime evidence:** Nsight memory trace showing H2D/D2H transfer time relative to compute. + +## What Static Analysis Cannot Tell Us + +- **The actual batch capabilities of the shipped `.plan` file.** `dynamic_batch_size_max` lives in the model package `trt_config.json`, not in this repo. *Resolves with:* reading `trt_config.json` from the cached model package directory, or `trtexec --loadEngine=engine.plan --verbose` dumping profile ranges. +- **Where wall-clock time is actually being spent.** Decode vs preprocess vs forward vs post vs Python overhead. *Resolves with:* Nsight Systems CPU+GPU timeline for a 4-stream run. +- **Whether `execute_async_v3` kernels on the three streams actually overlap, or serialize on a single engine queue.** TRT execution contexts cannot run two enqueues concurrently on one context regardless of stream. *Resolves with:* Nsight GPU kernel timeline showing kernel concurrency or lack thereof. +- **Whether CUDA graph capture is actually in use for the RF-DETR seg workload.** The cache is set up (`establish_trt_cuda_graph_cache`) but activation depends on `disable_cuda_graphs` and input-shape stability. *Resolves with:* a log line or trace confirming graph replay vs full enqueue. +- **Whether the OpenCV build on this Jetson is compiled with GStreamer support.** Without GStreamer in OpenCV, no pipeline-string backdoor is possible. *Resolves with:* `cv2.getBuildInformation()` on the target. +- **Python GIL contention between decode threads and the inference thread.** *Resolves with:* py-spy sampled profile during a 4-stream run. +- **Whether the `nvpmodel` / `jetson_clocks` state on the deployment unit has the GPU clocked to max.** *Resolves with:* `sudo nvpmodel -q` and `sudo jetson_clocks --show`. + +## Open Questions for Brad / Pawel + +1. **[OPEN — answer not yet known] What is the `.plan` file's dynamic shape profile** for RF-DETR nano seg as currently shipped from the inference-models cache? Specifically: what are `dynamic_batch_size_min/opt/max` in the shipped `trt_config.json`? This changes which optimization is first. Can be answered by reading the cached model package on an Orin NX that has already pulled the engine. +2. **[PARTIALLY RESOLVED — code supports both patterns, production choice still unknown]** The codebase supports two deployment patterns: + - **Pattern A (single process, multiple streams):** `InferencePipeline.init_with_workflow(video_reference=[...])` — `inference_pipeline.py:90-91` accepts a list; docs at `docs/workflows/video_processing/overview.md:71` demonstrate it. Example in-repo: `development/stream_interface/yolo_world_demo.py:19`. One model instance, one lock, one TRT context. + - **Pattern B (multiple processes):** Enterprise Stream Manager spawns one `InferencePipelineManager(Process)` per pipeline (`inference/enterprise/stream_management/manager/inference_pipeline_manager.py:44`, spawned from `manager/app.py:146`). Each process independent: own engine, own context, no shared lock. On Jetson without MPS, contexts time-slice the GPU. + - **Impact on opportunities:** batched-forward (#1) only helps Pattern A; CUDA MPS and per-process engine-sharing tricks only help Pattern B; NVDEC decode, stage-sync removal, and GPU preprocessing help both. Needs product confirmation which is in use for the Orin NX deployment. +3. **[ANSWERED] Target Jetson:** Orin NX. Captured in the "Target Hardware" section above. +4. **[TO BE MEASURED] Actual observed per-stream FPS today on Orin NX, and expected FPS at 2x.** Needs a profiling run; gives us the budget to compare against the static-analysis estimates above. +5. **[OPEN — answer not yet known] Was the `.plan` built on-device via the `RUNS_ON_JETSON` compile path, or shipped prebuilt?** Relevant because engines built with different `max_batch` may exist in the model cache, and rebuilding with a new profile is a different task from swapping an engine. Folds into Q1 — both resolve by inspecting what's in the model cache on a live Orin NX.