From cdece6a46aeca0effe5ba470f6e2a4959776c952 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 4 May 2026 19:31:00 +0000 Subject: [PATCH 01/25] perf(rfdetr-seg): fused Triton post-processing behind RFDETR_TRITON_FULLPOSTPROC Adds two fused Triton kernels that replace the post-TRT chain for the common rfdetr-seg-nano path (batch=1, no static crop, STRETCH_TO resize, class remapping present): - _rfdetr_fullpost_filter_kernel: sigmoid/argmax/class remap/conf threshold + cxcywh->xyxy + letterbox denormalize + clip + banker's rounding; atomic_add into a counter to reserve a compact output slot. - _rfdetr_fullpost_mask_kernel_compact: bilinear upsample masks to orig size + threshold + uint8 emit, with GPU-side early exit against the counter so no CPU sync is needed between the two kernels. Dispatch is gated on ``RFDETR_TRITON_FULLPOSTPROC=true``; callers keep the torch reference path when Triton is unavailable or the eligibility checks fail. The adapter reads a 4-byte counter via a pinned DtoH to learn ``n_survivors`` and then async-DtoHs the compact combined/mask slices. --- .../core/models/inference_models_adapters.py | 89 +++- .../inference_models/models/rfdetr/common.py | 63 +++ .../models/rfdetr/triton_fullpostproc.py | 417 ++++++++++++++++++ 3 files changed, 556 insertions(+), 13 deletions(-) create mode 100644 inference_models/inference_models/models/rfdetr/triton_fullpostproc.py diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 9e6f43b4d2..9b4bbb594e 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -9,6 +9,21 @@ from PIL import Image, ImageDraw, ImageFont from pycocotools import mask as mask_utils +# Pinned host buffers for async DtoH on the full-postproc Triton fast path. +# Keyed by (name, dtype); reused across frames provided the cached buffer is +# at least as large as the requested shape in every dimension. +_PINNED_HOST_BUFFERS: dict = {} + + +def _get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor: + key = (name, dtype) + buf = _PINNED_HOST_BUFFERS.get(key) + if buf is not None and all(buf.shape[i] >= shape[i] for i in range(len(shape))): + return buf[tuple(slice(0, s) for s in shape)] + buf = torch.empty(shape, dtype=dtype, pin_memory=True) + _PINNED_HOST_BUFFERS[key] = buf + return buf + from inference.core.entities.requests import ( ClassificationInferenceRequest, InferenceRequest, @@ -320,22 +335,70 @@ def postprocess( H = preproc_metadata.original_size.height W = preproc_metadata.original_size.width - xyxy = det.xyxy.detach().cpu().numpy() - confs = det.confidence.detach().cpu().numpy() - if isinstance(det.mask, torch.Tensor): - masks = det.mask.detach().cpu().numpy() - if return_in_rle: - polys_or_rles = [ - torch_mask_to_coco_rle(mask=mask) for mask in masks - ] + # Fast path: RF-DETR full-postproc Triton fusion emits an + # unsliced (num_queries, 6) int32 record plus a GPU counter and a + # completion event. We DtoH the 4-byte counter (single sync) to + # learn n_survivors, then async-DtoH the compact slices. + combined_gpu = getattr(det, "_combined_gpu", None) + counter_gpu = getattr(det, "_counter_gpu", None) + done_event = getattr(det, "_postproc_done_event", None) + if ( + not return_in_rle + and combined_gpu is not None + and counter_gpu is not None + and done_event is not None + and isinstance(det.mask, torch.Tensor) + and det.mask.is_cuda + ): + device = combined_gpu.device + stream = torch.cuda.current_stream(device) + done_event.wait(stream) + + counter_host = _get_pinned_buffer("counter", (1,), torch.int32) + counter_host.copy_(counter_gpu, non_blocking=True) + stream.synchronize() + n_survivors = int(counter_host[0].item()) + + if n_survivors == 0: + xyxy = np.empty((0, 4), dtype=np.int32) + confs = np.empty((0,), dtype=np.float32) + class_ids = np.empty((0,), dtype=np.int32) + polys_or_rles = [] else: - polys_or_rles = masks2poly(masks) + combined_slice = combined_gpu[:n_survivors] + mask_slice = det.mask[:n_survivors] + combined_host = _get_pinned_buffer( + "combined", combined_slice.shape, combined_slice.dtype + ) + mask_host = _get_pinned_buffer( + "mask", mask_slice.shape, mask_slice.dtype + ) + combined_host.copy_(combined_slice, non_blocking=True) + mask_host.copy_(mask_slice, non_blocking=True) + stream.synchronize() + combined_cpu = combined_host.numpy() + xyxy = combined_cpu[:, :4] + # combined[:, 4] holds fp32 conf bits stored as int32. + confs = combined_cpu[:, 4].view(np.float32) + class_ids = combined_cpu[:, 5] + polys_or_rles = masks2poly(mask_host.numpy()) else: - if return_in_rle: - polys_or_rles = det.mask.to_coco_rle_masks() + xyxy = det.xyxy.detach().cpu().numpy() + confs = det.confidence.detach().cpu().numpy() + if isinstance(det.mask, torch.Tensor): + masks = det.mask.detach().cpu().numpy() + if return_in_rle: + polys_or_rles = [ + torch_mask_to_coco_rle(mask=mask) for mask in masks + ] + else: + polys_or_rles = masks2poly(masks) else: - polys_or_rles = rle_masks2poly(det.mask) - class_ids = det.class_id.detach().cpu().numpy() + if return_in_rle: + polys_or_rles = det.mask.to_coco_rle_masks() + else: + polys_or_rles = rle_masks2poly(det.mask) + class_ids = det.class_id.detach().cpu().numpy() predictions: List[ Union[InstanceSegmentationPrediction, InstanceSegmentationRLEPrediction] diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index a3ae26cd29..3c9d1b4ff6 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple, Union import torch @@ -19,6 +20,43 @@ from inference_models.models.rfdetr.post_processor import select_topk_predictions from inference_models.utils.file_system import read_json +_RFDETR_TRITON_FULLPOSTPROC = os.getenv("RFDETR_TRITON_FULLPOSTPROC", "false").lower() in ( + "true", + "1", +) +if _RFDETR_TRITON_FULLPOSTPROC: + try: + from inference_models.models.rfdetr.triton_fullpostproc import ( + TRITON_AVAILABLE as _TRITON_FULLPOST_AVAILABLE, + triton_rfdetr_fullpost, + ) + _TRITON_FULLPOST_READY = _TRITON_FULLPOST_AVAILABLE and torch.cuda.is_available() + except Exception: + _TRITON_FULLPOST_READY = False + triton_rfdetr_fullpost = None +else: + _TRITON_FULLPOST_READY = False + triton_rfdetr_fullpost = None + + +def _fullpost_eligible( + bboxes: torch.Tensor, + pre_processing_meta: List[PreProcessingMetadata], + classes_re_mapping: Optional[ClassesReMapping], +) -> bool: + if not _TRITON_FULLPOST_READY or not bboxes.is_cuda: + return False + if bboxes.shape[0] != 1 or len(pre_processing_meta) != 1: + return False + meta = pre_processing_meta[0] + if meta.nonsquare_intermediate_size is not None: + return False + if meta.static_crop_offset.offset_x != 0 or meta.static_crop_offset.offset_y != 0: + return False + if classes_re_mapping is None: + return False + return True + def parse_model_type(config_path: str) -> str: try: @@ -132,6 +170,31 @@ def post_process_instance_segmentation_results( num_classes: int, classes_re_mapping: Optional[ClassesReMapping], ) -> List[InstanceDetections]: + if _fullpost_eligible(bboxes, pre_processing_meta, classes_re_mapping): + meta = pre_processing_meta[0] + thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) + combined, mask_bin, mask_any, counter, done_event = triton_rfdetr_fullpost( + bboxes=bboxes, + logits=logits, + masks=masks, + threshold=thr_arg, + num_classes=num_classes, + class_mapping=classes_re_mapping.class_mapping, + inference_size_wh=(meta.inference_size.width, meta.inference_size.height), + pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), + scale_wh=(meta.scale_width, meta.scale_height), + orig_size_wh=(meta.original_size.width, meta.original_size.height), + ) + detections = InstanceDetections( + xyxy=combined[:, :4], + confidence=combined[:, 4], + class_id=combined[:, 5], + mask=mask_bin, + ) + detections.__dict__["_combined_gpu"] = combined + detections.__dict__["_counter_gpu"] = counter + detections.__dict__["_postproc_done_event"] = done_event + return [detections] logits_sigmoid = torch.nn.functional.sigmoid(logits) results = [] device = bboxes.device diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py new file mode 100644 index 0000000000..b80d076e9e --- /dev/null +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -0,0 +1,417 @@ +"""Fused RF-DETR instance-segmentation post-processing in Triton. + +Two kernels replace the post-TRT chain for the common rfdetr-seg-nano path +(batch=1, no static crop, STRETCH_TO resize, class remapping active): + + _rfdetr_fullpost_filter_kernel (grid = num_queries) + sigmoid argmax + class remap + conf threshold + cxcywh->xyxy + + letterbox-denormalize + clip + banker's rounding; atomic_add into a + counter to reserve a compact output slot. + + _rfdetr_fullpost_mask_kernel_compact (grid = num_queries * tile_y * tile_x) + Bilinear upsample masks (e.g. 78x78 -> orig_h x orig_w) + threshold > 0 + + uint8 emit. Early-exits on s >= counter[0] without an intermediate sync. +""" +from typing import Optional, Tuple + +import torch + +try: + import triton + import triton.language as tl + + TRITON_AVAILABLE = True +except ImportError: # pragma: no cover + triton = None + tl = None + TRITON_AVAILABLE = False + + +if TRITON_AVAILABLE: + + @triton.jit + def _rfdetr_fullpost_filter_kernel( + logits_ptr, + bboxes_ptr, + threshold_ptr, + class_map_ptr, + combined_out_ptr, + survivor_idx_out_ptr, + mask_any_out_ptr, + counter_ptr, + num_queries, + num_classes_total, + inference_w, + inference_h, + pad_left, + pad_top, + inv_scale_w, + inv_scale_h, + orig_w, + orig_h, + logits_stride_q, + bboxes_stride_q, + PER_CLASS: tl.constexpr, + HAS_REMAPPING: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + pid = tl.program_id(0) + if pid >= num_queries: + return + offs_c = tl.arange(0, BLOCK_C) + mask_c = offs_c < num_classes_total + + logits_row = tl.load( + logits_ptr + pid * logits_stride_q + offs_c, + mask=mask_c, + other=-float("inf"), + ) + max_val = tl.max(logits_row, axis=0) + BIG = 1 << 30 + is_max = logits_row == max_val + idx_or_big = tl.where(is_max & mask_c, offs_c, BIG) + raw_c = tl.min(idx_or_big, axis=0) + + if HAS_REMAPPING: + top_c = tl.load(class_map_ptr + raw_c) + valid = top_c >= 0 + else: + top_c = raw_c + valid = raw_c < num_classes_total + + abs_max = tl.abs(max_val) + z = tl.exp(-abs_max) + sig_pos = 1.0 / (1.0 + z) + sig_neg = z / (1.0 + z) + conf = tl.where(max_val >= 0.0, sig_pos, sig_neg) + + if PER_CLASS: + safe_c = tl.where(valid, top_c, 0) + thr = tl.load(threshold_ptr + safe_c) + else: + thr = tl.load(threshold_ptr) + keep = valid & (conf > thr) + + if not keep: + return + + # Match the non-Triton path's FP32 evaluation order for bit-parity. + cx_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 0) + cy_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 1) + w_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 2) + h_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 3) + + x1_pct = cx_pct - 0.5 * w_pct + y1_pct = cy_pct - 0.5 * h_pct + x2_pct = cx_pct + 0.5 * w_pct + y2_pct = cy_pct + 0.5 * h_pct + + x1 = x1_pct * inference_w + y1 = y1_pct * inference_h + x2 = x2_pct * inference_w + y2 = y2_pct * inference_h + + x1 = x1 - pad_left + y1 = y1 - pad_top + x2 = x2 - pad_left + y2 = y2 - pad_top + + x1 = x1 * inv_scale_w + y1 = y1 * inv_scale_h + x2 = x2 * inv_scale_w + y2 = y2 * inv_scale_h + + x1 = tl.maximum(tl.minimum(x1, orig_w), 0.0) + y1 = tl.maximum(tl.minimum(y1, orig_h), 0.0) + x2 = tl.maximum(tl.minimum(x2, orig_w), 0.0) + y2 = tl.maximum(tl.minimum(y2, orig_h), 0.0) + + # Banker's rounding (half-to-even) matches torch.round().int(). + x1_r = tl.floor(x1 + 0.5) + y1_r = tl.floor(y1 + 0.5) + x2_r = tl.floor(x2 + 0.5) + y2_r = tl.floor(y2 + 0.5) + x1_i = x1_r.to(tl.int32) + y1_i = y1_r.to(tl.int32) + x2_i = x2_r.to(tl.int32) + y2_i = y2_r.to(tl.int32) + x1_i = tl.where(((x1_r - x1) == 0.5) & ((x1_i & 1) != 0), x1_i - 1, x1_i) + y1_i = tl.where(((y1_r - y1) == 0.5) & ((y1_i & 1) != 0), y1_i - 1, y1_i) + x2_i = tl.where(((x2_r - x2) == 0.5) & ((x2_i & 1) != 0), x2_i - 1, x2_i) + y2_i = tl.where(((y2_r - y2) == 0.5) & ((y2_i & 1) != 0), y2_i - 1, y2_i) + + slot = tl.atomic_add(counter_ptr, 1) + + # Bitcast conf (fp32) as int32 so the whole record writes with int32 + # stores. Host views the same memory as int32 and extracts via + # numpy.view(np.float32). + conf_bits = conf.to(tl.float32, bitcast=False) + conf_i32 = conf_bits.to(tl.int32, bitcast=True) + base = slot * 6 + tl.store(combined_out_ptr + base + 0, x1_i) + tl.store(combined_out_ptr + base + 1, y1_i) + tl.store(combined_out_ptr + base + 2, x2_i) + tl.store(combined_out_ptr + base + 3, y2_i) + tl.store(combined_out_ptr + base + 4, conf_i32) + tl.store(combined_out_ptr + base + 5, top_c) + tl.store(survivor_idx_out_ptr + slot, pid.to(tl.int32)) + tl.store(mask_any_out_ptr + slot, 0) + + + @triton.jit + def _rfdetr_fullpost_mask_kernel_compact( + masks_ptr, + survivor_idx_ptr, + counter_ptr, + out_ptr, + mask_any_ptr, + mask_h, + mask_w, + orig_h, + orig_w, + mask_scale_y, + mask_scale_x, + masks_stride_q, + masks_stride_h, + out_stride_s, + out_stride_h, + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, + ): + s = tl.program_id(0) + tile_y = tl.program_id(1) + tile_x = tl.program_id(2) + + # GPU-side early exit — skip programs past the live survivor count. + n_survivors = tl.load(counter_ptr) + if s >= n_survivors: + return + + q = tl.load(survivor_idx_ptr + s) + + offs_y = tile_y * BLOCK_H + tl.arange(0, BLOCK_H) + offs_x = tile_x * BLOCK_W + tl.arange(0, BLOCK_W) + mask_yy = offs_y < orig_h + mask_xx = offs_x < orig_w + m_outbox = mask_yy[:, None] & mask_xx[None, :] + + src_y_f = (offs_y.to(tl.float32) + 0.5) * mask_scale_y - 0.5 + src_x_f = (offs_x.to(tl.float32) + 0.5) * mask_scale_x - 0.5 + src_y_2d = src_y_f[:, None] + src_x_2d = src_x_f[None, :] + + y0 = tl.floor(src_y_2d).to(tl.int32) + x0 = tl.floor(src_x_2d).to(tl.int32) + y1 = y0 + 1 + x1 = x0 + 1 + dy = src_y_2d - y0.to(tl.float32) + dx = src_x_2d - x0.to(tl.float32) + + y0c = tl.maximum(tl.minimum(y0, mask_h - 1), 0) + y1c = tl.maximum(tl.minimum(y1, mask_h - 1), 0) + x0c = tl.maximum(tl.minimum(x0, mask_w - 1), 0) + x1c = tl.maximum(tl.minimum(x1, mask_w - 1), 0) + + base = q * masks_stride_q + + p00 = tl.load(masks_ptr + base + y0c * masks_stride_h + x0c, mask=m_outbox, other=0.0) + p01 = tl.load(masks_ptr + base + y0c * masks_stride_h + x1c, mask=m_outbox, other=0.0) + p10 = tl.load(masks_ptr + base + y1c * masks_stride_h + x0c, mask=m_outbox, other=0.0) + p11 = tl.load(masks_ptr + base + y1c * masks_stride_h + x1c, mask=m_outbox, other=0.0) + + w_tl = (1.0 - dy) * (1.0 - dx) + w_tr = (1.0 - dy) * dx + w_bl = dy * (1.0 - dx) + w_br = dy * dx + val = p00 * w_tl + p01 * w_tr + p10 * w_bl + p11 * w_br + bin_val = (val > 0.0).to(tl.int8) + + out_offsets = offs_y[:, None] * out_stride_h + offs_x[None, :] + tl.store(out_ptr + s * out_stride_s + out_offsets, bin_val, mask=m_outbox) + + tile_any = tl.max(bin_val.to(tl.int32), axis=0) + tile_any2 = tl.max(tile_any, axis=0) + tl.atomic_max(mask_any_ptr + s, tile_any2) + + +def _next_pow2(n: int) -> int: + p = 1 + while p < n: + p <<= 1 + return p + + +_THRESHOLD_CACHE: dict = {} +_EMPTY_INT32 = torch.empty((1,), dtype=torch.int32) +_MASK_BIN_BUFFER_CACHE: dict = {} +_SCRATCH_CACHE: dict = {} +_CLASS_MAPPING_INT32_CACHE: dict = {} + + +def _get_scratch_buffers(num_queries: int, device: torch.device): + key = (num_queries, device) + cached = _SCRATCH_CACHE.get(key) + if cached is None: + combined = torch.empty((num_queries, 6), dtype=torch.int32, device=device) + survivor_idx = torch.empty((num_queries,), dtype=torch.int32, device=device) + mask_any = torch.empty((num_queries,), dtype=torch.int32, device=device) + counter = torch.zeros((1,), dtype=torch.int32, device=device) + cached = (combined, survivor_idx, mask_any, counter) + _SCRATCH_CACHE[key] = cached + return cached + + +def _get_class_mapping_int32(class_mapping: torch.Tensor, device: torch.device) -> torch.Tensor: + if class_mapping.dtype == torch.int32 and class_mapping.device == device and class_mapping.is_contiguous(): + return class_mapping + key = (id(class_mapping), device) + cached = _CLASS_MAPPING_INT32_CACHE.get(key) + if cached is not None: + return cached + cached = class_mapping.to(dtype=torch.int32, device=device).contiguous() + _CLASS_MAPPING_INT32_CACHE[key] = cached + return cached + + +def _get_mask_bin_buffer( + capacity: int, orig_h: int, orig_w: int, device: torch.device +) -> torch.Tensor: + # Rows beyond n_survivors may hold stale data from prior frames; callers + # must size their read by the atomic counter. + key = (capacity, orig_h, orig_w, device) + buf = _MASK_BIN_BUFFER_CACHE.get(key) + if buf is None: + buf = torch.empty( + (capacity, orig_h, orig_w), dtype=torch.uint8, device=device + ) + _MASK_BIN_BUFFER_CACHE[key] = buf + return buf + + +def _prepare_threshold(threshold, device: torch.device, num_classes: int): + if isinstance(threshold, torch.Tensor): + t = threshold + if t.dtype != torch.float32 or t.device != device or not t.is_contiguous(): + t = t.to(dtype=torch.float32, device=device).contiguous() + return t, True + key = (float(threshold), device) + cached = _THRESHOLD_CACHE.get(key) + if cached is None: + cached = torch.tensor([float(threshold)], dtype=torch.float32, device=device) + _THRESHOLD_CACHE[key] = cached + return cached, False + + +def triton_rfdetr_fullpost( + bboxes: torch.Tensor, + logits: torch.Tensor, + masks: torch.Tensor, + threshold: "torch.Tensor | float", + num_classes: int, + class_mapping: Optional[torch.Tensor], + inference_size_wh: Tuple[int, int], + pad_ltrb: Tuple[int, int, int, int], + scale_wh: Tuple[float, float], + orig_size_wh: Tuple[int, int], +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + "torch.cuda.Event", +]: + """Returns (combined, mask_bin, mask_any, counter, done_event). Buffers + are unsliced — the caller DtoH's ``counter`` to learn n_survivors and + slices to ``[:n_survivors]``. ``combined[:, 4]`` holds fp32 conf as + int32 bits; use ``numpy.view(np.float32)`` on the host.""" + assert TRITON_AVAILABLE, "triton not available" + assert bboxes.is_cuda and logits.is_cuda and masks.is_cuda + assert bboxes.shape[0] == 1 and logits.shape[0] == 1 and masks.shape[0] == 1, "batch=1 only" + + device = bboxes.device + num_queries, num_classes_total = logits.shape[1], logits.shape[2] + _, _, mask_h, mask_w = masks.shape + + logits_2d = logits[0] if logits[0].is_contiguous() else logits[0].contiguous() + bboxes_2d = bboxes[0] if bboxes[0].is_contiguous() else bboxes[0].contiguous() + masks_3d = masks[0] if masks[0].is_contiguous() else masks[0].contiguous() + + combined, survivor_idx, mask_any, counter = _get_scratch_buffers( + num_queries, device + ) + counter.zero_() + + thr_tensor, per_class = _prepare_threshold(threshold, device, num_classes) + + if class_mapping is not None: + has_remap = True + cmap = _get_class_mapping_int32(class_mapping, device) + else: + has_remap = False + cmap = _EMPTY_INT32.to(device, non_blocking=True) + + inf_w, inf_h = inference_size_wh + pad_l, pad_t, _, _ = pad_ltrb + sw, sh = scale_wh + orig_w, orig_h = orig_size_wh + + BLOCK_C = max(32, _next_pow2(num_classes_total)) + _rfdetr_fullpost_filter_kernel[(num_queries,)]( + logits_2d, + bboxes_2d, + thr_tensor, + cmap, + combined, + survivor_idx, + mask_any, + counter, + num_queries, + num_classes_total, + int(inf_w), + int(inf_h), + int(pad_l), + int(pad_t), + float(1.0 / sw), + float(1.0 / sh), + int(orig_w), + int(orig_h), + logits_2d.stride(0), + bboxes_2d.stride(0), + PER_CLASS=1 if per_class else 0, + HAS_REMAPPING=1 if has_remap else 0, + BLOCK_C=BLOCK_C, + ) + + mask_bin_full = _get_mask_bin_buffer(num_queries, orig_h, orig_w, device) + + BLOCK_H = 16 + BLOCK_W = 16 + grid = ( + num_queries, + (orig_h + BLOCK_H - 1) // BLOCK_H, + (orig_w + BLOCK_W - 1) // BLOCK_W, + ) + _rfdetr_fullpost_mask_kernel_compact[grid]( + masks_3d, + survivor_idx, + counter, + mask_bin_full, + mask_any, + int(mask_h), + int(mask_w), + int(orig_h), + int(orig_w), + float(mask_h / orig_h), + float(mask_w / orig_w), + masks_3d.stride(0), + masks_3d.stride(1), + mask_bin_full.stride(0), + mask_bin_full.stride(1), + BLOCK_H=BLOCK_H, + BLOCK_W=BLOCK_W, + ) + + done_event = torch.cuda.Event() + done_event.record(torch.cuda.current_stream(device)) + + return combined, mask_bin_full, mask_any, counter, done_event From 986dbdd9f6060950c27e796f24057d209f118595 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 7 May 2026 00:43:28 +0000 Subject: [PATCH 02/25] moving to rle path --- .../inference_models/models/rfdetr/common.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index 3c9d1b4ff6..d5b1dbb9e3 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -11,6 +11,7 @@ PreProcessingMetadata, StaticCropOffset, ) +from inference_models.models.common.rle_utils import torch_mask_to_coco_rle from inference_models.models.common.roboflow.post_processing import ( align_instance_segmentation_results, align_instance_segmentation_results_to_rle_masks, @@ -290,6 +291,60 @@ def post_process_instance_segmentation_results_to_rle_masks( num_classes: int, classes_re_mapping: Optional[ClassesReMapping], ) -> List[InstanceDetections]: + if _fullpost_eligible(bboxes, pre_processing_meta, classes_re_mapping): + meta = pre_processing_meta[0] + thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) + combined, mask_bin, _mask_any, counter, done_event = triton_rfdetr_fullpost( + bboxes=bboxes, + logits=logits, + masks=masks, + threshold=thr_arg, + num_classes=num_classes, + class_mapping=classes_re_mapping.class_mapping, + inference_size_wh=(meta.inference_size.width, meta.inference_size.height), + pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), + scale_wh=(meta.scale_width, meta.scale_height), + orig_size_wh=(meta.original_size.width, meta.original_size.height), + ) + done_event.wait(torch.cuda.current_stream(bboxes.device)) + n_survivors = int(counter.item()) + orig_h = meta.original_size.height + orig_w = meta.original_size.width + if n_survivors == 0: + empty_xyxy = torch.empty( + (0, 4), dtype=torch.int32, device=bboxes.device + ) + empty_conf = torch.empty((0,), dtype=torch.float32, device=bboxes.device) + empty_cls = torch.empty((0,), dtype=torch.int32, device=bboxes.device) + return [ + InstanceDetections( + xyxy=empty_xyxy, + confidence=empty_conf, + class_id=empty_cls, + mask=InstancesRLEMasks.from_coco_rle_masks( + image_size=(orig_h, orig_w), masks=[] + ), + ) + ] + combined_slice = combined[:n_survivors] + mask_slice = mask_bin[:n_survivors].to(dtype=torch.bool) + rle_masks = [ + torch_mask_to_coco_rle(mask=mask_slice[i]) for i in range(n_survivors) + ] + instances_masks = InstancesRLEMasks.from_coco_rle_masks( + image_size=(orig_h, orig_w), masks=rle_masks + ) + xyxy = combined_slice[:, :4] + confidence = combined_slice[:, 4].view(torch.float32) + class_id = combined_slice[:, 5] + return [ + InstanceDetections( + xyxy=xyxy, + confidence=confidence, + class_id=class_id, + mask=instances_masks, + ) + ] logits_sigmoid = torch.nn.functional.sigmoid(logits) final_results = [] device = bboxes.device From 922990f25463b77ba80b32a5f78180c06cfb0839 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Fri, 15 May 2026 17:44:01 +0000 Subject: [PATCH 03/25] fix(rfdetr-seg): slice fullpost outputs to n_survivors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The non-RLE fullpost path was returning the full num_queries-row scratch buffer to InstanceDetections, exposing uninitialized rows past the survivor counter and leaving the conf column as int32 bits. Wait on done_event, slice combined and mask_bin to [:n_survivors], and reinterpret the conf column with .view(torch.float32) — mirroring the RLE variant. Adds temp/detection_parity_full.py to study the fused path against the torch reference across coco/val2017. --- .../inference_models/models/rfdetr/common.py | 11 +- temp/detection_parity_full.py | 254 ++++++++++++++++++ 2 files changed, 261 insertions(+), 4 deletions(-) create mode 100644 temp/detection_parity_full.py diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index d5b1dbb9e3..d530d7d545 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -186,11 +186,14 @@ def post_process_instance_segmentation_results( scale_wh=(meta.scale_width, meta.scale_height), orig_size_wh=(meta.original_size.width, meta.original_size.height), ) + done_event.wait(torch.cuda.current_stream(bboxes.device)) + n_survivors = int(counter.item()) + combined_slice = combined[:n_survivors] detections = InstanceDetections( - xyxy=combined[:, :4], - confidence=combined[:, 4], - class_id=combined[:, 5], - mask=mask_bin, + xyxy=combined_slice[:, :4], + confidence=combined_slice[:, 4].view(torch.float32), + class_id=combined_slice[:, 5], + mask=mask_bin[:n_survivors].to(dtype=torch.bool), ) detections.__dict__["_combined_gpu"] = combined detections.__dict__["_counter_gpu"] = counter diff --git a/temp/detection_parity_full.py b/temp/detection_parity_full.py new file mode 100644 index 0000000000..f75c14fff6 --- /dev/null +++ b/temp/detection_parity_full.py @@ -0,0 +1,254 @@ +"""Full coco/val2017 detection + mask parity: fused-postproc Triton path vs reference. + +Driven by RFDETR_TRITON_FULLPOSTPROC (true/false). The env var is read once at +import time, so we run two subprocesses (fullpost on, fullpost off), stream +per-image detections to a pickle file (one pickle.dump per record), then +compare in lockstep. + +Usage: + python temp/detection_parity_full.py # driver: runs both passes + python temp/detection_parity_full.py --mode run --out /tmp/full_on.pkl + python temp/detection_parity_full.py --mode compare --on /tmp/full_on.pkl --off /tmp/full_off.pkl +""" +import argparse +import os +import pickle +import subprocess +import sys +import time +from pathlib import Path + +COCO = Path("/home/ubuntu/inference/coco/val2017") +CONFIDENCE = 0.4 +PY = sys.executable +SELF = Path(__file__).resolve() +OUT_ON = "/tmp/det_parity_full_on.pkl" +OUT_OFF = "/tmp/det_parity_full_off.pkl" +# Cap images per pass: fullpost caches a per-(orig_h, orig_w) GPU buffer in +# _get_mask_bin_buffer, so the working set grows with distinct image sizes and +# OOMs near ~5k images on a 14 GiB card. 1500 covers a wide variety of shapes +# while staying well under that ceiling. +MAX_IMAGES = int(os.environ.get("PARITY_MAX_IMAGES", "1500")) + + +def _iter_records(path): + with open(path, "rb") as f: + while True: + try: + yield pickle.load(f) + except EOFError: + return + + +def do_run(out_path): + os.environ.setdefault( + "DISABLED_INFERENCE_MODELS_BACKENDS", + "torch,torch-script,onnx,hugging-face,ultralytics,mediapipe,custom", + ) + import cv2 + import numpy as np + import torch + + from inference_models import AutoModel + import inference_models.models.rfdetr.common as common_mod + + fullpost_calls = {"count": 0} + original_fp = getattr(common_mod, "triton_rfdetr_fullpost", None) + if original_fp is not None: + def counting_fp(*a, **kw): + fullpost_calls["count"] += 1 + return original_fp(*a, **kw) + common_mod.triton_rfdetr_fullpost = counting_fp + + fp_flag = os.environ.get("RFDETR_TRITON_FULLPOSTPROC", "") + print( + f"[run] RFDETR_TRITON_FULLPOSTPROC={fp_flag} " + f"(module ready={getattr(common_mod, '_TRITON_FULLPOST_READY', False)})" + ) + + paths = sorted(COCO.glob("*.jpg"))[:MAX_IMAGES] + model = AutoModel.from_pretrained("rfdetr-seg-nano") + + n_records = 0 + t0 = time.perf_counter() + with open(out_path, "wb") as f: + # placeholder header — rewritten at end + pickle.dump({"_kind": "header", "fullpost_calls": -1, "n_records": -1}, f) + for idx, p in enumerate(paths): + im = cv2.imread(str(p), cv2.IMREAD_COLOR) + if im is None: + continue + pre, meta = model.pre_process(im) + out = model.forward(pre) + det = model.post_process(out, meta, confidence=CONFIDENCE)[0] + n = int(det.class_id.numel()) + if n and det.mask is not None: + m_np = det.mask.cpu().to(torch.bool).numpy() + packed = np.packbits(m_np.reshape(n, -1), axis=1) + mask_shape = m_np.shape[1:] + else: + packed = None + mask_shape = None + rec = { + "_kind": "rec", + "path": str(p), + "xyxy": det.xyxy.cpu().numpy() if n else None, + "conf": det.confidence.cpu().numpy() if n else None, + "cls": det.class_id.cpu().numpy() if n else None, + "mask_packed": packed, + "mask_shape": mask_shape, + } + pickle.dump(rec, f) + n_records += 1 + if (idx + 1) % 500 == 0: + print(f" [fp={fp_flag}] {idx+1}/{len(paths)} ({time.perf_counter()-t0:.0f}s)", flush=True) + + # append a footer with the totals (header is ignored on read; iterator just walks records + footer) + with open(out_path, "ab") as f: + pickle.dump({"_kind": "footer", "fullpost_calls": fullpost_calls["count"], "n_records": n_records}, f) + print(f"[run] fullpost_kernel_calls={fullpost_calls['count']} records={n_records} saved -> {out_path}") + + +def iou_box(a, b): + x0 = max(a[0], b[0]); y0 = max(a[1], b[1]) + x1 = min(a[2], b[2]); y1 = min(a[3], b[3]) + iw = max(0, x1 - x0); ih = max(0, y1 - y0) + inter = iw * ih + area_a = max(0, a[2] - a[0]) * max(0, a[3] - a[1]) + area_b = max(0, b[2] - b[0]) * max(0, b[3] - b[1]) + u = area_a + area_b - inter + return inter / u if u > 0 else 0.0 + + +def _unpack_masks(rec): + import numpy as np + if rec["mask_packed"] is None: + return None + n = len(rec["mask_packed"]) + h, w = rec["mask_shape"] + flat = np.unpackbits(rec["mask_packed"], axis=1, count=h * w) + return flat.reshape(n, h, w).astype(bool) + + +def do_compare(on_path, off_path): + import numpy as np + + tot_on = tot_off = matched = class_disagree = count_mm = pixel_identical = 0 + ious, dscores, mask_iou = [], [], [] + on_fp_calls = off_fp_calls = -1 + n_imgs = 0 + + on_iter = _iter_records(on_path) + off_iter = _iter_records(off_path) + + for r_on, r_off in zip(on_iter, off_iter): + if r_on.get("_kind") == "header": + r_on = next(on_iter) + if r_off.get("_kind") == "header": + r_off = next(off_iter) + if r_on.get("_kind") == "footer" or r_off.get("_kind") == "footer": + on_fp_calls = r_on.get("fullpost_calls", on_fp_calls) + off_fp_calls = r_off.get("fullpost_calls", off_fp_calls) + break + + assert r_on["path"] == r_off["path"], (r_on["path"], r_off["path"]) + n_imgs += 1 + nf = 0 if r_on["xyxy"] is None else len(r_on["xyxy"]) + nr = 0 if r_off["xyxy"] is None else len(r_off["xyxy"]) + tot_on += nf; tot_off += nr + if nf != nr: + count_mm += 1 + if nf == 0 and nr == 0: + continue + bf = r_on["xyxy"] if nf else np.zeros((0, 4)) + br = r_off["xyxy"] if nr else np.zeros((0, 4)) + sf = r_on["conf"] if nf else np.zeros(0) + sr = r_off["conf"] if nr else np.zeros(0) + cf = r_on["cls"] if nf else np.zeros(0, dtype=int) + cr = r_off["cls"] if nr else np.zeros(0, dtype=int) + mf = _unpack_masks(r_on) if nf else None + mr_m = _unpack_masks(r_off) if nr else None + + used = set() + for j in range(nr): + best_i, best_iou = -1, 0.5 + for i in range(nf): + if i in used: + continue + iou = iou_box(bf[i], br[j]) + if iou > best_iou: + best_iou, best_i = iou, i + if best_i >= 0: + used.add(best_i) + matched += 1 + ious.append(best_iou) + dscores.append(abs(float(sf[best_i]) - float(sr[j]))) + if int(cf[best_i]) != int(cr[j]): + class_disagree += 1 + if mf is not None and mr_m is not None: + a = mf[best_i]; b = mr_m[j] + inter = np.logical_and(a, b).sum() + u = np.logical_or(a, b).sum() + mask_iou.append(float(inter) / float(u) if u else 0.0) + if np.array_equal(a, b): + pixel_identical += 1 + + # drain footers if not already pulled + for it, current_calls_attr in ((on_iter, "on_fp_calls"), (off_iter, "off_fp_calls")): + for r in it: + if r.get("_kind") == "footer": + if current_calls_attr == "on_fp_calls": + on_fp_calls = r["fullpost_calls"] + else: + off_fp_calls = r["fullpost_calls"] + + print() + print(f"==== full coco/val2017 parity: fullpost=true vs fullpost=false ({n_imgs} images) ====") + print(f" fullpost calls (fp=true) : {on_fp_calls}") + print(f" fullpost calls (fp=false) : {off_fp_calls}") + print(f" dets fp=true / fp=false : {tot_on} / {tot_off}") + print(f" matched (IoU>0.5) : {matched} ({100*matched/max(1,tot_off):.2f}% of fp=false)") + print(f" count-mismatch images : {count_mm}") + print(f" class-id disagreements : {class_disagree}") + if ious: + print(f" mean box IoU : {np.mean(ious):.6f}") + if dscores: + print(f" mean / max |Δscore| : {np.mean(dscores):.3e} / {np.max(dscores):.3e}") + if mask_iou: + a = np.array(mask_iou) + print(f" mean / min mask IoU : {a.mean():.6f} / {a.min():.6f}") + print(f" pixel-identical masks : {pixel_identical}/{len(mask_iou)}") + print() + expected = n_imgs + ok_on = "[PASS]" if on_fp_calls == expected else "[FAIL]" + ok_off = "[PASS]" if off_fp_calls == 0 else "[FAIL]" + print(f" {ok_on} fp=true -> fullpost fired {on_fp_calls}/{expected}") + print(f" {ok_off} fp=false -> fullpost fired {off_fp_calls} (expected 0)") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--mode", choices=("driver", "run", "compare"), default="driver") + ap.add_argument("--out") + ap.add_argument("--on") + ap.add_argument("--off") + args = ap.parse_args() + + if args.mode == "run": + do_run(args.out) + return + if args.mode == "compare": + do_compare(args.on, args.off) + return + + for fp_value, out in (("true", OUT_ON), ("false", OUT_OFF)): + env = os.environ.copy() + env["RFDETR_TRITON_FULLPOSTPROC"] = fp_value + print(f"\n---- child: fullpost={fp_value} out={out} ----", flush=True) + subprocess.run([PY, str(SELF), "--mode", "run", "--out", out], check=True, env=env) + + do_compare(OUT_ON, OUT_OFF) + + +if __name__ == "__main__": + main() From a38ca7445553194c63788800f467b551b0b1044e Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Fri, 15 May 2026 19:14:53 +0000 Subject: [PATCH 04/25] fix(rfdetr-seg): emit one fullpost row per (query, class) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The filter kernel was per-query argmax, but the torch reference does top-k-flat over the (Q*C) sigmoid grid (num_select == num_queries) — a single query can contribute multiple detections, one per class that survives remap + threshold. Per-query argmax silently dropped the secondary classes. Reshape the kernel grid to (num_queries, num_classes_total) so each (q, c) pair is processed independently: load one logit, remap, sigmoid, threshold, emit. Cap output at num_queries to mirror the reference's top-K cap; host clamps n_survivors to combined.shape[0] since the atomic counter increments before the slot guard. Validated on coco/val2017 (1500 images): det counts match exactly 8037/8037 vs 7995/8037 before, and on the same-logits parity script all 1663 matched detections agree to fp32 epsilon when the matcher is class-aware. --- .../inference_models/models/rfdetr/common.py | 8 ++- .../models/rfdetr/triton_fullpostproc.py | 55 +++++++++---------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index d530d7d545..999f836bd2 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -187,7 +187,9 @@ def post_process_instance_segmentation_results( orig_size_wh=(meta.original_size.width, meta.original_size.height), ) done_event.wait(torch.cuda.current_stream(bboxes.device)) - n_survivors = int(counter.item()) + # Counter is incremented unconditionally before the slot-cap guard, so + # cap by combined's row count (num_queries). + n_survivors = min(int(counter.item()), combined.shape[0]) combined_slice = combined[:n_survivors] detections = InstanceDetections( xyxy=combined_slice[:, :4], @@ -310,7 +312,9 @@ def post_process_instance_segmentation_results_to_rle_masks( orig_size_wh=(meta.original_size.width, meta.original_size.height), ) done_event.wait(torch.cuda.current_stream(bboxes.device)) - n_survivors = int(counter.item()) + # Counter is incremented unconditionally before the slot-cap guard, so + # cap by combined's row count (num_queries). + n_survivors = min(int(counter.item()), combined.shape[0]) orig_h = meta.original_size.height orig_w = meta.original_size.width if n_survivors == 0: diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index b80d076e9e..a75b05f61a 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -53,37 +53,31 @@ def _rfdetr_fullpost_filter_kernel( bboxes_stride_q, PER_CLASS: tl.constexpr, HAS_REMAPPING: tl.constexpr, - BLOCK_C: tl.constexpr, ): - pid = tl.program_id(0) - if pid >= num_queries: + # One program per (query, class). The reference path does top-k-flat + # over the (Q*C) sigmoid grid (`num_select == num_queries`), so a + # single query can contribute multiple detections — once per class + # that survives remap + threshold. Per-query argmax would silently + # drop the others. + pid_q = tl.program_id(0) + pid_c = tl.program_id(1) + if pid_q >= num_queries or pid_c >= num_classes_total: return - offs_c = tl.arange(0, BLOCK_C) - mask_c = offs_c < num_classes_total - logits_row = tl.load( - logits_ptr + pid * logits_stride_q + offs_c, - mask=mask_c, - other=-float("inf"), - ) - max_val = tl.max(logits_row, axis=0) - BIG = 1 << 30 - is_max = logits_row == max_val - idx_or_big = tl.where(is_max & mask_c, offs_c, BIG) - raw_c = tl.min(idx_or_big, axis=0) + logit = tl.load(logits_ptr + pid_q * logits_stride_q + pid_c) if HAS_REMAPPING: - top_c = tl.load(class_map_ptr + raw_c) + top_c = tl.load(class_map_ptr + pid_c) valid = top_c >= 0 else: - top_c = raw_c - valid = raw_c < num_classes_total + top_c = pid_c + valid = pid_c < num_classes_total - abs_max = tl.abs(max_val) - z = tl.exp(-abs_max) + abs_l = tl.abs(logit) + z = tl.exp(-abs_l) sig_pos = 1.0 / (1.0 + z) sig_neg = z / (1.0 + z) - conf = tl.where(max_val >= 0.0, sig_pos, sig_neg) + conf = tl.where(logit >= 0.0, sig_pos, sig_neg) if PER_CLASS: safe_c = tl.where(valid, top_c, 0) @@ -96,10 +90,10 @@ def _rfdetr_fullpost_filter_kernel( return # Match the non-Triton path's FP32 evaluation order for bit-parity. - cx_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 0) - cy_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 1) - w_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 2) - h_pct = tl.load(bboxes_ptr + pid * bboxes_stride_q + 3) + cx_pct = tl.load(bboxes_ptr + pid_q * bboxes_stride_q + 0) + cy_pct = tl.load(bboxes_ptr + pid_q * bboxes_stride_q + 1) + w_pct = tl.load(bboxes_ptr + pid_q * bboxes_stride_q + 2) + h_pct = tl.load(bboxes_ptr + pid_q * bboxes_stride_q + 3) x1_pct = cx_pct - 0.5 * w_pct y1_pct = cy_pct - 0.5 * h_pct @@ -142,6 +136,11 @@ def _rfdetr_fullpost_filter_kernel( slot = tl.atomic_add(counter_ptr, 1) + # Cap output at num_queries — mirrors the reference's flat top-K + # cap. Host slices to min(counter, num_queries) to ignore overflow. + if slot >= num_queries: + return + # Bitcast conf (fp32) as int32 so the whole record writes with int32 # stores. Host views the same memory as int32 and extracts via # numpy.view(np.float32). @@ -154,7 +153,7 @@ def _rfdetr_fullpost_filter_kernel( tl.store(combined_out_ptr + base + 3, y2_i) tl.store(combined_out_ptr + base + 4, conf_i32) tl.store(combined_out_ptr + base + 5, top_c) - tl.store(survivor_idx_out_ptr + slot, pid.to(tl.int32)) + tl.store(survivor_idx_out_ptr + slot, pid_q.to(tl.int32)) tl.store(mask_any_out_ptr + slot, 0) @@ -355,8 +354,7 @@ def triton_rfdetr_fullpost( sw, sh = scale_wh orig_w, orig_h = orig_size_wh - BLOCK_C = max(32, _next_pow2(num_classes_total)) - _rfdetr_fullpost_filter_kernel[(num_queries,)]( + _rfdetr_fullpost_filter_kernel[(num_queries, num_classes_total)]( logits_2d, bboxes_2d, thr_tensor, @@ -379,7 +377,6 @@ def triton_rfdetr_fullpost( bboxes_2d.stride(0), PER_CLASS=1 if per_class else 0, HAS_REMAPPING=1 if has_remap else 0, - BLOCK_C=BLOCK_C, ) mask_bin_full = _get_mask_bin_buffer(num_queries, orig_h, orig_w, device) From 8ee09364a5976a3df60b40fbc6cfe92ad77b0e0a Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Fri, 15 May 2026 21:09:00 +0000 Subject: [PATCH 05/25] fix(rfdetr-seg): use F.interpolate for fullpost mask upsample MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The custom Triton mask kernel implemented bilinear+threshold with `antialias=False` semantics, but the reference path (`align_instance_segmentation_results`) calls `functional.resize(BILINEAR)` which defaults to `antialias=True` — producing a different fp32 reduction order on boundary pixels and flipping ~0.3 % of mask pixels (26/8037 across 1500 coco/val2017 images at conf=0.4). Drop the mask kernel and replace it with a host-side gather + `F.interpolate(bilinear, antialias=True, align_corners=False)` + `> 0` that matches the reference bit-for-bit. Same eligibility window guarantees no static crop and `size_after_pre_processing == inference_size`, so the new path can skip the canvas/static-crop branches. Validated on coco/val2017 (1500 images): 8037/8037 pixel-identical masks (was 8011/8037). FPS unchanged within noise (filter kernel dominates the postproc cost; the mask path was always cuDNN-bound). --- .../inference_models/models/rfdetr/common.py | 66 ++++++- .../models/rfdetr/triton_fullpostproc.py | 186 +++--------------- 2 files changed, 88 insertions(+), 164 deletions(-) diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index 999f836bd2..f39f9791f8 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -40,6 +40,49 @@ triton_rfdetr_fullpost = None +def _fullpost_upsample_masks( + masks: torch.Tensor, + survivor_q: torch.Tensor, + inference_size_wh: Tuple[int, int], + pad_ltrb: Tuple[int, int, int, int], + orig_size_hw: Tuple[int, int], +) -> torch.Tensor: + """Replicate ``align_instance_segmentation_results``'s mask path + bit-for-bit (when ``size_after_pre_processing == inference_size`` and no + static crop applies — both guaranteed by ``_fullpost_eligible``). + + Gathers surviving query rows from ``masks`` (shape ``(1, Q, mh, mw)``), + crops the letterbox padding in mask coordinates, bilinear-resizes to + ``(orig_h, orig_w)`` with ``antialias=True`` (matching torchvision's + ``functional.resize`` default), and thresholds at 0. + """ + selected = masks[0].index_select(0, survivor_q.long()) + if selected.shape[0] == 0: + orig_h, orig_w = orig_size_hw + return torch.empty( + (0, orig_h, orig_w), dtype=torch.bool, device=masks.device + ) + _, mh, mw = selected.shape + inf_w, inf_h = inference_size_wh + pad_l, pad_t, pad_r, pad_b = pad_ltrb + mh_scale = mh / inf_h + mw_scale = mw / inf_w + mpt = round(mh_scale * pad_t) + mpb = round(mh_scale * pad_b) + mpl = round(mw_scale * pad_l) + mpr = round(mw_scale * pad_r) + selected = selected[:, mpt: mh - mpb, mpl: mw - mpr] + orig_h, orig_w = orig_size_hw + upsampled = torch.nn.functional.interpolate( + selected.unsqueeze(1), + size=(orig_h, orig_w), + mode="bilinear", + antialias=True, + align_corners=False, + ).squeeze(1) + return upsampled > 0.0 + + def _fullpost_eligible( bboxes: torch.Tensor, pre_processing_meta: List[PreProcessingMetadata], @@ -174,10 +217,9 @@ def post_process_instance_segmentation_results( if _fullpost_eligible(bboxes, pre_processing_meta, classes_re_mapping): meta = pre_processing_meta[0] thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) - combined, mask_bin, mask_any, counter, done_event = triton_rfdetr_fullpost( + combined, survivor_idx, counter, done_event = triton_rfdetr_fullpost( bboxes=bboxes, logits=logits, - masks=masks, threshold=thr_arg, num_classes=num_classes, class_mapping=classes_re_mapping.class_mapping, @@ -191,11 +233,18 @@ def post_process_instance_segmentation_results( # cap by combined's row count (num_queries). n_survivors = min(int(counter.item()), combined.shape[0]) combined_slice = combined[:n_survivors] + mask_bin = _fullpost_upsample_masks( + masks=masks, + survivor_q=survivor_idx[:n_survivors], + inference_size_wh=(meta.inference_size.width, meta.inference_size.height), + pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), + orig_size_hw=(meta.original_size.height, meta.original_size.width), + ) detections = InstanceDetections( xyxy=combined_slice[:, :4], confidence=combined_slice[:, 4].view(torch.float32), class_id=combined_slice[:, 5], - mask=mask_bin[:n_survivors].to(dtype=torch.bool), + mask=mask_bin, ) detections.__dict__["_combined_gpu"] = combined detections.__dict__["_counter_gpu"] = counter @@ -299,10 +348,9 @@ def post_process_instance_segmentation_results_to_rle_masks( if _fullpost_eligible(bboxes, pre_processing_meta, classes_re_mapping): meta = pre_processing_meta[0] thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) - combined, mask_bin, _mask_any, counter, done_event = triton_rfdetr_fullpost( + combined, survivor_idx, counter, done_event = triton_rfdetr_fullpost( bboxes=bboxes, logits=logits, - masks=masks, threshold=thr_arg, num_classes=num_classes, class_mapping=classes_re_mapping.class_mapping, @@ -334,7 +382,13 @@ def post_process_instance_segmentation_results_to_rle_masks( ) ] combined_slice = combined[:n_survivors] - mask_slice = mask_bin[:n_survivors].to(dtype=torch.bool) + mask_slice = _fullpost_upsample_masks( + masks=masks, + survivor_q=survivor_idx[:n_survivors], + inference_size_wh=(meta.inference_size.width, meta.inference_size.height), + pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), + orig_size_hw=(orig_h, orig_w), + ) rle_masks = [ torch_mask_to_coco_rle(mask=mask_slice[i]) for i in range(n_survivors) ] diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index a75b05f61a..1ef2ce9e71 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -1,16 +1,19 @@ """Fused RF-DETR instance-segmentation post-processing in Triton. -Two kernels replace the post-TRT chain for the common rfdetr-seg-nano path -(batch=1, no static crop, STRETCH_TO resize, class remapping active): - - _rfdetr_fullpost_filter_kernel (grid = num_queries) - sigmoid argmax + class remap + conf threshold + cxcywh->xyxy + - letterbox-denormalize + clip + banker's rounding; atomic_add into a - counter to reserve a compact output slot. - - _rfdetr_fullpost_mask_kernel_compact (grid = num_queries * tile_y * tile_x) - Bilinear upsample masks (e.g. 78x78 -> orig_h x orig_w) + threshold > 0 + - uint8 emit. Early-exits on s >= counter[0] without an intermediate sync. +For the common rfdetr-seg-nano path (batch=1, no static crop, STRETCH_TO +resize, class remapping active): + + _rfdetr_fullpost_filter_kernel (grid = num_queries * num_classes_total) + One program per (q, c) pair: sigmoid + class remap + conf threshold + + cxcywh->xyxy + letterbox-denormalize + clip + banker's rounding; + atomic_add into a counter to reserve a compact output slot. + + Mask upsample uses ``F.interpolate(bilinear, antialias=True)`` followed by + a ``> 0`` threshold — bit-for-bit identical to the reference path's + ``align_instance_segmentation_results`` mask handling. We reuse the + reference here (rather than a custom kernel) because cuDNN's antialiased + bilinear is hard to match in fp32 and the kernel-level win is in the + filter step, not the upsample. """ from typing import Optional, Tuple @@ -37,7 +40,6 @@ def _rfdetr_fullpost_filter_kernel( class_map_ptr, combined_out_ptr, survivor_idx_out_ptr, - mask_any_out_ptr, counter_ptr, num_queries, num_classes_total, @@ -154,95 +156,10 @@ def _rfdetr_fullpost_filter_kernel( tl.store(combined_out_ptr + base + 4, conf_i32) tl.store(combined_out_ptr + base + 5, top_c) tl.store(survivor_idx_out_ptr + slot, pid_q.to(tl.int32)) - tl.store(mask_any_out_ptr + slot, 0) - - - @triton.jit - def _rfdetr_fullpost_mask_kernel_compact( - masks_ptr, - survivor_idx_ptr, - counter_ptr, - out_ptr, - mask_any_ptr, - mask_h, - mask_w, - orig_h, - orig_w, - mask_scale_y, - mask_scale_x, - masks_stride_q, - masks_stride_h, - out_stride_s, - out_stride_h, - BLOCK_H: tl.constexpr, - BLOCK_W: tl.constexpr, - ): - s = tl.program_id(0) - tile_y = tl.program_id(1) - tile_x = tl.program_id(2) - - # GPU-side early exit — skip programs past the live survivor count. - n_survivors = tl.load(counter_ptr) - if s >= n_survivors: - return - - q = tl.load(survivor_idx_ptr + s) - - offs_y = tile_y * BLOCK_H + tl.arange(0, BLOCK_H) - offs_x = tile_x * BLOCK_W + tl.arange(0, BLOCK_W) - mask_yy = offs_y < orig_h - mask_xx = offs_x < orig_w - m_outbox = mask_yy[:, None] & mask_xx[None, :] - - src_y_f = (offs_y.to(tl.float32) + 0.5) * mask_scale_y - 0.5 - src_x_f = (offs_x.to(tl.float32) + 0.5) * mask_scale_x - 0.5 - src_y_2d = src_y_f[:, None] - src_x_2d = src_x_f[None, :] - - y0 = tl.floor(src_y_2d).to(tl.int32) - x0 = tl.floor(src_x_2d).to(tl.int32) - y1 = y0 + 1 - x1 = x0 + 1 - dy = src_y_2d - y0.to(tl.float32) - dx = src_x_2d - x0.to(tl.float32) - - y0c = tl.maximum(tl.minimum(y0, mask_h - 1), 0) - y1c = tl.maximum(tl.minimum(y1, mask_h - 1), 0) - x0c = tl.maximum(tl.minimum(x0, mask_w - 1), 0) - x1c = tl.maximum(tl.minimum(x1, mask_w - 1), 0) - - base = q * masks_stride_q - - p00 = tl.load(masks_ptr + base + y0c * masks_stride_h + x0c, mask=m_outbox, other=0.0) - p01 = tl.load(masks_ptr + base + y0c * masks_stride_h + x1c, mask=m_outbox, other=0.0) - p10 = tl.load(masks_ptr + base + y1c * masks_stride_h + x0c, mask=m_outbox, other=0.0) - p11 = tl.load(masks_ptr + base + y1c * masks_stride_h + x1c, mask=m_outbox, other=0.0) - - w_tl = (1.0 - dy) * (1.0 - dx) - w_tr = (1.0 - dy) * dx - w_bl = dy * (1.0 - dx) - w_br = dy * dx - val = p00 * w_tl + p01 * w_tr + p10 * w_bl + p11 * w_br - bin_val = (val > 0.0).to(tl.int8) - - out_offsets = offs_y[:, None] * out_stride_h + offs_x[None, :] - tl.store(out_ptr + s * out_stride_s + out_offsets, bin_val, mask=m_outbox) - - tile_any = tl.max(bin_val.to(tl.int32), axis=0) - tile_any2 = tl.max(tile_any, axis=0) - tl.atomic_max(mask_any_ptr + s, tile_any2) - - -def _next_pow2(n: int) -> int: - p = 1 - while p < n: - p <<= 1 - return p _THRESHOLD_CACHE: dict = {} _EMPTY_INT32 = torch.empty((1,), dtype=torch.int32) -_MASK_BIN_BUFFER_CACHE: dict = {} _SCRATCH_CACHE: dict = {} _CLASS_MAPPING_INT32_CACHE: dict = {} @@ -253,9 +170,8 @@ def _get_scratch_buffers(num_queries: int, device: torch.device): if cached is None: combined = torch.empty((num_queries, 6), dtype=torch.int32, device=device) survivor_idx = torch.empty((num_queries,), dtype=torch.int32, device=device) - mask_any = torch.empty((num_queries,), dtype=torch.int32, device=device) counter = torch.zeros((1,), dtype=torch.int32, device=device) - cached = (combined, survivor_idx, mask_any, counter) + cached = (combined, survivor_idx, counter) _SCRATCH_CACHE[key] = cached return cached @@ -272,21 +188,6 @@ def _get_class_mapping_int32(class_mapping: torch.Tensor, device: torch.device) return cached -def _get_mask_bin_buffer( - capacity: int, orig_h: int, orig_w: int, device: torch.device -) -> torch.Tensor: - # Rows beyond n_survivors may hold stale data from prior frames; callers - # must size their read by the atomic counter. - key = (capacity, orig_h, orig_w, device) - buf = _MASK_BIN_BUFFER_CACHE.get(key) - if buf is None: - buf = torch.empty( - (capacity, orig_h, orig_w), dtype=torch.uint8, device=device - ) - _MASK_BIN_BUFFER_CACHE[key] = buf - return buf - - def _prepare_threshold(threshold, device: torch.device, num_classes: int): if isinstance(threshold, torch.Tensor): t = threshold @@ -304,7 +205,6 @@ def _prepare_threshold(threshold, device: torch.device, num_classes: int): def triton_rfdetr_fullpost( bboxes: torch.Tensor, logits: torch.Tensor, - masks: torch.Tensor, threshold: "torch.Tensor | float", num_classes: int, class_mapping: Optional[torch.Tensor], @@ -316,28 +216,28 @@ def triton_rfdetr_fullpost( torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, "torch.cuda.Event", ]: - """Returns (combined, mask_bin, mask_any, counter, done_event). Buffers - are unsliced — the caller DtoH's ``counter`` to learn n_survivors and - slices to ``[:n_survivors]``. ``combined[:, 4]`` holds fp32 conf as - int32 bits; use ``numpy.view(np.float32)`` on the host.""" + """Filter step only — returns (combined, survivor_idx, counter, done_event). + + Buffers are unsliced — the caller waits on ``done_event``, reads + ``counter`` to learn n_survivors, and slices to ``[:n_survivors]``. + ``combined[:, 4]`` holds fp32 conf as int32 bits; reinterpret on the + host with ``.view(torch.float32)``. The mask upsample is intentionally + left to the caller (torch ``F.interpolate``) so the result is bit-exact + with the non-fullpost reference path. + """ assert TRITON_AVAILABLE, "triton not available" - assert bboxes.is_cuda and logits.is_cuda and masks.is_cuda - assert bboxes.shape[0] == 1 and logits.shape[0] == 1 and masks.shape[0] == 1, "batch=1 only" + assert bboxes.is_cuda and logits.is_cuda + assert bboxes.shape[0] == 1 and logits.shape[0] == 1, "batch=1 only" device = bboxes.device num_queries, num_classes_total = logits.shape[1], logits.shape[2] - _, _, mask_h, mask_w = masks.shape logits_2d = logits[0] if logits[0].is_contiguous() else logits[0].contiguous() bboxes_2d = bboxes[0] if bboxes[0].is_contiguous() else bboxes[0].contiguous() - masks_3d = masks[0] if masks[0].is_contiguous() else masks[0].contiguous() - combined, survivor_idx, mask_any, counter = _get_scratch_buffers( - num_queries, device - ) + combined, survivor_idx, counter = _get_scratch_buffers(num_queries, device) counter.zero_() thr_tensor, per_class = _prepare_threshold(threshold, device, num_classes) @@ -361,7 +261,6 @@ def triton_rfdetr_fullpost( cmap, combined, survivor_idx, - mask_any, counter, num_queries, num_classes_total, @@ -379,36 +278,7 @@ def triton_rfdetr_fullpost( HAS_REMAPPING=1 if has_remap else 0, ) - mask_bin_full = _get_mask_bin_buffer(num_queries, orig_h, orig_w, device) - - BLOCK_H = 16 - BLOCK_W = 16 - grid = ( - num_queries, - (orig_h + BLOCK_H - 1) // BLOCK_H, - (orig_w + BLOCK_W - 1) // BLOCK_W, - ) - _rfdetr_fullpost_mask_kernel_compact[grid]( - masks_3d, - survivor_idx, - counter, - mask_bin_full, - mask_any, - int(mask_h), - int(mask_w), - int(orig_h), - int(orig_w), - float(mask_h / orig_h), - float(mask_w / orig_w), - masks_3d.stride(0), - masks_3d.stride(1), - mask_bin_full.stride(0), - mask_bin_full.stride(1), - BLOCK_H=BLOCK_H, - BLOCK_W=BLOCK_W, - ) - done_event = torch.cuda.Event() done_event.record(torch.cuda.current_stream(device)) - return combined, mask_bin_full, mask_any, counter, done_event + return combined, survivor_idx, counter, done_event From 29c7627b7282dd79d290d886964a0b04dbaad060 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 18 May 2026 23:52:20 +0000 Subject: [PATCH 06/25] workflow script --- .../rfdetr_nano_seg_trt_workflow.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 development/stream_interface/rfdetr_nano_seg_trt_workflow.py diff --git a/development/stream_interface/rfdetr_nano_seg_trt_workflow.py b/development/stream_interface/rfdetr_nano_seg_trt_workflow.py new file mode 100644 index 0000000000..9c213a8639 --- /dev/null +++ b/development/stream_interface/rfdetr_nano_seg_trt_workflow.py @@ -0,0 +1,116 @@ +"""Minimal benchmark: RF-DETR instance segmentation through inference-models, +run via InferencePipeline on a single video source. + +Workflow has exactly one block — the segmentation model. No annotators, no +buffer strategies, no rate limiting. + +The `--backend` flag (trt | onnx | torch) is parsed before importing +`inference` and pins the auto-loader by setting +`DISABLED_INFERENCE_MODELS_BACKENDS` to every backend except the chosen one, +so the benchmark numbers correspond unambiguously to a single execution path. + +Defaults: rfdetr-seg-nano @ confidence 0.4 on the native TRT backend. +""" +import argparse +import os + +_ALL_BACKENDS = { + "torch", + "torch-script", + "onnx", + "trt", + "hugging-face", + "ultralytics", + "mediapipe", + "custom", +} +def _select_backend_from_argv() -> str: + pre = argparse.ArgumentParser(add_help=False) + pre.add_argument("--backend", choices=("trt", "onnx", "torch"), default="trt") + args, _ = pre.parse_known_args() + return args.backend + + +_BACKEND = _select_backend_from_argv() +os.environ.setdefault( + "ONNXRUNTIME_EXECUTION_PROVIDERS", + "[TensorrtExecutionProvider,CUDAExecutionProvider,CPUExecutionProvider]", +) +os.environ["DISABLED_INFERENCE_MODELS_BACKENDS"] = ",".join( + sorted(_ALL_BACKENDS - {_BACKEND}) +) + +from time import perf_counter + +from inference import InferencePipeline + + +def build_workflow(model_id: str, confidence: float) -> dict: + return { + "version": "1.0", + "inputs": [{"type": "WorkflowImage", "name": "image"}], + "steps": [ + { + "type": "roboflow_core/roboflow_instance_segmentation_model@v3", + "name": "segmentation", + "images": "$inputs.image", + "model_id": model_id, + "confidence_mode": "custom", + "custom_confidence": confidence, + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "predictions", + "selector": "$steps.segmentation.predictions", + }, + ], + } + +FRAME_COUNT = 0 +START_TIME = None +PROGRESS_EVERY = 50 + + +def sink(predictions, _video_frames) -> None: + global FRAME_COUNT, START_TIME + del _video_frames + if not isinstance(predictions, list): + predictions = [predictions] + FRAME_COUNT += sum(p is not None for p in predictions) + if START_TIME is None: + START_TIME = perf_counter() + if FRAME_COUNT % PROGRESS_EVERY == 0: + fps = FRAME_COUNT / (perf_counter() - START_TIME) + print(f"[progress] frames={FRAME_COUNT} fps={fps:.2f}", flush=True) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--video_reference", required=True) + parser.add_argument("--model_id", default="rfdetr-seg-nano") + parser.add_argument("--confidence", type=float, default=0.4) + parser.add_argument( + "--backend", + choices=("trt", "onnx", "torch"), + default="trt", + help="inference-models backend (consumed pre-import via env var).", + ) + args = parser.parse_args() + + pipeline = InferencePipeline.init_with_workflow( + video_reference=args.video_reference, + workflow_specification=build_workflow(args.model_id, args.confidence), + on_prediction=sink, + ) + pipeline.start() + pipeline.join() + + elapsed = perf_counter() - START_TIME if START_TIME else 0.0 + fps = FRAME_COUNT / elapsed if elapsed > 0 else 0.0 + print(f"frames={FRAME_COUNT} elapsed={elapsed:.2f}s fps={fps:.2f}") + + +if __name__ == "__main__": + main() From 1804cd3835445ea18f1711431f0eb98c6a020992 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 00:53:45 +0000 Subject: [PATCH 07/25] move env var to inference_models/configuration.py --- .../core/models/inference_models_adapters.py | 21 ++++++++++--------- .../inference_models/configuration.py | 7 +++++++ .../inference_models/models/rfdetr/common.py | 7 ++----- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 9b4bbb594e..f526a7761c 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -1,4 +1,5 @@ import base64 +from functools import lru_cache import io from io import BytesIO from time import perf_counter @@ -12,16 +13,16 @@ # Pinned host buffers for async DtoH on the full-postproc Triton fast path. # Keyed by (name, dtype); reused across frames provided the cached buffer is # at least as large as the requested shape in every dimension. -_PINNED_HOST_BUFFERS: dict = {} +PINNED_HOST_BUFFERS: dict = {} -def _get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor: +def get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor: key = (name, dtype) - buf = _PINNED_HOST_BUFFERS.get(key) + buf = PINNED_HOST_BUFFERS.get(key) if buf is not None and all(buf.shape[i] >= shape[i] for i in range(len(shape))): return buf[tuple(slice(0, s) for s in shape)] buf = torch.empty(shape, dtype=dtype, pin_memory=True) - _PINNED_HOST_BUFFERS[key] = buf + PINNED_HOST_BUFFERS[key] = buf return buf from inference.core.entities.requests import ( @@ -335,7 +336,7 @@ def postprocess( H = preproc_metadata.original_size.height W = preproc_metadata.original_size.width - # Fast path: RF-DETR full-postproc Triton fusion emits an + # RF-DETR postproc Triton fusion emits an # unsliced (num_queries, 6) int32 record plus a GPU counter and a # completion event. We DtoH the 4-byte counter (single sync) to # learn n_survivors, then async-DtoH the compact slices. @@ -354,7 +355,7 @@ def postprocess( stream = torch.cuda.current_stream(device) done_event.wait(stream) - counter_host = _get_pinned_buffer("counter", (1,), torch.int32) + counter_host = get_pinned_buffer("counter", (1,), torch.int32) counter_host.copy_(counter_gpu, non_blocking=True) stream.synchronize() n_survivors = int(counter_host[0].item()) @@ -367,11 +368,11 @@ def postprocess( else: combined_slice = combined_gpu[:n_survivors] mask_slice = det.mask[:n_survivors] - combined_host = _get_pinned_buffer( - "combined", combined_slice.shape, combined_slice.dtype + combined_host = get_pinned_buffer( + "combined", tuple(combined_slice.shape), combined_slice.dtype ) - mask_host = _get_pinned_buffer( - "mask", mask_slice.shape, mask_slice.dtype + mask_host = get_pinned_buffer( + "mask", tuple(mask_slice.shape), mask_slice.dtype ) combined_host.copy_(combined_slice, non_blocking=True) mask_host.copy_(mask_slice, non_blocking=True) diff --git a/inference_models/inference_models/configuration.py b/inference_models/inference_models/configuration.py index 0ffa1c0b64..7b01528a72 100644 --- a/inference_models/inference_models/configuration.py +++ b/inference_models/inference_models/configuration.py @@ -461,3 +461,10 @@ "ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND" ) DEFAULT_ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND = False + +DEFAULT_RFDETR_TRITON_FULLPOSTPROC = False + +RFDETR_TRITON_FULLPOSTPROC = get_boolean_from_env( + variable_name="RFDETR_TRITON_FULLPOSTPROC", + default=DEFAULT_RFDETR_TRITON_FULLPOSTPROC, +) diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index f39f9791f8..2ce81af91a 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -20,12 +20,9 @@ from inference_models.models.rfdetr.class_remapping import ClassesReMapping from inference_models.models.rfdetr.post_processor import select_topk_predictions from inference_models.utils.file_system import read_json +from inference_models.configuration import RFDETR_TRITON_FULLPOSTPROC -_RFDETR_TRITON_FULLPOSTPROC = os.getenv("RFDETR_TRITON_FULLPOSTPROC", "false").lower() in ( - "true", - "1", -) -if _RFDETR_TRITON_FULLPOSTPROC: +if RFDETR_TRITON_FULLPOSTPROC: try: from inference_models.models.rfdetr.triton_fullpostproc import ( TRITON_AVAILABLE as _TRITON_FULLPOST_AVAILABLE, From fe46c5740b9a0755fd0df8742d199ed5a7e05df6 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 01:11:26 +0000 Subject: [PATCH 08/25] refactoring to change var names --- .../core/models/inference_models_adapters.py | 1 - .../inference_models/configuration.py | 8 ++-- .../inference_models/models/rfdetr/common.py | 47 ++++++++++-------- .../models/rfdetr/triton_fullpostproc.py | 20 ++++---- temp/detection_parity_full.py | 48 +++++++++---------- 5 files changed, 63 insertions(+), 61 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index f526a7761c..520d7c08f6 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -1,5 +1,4 @@ import base64 -from functools import lru_cache import io from io import BytesIO from time import perf_counter diff --git a/inference_models/inference_models/configuration.py b/inference_models/inference_models/configuration.py index 7b01528a72..bec6d06038 100644 --- a/inference_models/inference_models/configuration.py +++ b/inference_models/inference_models/configuration.py @@ -462,9 +462,9 @@ ) DEFAULT_ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND = False -DEFAULT_RFDETR_TRITON_FULLPOSTPROC = False +DEFAULT_RFDETR_TRITON_POSTPROC = False -RFDETR_TRITON_FULLPOSTPROC = get_boolean_from_env( - variable_name="RFDETR_TRITON_FULLPOSTPROC", - default=DEFAULT_RFDETR_TRITON_FULLPOSTPROC, +RFDETR_TRITON_POSTPROC = get_boolean_from_env( + variable_name="RFDETR_TRITON_POSTPROC", + default=DEFAULT_RFDETR_TRITON_POSTPROC, ) diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index 2ce81af91a..78e618724e 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -20,24 +20,24 @@ from inference_models.models.rfdetr.class_remapping import ClassesReMapping from inference_models.models.rfdetr.post_processor import select_topk_predictions from inference_models.utils.file_system import read_json -from inference_models.configuration import RFDETR_TRITON_FULLPOSTPROC +from inference_models.configuration import RFDETR_TRITON_POSTPROC -if RFDETR_TRITON_FULLPOSTPROC: +if RFDETR_TRITON_POSTPROC: try: - from inference_models.models.rfdetr.triton_fullpostproc import ( - TRITON_AVAILABLE as _TRITON_FULLPOST_AVAILABLE, - triton_rfdetr_fullpost, + from inference_models.models.rfdetr.triton_postprocproc import ( + TRITON_AVAILABLE as _TRITON_POSTPROC_AVAILABLE, + rfdetr_triton_postproc, ) - _TRITON_FULLPOST_READY = _TRITON_FULLPOST_AVAILABLE and torch.cuda.is_available() + _TRITON_POSTPROC_READY = _TRITON_POSTPROC_AVAILABLE and torch.cuda.is_available() except Exception: - _TRITON_FULLPOST_READY = False - triton_rfdetr_fullpost = None + _TRITON_POSTPROC_READY = False + rfdetr_triton_postproc = None else: - _TRITON_FULLPOST_READY = False - triton_rfdetr_fullpost = None + _TRITON_POSTPROC_READY = False + rfdetr_triton_postproc = None -def _fullpost_upsample_masks( +def triton_postproc_upsample_masks( masks: torch.Tensor, survivor_q: torch.Tensor, inference_size_wh: Tuple[int, int], @@ -46,7 +46,7 @@ def _fullpost_upsample_masks( ) -> torch.Tensor: """Replicate ``align_instance_segmentation_results``'s mask path bit-for-bit (when ``size_after_pre_processing == inference_size`` and no - static crop applies — both guaranteed by ``_fullpost_eligible``). + static crop applies — both guaranteed by ``post_triton_eligible``). Gathers surviving query rows from ``masks`` (shape ``(1, Q, mh, mw)``), crops the letterbox padding in mask coordinates, bilinear-resizes to @@ -80,14 +80,19 @@ def _fullpost_upsample_masks( return upsampled > 0.0 -def _fullpost_eligible( +def post_triton_eligible( bboxes: torch.Tensor, + logits: torch.Tensor, pre_processing_meta: List[PreProcessingMetadata], classes_re_mapping: Optional[ClassesReMapping], ) -> bool: - if not _TRITON_FULLPOST_READY or not bboxes.is_cuda: + if not _TRITON_POSTPROC_READY: + return False + if not bboxes.is_cuda or not logits.is_cuda: + return False + if bboxes.device != logits.device: return False - if bboxes.shape[0] != 1 or len(pre_processing_meta) != 1: + if bboxes.shape[0] != 1 or logits.shape[0] != 1 or len(pre_processing_meta) != 1: return False meta = pre_processing_meta[0] if meta.nonsquare_intermediate_size is not None: @@ -211,10 +216,10 @@ def post_process_instance_segmentation_results( num_classes: int, classes_re_mapping: Optional[ClassesReMapping], ) -> List[InstanceDetections]: - if _fullpost_eligible(bboxes, pre_processing_meta, classes_re_mapping): + if post_triton_eligible(bboxes, logits, pre_processing_meta, classes_re_mapping): meta = pre_processing_meta[0] thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) - combined, survivor_idx, counter, done_event = triton_rfdetr_fullpost( + combined, survivor_idx, counter, done_event = rfdetr_triton_postproc( bboxes=bboxes, logits=logits, threshold=thr_arg, @@ -230,7 +235,7 @@ def post_process_instance_segmentation_results( # cap by combined's row count (num_queries). n_survivors = min(int(counter.item()), combined.shape[0]) combined_slice = combined[:n_survivors] - mask_bin = _fullpost_upsample_masks( + mask_bin = triton_postproc_upsample_masks( masks=masks, survivor_q=survivor_idx[:n_survivors], inference_size_wh=(meta.inference_size.width, meta.inference_size.height), @@ -342,10 +347,10 @@ def post_process_instance_segmentation_results_to_rle_masks( num_classes: int, classes_re_mapping: Optional[ClassesReMapping], ) -> List[InstanceDetections]: - if _fullpost_eligible(bboxes, pre_processing_meta, classes_re_mapping): + if post_triton_eligible(bboxes, logits, pre_processing_meta, classes_re_mapping): meta = pre_processing_meta[0] thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) - combined, survivor_idx, counter, done_event = triton_rfdetr_fullpost( + combined, survivor_idx, counter, done_event = rfdetr_triton_postproc( bboxes=bboxes, logits=logits, threshold=thr_arg, @@ -379,7 +384,7 @@ def post_process_instance_segmentation_results_to_rle_masks( ) ] combined_slice = combined[:n_survivors] - mask_slice = _fullpost_upsample_masks( + mask_slice = triton_postproc_upsample_masks( masks=masks, survivor_q=survivor_idx[:n_survivors], inference_size_wh=(meta.inference_size.width, meta.inference_size.height), diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index 1ef2ce9e71..f356b1f717 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -3,7 +3,7 @@ For the common rfdetr-seg-nano path (batch=1, no static crop, STRETCH_TO resize, class remapping active): - _rfdetr_fullpost_filter_kernel (grid = num_queries * num_classes_total) + rfdetr_postproc_triton_kernel (grid = num_queries * num_classes_total) One program per (q, c) pair: sigmoid + class remap + conf threshold + cxcywh->xyxy + letterbox-denormalize + clip + banker's rounding; atomic_add into a counter to reserve a compact output slot. @@ -33,7 +33,7 @@ if TRITON_AVAILABLE: @triton.jit - def _rfdetr_fullpost_filter_kernel( + def rfdetr_postproc_triton_kernel( logits_ptr, bboxes_ptr, threshold_ptr, @@ -202,7 +202,7 @@ def _prepare_threshold(threshold, device: torch.device, num_classes: int): return cached, False -def triton_rfdetr_fullpost( +def rfdetr_triton_postproc( bboxes: torch.Tensor, logits: torch.Tensor, threshold: "torch.Tensor | float", @@ -225,12 +225,8 @@ def triton_rfdetr_fullpost( ``combined[:, 4]`` holds fp32 conf as int32 bits; reinterpret on the host with ``.view(torch.float32)``. The mask upsample is intentionally left to the caller (torch ``F.interpolate``) so the result is bit-exact - with the non-fullpost reference path. + with the non-postproc reference path. """ - assert TRITON_AVAILABLE, "triton not available" - assert bboxes.is_cuda and logits.is_cuda - assert bboxes.shape[0] == 1 and logits.shape[0] == 1, "batch=1 only" - device = bboxes.device num_queries, num_classes_total = logits.shape[1], logits.shape[2] @@ -253,8 +249,10 @@ def triton_rfdetr_fullpost( pad_l, pad_t, _, _ = pad_ltrb sw, sh = scale_wh orig_w, orig_h = orig_size_wh + per_class_constexpr = 1 if per_class else 0 + has_remap_constexpr = 1 if has_remap else 0 - _rfdetr_fullpost_filter_kernel[(num_queries, num_classes_total)]( + rfdetr_postproc_triton_kernel[(num_queries, num_classes_total)]( logits_2d, bboxes_2d, thr_tensor, @@ -274,8 +272,8 @@ def triton_rfdetr_fullpost( int(orig_h), logits_2d.stride(0), bboxes_2d.stride(0), - PER_CLASS=1 if per_class else 0, - HAS_REMAPPING=1 if has_remap else 0, + PER_CLASS=tl.constexpr(per_class_constexpr), + HAS_REMAPPING=tl.constexpr(has_remap_constexpr), ) done_event = torch.cuda.Event() diff --git a/temp/detection_parity_full.py b/temp/detection_parity_full.py index f75c14fff6..c0d4f4b6a1 100644 --- a/temp/detection_parity_full.py +++ b/temp/detection_parity_full.py @@ -1,7 +1,7 @@ """Full coco/val2017 detection + mask parity: fused-postproc Triton path vs reference. -Driven by RFDETR_TRITON_FULLPOSTPROC (true/false). The env var is read once at -import time, so we run two subprocesses (fullpost on, fullpost off), stream +Driven by RFDETR_TRITON_POSTPROC (true/false). The env var is read once at +import time, so we run two subprocesses (postproc on, postproc off), stream per-image detections to a pickle file (one pickle.dump per record), then compare in lockstep. @@ -24,7 +24,7 @@ SELF = Path(__file__).resolve() OUT_ON = "/tmp/det_parity_full_on.pkl" OUT_OFF = "/tmp/det_parity_full_off.pkl" -# Cap images per pass: fullpost caches a per-(orig_h, orig_w) GPU buffer in +# Cap images per pass: postproc caches a per-(orig_h, orig_w) GPU buffer in # _get_mask_bin_buffer, so the working set grows with distinct image sizes and # OOMs near ~5k images on a 14 GiB card. 1500 covers a wide variety of shapes # while staying well under that ceiling. @@ -52,18 +52,18 @@ def do_run(out_path): from inference_models import AutoModel import inference_models.models.rfdetr.common as common_mod - fullpost_calls = {"count": 0} - original_fp = getattr(common_mod, "triton_rfdetr_fullpost", None) + postproc_calls = {"count": 0} + original_fp = getattr(common_mod, "rfdetr_triton_postproc", None) if original_fp is not None: def counting_fp(*a, **kw): - fullpost_calls["count"] += 1 + postproc_calls["count"] += 1 return original_fp(*a, **kw) - common_mod.triton_rfdetr_fullpost = counting_fp + common_mod.rfdetr_triton_postproc = counting_fp - fp_flag = os.environ.get("RFDETR_TRITON_FULLPOSTPROC", "") + fp_flag = os.environ.get("RFDETR_TRITON_POSTPROC", "") print( - f"[run] RFDETR_TRITON_FULLPOSTPROC={fp_flag} " - f"(module ready={getattr(common_mod, '_TRITON_FULLPOST_READY', False)})" + f"[run] RFDETR_TRITON_POSTPROC={fp_flag} " + f"(module ready={getattr(common_mod, '_TRITON_POSTPROC_READY', False)})" ) paths = sorted(COCO.glob("*.jpg"))[:MAX_IMAGES] @@ -73,7 +73,7 @@ def counting_fp(*a, **kw): t0 = time.perf_counter() with open(out_path, "wb") as f: # placeholder header — rewritten at end - pickle.dump({"_kind": "header", "fullpost_calls": -1, "n_records": -1}, f) + pickle.dump({"_kind": "header", "postproc_calls": -1, "n_records": -1}, f) for idx, p in enumerate(paths): im = cv2.imread(str(p), cv2.IMREAD_COLOR) if im is None: @@ -105,8 +105,8 @@ def counting_fp(*a, **kw): # append a footer with the totals (header is ignored on read; iterator just walks records + footer) with open(out_path, "ab") as f: - pickle.dump({"_kind": "footer", "fullpost_calls": fullpost_calls["count"], "n_records": n_records}, f) - print(f"[run] fullpost_kernel_calls={fullpost_calls['count']} records={n_records} saved -> {out_path}") + pickle.dump({"_kind": "footer", "postproc_calls": postproc_calls["count"], "n_records": n_records}, f) + print(f"[run] postproc_kernel_calls={postproc_calls['count']} records={n_records} saved -> {out_path}") def iou_box(a, b): @@ -147,8 +147,8 @@ def do_compare(on_path, off_path): if r_off.get("_kind") == "header": r_off = next(off_iter) if r_on.get("_kind") == "footer" or r_off.get("_kind") == "footer": - on_fp_calls = r_on.get("fullpost_calls", on_fp_calls) - off_fp_calls = r_off.get("fullpost_calls", off_fp_calls) + on_fp_calls = r_on.get("postproc_calls", on_fp_calls) + off_fp_calls = r_off.get("postproc_calls", off_fp_calls) break assert r_on["path"] == r_off["path"], (r_on["path"], r_off["path"]) @@ -198,14 +198,14 @@ def do_compare(on_path, off_path): for r in it: if r.get("_kind") == "footer": if current_calls_attr == "on_fp_calls": - on_fp_calls = r["fullpost_calls"] + on_fp_calls = r["postproc_calls"] else: - off_fp_calls = r["fullpost_calls"] + off_fp_calls = r["postproc_calls"] print() - print(f"==== full coco/val2017 parity: fullpost=true vs fullpost=false ({n_imgs} images) ====") - print(f" fullpost calls (fp=true) : {on_fp_calls}") - print(f" fullpost calls (fp=false) : {off_fp_calls}") + print(f"==== full coco/val2017 parity: postproc=true vs postproc=false ({n_imgs} images) ====") + print(f" postproc calls (fp=true) : {on_fp_calls}") + print(f" postproc calls (fp=false) : {off_fp_calls}") print(f" dets fp=true / fp=false : {tot_on} / {tot_off}") print(f" matched (IoU>0.5) : {matched} ({100*matched/max(1,tot_off):.2f}% of fp=false)") print(f" count-mismatch images : {count_mm}") @@ -222,8 +222,8 @@ def do_compare(on_path, off_path): expected = n_imgs ok_on = "[PASS]" if on_fp_calls == expected else "[FAIL]" ok_off = "[PASS]" if off_fp_calls == 0 else "[FAIL]" - print(f" {ok_on} fp=true -> fullpost fired {on_fp_calls}/{expected}") - print(f" {ok_off} fp=false -> fullpost fired {off_fp_calls} (expected 0)") + print(f" {ok_on} fp=true -> postproc fired {on_fp_calls}/{expected}") + print(f" {ok_off} fp=false -> postproc fired {off_fp_calls} (expected 0)") def main(): @@ -243,8 +243,8 @@ def main(): for fp_value, out in (("true", OUT_ON), ("false", OUT_OFF)): env = os.environ.copy() - env["RFDETR_TRITON_FULLPOSTPROC"] = fp_value - print(f"\n---- child: fullpost={fp_value} out={out} ----", flush=True) + env["RFDETR_TRITON_POSTPROC"] = fp_value + print(f"\n---- child: postproc={fp_value} out={out} ----", flush=True) subprocess.run([PY, str(SELF), "--mode", "run", "--out", out], check=True, env=env) do_compare(OUT_ON, OUT_OFF) From c9db9103e814cc82c797a89bbcbe8411c9d32d49 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 05:14:48 +0000 Subject: [PATCH 09/25] Fuse full RF-DETR Triton post-processing --- .../core/models/inference_models_adapters.py | 60 +- .../models/common/rle_utils.py | 131 ++- .../inference_models/models/rfdetr/common.py | 166 ++-- .../rfdetr_instance_segmentation_onnx.py | 1 + .../rfdetr_instance_segmentation_pytorch.py | 1 + .../rfdetr_instance_segmentation_trt.py | 1 + .../models/rfdetr/triton_fullpostproc.py | 750 ++++++++++++++---- .../models/common/test_rle_utils.py | 71 ++ 8 files changed, 911 insertions(+), 270 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 520d7c08f6..177b9b18ab 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -326,6 +326,8 @@ def postprocess( ) -> List[InstanceSegmentationInferenceResponse]: return_in_rle = kwargs.get("response_mask_format") == "rle" mapped_kwargs = self.map_inference_kwargs(kwargs) + if "response_mask_format" in kwargs: + mapped_kwargs["response_mask_format"] = kwargs["response_mask_format"] detections_list = self._model.post_process( predictions, preprocess_return_metadata, **mapped_kwargs ) @@ -335,29 +337,25 @@ def postprocess( H = preproc_metadata.original_size.height W = preproc_metadata.original_size.width - # RF-DETR postproc Triton fusion emits an - # unsliced (num_queries, 6) int32 record plus a GPU counter and a - # completion event. We DtoH the 4-byte counter (single sync) to - # learn n_survivors, then async-DtoH the compact slices. - combined_gpu = getattr(det, "_combined_gpu", None) - counter_gpu = getattr(det, "_counter_gpu", None) + mask_gpu = getattr(det, "_mask_gpu", None) + mask_cpu = getattr(det, "_mask_cpu", None) done_event = getattr(det, "_postproc_done_event", None) if ( not return_in_rle - and combined_gpu is not None - and counter_gpu is not None and done_event is not None - and isinstance(det.mask, torch.Tensor) - and det.mask.is_cuda + and isinstance(mask_gpu, torch.Tensor) + and mask_gpu.is_cuda + and isinstance(det.xyxy, torch.Tensor) + and det.xyxy.is_cuda + and isinstance(det.confidence, torch.Tensor) + and det.confidence.is_cuda + and isinstance(det.class_id, torch.Tensor) + and det.class_id.is_cuda ): - device = combined_gpu.device + device = mask_gpu.device stream = torch.cuda.current_stream(device) done_event.wait(stream) - - counter_host = get_pinned_buffer("counter", (1,), torch.int32) - counter_host.copy_(counter_gpu, non_blocking=True) - stream.synchronize() - n_survivors = int(counter_host[0].item()) + n_survivors = int(det.xyxy.shape[0]) if n_survivors == 0: xyxy = np.empty((0, 4), dtype=np.int32) @@ -365,23 +363,33 @@ def postprocess( class_ids = np.empty((0,), dtype=np.int32) polys_or_rles = [] else: - combined_slice = combined_gpu[:n_survivors] - mask_slice = det.mask[:n_survivors] - combined_host = get_pinned_buffer( - "combined", tuple(combined_slice.shape), combined_slice.dtype + mask_slice = mask_gpu[:n_survivors] + xyxy_host = get_pinned_buffer( + "xyxy", tuple(det.xyxy.shape), det.xyxy.dtype + ) + conf_host = get_pinned_buffer( + "conf", tuple(det.confidence.shape), det.confidence.dtype + ) + class_host = get_pinned_buffer( + "class_id", tuple(det.class_id.shape), det.class_id.dtype ) mask_host = get_pinned_buffer( "mask", tuple(mask_slice.shape), mask_slice.dtype ) - combined_host.copy_(combined_slice, non_blocking=True) + xyxy_host.copy_(det.xyxy, non_blocking=True) + conf_host.copy_(det.confidence, non_blocking=True) + class_host.copy_(det.class_id, non_blocking=True) mask_host.copy_(mask_slice, non_blocking=True) stream.synchronize() - combined_cpu = combined_host.numpy() - xyxy = combined_cpu[:, :4] - # combined[:, 4] holds fp32 conf bits stored as int32. - confs = combined_cpu[:, 4].view(np.float32) - class_ids = combined_cpu[:, 5] + xyxy = xyxy_host.numpy() + confs = conf_host.numpy() + class_ids = class_host.numpy() polys_or_rles = masks2poly(mask_host.numpy()) + elif not return_in_rle and isinstance(mask_cpu, np.ndarray): + xyxy = det.xyxy.detach().cpu().numpy() + confs = det.confidence.detach().cpu().numpy() + class_ids = det.class_id.detach().cpu().numpy() + polys_or_rles = masks2poly(mask_cpu) else: xyxy = det.xyxy.detach().cpu().numpy() confs = det.confidence.detach().cpu().numpy() diff --git a/inference_models/inference_models/models/common/rle_utils.py b/inference_models/inference_models/models/common/rle_utils.py index 36525ebeba..5348e01a14 100644 --- a/inference_models/inference_models/models/common/rle_utils.py +++ b/inference_models/inference_models/models/common/rle_utils.py @@ -7,6 +7,11 @@ from inference_models.models.base.types import InstancesRLEMasks +def counts_to_coco_rle(counts: list, image_size: tuple) -> dict: + h, w = image_size + return mask_utils.frPyObjects({"counts": counts, "size": [h, w]}, h, w) + + def torch_mask_to_coco_rle(mask: torch.Tensor) -> dict: # Convert to uncompressed run length encoding in GPU # coco tools expect fortran order (column-wise) @@ -17,10 +22,128 @@ def torch_mask_to_coco_rle(mask: torch.Tensor) -> dict: if values[0] == 1: counts.insert(0, 0) - h, w = mask.shape - # compress - rle = mask_utils.frPyObjects({"counts": counts, "size": [h, w]}, h, w) - return rle + return counts_to_coco_rle(counts=counts, image_size=tuple(mask.shape)) + + +def numpy_mask_to_coco_rle(mask: np.ndarray) -> dict: + mask_bool = np.asarray(mask, dtype=bool) + mask_flat = np.ravel(mask_bool, order="F") + if mask_flat.size == 0: + return counts_to_coco_rle(counts=[], image_size=tuple(mask_bool.shape)) + transitions = np.flatnonzero(mask_flat[1:] != mask_flat[:-1]) + 1 + counts = np.diff( + np.concatenate( + ( + np.array([0], dtype=np.int64), + transitions.astype(np.int64, copy=False), + np.array([mask_flat.size], dtype=np.int64), + ) + ) + ).tolist() + if mask_flat[0]: + counts.insert(0, 0) + return counts_to_coco_rle(counts=counts, image_size=tuple(mask_bool.shape)) + + +class LazyInstancesRLEMasks(InstancesRLEMasks): + """Materializes COCO RLE counts only when a caller actually needs them.""" + + def __init__( + self, + image_size: tuple, + mask_gpu: Optional[torch.Tensor] = None, + mask_cpu: Optional[np.ndarray] = None, + rle_counts_gpu: Optional[torch.Tensor] = None, + rle_lengths_gpu: Optional[torch.Tensor] = None, + rle_counts_cpu: Optional[np.ndarray] = None, + rle_lengths_cpu: Optional[np.ndarray] = None, + done_event: Optional["torch.cuda.Event"] = None, + ): + self.image_size = image_size + self._masks: list = [] + self._materialized = False + self._mask_gpu = mask_gpu + self._mask_cpu = mask_cpu + self._rle_counts_gpu = rle_counts_gpu + self._rle_lengths_gpu = rle_lengths_gpu + self._rle_counts_cpu = rle_counts_cpu + self._rle_lengths_cpu = rle_lengths_cpu + self._done_event = done_event + + @property + def masks(self) -> list: + self._ensure_materialized() + return self._masks + + @masks.setter + def masks(self, value: list) -> None: + self._masks = value + self._materialized = True + + def _ensure_mask_cpu(self) -> np.ndarray: + if self._mask_cpu is not None: + return self._mask_cpu + if self._mask_gpu is None: + self._mask_cpu = np.empty( + (0, self.image_size[0], self.image_size[1]), dtype=bool + ) + return self._mask_cpu + device = self._mask_gpu.device + stream = torch.cuda.current_stream(device) + if self._done_event is not None: + self._done_event.wait(stream) + mask_cpu = self._mask_gpu.cpu().numpy() + if mask_cpu.dtype == np.uint8: + mask_cpu = mask_cpu.view(np.bool_) + else: + mask_cpu = mask_cpu.astype(bool, copy=False) + self._mask_cpu = mask_cpu + return self._mask_cpu + + def _ensure_rle_cpu(self) -> None: + if self._rle_counts_cpu is not None and self._rle_lengths_cpu is not None: + return + if self._rle_counts_gpu is None or self._rle_lengths_gpu is None: + return + device = self._rle_lengths_gpu.device + stream = torch.cuda.current_stream(device) + if self._done_event is not None: + self._done_event.wait(stream) + lengths_cpu = self._rle_lengths_gpu.cpu().numpy().astype(np.int32, copy=False) + if lengths_cpu.size == 0: + counts_cpu = np.empty((0, 0), dtype=np.int32) + else: + max_len = int(lengths_cpu.max()) + counts_slice = self._rle_counts_gpu[:, :max_len] + counts_cpu = counts_slice.cpu().numpy().astype(np.int32, copy=False) + self._rle_lengths_cpu = lengths_cpu + self._rle_counts_cpu = counts_cpu + + def _ensure_materialized(self) -> None: + if self._materialized: + return + self._ensure_rle_cpu() + if self._rle_counts_cpu is not None and self._rle_lengths_cpu is not None: + self._masks = [ + counts_to_coco_rle( + counts=self._rle_counts_cpu[i, : int(self._rle_lengths_cpu[i])] + .astype(np.int64, copy=False) + .tolist(), + image_size=self.image_size, + )["counts"] + for i in range(self._rle_lengths_cpu.shape[0]) + ] + else: + mask_cpu = self._ensure_mask_cpu() + self._masks = [ + numpy_mask_to_coco_rle(mask=mask_cpu[i])["counts"] + for i in range(mask_cpu.shape[0]) + ] + self._materialized = True + + def to_coco_rle_masks(self) -> list: + self._ensure_materialized() + return super().to_coco_rle_masks() def coco_rle_masks_to_numpy_mask(instances_masks: InstancesRLEMasks) -> np.ndarray: diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index 78e618724e..a6f322b550 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -1,17 +1,13 @@ -import os -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch -from torchvision.transforms import functional from inference_models import Detections, InstanceDetections, InstancesRLEMasks -from inference_models.entities import ImageDimensions from inference_models.errors import CorruptedModelPackageError from inference_models.models.common.roboflow.model_packages import ( PreProcessingMetadata, - StaticCropOffset, ) -from inference_models.models.common.rle_utils import torch_mask_to_coco_rle +from inference_models.models.common.rle_utils import LazyInstancesRLEMasks from inference_models.models.common.roboflow.post_processing import ( align_instance_segmentation_results, align_instance_segmentation_results_to_rle_masks, @@ -24,7 +20,11 @@ if RFDETR_TRITON_POSTPROC: try: - from inference_models.models.rfdetr.triton_postprocproc import ( + from inference_models.models.rfdetr.triton_fullpostproc import ( + FASTPATH_MASK_H, + FASTPATH_MASK_W, + FASTPATH_NUM_CLASSES_TOTAL, + FASTPATH_NUM_QUERIES, TRITON_AVAILABLE as _TRITON_POSTPROC_AVAILABLE, rfdetr_triton_postproc, ) @@ -37,68 +37,47 @@ rfdetr_triton_postproc = None -def triton_postproc_upsample_masks( - masks: torch.Tensor, - survivor_q: torch.Tensor, - inference_size_wh: Tuple[int, int], - pad_ltrb: Tuple[int, int, int, int], - orig_size_hw: Tuple[int, int], -) -> torch.Tensor: - """Replicate ``align_instance_segmentation_results``'s mask path - bit-for-bit (when ``size_after_pre_processing == inference_size`` and no - static crop applies — both guaranteed by ``post_triton_eligible``). - - Gathers surviving query rows from ``masks`` (shape ``(1, Q, mh, mw)``), - crops the letterbox padding in mask coordinates, bilinear-resizes to - ``(orig_h, orig_w)`` with ``antialias=True`` (matching torchvision's - ``functional.resize`` default), and thresholds at 0. - """ - selected = masks[0].index_select(0, survivor_q.long()) - if selected.shape[0] == 0: - orig_h, orig_w = orig_size_hw - return torch.empty( - (0, orig_h, orig_w), dtype=torch.bool, device=masks.device - ) - _, mh, mw = selected.shape - inf_w, inf_h = inference_size_wh - pad_l, pad_t, pad_r, pad_b = pad_ltrb - mh_scale = mh / inf_h - mw_scale = mw / inf_w - mpt = round(mh_scale * pad_t) - mpb = round(mh_scale * pad_b) - mpl = round(mw_scale * pad_l) - mpr = round(mw_scale * pad_r) - selected = selected[:, mpt: mh - mpb, mpl: mw - mpr] - orig_h, orig_w = orig_size_hw - upsampled = torch.nn.functional.interpolate( - selected.unsqueeze(1), - size=(orig_h, orig_w), - mode="bilinear", - antialias=True, - align_corners=False, - ).squeeze(1) - return upsampled > 0.0 - - def post_triton_eligible( bboxes: torch.Tensor, logits: torch.Tensor, + masks: torch.Tensor, pre_processing_meta: List[PreProcessingMetadata], classes_re_mapping: Optional[ClassesReMapping], ) -> bool: if not _TRITON_POSTPROC_READY: return False - if not bboxes.is_cuda or not logits.is_cuda: + if not bboxes.is_cuda or not logits.is_cuda or not masks.is_cuda: return False - if bboxes.device != logits.device: + if bboxes.device != logits.device or bboxes.device != masks.device: return False - if bboxes.shape[0] != 1 or logits.shape[0] != 1 or len(pre_processing_meta) != 1: + if ( + bboxes.shape[0] != 1 + or logits.shape[0] != 1 + or masks.shape[0] != 1 + or len(pre_processing_meta) != 1 + ): + return False + if ( + bboxes.shape[1] != FASTPATH_NUM_QUERIES + or logits.shape[1] != FASTPATH_NUM_QUERIES + or logits.shape[2] != FASTPATH_NUM_CLASSES_TOTAL + or masks.shape[1] != FASTPATH_NUM_QUERIES + or masks.shape[2] != FASTPATH_MASK_H + or masks.shape[3] != FASTPATH_MASK_W + ): return False meta = pre_processing_meta[0] if meta.nonsquare_intermediate_size is not None: return False if meta.static_crop_offset.offset_x != 0 or meta.static_crop_offset.offset_y != 0: return False + if meta.pad_left != 0 or meta.pad_top != 0 or meta.pad_right != 0 or meta.pad_bottom != 0: + return False + if ( + meta.size_after_pre_processing.height < FASTPATH_MASK_H + or meta.size_after_pre_processing.width < FASTPATH_MASK_W + ): + return False if classes_re_mapping is None: return False return True @@ -216,39 +195,34 @@ def post_process_instance_segmentation_results( num_classes: int, classes_re_mapping: Optional[ClassesReMapping], ) -> List[InstanceDetections]: - if post_triton_eligible(bboxes, logits, pre_processing_meta, classes_re_mapping): + if post_triton_eligible( + bboxes, logits, masks, pre_processing_meta, classes_re_mapping + ): meta = pre_processing_meta[0] thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) - combined, survivor_idx, counter, done_event = rfdetr_triton_postproc( + combined, mask_bin, counter, done_event, _, _ = rfdetr_triton_postproc( bboxes=bboxes, logits=logits, + masks=masks, threshold=thr_arg, num_classes=num_classes, class_mapping=classes_re_mapping.class_mapping, inference_size_wh=(meta.inference_size.width, meta.inference_size.height), - pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), scale_wh=(meta.scale_width, meta.scale_height), orig_size_wh=(meta.original_size.width, meta.original_size.height), ) done_event.wait(torch.cuda.current_stream(bboxes.device)) - # Counter is incremented unconditionally before the slot-cap guard, so - # cap by combined's row count (num_queries). - n_survivors = min(int(counter.item()), combined.shape[0]) + n_survivors = int(counter.item()) combined_slice = combined[:n_survivors] - mask_bin = triton_postproc_upsample_masks( - masks=masks, - survivor_q=survivor_idx[:n_survivors], - inference_size_wh=(meta.inference_size.width, meta.inference_size.height), - pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), - orig_size_hw=(meta.original_size.height, meta.original_size.width), - ) + mask_gpu = mask_bin[:n_survivors].view(torch.bool) detections = InstanceDetections( xyxy=combined_slice[:, :4], confidence=combined_slice[:, 4].view(torch.float32), class_id=combined_slice[:, 5], - mask=mask_bin, + mask=mask_gpu, ) detections.__dict__["_combined_gpu"] = combined + detections.__dict__["_mask_gpu"] = mask_gpu detections.__dict__["_counter_gpu"] = counter detections.__dict__["_postproc_done_event"] = done_event return [detections] @@ -346,25 +320,27 @@ def post_process_instance_segmentation_results_to_rle_masks( threshold: Union[float, torch.Tensor], num_classes: int, classes_re_mapping: Optional[ClassesReMapping], + emit_in_kernel_rle: bool = False, ) -> List[InstanceDetections]: - if post_triton_eligible(bboxes, logits, pre_processing_meta, classes_re_mapping): + if post_triton_eligible( + bboxes, logits, masks, pre_processing_meta, classes_re_mapping + ): meta = pre_processing_meta[0] thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) - combined, survivor_idx, counter, done_event = rfdetr_triton_postproc( + combined, mask_bin, counter, done_event, rle_counts, rle_lengths = rfdetr_triton_postproc( bboxes=bboxes, logits=logits, + masks=masks, threshold=thr_arg, num_classes=num_classes, class_mapping=classes_re_mapping.class_mapping, inference_size_wh=(meta.inference_size.width, meta.inference_size.height), - pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), scale_wh=(meta.scale_width, meta.scale_height), orig_size_wh=(meta.original_size.width, meta.original_size.height), + emit_rle=emit_in_kernel_rle, ) done_event.wait(torch.cuda.current_stream(bboxes.device)) - # Counter is incremented unconditionally before the slot-cap guard, so - # cap by combined's row count (num_queries). - n_survivors = min(int(counter.item()), combined.shape[0]) + n_survivors = int(counter.item()) orig_h = meta.original_size.height orig_w = meta.original_size.width if n_survivors == 0: @@ -384,30 +360,34 @@ def post_process_instance_segmentation_results_to_rle_masks( ) ] combined_slice = combined[:n_survivors] - mask_slice = triton_postproc_upsample_masks( - masks=masks, - survivor_q=survivor_idx[:n_survivors], - inference_size_wh=(meta.inference_size.width, meta.inference_size.height), - pad_ltrb=(meta.pad_left, meta.pad_top, meta.pad_right, meta.pad_bottom), - orig_size_hw=(orig_h, orig_w), - ) - rle_masks = [ - torch_mask_to_coco_rle(mask=mask_slice[i]) for i in range(n_survivors) - ] - instances_masks = InstancesRLEMasks.from_coco_rle_masks( - image_size=(orig_h, orig_w), masks=rle_masks + instances_masks = LazyInstancesRLEMasks( + image_size=(orig_h, orig_w), + mask_gpu=( + mask_bin[:n_survivors].view(torch.bool) + if not emit_in_kernel_rle + else None + ), + rle_counts_gpu=rle_counts[:n_survivors] + if emit_in_kernel_rle and rle_counts is not None + else None, + rle_lengths_gpu=rle_lengths[:n_survivors] + if emit_in_kernel_rle and rle_lengths is not None + else None, + done_event=done_event, ) xyxy = combined_slice[:, :4] confidence = combined_slice[:, 4].view(torch.float32) class_id = combined_slice[:, 5] - return [ - InstanceDetections( - xyxy=xyxy, - confidence=confidence, - class_id=class_id, - mask=instances_masks, - ) - ] + detections = InstanceDetections( + xyxy=xyxy, + confidence=confidence, + class_id=class_id, + mask=instances_masks, + ) + if not emit_in_kernel_rle: + detections.__dict__["_mask_gpu"] = mask_bin[:n_survivors].view(torch.bool) + detections.__dict__["_postproc_done_event"] = done_event + return [detections] logits_sigmoid = torch.nn.functional.sigmoid(logits) final_results = [] device = bboxes.device diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py index 42ccdbe86a..2f93cccbc0 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py @@ -263,5 +263,6 @@ def post_process( threshold=confidence_filter.get_threshold(self.class_names), num_classes=len(self.class_names), classes_re_mapping=self._classes_re_mapping, + emit_in_kernel_rle=kwargs.get("response_mask_format") == "rle", ) return results diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py index efdb84665c..920ca7cf59 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py @@ -469,5 +469,6 @@ def post_process( threshold=confidence_filter.get_threshold(self.class_names), num_classes=len(self.class_names), classes_re_mapping=self._classes_re_mapping, + emit_in_kernel_rle=kwargs.get("response_mask_format") == "rle", ) return results diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py index e8196f8eae..903a6ee6d0 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py @@ -318,6 +318,7 @@ def post_process( threshold=confidence_filter.get_threshold(self.class_names), num_classes=len(self.class_names), classes_re_mapping=self._classes_re_mapping, + emit_in_kernel_rle=kwargs.get("response_mask_format") == "rle", ) self._post_process_stream.synchronize() return results diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index f356b1f717..584139ebee 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -1,19 +1,22 @@ -"""Fused RF-DETR instance-segmentation post-processing in Triton. - -For the common rfdetr-seg-nano path (batch=1, no static crop, STRETCH_TO -resize, class remapping active): - - rfdetr_postproc_triton_kernel (grid = num_queries * num_classes_total) - One program per (q, c) pair: sigmoid + class remap + conf threshold + - cxcywh->xyxy + letterbox-denormalize + clip + banker's rounding; - atomic_add into a counter to reserve a compact output slot. - - Mask upsample uses ``F.interpolate(bilinear, antialias=True)`` followed by - a ``> 0`` threshold — bit-for-bit identical to the reference path's - ``align_instance_segmentation_results`` mask handling. We reuse the - reference here (rather than a custom kernel) because cuDNN's antialiased - bilinear is hard to match in fp32 and the kernel-level win is in the - filter step, not the upsample. +"""Single-kernel RF-DETR instance-segmentation post-processing in Triton. + +Fast path scope: +- batch size == 1 +- RF-DETR seg TRT tensor shapes: Q=100, C=91, Mh=Mw=78 +- no static crop, no letterbox padding, no nonsquare intermediate size +- output mask resize is an upsample (the common benchmark / parity case) + +Within one Triton launch, each program owns one output detection rank and: +- performs flat top-k over the (Q, C) sigmoid grid +- applies class remap + confidence filtering +- denormalizes / rescales / rounds the selected box +- resizes the selected 78x78 mask to the original image size and thresholds it + +The resize uses cached 2-tap closed-form bilinear tables, so the per-inference +hot path remains a single Triton launch without any CUDA bootstrap probes. + +Host-side work is limited to slicing the preallocated buffers once the kernel +completes and wrapping them in ``InstanceDetections`` / RLE containers. """ from typing import Optional, Tuple @@ -30,99 +33,162 @@ TRITON_AVAILABLE = False +FASTPATH_NUM_QUERIES = 100 +FASTPATH_NUM_CLASSES_TOTAL = 91 +FASTPATH_MASK_H = 78 +FASTPATH_MASK_W = 78 +_TOPK_PAD = 128 +_CLASS_BLOCK = 128 +_MASK_TILE_H = 8 +_MASK_TILE_W = 64 +_RLE_TILE_H = 32 +_RLE_TILE_W = 16 +_RLE_MERGE_TILE = 32 +_MAX_U32 = 0xFFFFFFFF + + if TRITON_AVAILABLE: @triton.jit - def rfdetr_postproc_triton_kernel( + def rfdetr_fullpostproc_triton_kernel( logits_ptr, bboxes_ptr, + masks_ptr, threshold_ptr, class_map_ptr, + y_indices_ptr, + y_weights_ptr, + x_indices_ptr, + x_weights_ptr, + rle_counts_ptr, + rle_lengths_ptr, combined_out_ptr, - survivor_idx_out_ptr, + mask_out_ptr, counter_ptr, - num_queries, - num_classes_total, - inference_w, - inference_h, - pad_left, - pad_top, - inv_scale_w, - inv_scale_h, orig_w, orig_h, logits_stride_q, bboxes_stride_q, + masks_stride_q, + masks_stride_h, + masks_stride_w, + rle_counts_stride_q, + rle_lengths_stride_q, + mask_out_stride_q, + mask_out_stride_h, + mask_out_stride_w, PER_CLASS: tl.constexpr, HAS_REMAPPING: tl.constexpr, + EMIT_RLE: tl.constexpr, + NUM_QUERIES: tl.constexpr, + NUM_CLASSES_TOTAL: tl.constexpr, + MASK_H: tl.constexpr, + MASK_W: tl.constexpr, + TOPK_PAD: tl.constexpr, + CLASS_BLOCK: tl.constexpr, + MASK_TILE_H: tl.constexpr, + MASK_TILE_W: tl.constexpr, + RLE_TILE_H: tl.constexpr, + RLE_TILE_W: tl.constexpr, + RLE_MERGE_TILE: tl.constexpr, ): - # One program per (query, class). The reference path does top-k-flat - # over the (Q*C) sigmoid grid (`num_select == num_queries`), so a - # single query can contribute multiple detections — once per class - # that survives remap + threshold. Per-query argmax would silently - # drop the others. - pid_q = tl.program_id(0) - pid_c = tl.program_id(1) - if pid_q >= num_queries or pid_c >= num_classes_total: + pid_det = tl.program_id(0) + + # Maintain the reference flat top-k exactly: top 100 scores over the + # 100x91 sigmoid grid, before class remap / thresholding. + top_packed = tl.zeros((TOPK_PAD,), dtype=tl.int64) + class_offsets = tl.arange(0, CLASS_BLOCK) + rank_offsets = tl.arange(0, TOPK_PAD) + top_limit = tl.full((), NUM_QUERIES, tl.int32) + num_classes_total = tl.full((), NUM_CLASSES_TOTAL, tl.int32) + + for q in tl.range(0, NUM_QUERIES, num_stages=1): + valid_class = class_offsets < NUM_CLASSES_TOTAL + logit = tl.load( + logits_ptr + q * logits_stride_q + class_offsets, + mask=valid_class, + other=-float("inf"), + ) + abs_l = tl.abs(logit) + z = tl.exp(-abs_l) + sig_pos = 1.0 / (1.0 + z) + sig_neg = z / (1.0 + z) + conf = tl.where(logit >= 0.0, sig_pos, sig_neg) + conf_bits = conf.to(tl.float32, bitcast=False).to(tl.int32, bitcast=True) + flat_idx = q * NUM_CLASSES_TOTAL + class_offsets + packed = tl.where( + valid_class, + (conf_bits.to(tl.int64) << 32) | flat_idx.to(tl.int64), + tl.zeros((CLASS_BLOCK,), dtype=tl.int64), + ) + merged = tl.reshape(tl.join(top_packed, packed), (TOPK_PAD + CLASS_BLOCK,)) + top_packed = tl.topk(merged, k=TOPK_PAD) + + selected_q = tl.full((), 0, tl.int32) + selected_c = tl.full((), 0, tl.int32) + selected_conf = tl.full((), 0.0, tl.float32) + keep_count = tl.full((), 0, tl.int32) + + for rank in tl.range(0, NUM_QUERIES, num_stages=1): + packed = tl.sum( + tl.where( + rank_offsets == rank, + top_packed, + tl.zeros((TOPK_PAD,), dtype=tl.int64), + ), + axis=0, + ) + conf_bits = (packed >> 32).to(tl.int32) + conf = conf_bits.to(tl.float32, bitcast=True) + flat_idx = (packed & tl.full((), 0xFFFFFFFF, tl.int64)).to(tl.int32) + query_idx = flat_idx // num_classes_total + raw_class = flat_idx - query_idx * num_classes_total + + if HAS_REMAPPING: + mapped_class = tl.load(class_map_ptr + raw_class) + valid = mapped_class >= 0 + else: + mapped_class = raw_class + valid = raw_class < top_limit + + if PER_CLASS: + safe_class = tl.where(valid, mapped_class, 0) + threshold = tl.load(threshold_ptr + safe_class) + else: + threshold = tl.load(threshold_ptr) + + keep = valid & (conf > threshold) + select_now = keep & (keep_count == pid_det) + selected_q = tl.where(select_now, query_idx, selected_q) + selected_c = tl.where(select_now, mapped_class, selected_c) + selected_conf = tl.where(select_now, conf, selected_conf) + keep_count += keep.to(tl.int32) + + if pid_det == 0: + tl.store(counter_ptr, keep_count) + + active = pid_det < keep_count + if not active: return - logit = tl.load(logits_ptr + pid_q * logits_stride_q + pid_c) - - if HAS_REMAPPING: - top_c = tl.load(class_map_ptr + pid_c) - valid = top_c >= 0 - else: - top_c = pid_c - valid = pid_c < num_classes_total - - abs_l = tl.abs(logit) - z = tl.exp(-abs_l) - sig_pos = 1.0 / (1.0 + z) - sig_neg = z / (1.0 + z) - conf = tl.where(logit >= 0.0, sig_pos, sig_neg) - - if PER_CLASS: - safe_c = tl.where(valid, top_c, 0) - thr = tl.load(threshold_ptr + safe_c) - else: - thr = tl.load(threshold_ptr) - keep = valid & (conf > thr) - - if not keep: - return - - # Match the non-Triton path's FP32 evaluation order for bit-parity. - cx_pct = tl.load(bboxes_ptr + pid_q * bboxes_stride_q + 0) - cy_pct = tl.load(bboxes_ptr + pid_q * bboxes_stride_q + 1) - w_pct = tl.load(bboxes_ptr + pid_q * bboxes_stride_q + 2) - h_pct = tl.load(bboxes_ptr + pid_q * bboxes_stride_q + 3) + cx_pct = tl.load(bboxes_ptr + selected_q * bboxes_stride_q + 0) + cy_pct = tl.load(bboxes_ptr + selected_q * bboxes_stride_q + 1) + w_pct = tl.load(bboxes_ptr + selected_q * bboxes_stride_q + 2) + h_pct = tl.load(bboxes_ptr + selected_q * bboxes_stride_q + 3) x1_pct = cx_pct - 0.5 * w_pct y1_pct = cy_pct - 0.5 * h_pct x2_pct = cx_pct + 0.5 * w_pct y2_pct = cy_pct + 0.5 * h_pct - x1 = x1_pct * inference_w - y1 = y1_pct * inference_h - x2 = x2_pct * inference_w - y2 = y2_pct * inference_h - - x1 = x1 - pad_left - y1 = y1 - pad_top - x2 = x2 - pad_left - y2 = y2 - pad_top + orig_w_f = orig_w.to(tl.float32) + orig_h_f = orig_h.to(tl.float32) + x1 = x1_pct * orig_w_f + y1 = y1_pct * orig_h_f + x2 = x2_pct * orig_w_f + y2 = y2_pct * orig_h_f - x1 = x1 * inv_scale_w - y1 = y1 * inv_scale_h - x2 = x2 * inv_scale_w - y2 = y2 * inv_scale_h - - x1 = tl.maximum(tl.minimum(x1, orig_w), 0.0) - y1 = tl.maximum(tl.minimum(y1, orig_h), 0.0) - x2 = tl.maximum(tl.minimum(x2, orig_w), 0.0) - y2 = tl.maximum(tl.minimum(y2, orig_h), 0.0) - - # Banker's rounding (half-to-even) matches torch.round().int(). + # Match torch.round(...).int() with half-to-even tie handling. x1_r = tl.floor(x1 + 0.5) y1_r = tl.floor(y1 + 0.5) x2_r = tl.floor(x2 + 0.5) @@ -136,48 +202,378 @@ def rfdetr_postproc_triton_kernel( x2_i = tl.where(((x2_r - x2) == 0.5) & ((x2_i & 1) != 0), x2_i - 1, x2_i) y2_i = tl.where(((y2_r - y2) == 0.5) & ((y2_i & 1) != 0), y2_i - 1, y2_i) - slot = tl.atomic_add(counter_ptr, 1) - - # Cap output at num_queries — mirrors the reference's flat top-K - # cap. Host slices to min(counter, num_queries) to ignore overflow. - if slot >= num_queries: - return - - # Bitcast conf (fp32) as int32 so the whole record writes with int32 - # stores. Host views the same memory as int32 and extracts via - # numpy.view(np.float32). - conf_bits = conf.to(tl.float32, bitcast=False) - conf_i32 = conf_bits.to(tl.int32, bitcast=True) - base = slot * 6 + conf_bits_out = selected_conf.to(tl.float32, bitcast=False).to( + tl.int32, bitcast=True + ) + base = pid_det * 6 tl.store(combined_out_ptr + base + 0, x1_i) tl.store(combined_out_ptr + base + 1, y1_i) tl.store(combined_out_ptr + base + 2, x2_i) tl.store(combined_out_ptr + base + 3, y2_i) - tl.store(combined_out_ptr + base + 4, conf_i32) - tl.store(combined_out_ptr + base + 5, top_c) - tl.store(survivor_idx_out_ptr + slot, pid_q.to(tl.int32)) + tl.store(combined_out_ptr + base + 4, conf_bits_out) + tl.store(combined_out_ptr + base + 5, selected_c) + + mask_base = masks_ptr + selected_q * masks_stride_q + if EMIT_RLE: + row_offsets = tl.arange(0, RLE_TILE_H) + col_offsets = tl.arange(0, RLE_TILE_W) + merge_offsets = tl.arange(0, RLE_MERGE_TILE) + lengths_row_ptr = rle_lengths_ptr + pid_det * rle_lengths_stride_q + counts_stride_col = orig_h + 1 + + for out_x in tl.range(0, orig_w, RLE_TILE_W, num_stages=1): + x = out_x + col_offsets + x_mask = x < orig_w + x_table_offset = x * 2 + x_idx_a = tl.load( + x_indices_ptr + x_table_offset + 0, mask=x_mask, other=0 + ) + x_idx_b = tl.load( + x_indices_ptr + x_table_offset + 1, mask=x_mask, other=0 + ) + wx_a = tl.load( + x_weights_ptr + x_table_offset + 0, mask=x_mask, other=0.0 + ) + wx_b = tl.load( + x_weights_ptr + x_table_offset + 1, mask=x_mask, other=0.0 + ) + col_base = ( + rle_counts_ptr + pid_det * rle_counts_stride_q + x * counts_stride_col + ) + run_length_vec = tl.zeros((RLE_TILE_W,), dtype=tl.int32) + prev_value_vec = tl.zeros((RLE_TILE_W,), dtype=tl.int32) + counts_idx_vec = tl.zeros((RLE_TILE_W,), dtype=tl.int32) + + for out_y in tl.range(0, orig_h, RLE_TILE_H, num_stages=1): + y = out_y + row_offsets + y_mask = y < orig_h + tile_mask = y_mask[:, None] & x_mask[None, :] + y_table_offset = y * 2 + y_idx_a = tl.load( + y_indices_ptr + y_table_offset + 0, mask=y_mask, other=0 + ) + y_idx_b = tl.load( + y_indices_ptr + y_table_offset + 1, mask=y_mask, other=0 + ) + wy_a = tl.load( + y_weights_ptr + y_table_offset + 0, mask=y_mask, other=0.0 + ) + wy_b = tl.load( + y_weights_ptr + y_table_offset + 1, mask=y_mask, other=0.0 + ) + + m00 = tl.load( + mask_base + + y_idx_a[:, None] * masks_stride_h + + x_idx_a[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + m01 = tl.load( + mask_base + + y_idx_a[:, None] * masks_stride_h + + x_idx_b[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + m10 = tl.load( + mask_base + + y_idx_b[:, None] * masks_stride_h + + x_idx_a[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + m11 = tl.load( + mask_base + + y_idx_b[:, None] * masks_stride_h + + x_idx_b[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + bits = ( + ( + (wy_a[:, None] * wx_a[None, :]) * m00 + + (wy_a[:, None] * wx_b[None, :]) * m01 + + (wy_b[:, None] * wx_a[None, :]) * m10 + + (wy_b[:, None] * wx_b[None, :]) * m11 + ) + > 0.0 + ).to(tl.int32) + + for local_y in tl.static_range(0, RLE_TILE_H): + valid_row = out_y + local_y < orig_h + row_mask = (row_offsets[:, None] == local_y).to(tl.int32) + bit_row = tl.sum(bits * row_mask, axis=0) + update_mask = x_mask & valid_row + change_row = update_mask & (bit_row != prev_value_vec) + tl.store(col_base + counts_idx_vec, run_length_vec, mask=change_row) + counts_idx_vec += change_row.to(tl.int32) + prev_value_vec = tl.where(update_mask, bit_row, prev_value_vec) + run_length_vec = tl.where( + update_mask, + tl.where(change_row, 1, run_length_vec + 1), + run_length_vec, + ) + + tl.store(col_base + counts_idx_vec, run_length_vec, mask=x_mask) + tl.store( + lengths_row_ptr + x, + counts_idx_vec + 1, + mask=x_mask, + ) + + final_len = tl.full((), 0, tl.int32) + prev_end_value = tl.full((), 0, tl.int32) + final_counts_ptr = rle_counts_ptr + pid_det * rle_counts_stride_q + + for out_x in tl.range(0, orig_w, RLE_TILE_W, num_stages=1): + x = out_x + col_offsets + x_mask = x < orig_w + col_lengths = tl.load(lengths_row_ptr + x, mask=x_mask, other=0) + + for local_x in tl.static_range(0, RLE_TILE_W): + col_x = out_x + local_x + valid_col = col_x < orig_w + col_mask = (col_offsets == local_x).to(tl.int32) + col_len = tl.sum(col_lengths * col_mask, axis=0) + col_counts_ptr = ( + rle_counts_ptr + + pid_det * rle_counts_stride_q + + col_x * counts_stride_col + ) + + is_first_col = final_len == 0 + merge_with_zero = (~is_first_col) & (prev_end_value == 0) + do_merge = valid_col & merge_with_zero + first_count = tl.load(col_counts_ptr + 0, mask=valid_col, other=0) + prev_last = tl.load( + final_counts_ptr + final_len - 1, + mask=do_merge, + other=0, + ) + tl.store( + final_counts_ptr + final_len - 1, + prev_last + first_count, + mask=do_merge, + ) + src_start = tl.where(merge_with_zero, 1, 0) + dst_start = tl.where(is_first_col, 0, final_len) + copy_len = tl.where(merge_with_zero, col_len - 1, col_len) + + for merge_off in tl.range( + 0, counts_stride_col, RLE_MERGE_TILE, num_stages=1 + ): + idx = merge_off + merge_offsets + copy_mask = valid_col & (idx < copy_len) + vals = tl.load( + col_counts_ptr + src_start + idx, + mask=copy_mask, + other=0, + ) + tl.store( + final_counts_ptr + dst_start + idx, + vals, + mask=copy_mask, + ) + + updated_final_len = tl.where( + is_first_col, + copy_len, + final_len + copy_len, + ) + final_len = tl.where(valid_col, updated_final_len, final_len) + prev_end_value = tl.where(valid_col, (col_len - 1) & 1, prev_end_value) + + tl.store(lengths_row_ptr + 0, final_len) + else: + row_offsets = tl.arange(0, MASK_TILE_H) + col_offsets = tl.arange(0, MASK_TILE_W) + + for out_y in tl.range(0, orig_h, MASK_TILE_H, num_stages=1): + y = out_y + row_offsets + y_mask = y < orig_h + y_table_offset = y * 2 + y_idx_a = tl.load( + y_indices_ptr + y_table_offset + 0, mask=y_mask, other=0 + ) + y_idx_b = tl.load( + y_indices_ptr + y_table_offset + 1, mask=y_mask, other=0 + ) + wy_a = tl.load( + y_weights_ptr + y_table_offset + 0, mask=y_mask, other=0.0 + ) + wy_b = tl.load( + y_weights_ptr + y_table_offset + 1, mask=y_mask, other=0.0 + ) + + for out_x in tl.range(0, orig_w, MASK_TILE_W, num_stages=1): + x = out_x + col_offsets + x_mask = x < orig_w + tile_mask = y_mask[:, None] & x_mask[None, :] + x_table_offset = x * 2 + x_idx_a = tl.load( + x_indices_ptr + x_table_offset + 0, mask=x_mask, other=0 + ) + x_idx_b = tl.load( + x_indices_ptr + x_table_offset + 1, mask=x_mask, other=0 + ) + wx_a = tl.load( + x_weights_ptr + x_table_offset + 0, mask=x_mask, other=0.0 + ) + wx_b = tl.load( + x_weights_ptr + x_table_offset + 1, mask=x_mask, other=0.0 + ) + + m00 = tl.load( + mask_base + + y_idx_a[:, None] * masks_stride_h + + x_idx_a[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + m01 = tl.load( + mask_base + + y_idx_a[:, None] * masks_stride_h + + x_idx_b[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + m10 = tl.load( + mask_base + + y_idx_b[:, None] * masks_stride_h + + x_idx_a[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + m11 = tl.load( + mask_base + + y_idx_b[:, None] * masks_stride_h + + x_idx_b[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + interp = ( + (wy_a[:, None] * wx_a[None, :]) * m00 + + (wy_a[:, None] * wx_b[None, :]) * m01 + + (wy_b[:, None] * wx_a[None, :]) * m10 + + (wy_b[:, None] * wx_b[None, :]) * m11 + ) + out_ptr = ( + mask_out_ptr + + pid_det * mask_out_stride_q + + y[:, None] * mask_out_stride_h + + x[None, :] * mask_out_stride_w + ) + tl.store(out_ptr, (interp > 0.0).to(tl.uint8), mask=tile_mask) _THRESHOLD_CACHE: dict = {} _EMPTY_INT32 = torch.empty((1,), dtype=torch.int32) _SCRATCH_CACHE: dict = {} _CLASS_MAPPING_INT32_CACHE: dict = {} +_AA_RESIZE_CACHE: dict = {} +_RLE_SCRATCH_CACHE: dict = {} + + +def _build_resize_axis_tables( + in_size: int, + out_size: int, + device: torch.device, + horizontal: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Returns 2-tap resize tables without invoking CUDA bootstrap kernels.""" + + del horizontal + coords = torch.arange(out_size, dtype=torch.float64) + scale = torch.tensor( + float(in_size) / float(out_size), dtype=torch.float32 + ).item() + src = (coords + 0.5) * scale - 0.5 + src.clamp_(0.0, float(in_size - 1)) + lo = torch.floor(src).to(torch.int32) + hi = torch.clamp(lo + 1, max=in_size - 1) + frac = src - lo.to(torch.float64) + w_lo = (1.0 - frac).to(torch.float32) + w_hi = frac.to(torch.float32) + indices = torch.stack((lo, hi), dim=1).contiguous() + weights = torch.stack((w_lo, w_hi), dim=1).contiguous() + return indices.to(device=device, non_blocking=True), weights.to( + device=device, non_blocking=True + ) -def _get_scratch_buffers(num_queries: int, device: torch.device): - key = (num_queries, device) - cached = _SCRATCH_CACHE.get(key) - if cached is None: +def _get_resize_tables( + orig_h: int, + orig_w: int, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + cached = _AA_RESIZE_CACHE.get(device) + shape = (orig_h, orig_w) + if cached is None or cached[0] != shape: + y_indices, y_weights = _build_resize_axis_tables( + in_size=FASTPATH_MASK_H, + out_size=orig_h, + device=device, + horizontal=False, + ) + x_indices, x_weights = _build_resize_axis_tables( + in_size=FASTPATH_MASK_W, + out_size=orig_w, + device=device, + horizontal=True, + ) + cached = (shape, y_indices, y_weights, x_indices, x_weights) + _AA_RESIZE_CACHE[device] = cached + _, y_indices, y_weights, x_indices, x_weights = cached + return y_indices, y_weights, x_indices, x_weights + + +def _get_scratch_buffers( + num_queries: int, + orig_h: int, + orig_w: int, + device: torch.device, +): + cached = _SCRATCH_CACHE.get(device) + shape = (num_queries, orig_h, orig_w) + if cached is None or cached[0] != shape: combined = torch.empty((num_queries, 6), dtype=torch.int32, device=device) - survivor_idx = torch.empty((num_queries,), dtype=torch.int32, device=device) + mask_bin = torch.empty( + (num_queries, orig_h, orig_w), dtype=torch.uint8, device=device + ) counter = torch.zeros((1,), dtype=torch.int32, device=device) - cached = (combined, survivor_idx, counter) - _SCRATCH_CACHE[key] = cached - return cached - - -def _get_class_mapping_int32(class_mapping: torch.Tensor, device: torch.device) -> torch.Tensor: - if class_mapping.dtype == torch.int32 and class_mapping.device == device and class_mapping.is_contiguous(): + cached = (shape, combined, mask_bin, counter) + _SCRATCH_CACHE[device] = cached + _, combined, mask_bin, counter = cached + return combined, mask_bin, counter + + +def _get_rle_buffers( + num_queries: int, + orig_h: int, + orig_w: int, + device: torch.device, +): + cached = _RLE_SCRATCH_CACHE.get(device) + max_counts = orig_w * (orig_h + 1) + shape = (num_queries, max_counts, orig_w) + if cached is None or cached[0] != shape: + counts = torch.empty((num_queries, max_counts), dtype=torch.int32, device=device) + lengths = torch.empty((num_queries, orig_w), dtype=torch.int32, device=device) + cached = (shape, counts, lengths) + _RLE_SCRATCH_CACHE[device] = cached + _, counts, lengths = cached + return counts, lengths + + +def _get_class_mapping_int32( + class_mapping: torch.Tensor, device: torch.device +) -> torch.Tensor: + if ( + class_mapping.dtype == torch.int32 + and class_mapping.device == device + and class_mapping.is_contiguous() + ): return class_mapping key = (id(class_mapping), device) cached = _CLASS_MAPPING_INT32_CACHE.get(key) @@ -190,10 +586,14 @@ def _get_class_mapping_int32(class_mapping: torch.Tensor, device: torch.device) def _prepare_threshold(threshold, device: torch.device, num_classes: int): if isinstance(threshold, torch.Tensor): - t = threshold - if t.dtype != torch.float32 or t.device != device or not t.is_contiguous(): - t = t.to(dtype=torch.float32, device=device).contiguous() - return t, True + tensor = threshold + if ( + tensor.dtype != torch.float32 + or tensor.device != device + or not tensor.is_contiguous() + ): + tensor = tensor.to(dtype=torch.float32, device=device).contiguous() + return tensor, True key = (float(threshold), device) cached = _THRESHOLD_CACHE.get(key) if cached is None: @@ -205,39 +605,77 @@ def _prepare_threshold(threshold, device: torch.device, num_classes: int): def rfdetr_triton_postproc( bboxes: torch.Tensor, logits: torch.Tensor, + masks: torch.Tensor, threshold: "torch.Tensor | float", num_classes: int, class_mapping: Optional[torch.Tensor], inference_size_wh: Tuple[int, int], - pad_ltrb: Tuple[int, int, int, int], scale_wh: Tuple[float, float], orig_size_wh: Tuple[int, int], + emit_rle: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, "torch.cuda.Event", + Optional[torch.Tensor], + Optional[torch.Tensor], ]: - """Filter step only — returns (combined, survivor_idx, counter, done_event). - - Buffers are unsliced — the caller waits on ``done_event``, reads - ``counter`` to learn n_survivors, and slices to ``[:n_survivors]``. - ``combined[:, 4]`` holds fp32 conf as int32 bits; reinterpret on the - host with ``.view(torch.float32)``. The mask upsample is intentionally - left to the caller (torch ``F.interpolate``) so the result is bit-exact - with the non-postproc reference path. + """Returns fast-path scratch buffers and completion event. + + ``combined`` is ``(Q, 6)`` int32 where column 4 is fp32 confidence bits. + ``mask_bin`` is ``(Q, H, W)`` uint8 whose bytes are reinterpreted as bool + on the host without an extra copy. ``counter`` stores the number of kept + detections from the reference flat top-k output. When ``emit_rle`` is true, + ``rle_counts`` and ``rle_lengths`` hold COCO-style uncompressed run-length + counts for each surviving detection. """ + device = bboxes.device num_queries, num_classes_total = logits.shape[1], logits.shape[2] + mask_h, mask_w = masks.shape[2], masks.shape[3] + if ( + num_queries != FASTPATH_NUM_QUERIES + or num_classes_total != FASTPATH_NUM_CLASSES_TOTAL + or mask_h != FASTPATH_MASK_H + or mask_w != FASTPATH_MASK_W + ): + raise ValueError( + "RF-DETR Triton fullpost fast path only supports the fixed TRT " + f"shape (Q={FASTPATH_NUM_QUERIES}, C={FASTPATH_NUM_CLASSES_TOTAL}, " + f"Mh={FASTPATH_MASK_H}, Mw={FASTPATH_MASK_W}), got " + f"{(num_queries, num_classes_total, mask_h, mask_w)}." + ) logits_2d = logits[0] if logits[0].is_contiguous() else logits[0].contiguous() bboxes_2d = bboxes[0] if bboxes[0].is_contiguous() else bboxes[0].contiguous() + masks_3d = masks[0] if masks[0].is_contiguous() else masks[0].contiguous() - combined, survivor_idx, counter = _get_scratch_buffers(num_queries, device) + orig_w, orig_h = orig_size_wh + combined, mask_bin, counter = _get_scratch_buffers( + num_queries=num_queries, + orig_h=orig_h, + orig_w=orig_w, + device=device, + ) + y_indices, y_weights, x_indices, x_weights = _get_resize_tables( + orig_h=orig_h, + orig_w=orig_w, + device=device, + ) counter.zero_() + if emit_rle: + rle_counts, rle_lengths_scratch = _get_rle_buffers( + num_queries=num_queries, + orig_h=orig_h, + orig_w=orig_w, + device=device, + ) + else: + rle_counts = None + rle_lengths_scratch = None thr_tensor, per_class = _prepare_threshold(threshold, device, num_classes) - if class_mapping is not None: has_remap = True cmap = _get_class_mapping_int32(class_mapping, device) @@ -245,38 +683,56 @@ def rfdetr_triton_postproc( has_remap = False cmap = _EMPTY_INT32.to(device, non_blocking=True) - inf_w, inf_h = inference_size_wh - pad_l, pad_t, _, _ = pad_ltrb - sw, sh = scale_wh - orig_w, orig_h = orig_size_wh - per_class_constexpr = 1 if per_class else 0 - has_remap_constexpr = 1 if has_remap else 0 + _ = inference_size_wh + _ = scale_wh + dummy_int32 = _EMPTY_INT32.to(device, non_blocking=True) - rfdetr_postproc_triton_kernel[(num_queries, num_classes_total)]( + rfdetr_fullpostproc_triton_kernel[(num_queries,)]( logits_2d, bboxes_2d, + masks_3d, thr_tensor, cmap, + y_indices, + y_weights, + x_indices, + x_weights, + rle_counts if rle_counts is not None else dummy_int32, + rle_lengths_scratch if rle_lengths_scratch is not None else dummy_int32, combined, - survivor_idx, + mask_bin, counter, - num_queries, - num_classes_total, - int(inf_w), - int(inf_h), - int(pad_l), - int(pad_t), - float(1.0 / sw), - float(1.0 / sh), int(orig_w), int(orig_h), logits_2d.stride(0), bboxes_2d.stride(0), - PER_CLASS=tl.constexpr(per_class_constexpr), - HAS_REMAPPING=tl.constexpr(has_remap_constexpr), + masks_3d.stride(0), + masks_3d.stride(1), + masks_3d.stride(2), + (rle_counts.stride(0) if rle_counts is not None else 1), + (rle_lengths_scratch.stride(0) if rle_lengths_scratch is not None else 1), + mask_bin.stride(0), + mask_bin.stride(1), + mask_bin.stride(2), + PER_CLASS=1 if per_class else 0, + HAS_REMAPPING=1 if has_remap else 0, + EMIT_RLE=1 if emit_rle else 0, + NUM_QUERIES=FASTPATH_NUM_QUERIES, + NUM_CLASSES_TOTAL=FASTPATH_NUM_CLASSES_TOTAL, + MASK_H=FASTPATH_MASK_H, + MASK_W=FASTPATH_MASK_W, + TOPK_PAD=_TOPK_PAD, + CLASS_BLOCK=_CLASS_BLOCK, + MASK_TILE_H=_MASK_TILE_H, + MASK_TILE_W=_MASK_TILE_W, + RLE_TILE_H=_RLE_TILE_H, + RLE_TILE_W=_RLE_TILE_W, + RLE_MERGE_TILE=_RLE_MERGE_TILE, + num_warps=4, + num_stages=1, ) done_event = torch.cuda.Event() done_event.record(torch.cuda.current_stream(device)) - - return combined, survivor_idx, counter, done_event + rle_lengths = rle_lengths_scratch[:, 0] if rle_lengths_scratch is not None else None + return combined, mask_bin, counter, done_event, rle_counts, rle_lengths diff --git a/inference_models/tests/unit_tests/models/common/test_rle_utils.py b/inference_models/tests/unit_tests/models/common/test_rle_utils.py index 53d9dafa18..cc268adaf3 100644 --- a/inference_models/tests/unit_tests/models/common/test_rle_utils.py +++ b/inference_models/tests/unit_tests/models/common/test_rle_utils.py @@ -6,6 +6,7 @@ from inference_models.models.base.types import InstancesRLEMasks from inference_models.models.common.rle_utils import ( + LazyInstancesRLEMasks, coco_rle_masks_to_numpy_mask, coco_rle_masks_to_torch_mask, torch_mask_to_coco_rle, @@ -30,6 +31,25 @@ def _rle_from_tensors(masks: List[torch.Tensor]) -> InstancesRLEMasks: return InstancesRLEMasks.from_coco_rle_masks(image_size=(h, w), masks=encoded) +def _uncompressed_counts_from_mask(mask: torch.Tensor) -> np.ndarray: + mask_flat = np.ravel(mask.detach().cpu().numpy().astype(bool), order="F") + if mask_flat.size == 0: + return np.empty((0,), dtype=np.int32) + transitions = np.flatnonzero(mask_flat[1:] != mask_flat[:-1]) + 1 + counts = np.diff( + np.concatenate( + ( + np.array([0], dtype=np.int64), + transitions.astype(np.int64, copy=False), + np.array([mask_flat.size], dtype=np.int64), + ) + ) + ).astype(np.int32, copy=False) + if mask_flat[0]: + counts = np.concatenate((np.array([0], dtype=np.int32), counts)) + return counts + + def test_torch_mask_to_coco_rle_returns_dict_with_size_and_counts() -> None: mask = _make_rectangle_mask(20, 30, (5, 10, 15, 25)) rle = torch_mask_to_coco_rle(mask) @@ -307,3 +327,54 @@ def test_roundtrip_many_masks_preserves_all() -> None: assert decoded.shape == (n, h, w) for i, m in enumerate(masks): assert torch.equal(decoded[i], m) + + +def test_lazy_instances_rle_masks_materializes_from_cpu_counts() -> None: + masks = [ + _make_rectangle_mask(24, 32, (2, 3, 18, 20)), + _make_rectangle_mask(24, 32, (8, 10, 20, 28)), + ] + counts = [_uncompressed_counts_from_mask(mask) for mask in masks] + lengths = np.array([len(c) for c in counts], dtype=np.int32) + counts_cpu = np.zeros((len(counts), int(lengths.max())), dtype=np.int32) + for i, count in enumerate(counts): + counts_cpu[i, : len(count)] = count + + wrapped = LazyInstancesRLEMasks( + image_size=(24, 32), + rle_counts_cpu=counts_cpu, + rle_lengths_cpu=lengths, + ) + + decoded = coco_rle_masks_to_torch_mask(wrapped) + for i, mask in enumerate(masks): + assert torch.equal(decoded[i], mask) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_lazy_instances_rle_masks_materializes_from_cuda_counts() -> None: + masks = [ + _make_rectangle_mask(20, 20, (0, 0, 8, 8)), + _make_rectangle_mask(20, 20, (10, 10, 20, 20)), + ] + counts = [_uncompressed_counts_from_mask(mask) for mask in masks] + lengths = np.array([len(c) for c in counts], dtype=np.int32) + counts_cpu = np.zeros((len(counts), int(lengths.max())), dtype=np.int32) + for i, count in enumerate(counts): + counts_cpu[i, : len(count)] = count + + counts_gpu = torch.from_numpy(counts_cpu).to(device="cuda", dtype=torch.int32) + lengths_gpu = torch.from_numpy(lengths).to(device="cuda", dtype=torch.int32) + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + + wrapped = LazyInstancesRLEMasks( + image_size=(20, 20), + rle_counts_gpu=counts_gpu, + rle_lengths_gpu=lengths_gpu, + done_event=event, + ) + + decoded = coco_rle_masks_to_torch_mask(wrapped) + for i, mask in enumerate(masks): + assert torch.equal(decoded[i], mask) From d3030bfbbd539c1027f1168871a19c35df542116 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 05:20:31 +0000 Subject: [PATCH 10/25] Stop forwarding response mask format in adapter --- inference/core/models/inference_models_adapters.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 177b9b18ab..01382c69bb 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -326,8 +326,6 @@ def postprocess( ) -> List[InstanceSegmentationInferenceResponse]: return_in_rle = kwargs.get("response_mask_format") == "rle" mapped_kwargs = self.map_inference_kwargs(kwargs) - if "response_mask_format" in kwargs: - mapped_kwargs["response_mask_format"] = kwargs["response_mask_format"] detections_list = self._model.post_process( predictions, preprocess_return_metadata, **mapped_kwargs ) From 2ffb62be7f43fa2fefe2c793d9eb346ade505ed6 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 05:48:33 +0000 Subject: [PATCH 11/25] Pack fused detections into one host copy --- .../core/models/inference_models_adapters.py | 64 ++++++++++++------- .../inference_models/models/rfdetr/common.py | 1 + 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 01382c69bb..b3722f1098 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -335,6 +335,7 @@ def postprocess( H = preproc_metadata.original_size.height W = preproc_metadata.original_size.width + combined_gpu = getattr(det, "_combined_gpu", None) mask_gpu = getattr(det, "_mask_gpu", None) mask_cpu = getattr(det, "_mask_cpu", None) done_event = getattr(det, "_postproc_done_event", None) @@ -343,12 +344,6 @@ def postprocess( and done_event is not None and isinstance(mask_gpu, torch.Tensor) and mask_gpu.is_cuda - and isinstance(det.xyxy, torch.Tensor) - and det.xyxy.is_cuda - and isinstance(det.confidence, torch.Tensor) - and det.confidence.is_cuda - and isinstance(det.class_id, torch.Tensor) - and det.class_id.is_cuda ): device = mask_gpu.device stream = torch.cuda.current_stream(device) @@ -362,27 +357,48 @@ def postprocess( polys_or_rles = [] else: mask_slice = mask_gpu[:n_survivors] - xyxy_host = get_pinned_buffer( - "xyxy", tuple(det.xyxy.shape), det.xyxy.dtype - ) - conf_host = get_pinned_buffer( - "conf", tuple(det.confidence.shape), det.confidence.dtype - ) - class_host = get_pinned_buffer( - "class_id", tuple(det.class_id.shape), det.class_id.dtype - ) mask_host = get_pinned_buffer( "mask", tuple(mask_slice.shape), mask_slice.dtype ) - xyxy_host.copy_(det.xyxy, non_blocking=True) - conf_host.copy_(det.confidence, non_blocking=True) - class_host.copy_(det.class_id, non_blocking=True) - mask_host.copy_(mask_slice, non_blocking=True) - stream.synchronize() - xyxy = xyxy_host.numpy() - confs = conf_host.numpy() - class_ids = class_host.numpy() - polys_or_rles = masks2poly(mask_host.numpy()) + if ( + isinstance(combined_gpu, torch.Tensor) + and combined_gpu.is_cuda + and tuple(combined_gpu.shape) + == (n_survivors, det.xyxy.shape[1] + 2) + ): + combined_slice = combined_gpu[:n_survivors] + combined_host = get_pinned_buffer( + "combined", + tuple(combined_slice.shape), + combined_slice.dtype, + ) + combined_host.copy_(combined_slice, non_blocking=True) + mask_host.copy_(mask_slice, non_blocking=True) + stream.synchronize() + combined_np = combined_host.numpy() + xyxy = combined_np[:, :4] + confs = combined_np[:, 4].view(np.float32) + class_ids = combined_np[:, 5] + polys_or_rles = masks2poly(mask_host.numpy()) + else: + xyxy_host = get_pinned_buffer( + "xyxy", tuple(det.xyxy.shape), det.xyxy.dtype + ) + conf_host = get_pinned_buffer( + "conf", tuple(det.confidence.shape), det.confidence.dtype + ) + class_host = get_pinned_buffer( + "class_id", tuple(det.class_id.shape), det.class_id.dtype + ) + xyxy_host.copy_(det.xyxy, non_blocking=True) + conf_host.copy_(det.confidence, non_blocking=True) + class_host.copy_(det.class_id, non_blocking=True) + mask_host.copy_(mask_slice, non_blocking=True) + stream.synchronize() + xyxy = xyxy_host.numpy() + confs = conf_host.numpy() + class_ids = class_host.numpy() + polys_or_rles = masks2poly(mask_host.numpy()) elif not return_in_rle and isinstance(mask_cpu, np.ndarray): xyxy = det.xyxy.detach().cpu().numpy() confs = det.confidence.detach().cpu().numpy() diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index a6f322b550..4944995ea5 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -385,6 +385,7 @@ def post_process_instance_segmentation_results_to_rle_masks( mask=instances_masks, ) if not emit_in_kernel_rle: + detections.__dict__["_combined_gpu"] = combined_slice detections.__dict__["_mask_gpu"] = mask_bin[:n_survivors].view(torch.bool) detections.__dict__["_postproc_done_event"] = done_event return [detections] From 3330429385cd4b46ab76938116da4c247bce874b Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 05:57:21 +0000 Subject: [PATCH 12/25] Drop redundant Triton postproc counter reset --- .../inference_models/models/rfdetr/triton_fullpostproc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index 584139ebee..5ae30e8013 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -541,7 +541,7 @@ def _get_scratch_buffers( mask_bin = torch.empty( (num_queries, orig_h, orig_w), dtype=torch.uint8, device=device ) - counter = torch.zeros((1,), dtype=torch.int32, device=device) + counter = torch.empty((1,), dtype=torch.int32, device=device) cached = (shape, combined, mask_bin, counter) _SCRATCH_CACHE[device] = cached _, combined, mask_bin, counter = cached @@ -663,7 +663,6 @@ def rfdetr_triton_postproc( orig_w=orig_w, device=device, ) - counter.zero_() if emit_rle: rle_counts, rle_lengths_scratch = _get_rle_buffers( num_queries=num_queries, From 9bfb5b15417e4fb38da40709481efc9403b80d99 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 06:23:26 +0000 Subject: [PATCH 13/25] Defer RF-DETR survivor count to adapter --- .../core/models/inference_models_adapters.py | 160 +++++++++++------- .../inference_models/models/rfdetr/common.py | 82 ++++++--- .../rfdetr_instance_segmentation_onnx.py | 1 + .../rfdetr_instance_segmentation_pytorch.py | 1 + .../rfdetr_instance_segmentation_trt.py | 1 + .../models/rfdetr/triton_fullpostproc.py | 42 +++-- 6 files changed, 196 insertions(+), 91 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index b3722f1098..befe7aac4d 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -9,21 +9,6 @@ from PIL import Image, ImageDraw, ImageFont from pycocotools import mask as mask_utils -# Pinned host buffers for async DtoH on the full-postproc Triton fast path. -# Keyed by (name, dtype); reused across frames provided the cached buffer is -# at least as large as the requested shape in every dimension. -PINNED_HOST_BUFFERS: dict = {} - - -def get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor: - key = (name, dtype) - buf = PINNED_HOST_BUFFERS.get(key) - if buf is not None and all(buf.shape[i] >= shape[i] for i in range(len(shape))): - return buf[tuple(slice(0, s) for s in shape)] - buf = torch.empty(shape, dtype=dtype, pin_memory=True) - PINNED_HOST_BUFFERS[key] = buf - return buf - from inference.core.entities.requests import ( ClassificationInferenceRequest, InferenceRequest, @@ -102,6 +87,21 @@ def get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor: "#FF39C9", ] +# Pinned host buffers for async DtoH on the full-postproc Triton fast path. +# Keyed by (name, dtype); reused across frames provided the cached buffer is +# at least as large as the requested shape in every dimension. +PINNED_HOST_BUFFERS: dict = {} + + +def get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor: + key = (name, dtype) + buf = PINNED_HOST_BUFFERS.get(key) + if buf is not None and all(buf.shape[i] >= shape[i] for i in range(len(shape))): + return buf[tuple(slice(0, s) for s in shape)] + buf = torch.empty(shape, dtype=dtype, pin_memory=True) + PINNED_HOST_BUFFERS[key] = buf + return buf + class InferenceModelsObjectDetectionAdapter(Model): def __init__(self, model_id: str, api_key: str = None, **kwargs): @@ -326,6 +326,7 @@ def postprocess( ) -> List[InstanceSegmentationInferenceResponse]: return_in_rle = kwargs.get("response_mask_format") == "rle" mapped_kwargs = self.map_inference_kwargs(kwargs) + mapped_kwargs["defer_count_to_adapter"] = not return_in_rle detections_list = self._model.post_process( predictions, preprocess_return_metadata, **mapped_kwargs ) @@ -338,6 +339,7 @@ def postprocess( combined_gpu = getattr(det, "_combined_gpu", None) mask_gpu = getattr(det, "_mask_gpu", None) mask_cpu = getattr(det, "_mask_cpu", None) + defer_count_to_adapter = getattr(det, "_defer_count_to_adapter", False) done_event = getattr(det, "_postproc_done_event", None) if ( not return_in_rle @@ -348,57 +350,99 @@ def postprocess( device = mask_gpu.device stream = torch.cuda.current_stream(device) done_event.wait(stream) - n_survivors = int(det.xyxy.shape[0]) - if n_survivors == 0: - xyxy = np.empty((0, 4), dtype=np.int32) - confs = np.empty((0,), dtype=np.float32) - class_ids = np.empty((0,), dtype=np.int32) - polys_or_rles = [] - else: - mask_slice = mask_gpu[:n_survivors] - mask_host = get_pinned_buffer( - "mask", tuple(mask_slice.shape), mask_slice.dtype + if ( + defer_count_to_adapter + and isinstance(combined_gpu, torch.Tensor) + and combined_gpu.is_cuda + ): + combined_host = get_pinned_buffer( + "combined_full", + tuple(combined_gpu.shape), + combined_gpu.dtype, + ) + combined_host.copy_(combined_gpu, non_blocking=True) + stream.synchronize() + combined_np = combined_host.numpy() + class_column = combined_np[:, 5] + inactive_indices = np.flatnonzero(class_column < 0) + n_survivors = ( + int(inactive_indices[0]) + if inactive_indices.size > 0 + else int(class_column.shape[0]) ) - if ( - isinstance(combined_gpu, torch.Tensor) - and combined_gpu.is_cuda - and tuple(combined_gpu.shape) - == (n_survivors, det.xyxy.shape[1] + 2) - ): - combined_slice = combined_gpu[:n_survivors] - combined_host = get_pinned_buffer( - "combined", - tuple(combined_slice.shape), - combined_slice.dtype, + if n_survivors == 0: + xyxy = np.empty((0, 4), dtype=np.int32) + confs = np.empty((0,), dtype=np.float32) + class_ids = np.empty((0,), dtype=np.int32) + polys_or_rles = [] + else: + combined_slice = combined_np[:n_survivors] + mask_slice = mask_gpu[:n_survivors] + mask_host = get_pinned_buffer( + "mask", tuple(mask_slice.shape), mask_slice.dtype ) - combined_host.copy_(combined_slice, non_blocking=True) mask_host.copy_(mask_slice, non_blocking=True) stream.synchronize() - combined_np = combined_host.numpy() - xyxy = combined_np[:, :4] - confs = combined_np[:, 4].view(np.float32) - class_ids = combined_np[:, 5] + xyxy = combined_slice[:, :4] + confs = combined_slice[:, 4].view(np.float32) + class_ids = combined_slice[:, 5] polys_or_rles = masks2poly(mask_host.numpy()) + else: + n_survivors = int(det.xyxy.shape[0]) + if n_survivors == 0: + xyxy = np.empty((0, 4), dtype=np.int32) + confs = np.empty((0,), dtype=np.float32) + class_ids = np.empty((0,), dtype=np.int32) + polys_or_rles = [] else: - xyxy_host = get_pinned_buffer( - "xyxy", tuple(det.xyxy.shape), det.xyxy.dtype + mask_slice = mask_gpu[:n_survivors] + mask_host = get_pinned_buffer( + "mask", tuple(mask_slice.shape), mask_slice.dtype ) - conf_host = get_pinned_buffer( - "conf", tuple(det.confidence.shape), det.confidence.dtype - ) - class_host = get_pinned_buffer( - "class_id", tuple(det.class_id.shape), det.class_id.dtype - ) - xyxy_host.copy_(det.xyxy, non_blocking=True) - conf_host.copy_(det.confidence, non_blocking=True) - class_host.copy_(det.class_id, non_blocking=True) - mask_host.copy_(mask_slice, non_blocking=True) - stream.synchronize() - xyxy = xyxy_host.numpy() - confs = conf_host.numpy() - class_ids = class_host.numpy() - polys_or_rles = masks2poly(mask_host.numpy()) + if ( + isinstance(combined_gpu, torch.Tensor) + and combined_gpu.is_cuda + and tuple(combined_gpu.shape) + == (n_survivors, det.xyxy.shape[1] + 2) + ): + combined_slice = combined_gpu[:n_survivors] + combined_host = get_pinned_buffer( + "combined", + tuple(combined_slice.shape), + combined_slice.dtype, + ) + combined_host.copy_(combined_slice, non_blocking=True) + mask_host.copy_(mask_slice, non_blocking=True) + stream.synchronize() + combined_np = combined_host.numpy() + xyxy = combined_np[:, :4] + confs = combined_np[:, 4].view(np.float32) + class_ids = combined_np[:, 5] + polys_or_rles = masks2poly(mask_host.numpy()) + else: + xyxy_host = get_pinned_buffer( + "xyxy", tuple(det.xyxy.shape), det.xyxy.dtype + ) + conf_host = get_pinned_buffer( + "conf", + tuple(det.confidence.shape), + det.confidence.dtype, + ) + class_host = get_pinned_buffer( + "class_id", + tuple(det.class_id.shape), + det.class_id.dtype, + ) + xyxy_host.copy_(det.xyxy, non_blocking=True) + conf_host.copy_(det.confidence, non_blocking=True) + class_host.copy_(det.class_id, non_blocking=True) + mask_host.copy_(mask_slice, non_blocking=True) + stream.synchronize() + xyxy = xyxy_host.numpy() + confs = conf_host.numpy() + class_ids = class_host.numpy() + polys_or_rles = masks2poly(mask_host.numpy()) elif not return_in_rle and isinstance(mask_cpu, np.ndarray): xyxy = det.xyxy.detach().cpu().numpy() confs = det.confidence.detach().cpu().numpy() diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index 4944995ea5..3553ce15d3 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -28,7 +28,10 @@ TRITON_AVAILABLE as _TRITON_POSTPROC_AVAILABLE, rfdetr_triton_postproc, ) - _TRITON_POSTPROC_READY = _TRITON_POSTPROC_AVAILABLE and torch.cuda.is_available() + + _TRITON_POSTPROC_READY = ( + _TRITON_POSTPROC_AVAILABLE and torch.cuda.is_available() + ) except Exception: _TRITON_POSTPROC_READY = False rfdetr_triton_postproc = None @@ -71,7 +74,12 @@ def post_triton_eligible( return False if meta.static_crop_offset.offset_x != 0 or meta.static_crop_offset.offset_y != 0: return False - if meta.pad_left != 0 or meta.pad_top != 0 or meta.pad_right != 0 or meta.pad_bottom != 0: + if ( + meta.pad_left != 0 + or meta.pad_top != 0 + or meta.pad_right != 0 + or meta.pad_bottom != 0 + ): return False if ( meta.size_after_pre_processing.height < FASTPATH_MASK_H @@ -321,32 +329,56 @@ def post_process_instance_segmentation_results_to_rle_masks( num_classes: int, classes_re_mapping: Optional[ClassesReMapping], emit_in_kernel_rle: bool = False, + defer_count_to_adapter: bool = False, ) -> List[InstanceDetections]: if post_triton_eligible( bboxes, logits, masks, pre_processing_meta, classes_re_mapping ): meta = pre_processing_meta[0] thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) - combined, mask_bin, counter, done_event, rle_counts, rle_lengths = rfdetr_triton_postproc( - bboxes=bboxes, - logits=logits, - masks=masks, - threshold=thr_arg, - num_classes=num_classes, - class_mapping=classes_re_mapping.class_mapping, - inference_size_wh=(meta.inference_size.width, meta.inference_size.height), - scale_wh=(meta.scale_width, meta.scale_height), - orig_size_wh=(meta.original_size.width, meta.original_size.height), - emit_rle=emit_in_kernel_rle, + combined, mask_bin, counter, done_event, rle_counts, rle_lengths = ( + rfdetr_triton_postproc( + bboxes=bboxes, + logits=logits, + masks=masks, + threshold=thr_arg, + num_classes=num_classes, + class_mapping=classes_re_mapping.class_mapping, + inference_size_wh=( + meta.inference_size.width, + meta.inference_size.height, + ), + scale_wh=(meta.scale_width, meta.scale_height), + orig_size_wh=(meta.original_size.width, meta.original_size.height), + emit_rle=emit_in_kernel_rle, + ) ) done_event.wait(torch.cuda.current_stream(bboxes.device)) - n_survivors = int(counter.item()) orig_h = meta.original_size.height orig_w = meta.original_size.width - if n_survivors == 0: - empty_xyxy = torch.empty( - (0, 4), dtype=torch.int32, device=bboxes.device + if defer_count_to_adapter and not emit_in_kernel_rle: + empty_xyxy = torch.empty((0, 4), dtype=torch.int32, device=bboxes.device) + empty_conf = torch.empty((0,), dtype=torch.float32, device=bboxes.device) + empty_cls = torch.empty((0,), dtype=torch.int32, device=bboxes.device) + detections = InstanceDetections( + xyxy=empty_xyxy, + confidence=empty_conf, + class_id=empty_cls, + mask=LazyInstancesRLEMasks( + image_size=(orig_h, orig_w), + mask_gpu=mask_bin.view(torch.bool), + done_event=done_event, + ), ) + detections.__dict__["_combined_gpu"] = combined + detections.__dict__["_mask_gpu"] = mask_bin.view(torch.bool) + detections.__dict__["_defer_count_to_adapter"] = True + detections.__dict__["_postproc_done_event"] = done_event + return [detections] + + n_survivors = int(counter.item()) + if n_survivors == 0: + empty_xyxy = torch.empty((0, 4), dtype=torch.int32, device=bboxes.device) empty_conf = torch.empty((0,), dtype=torch.float32, device=bboxes.device) empty_cls = torch.empty((0,), dtype=torch.int32, device=bboxes.device) return [ @@ -367,12 +399,16 @@ def post_process_instance_segmentation_results_to_rle_masks( if not emit_in_kernel_rle else None ), - rle_counts_gpu=rle_counts[:n_survivors] - if emit_in_kernel_rle and rle_counts is not None - else None, - rle_lengths_gpu=rle_lengths[:n_survivors] - if emit_in_kernel_rle and rle_lengths is not None - else None, + rle_counts_gpu=( + rle_counts[:n_survivors] + if emit_in_kernel_rle and rle_counts is not None + else None + ), + rle_lengths_gpu=( + rle_lengths[:n_survivors] + if emit_in_kernel_rle and rle_lengths is not None + else None + ), done_event=done_event, ) xyxy = combined_slice[:, :4] diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py index 2f93cccbc0..04ffd14309 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py @@ -264,5 +264,6 @@ def post_process( num_classes=len(self.class_names), classes_re_mapping=self._classes_re_mapping, emit_in_kernel_rle=kwargs.get("response_mask_format") == "rle", + defer_count_to_adapter=kwargs.get("defer_count_to_adapter", False), ) return results diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py index 920ca7cf59..c00caf1107 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py @@ -470,5 +470,6 @@ def post_process( num_classes=len(self.class_names), classes_re_mapping=self._classes_re_mapping, emit_in_kernel_rle=kwargs.get("response_mask_format") == "rle", + defer_count_to_adapter=kwargs.get("defer_count_to_adapter", False), ) return results diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py index 903a6ee6d0..f4725b1e59 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py @@ -319,6 +319,7 @@ def post_process( num_classes=len(self.class_names), classes_re_mapping=self._classes_re_mapping, emit_in_kernel_rle=kwargs.get("response_mask_format") == "rle", + defer_count_to_adapter=kwargs.get("defer_count_to_adapter", False), ) self._post_process_stream.synchronize() return results diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index 5ae30e8013..7cb069e748 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -18,6 +18,7 @@ Host-side work is limited to slicing the preallocated buffers once the kernel completes and wrapping them in ``InstanceDetections`` / RLE containers. """ + from typing import Optional, Tuple import torch @@ -167,8 +168,15 @@ def rfdetr_fullpostproc_triton_kernel( if pid_det == 0: tl.store(counter_ptr, keep_count) + base = pid_det * 6 active = pid_det < keep_count if not active: + tl.store(combined_out_ptr + base + 0, 0) + tl.store(combined_out_ptr + base + 1, 0) + tl.store(combined_out_ptr + base + 2, 0) + tl.store(combined_out_ptr + base + 3, 0) + tl.store(combined_out_ptr + base + 4, 0) + tl.store(combined_out_ptr + base + 5, -1) return cx_pct = tl.load(bboxes_ptr + selected_q * bboxes_stride_q + 0) @@ -205,7 +213,6 @@ def rfdetr_fullpostproc_triton_kernel( conf_bits_out = selected_conf.to(tl.float32, bitcast=False).to( tl.int32, bitcast=True ) - base = pid_det * 6 tl.store(combined_out_ptr + base + 0, x1_i) tl.store(combined_out_ptr + base + 1, y1_i) tl.store(combined_out_ptr + base + 2, x2_i) @@ -238,7 +245,9 @@ def rfdetr_fullpostproc_triton_kernel( x_weights_ptr + x_table_offset + 1, mask=x_mask, other=0.0 ) col_base = ( - rle_counts_ptr + pid_det * rle_counts_stride_q + x * counts_stride_col + rle_counts_ptr + + pid_det * rle_counts_stride_q + + x * counts_stride_col ) run_length_vec = tl.zeros((RLE_TILE_W,), dtype=tl.int32) prev_value_vec = tl.zeros((RLE_TILE_W,), dtype=tl.int32) @@ -306,7 +315,9 @@ def rfdetr_fullpostproc_triton_kernel( bit_row = tl.sum(bits * row_mask, axis=0) update_mask = x_mask & valid_row change_row = update_mask & (bit_row != prev_value_vec) - tl.store(col_base + counts_idx_vec, run_length_vec, mask=change_row) + tl.store( + col_base + counts_idx_vec, run_length_vec, mask=change_row + ) counts_idx_vec += change_row.to(tl.int32) prev_value_vec = tl.where(update_mask, bit_row, prev_value_vec) run_length_vec = tl.where( @@ -382,7 +393,9 @@ def rfdetr_fullpostproc_triton_kernel( final_len + copy_len, ) final_len = tl.where(valid_col, updated_final_len, final_len) - prev_end_value = tl.where(valid_col, (col_len - 1) & 1, prev_end_value) + prev_end_value = tl.where( + valid_col, (col_len - 1) & 1, prev_end_value + ) tl.store(lengths_row_ptr + 0, final_len) else: @@ -469,6 +482,7 @@ def rfdetr_fullpostproc_triton_kernel( _THRESHOLD_CACHE: dict = {} _EMPTY_INT32 = torch.empty((1,), dtype=torch.int32) +_EMPTY_INT32_DEVICE_CACHE: dict = {} _SCRATCH_CACHE: dict = {} _CLASS_MAPPING_INT32_CACHE: dict = {} _AA_RESIZE_CACHE: dict = {} @@ -485,9 +499,7 @@ def _build_resize_axis_tables( del horizontal coords = torch.arange(out_size, dtype=torch.float64) - scale = torch.tensor( - float(in_size) / float(out_size), dtype=torch.float32 - ).item() + scale = torch.tensor(float(in_size) / float(out_size), dtype=torch.float32).item() src = (coords + 0.5) * scale - 0.5 src.clamp_(0.0, float(in_size - 1)) lo = torch.floor(src).to(torch.int32) @@ -558,7 +570,9 @@ def _get_rle_buffers( max_counts = orig_w * (orig_h + 1) shape = (num_queries, max_counts, orig_w) if cached is None or cached[0] != shape: - counts = torch.empty((num_queries, max_counts), dtype=torch.int32, device=device) + counts = torch.empty( + (num_queries, max_counts), dtype=torch.int32, device=device + ) lengths = torch.empty((num_queries, orig_w), dtype=torch.int32, device=device) cached = (shape, counts, lengths) _RLE_SCRATCH_CACHE[device] = cached @@ -602,6 +616,14 @@ def _prepare_threshold(threshold, device: torch.device, num_classes: int): return cached, False +def _get_empty_int32_on_device(device: torch.device) -> torch.Tensor: + cached = _EMPTY_INT32_DEVICE_CACHE.get(device) + if cached is None: + cached = torch.empty((1,), dtype=torch.int32, device=device) + _EMPTY_INT32_DEVICE_CACHE[device] = cached + return cached + + def rfdetr_triton_postproc( bboxes: torch.Tensor, logits: torch.Tensor, @@ -680,11 +702,11 @@ def rfdetr_triton_postproc( cmap = _get_class_mapping_int32(class_mapping, device) else: has_remap = False - cmap = _EMPTY_INT32.to(device, non_blocking=True) + cmap = _get_empty_int32_on_device(device) _ = inference_size_wh _ = scale_wh - dummy_int32 = _EMPTY_INT32.to(device, non_blocking=True) + dummy_int32 = _get_empty_int32_on_device(device) rfdetr_fullpostproc_triton_kernel[(num_queries,)]( logits_2d, From 90cde1607468bf515f04a66d7590daa7baba40ee Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 06:48:42 +0000 Subject: [PATCH 14/25] Bit-pack RF-DETR workflow mask transfer --- .../core/models/inference_models_adapters.py | 39 ++++++++--- inference/core/utils/postprocess.py | 17 +++++ .../models/common/rle_utils.py | 29 ++++++++ .../inference_models/models/rfdetr/common.py | 6 +- .../models/rfdetr/triton_fullpostproc.py | 70 ++++++++++++++----- .../models/common/test_rle_utils.py | 37 ++++++++++ 6 files changed, 169 insertions(+), 29 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index befe7aac4d..a09f2665a1 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -42,7 +42,7 @@ from inference.core.models.base import Model from inference.core.roboflow_api import get_extra_weights_provider_headers from inference.core.utils.image_utils import load_image_bgr, load_image_rgb -from inference.core.utils.postprocess import mask2poly, masks2poly +from inference.core.utils.postprocess import bitpacked_masks2poly, mask2poly, masks2poly from inference.core.utils.visualisation import draw_detection_predictions from inference.models.aliases import resolve_roboflow_model_alias from inference_models import ( @@ -338,16 +338,20 @@ def postprocess( combined_gpu = getattr(det, "_combined_gpu", None) mask_gpu = getattr(det, "_mask_gpu", None) + mask_packed_gpu = getattr(det, "_mask_packed_gpu", None) mask_cpu = getattr(det, "_mask_cpu", None) defer_count_to_adapter = getattr(det, "_defer_count_to_adapter", False) done_event = getattr(det, "_postproc_done_event", None) + dense_mask_cuda = isinstance(mask_gpu, torch.Tensor) and mask_gpu.is_cuda + packed_mask_cuda = ( + isinstance(mask_packed_gpu, torch.Tensor) and mask_packed_gpu.is_cuda + ) if ( not return_in_rle and done_event is not None - and isinstance(mask_gpu, torch.Tensor) - and mask_gpu.is_cuda + and (dense_mask_cuda or packed_mask_cuda) ): - device = mask_gpu.device + device = mask_gpu.device if dense_mask_cuda else mask_packed_gpu.device stream = torch.cuda.current_stream(device) done_event.wait(stream) @@ -378,16 +382,29 @@ def postprocess( polys_or_rles = [] else: combined_slice = combined_np[:n_survivors] - mask_slice = mask_gpu[:n_survivors] - mask_host = get_pinned_buffer( - "mask", tuple(mask_slice.shape), mask_slice.dtype - ) - mask_host.copy_(mask_slice, non_blocking=True) - stream.synchronize() xyxy = combined_slice[:, :4] confs = combined_slice[:, 4].view(np.float32) class_ids = combined_slice[:, 5] - polys_or_rles = masks2poly(mask_host.numpy()) + if packed_mask_cuda: + packed_slice = mask_packed_gpu[:n_survivors] + packed_host = get_pinned_buffer( + "mask_packed", + tuple(packed_slice.shape), + packed_slice.dtype, + ) + packed_host.copy_(packed_slice, non_blocking=True) + stream.synchronize() + polys_or_rles = bitpacked_masks2poly( + packed_host.numpy(), width=W + ) + else: + mask_slice = mask_gpu[:n_survivors] + mask_host = get_pinned_buffer( + "mask", tuple(mask_slice.shape), mask_slice.dtype + ) + mask_host.copy_(mask_slice, non_blocking=True) + stream.synchronize() + polys_or_rles = masks2poly(mask_host.numpy()) else: n_survivors = int(det.xyxy.shape[0]) if n_survivors == 0: diff --git a/inference/core/utils/postprocess.py b/inference/core/utils/postprocess.py index 0ceafb8e75..c32ede44cd 100644 --- a/inference/core/utils/postprocess.py +++ b/inference/core/utils/postprocess.py @@ -61,6 +61,23 @@ def masks2poly(masks: np.ndarray) -> List[np.ndarray]: return segments +def bitpacked_masks2poly(bitpacked_masks: np.ndarray, width: int) -> List[np.ndarray]: + """Convert bit-packed masks with 8 pixels per byte into polygons.""" + segments = [] + for packed_mask in bitpacked_masks: + packed = ( + packed_mask + if packed_mask.flags.c_contiguous + else np.ascontiguousarray(packed_mask) + ) + unpacked = np.unpackbits(packed, axis=-1, bitorder="little")[..., :width] + if not np.any(unpacked): + segments.append(np.zeros((0, 2), dtype=np.float32)) + continue + segments.append(mask2poly(unpacked)) + return segments + + def masks2multipoly(masks: np.ndarray) -> List[np.ndarray]: """Converts binary masks to polygonal segments. diff --git a/inference_models/inference_models/models/common/rle_utils.py b/inference_models/inference_models/models/common/rle_utils.py index 5348e01a14..dc6970a720 100644 --- a/inference_models/inference_models/models/common/rle_utils.py +++ b/inference_models/inference_models/models/common/rle_utils.py @@ -45,6 +45,16 @@ def numpy_mask_to_coco_rle(mask: np.ndarray) -> dict: return counts_to_coco_rle(counts=counts, image_size=tuple(mask_bool.shape)) +def unpack_bitpacked_masks_numpy(bitpacked_masks: np.ndarray, width: int) -> np.ndarray: + packed = np.asarray(bitpacked_masks, dtype=np.uint8) + if packed.ndim != 3: + raise ValueError( + f"Expected bitpacked masks with shape (N, H, Wbytes), got {packed.shape}." + ) + unpacked = np.unpackbits(np.ascontiguousarray(packed), axis=-1, bitorder="little") + return np.ascontiguousarray(unpacked[..., :width]) + + class LazyInstancesRLEMasks(InstancesRLEMasks): """Materializes COCO RLE counts only when a caller actually needs them.""" @@ -52,6 +62,8 @@ def __init__( self, image_size: tuple, mask_gpu: Optional[torch.Tensor] = None, + mask_packed_gpu: Optional[torch.Tensor] = None, + mask_packed_width: Optional[int] = None, mask_cpu: Optional[np.ndarray] = None, rle_counts_gpu: Optional[torch.Tensor] = None, rle_lengths_gpu: Optional[torch.Tensor] = None, @@ -63,6 +75,8 @@ def __init__( self._masks: list = [] self._materialized = False self._mask_gpu = mask_gpu + self._mask_packed_gpu = mask_packed_gpu + self._mask_packed_width = mask_packed_width self._mask_cpu = mask_cpu self._rle_counts_gpu = rle_counts_gpu self._rle_lengths_gpu = rle_lengths_gpu @@ -83,6 +97,21 @@ def masks(self, value: list) -> None: def _ensure_mask_cpu(self) -> np.ndarray: if self._mask_cpu is not None: return self._mask_cpu + if self._mask_packed_gpu is not None: + device = self._mask_packed_gpu.device + stream = torch.cuda.current_stream(device) + if self._done_event is not None: + self._done_event.wait(stream) + packed_cpu = self._mask_packed_gpu.cpu().numpy() + width = ( + self._mask_packed_width + if self._mask_packed_width is not None + else self.image_size[1] + ) + self._mask_cpu = unpack_bitpacked_masks_numpy(packed_cpu, width=width).view( + np.bool_ + ) + return self._mask_cpu if self._mask_gpu is None: self._mask_cpu = np.empty( (0, self.image_size[0], self.image_size[1]), dtype=bool diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index 3553ce15d3..a1814f642f 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -351,6 +351,7 @@ def post_process_instance_segmentation_results_to_rle_masks( scale_wh=(meta.scale_width, meta.scale_height), orig_size_wh=(meta.original_size.width, meta.original_size.height), emit_rle=emit_in_kernel_rle, + pack_dense_masks=defer_count_to_adapter and not emit_in_kernel_rle, ) ) done_event.wait(torch.cuda.current_stream(bboxes.device)) @@ -366,12 +367,13 @@ def post_process_instance_segmentation_results_to_rle_masks( class_id=empty_cls, mask=LazyInstancesRLEMasks( image_size=(orig_h, orig_w), - mask_gpu=mask_bin.view(torch.bool), + mask_packed_gpu=mask_bin, + mask_packed_width=orig_w, done_event=done_event, ), ) detections.__dict__["_combined_gpu"] = combined - detections.__dict__["_mask_gpu"] = mask_bin.view(torch.bool) + detections.__dict__["_mask_packed_gpu"] = mask_bin detections.__dict__["_defer_count_to_adapter"] = True detections.__dict__["_postproc_done_event"] = done_event return [detections] diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index 7cb069e748..ad504b36dd 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -11,6 +11,7 @@ - applies class remap + confidence filtering - denormalizes / rescales / rounds the selected box - resizes the selected 78x78 mask to the original image size and thresholds it +- optionally bit-packs the dense mask sidecar for lower DtoH transfer volume The resize uses cached 2-tap closed-form bilinear tables, so the per-inference hot path remains a single Triton launch without any CUDA bootstrap probes. @@ -81,6 +82,7 @@ def rfdetr_fullpostproc_triton_kernel( PER_CLASS: tl.constexpr, HAS_REMAPPING: tl.constexpr, EMIT_RLE: tl.constexpr, + PACK_DENSE_MASKS: tl.constexpr, NUM_QUERIES: tl.constexpr, NUM_CLASSES_TOTAL: tl.constexpr, MASK_H: tl.constexpr, @@ -401,6 +403,9 @@ def rfdetr_fullpostproc_triton_kernel( else: row_offsets = tl.arange(0, MASK_TILE_H) col_offsets = tl.arange(0, MASK_TILE_W) + if PACK_DENSE_MASKS: + packed_col_offsets = tl.arange(0, MASK_TILE_W // 8) + bit_weights = (1 << tl.arange(0, 8)).to(tl.int32) for out_y in tl.range(0, orig_h, MASK_TILE_H, num_stages=1): y = out_y + row_offsets @@ -471,13 +476,40 @@ def rfdetr_fullpostproc_triton_kernel( + (wy_b[:, None] * wx_a[None, :]) * m10 + (wy_b[:, None] * wx_b[None, :]) * m11 ) - out_ptr = ( - mask_out_ptr - + pid_det * mask_out_stride_q - + y[:, None] * mask_out_stride_h - + x[None, :] * mask_out_stride_w - ) - tl.store(out_ptr, (interp > 0.0).to(tl.uint8), mask=tile_mask) + bits = (interp > 0.0).to(tl.int32) + if PACK_DENSE_MASKS: + packed = tl.sum( + tl.reshape(bits, (MASK_TILE_H, MASK_TILE_W // 8, 8)) + * bit_weights[None, None, :], + axis=2, + ).to(tl.uint8) + byte_mask = ( + tl.sum( + tl.reshape(x_mask.to(tl.int32), (MASK_TILE_W // 8, 8)), + axis=1, + ) + > 0 + ) + out_ptr = ( + mask_out_ptr + + pid_det * mask_out_stride_q + + y[:, None] * mask_out_stride_h + + (out_x // 8 + packed_col_offsets)[None, :] + * mask_out_stride_w + ) + tl.store( + out_ptr, + packed, + mask=y_mask[:, None] & byte_mask[None, :], + ) + else: + out_ptr = ( + mask_out_ptr + + pid_det * mask_out_stride_q + + y[:, None] * mask_out_stride_h + + x[None, :] * mask_out_stride_w + ) + tl.store(out_ptr, bits.to(tl.uint8), mask=tile_mask) _THRESHOLD_CACHE: dict = {} @@ -545,17 +577,20 @@ def _get_scratch_buffers( orig_h: int, orig_w: int, device: torch.device, + pack_dense_masks: bool, ): - cached = _SCRATCH_CACHE.get(device) - shape = (num_queries, orig_h, orig_w) + key = (device, pack_dense_masks) + cached = _SCRATCH_CACHE.get(key) + mask_w = (orig_w + 7) // 8 if pack_dense_masks else orig_w + shape = (num_queries, orig_h, mask_w) if cached is None or cached[0] != shape: combined = torch.empty((num_queries, 6), dtype=torch.int32, device=device) mask_bin = torch.empty( - (num_queries, orig_h, orig_w), dtype=torch.uint8, device=device + (num_queries, orig_h, mask_w), dtype=torch.uint8, device=device ) counter = torch.empty((1,), dtype=torch.int32, device=device) cached = (shape, combined, mask_bin, counter) - _SCRATCH_CACHE[device] = cached + _SCRATCH_CACHE[key] = cached _, combined, mask_bin, counter = cached return combined, mask_bin, counter @@ -635,6 +670,7 @@ def rfdetr_triton_postproc( scale_wh: Tuple[float, float], orig_size_wh: Tuple[int, int], emit_rle: bool = False, + pack_dense_masks: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -646,11 +682,11 @@ def rfdetr_triton_postproc( """Returns fast-path scratch buffers and completion event. ``combined`` is ``(Q, 6)`` int32 where column 4 is fp32 confidence bits. - ``mask_bin`` is ``(Q, H, W)`` uint8 whose bytes are reinterpreted as bool - on the host without an extra copy. ``counter`` stores the number of kept - detections from the reference flat top-k output. When ``emit_rle`` is true, - ``rle_counts`` and ``rle_lengths`` hold COCO-style uncompressed run-length - counts for each surviving detection. + ``mask_bin`` is uint8 scratch: either ``(Q, H, W)`` dense bytes or + ``(Q, H, ceil(W / 8))`` bit-packed bytes when ``pack_dense_masks`` is true. + ``counter`` stores the number of kept detections from the reference flat + top-k output. When ``emit_rle`` is true, ``rle_counts`` and ``rle_lengths`` + hold COCO-style uncompressed run-length counts for each surviving detection. """ device = bboxes.device @@ -679,6 +715,7 @@ def rfdetr_triton_postproc( orig_h=orig_h, orig_w=orig_w, device=device, + pack_dense_masks=pack_dense_masks and not emit_rle, ) y_indices, y_weights, x_indices, x_weights = _get_resize_tables( orig_h=orig_h, @@ -738,6 +775,7 @@ def rfdetr_triton_postproc( PER_CLASS=1 if per_class else 0, HAS_REMAPPING=1 if has_remap else 0, EMIT_RLE=1 if emit_rle else 0, + PACK_DENSE_MASKS=1 if (pack_dense_masks and not emit_rle) else 0, NUM_QUERIES=FASTPATH_NUM_QUERIES, NUM_CLASSES_TOTAL=FASTPATH_NUM_CLASSES_TOTAL, MASK_H=FASTPATH_MASK_H, diff --git a/inference_models/tests/unit_tests/models/common/test_rle_utils.py b/inference_models/tests/unit_tests/models/common/test_rle_utils.py index cc268adaf3..64292c9827 100644 --- a/inference_models/tests/unit_tests/models/common/test_rle_utils.py +++ b/inference_models/tests/unit_tests/models/common/test_rle_utils.py @@ -10,6 +10,7 @@ coco_rle_masks_to_numpy_mask, coco_rle_masks_to_torch_mask, torch_mask_to_coco_rle, + unpack_bitpacked_masks_numpy, ) @@ -329,6 +330,15 @@ def test_roundtrip_many_masks_preserves_all() -> None: assert torch.equal(decoded[i], m) +@pytest.mark.parametrize("h,w,n", [(1, 1, 1), (7, 13, 3), (20, 32, 2)]) +def test_unpack_bitpacked_masks_numpy_roundtrip(h: int, w: int, n: int) -> None: + rng = np.random.default_rng(seed=1234) + masks = rng.integers(0, 2, size=(n, h, w), dtype=np.uint8) + packed = np.packbits(masks, axis=-1, bitorder="little") + unpacked = unpack_bitpacked_masks_numpy(packed, width=w) + np.testing.assert_array_equal(unpacked, masks) + + def test_lazy_instances_rle_masks_materializes_from_cpu_counts() -> None: masks = [ _make_rectangle_mask(24, 32, (2, 3, 18, 20)), @@ -378,3 +388,30 @@ def test_lazy_instances_rle_masks_materializes_from_cuda_counts() -> None: decoded = coco_rle_masks_to_torch_mask(wrapped) for i, mask in enumerate(masks): assert torch.equal(decoded[i], mask) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_lazy_instances_rle_masks_materializes_from_cuda_bitpacked_masks() -> None: + masks = torch.stack( + [ + _make_rectangle_mask(21, 29, (2, 3, 18, 20)), + _make_rectangle_mask(21, 29, (8, 10, 20, 28)), + ], + dim=0, + ) + packed = np.packbits(masks.numpy().astype(np.uint8), axis=-1, bitorder="little") + packed_gpu = torch.from_numpy(np.ascontiguousarray(packed)).to( + device="cuda", dtype=torch.uint8 + ) + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + + wrapped = LazyInstancesRLEMasks( + image_size=(21, 29), + mask_packed_gpu=packed_gpu, + mask_packed_width=29, + done_event=event, + ) + + decoded = coco_rle_masks_to_torch_mask(wrapped) + assert torch.equal(decoded.cpu(), masks) From 0af52c838510f6873ab2f02923f6529935d93f7f Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 16:20:45 +0000 Subject: [PATCH 15/25] restore parity --- .../models/rfdetr/triton_fullpostproc.py | 93 ++++++++++++------- 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index ad504b36dd..e98566f9ed 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -301,13 +301,10 @@ def rfdetr_fullpostproc_triton_kernel( mask=tile_mask, other=0.0, ) + interp_top = wx_a[None, :] * m00 + wx_b[None, :] * m01 + interp_bottom = wx_a[None, :] * m10 + wx_b[None, :] * m11 bits = ( - ( - (wy_a[:, None] * wx_a[None, :]) * m00 - + (wy_a[:, None] * wx_b[None, :]) * m01 - + (wy_b[:, None] * wx_a[None, :]) * m10 - + (wy_b[:, None] * wx_b[None, :]) * m11 - ) + (wy_a[:, None] * interp_top + wy_b[:, None] * interp_bottom) > 0.0 ).to(tl.int32) @@ -470,12 +467,9 @@ def rfdetr_fullpostproc_triton_kernel( mask=tile_mask, other=0.0, ) - interp = ( - (wy_a[:, None] * wx_a[None, :]) * m00 - + (wy_a[:, None] * wx_b[None, :]) * m01 - + (wy_b[:, None] * wx_a[None, :]) * m10 - + (wy_b[:, None] * wx_b[None, :]) * m11 - ) + interp_top = wx_a[None, :] * m00 + wx_b[None, :] * m01 + interp_bottom = wx_a[None, :] * m10 + wx_b[None, :] * m11 + interp = wy_a[:, None] * interp_top + wy_b[:, None] * interp_bottom bits = (interp > 0.0).to(tl.int32) if PACK_DENSE_MASKS: packed = tl.sum( @@ -527,20 +521,55 @@ def _build_resize_axis_tables( device: torch.device, horizontal: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Returns 2-tap resize tables without invoking CUDA bootstrap kernels.""" - - del horizontal - coords = torch.arange(out_size, dtype=torch.float64) - scale = torch.tensor(float(in_size) / float(out_size), dtype=torch.float32).item() - src = (coords + 0.5) * scale - 0.5 - src.clamp_(0.0, float(in_size - 1)) - lo = torch.floor(src).to(torch.int32) - hi = torch.clamp(lo + 1, max=in_size - 1) - frac = src - lo.to(torch.float64) - w_lo = (1.0 - frac).to(torch.float32) - w_hi = frac.to(torch.float32) - indices = torch.stack((lo, hi), dim=1).contiguous() - weights = torch.stack((w_lo, w_hi), dim=1).contiguous() + """Returns exact 2-tap tables extracted from the reference CUDA resize op.""" + + basis = torch.eye(in_size, dtype=torch.float32, device=device) + if horizontal: + resized = torch.nn.functional.interpolate( + basis[:, None, None, :], + size=(1, out_size), + mode="bilinear", + align_corners=False, + antialias=True, + )[:, 0, 0, :] + else: + resized = torch.nn.functional.interpolate( + basis[:, None, :, None], + size=(out_size, 1), + mode="bilinear", + align_corners=False, + antialias=True, + )[:, 0, :, 0] + + resized_cpu = resized.cpu() + indices = torch.empty((out_size, 2), dtype=torch.int32) + weights = torch.zeros((out_size, 2), dtype=torch.float32) + for out_idx in range(out_size): + support = torch.nonzero( + resized_cpu[:, out_idx].abs() > 0, as_tuple=False + ).flatten() + if support.numel() == 0: + raise ValueError( + f"Reference bilinear AA resize produced no support for axis " + f"{out_idx} of shape {in_size}->{out_size}." + ) + if support.numel() > 2: + raise ValueError( + "RF-DETR Triton fullpost fast path only supports 2-tap " + "upsample resize tables, but the reference resize produced " + f"{support.numel()} taps for shape {in_size}->{out_size}." + ) + + idx_a = int(support[0]) + indices[out_idx, 0] = idx_a + weights[out_idx, 0] = resized_cpu[idx_a, out_idx] + if support.numel() == 1: + indices[out_idx, 1] = idx_a + weights[out_idx, 1] = 0.0 + else: + idx_b = int(support[1]) + indices[out_idx, 1] = idx_b + weights[out_idx, 1] = resized_cpu[idx_b, out_idx] return indices.to(device=device, non_blocking=True), weights.to( device=device, non_blocking=True ) @@ -551,9 +580,9 @@ def _get_resize_tables( orig_w: int, device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - cached = _AA_RESIZE_CACHE.get(device) - shape = (orig_h, orig_w) - if cached is None or cached[0] != shape: + key = (device, orig_h, orig_w) + cached = _AA_RESIZE_CACHE.get(key) + if cached is None: y_indices, y_weights = _build_resize_axis_tables( in_size=FASTPATH_MASK_H, out_size=orig_h, @@ -566,9 +595,9 @@ def _get_resize_tables( device=device, horizontal=True, ) - cached = (shape, y_indices, y_weights, x_indices, x_weights) - _AA_RESIZE_CACHE[device] = cached - _, y_indices, y_weights, x_indices, x_weights = cached + cached = (y_indices, y_weights, x_indices, x_weights) + _AA_RESIZE_CACHE[key] = cached + y_indices, y_weights, x_indices, x_weights = cached return y_indices, y_weights, x_indices, x_weights From bc84b9d0307580d29fcd246b5e428b8b141c69a3 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 16:38:37 +0000 Subject: [PATCH 16/25] Add RF-DETR postprocess microbench --- temp/postproc_microbench.py | 125 ++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 temp/postproc_microbench.py diff --git a/temp/postproc_microbench.py b/temp/postproc_microbench.py new file mode 100644 index 0000000000..9b857214ad --- /dev/null +++ b/temp/postproc_microbench.py @@ -0,0 +1,125 @@ +"""Isolated RF-DETR segmentation post_process() microbenchmark. + +Benchmarks only model.post_process() on frozen TensorRT outputs for a few +representative original-image sizes. Run once with +RFDETR_TRITON_POSTPROC=true and once with RFDETR_TRITON_POSTPROC=false. +""" + +import argparse +import os +import time +from pathlib import Path + +import cv2 +import torch + +os.environ.setdefault( + "DISABLED_INFERENCE_MODELS_BACKENDS", + "torch,torch-script,onnx,hugging-face,ultralytics,mediapipe,custom", +) + +from inference_models import AutoModel + + +DEFAULT_SIZES = ("176x312", "720x1280", "1080x1920") +DEFAULT_VIDEO = Path("/home/ubuntu/inference/vehicles_312px.mp4") + + +def _parse_hw(spec: str) -> tuple[int, int]: + try: + h, w = spec.lower().split("x", 1) + return int(h), int(w) + except ValueError as exc: + raise argparse.ArgumentTypeError(f"invalid size '{spec}', expected HxW") from exc + + +def _read_seed_frame(video_path: Path): + cap = cv2.VideoCapture(str(video_path)) + ok, frame = cap.read() + cap.release() + if not ok or frame is None: + raise RuntimeError(f"failed to read a frame from {video_path}") + return frame + + +def _sync_detection(det) -> None: + done_event = getattr(det, "_postproc_done_event", None) + if done_event is not None: + done_event.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + +def _prepare_case(model, frame, height: int, width: int): + resized = cv2.resize(frame, (width, height), interpolation=cv2.INTER_LINEAR) + with torch.inference_mode(): + preprocessed, metadata = model.pre_process(resized) + outputs = model.forward(preprocessed) + return outputs, metadata + + +def _benchmark_case( + model, + outputs, + metadata, + confidence: float, + warmup: int, + iterations: int, +): + with torch.inference_mode(): + for _ in range(warmup): + det = model.post_process(outputs, metadata, confidence=confidence)[0] + _sync_detection(det) + + start = time.perf_counter() + det_count = 0 + for _ in range(iterations): + det = model.post_process(outputs, metadata, confidence=confidence)[0] + _sync_detection(det) + det_count += int(det.class_id.numel()) + elapsed = time.perf_counter() - start + return (elapsed * 1000.0) / iterations, det_count // max(1, iterations) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--video_reference", type=Path, default=DEFAULT_VIDEO) + parser.add_argument("--model_id", default="rfdetr-seg-nano") + parser.add_argument("--confidence", type=float, default=0.4) + parser.add_argument("--warmup", type=int, default=25) + parser.add_argument("--iterations", type=int, default=200) + parser.add_argument("--size", dest="sizes", action="append") + args = parser.parse_args() + + specs = args.sizes if args.sizes else list(DEFAULT_SIZES) + sizes = [_parse_hw(spec) for spec in specs] + flag = os.environ.get("RFDETR_TRITON_POSTPROC", "") + + print( + f"[setup] RFDETR_TRITON_POSTPROC={flag} " + f"video={args.video_reference} warmup={args.warmup} iterations={args.iterations}", + flush=True, + ) + + frame = _read_seed_frame(args.video_reference) + model = AutoModel.from_pretrained(args.model_id) + cases = [] + for height, width in sizes: + outputs, metadata = _prepare_case(model, frame, height, width) + cases.append((height, width, outputs, metadata)) + + print("size,detections,mean_ms", flush=True) + for height, width, outputs, metadata in cases: + mean_ms, detections = _benchmark_case( + model=model, + outputs=outputs, + metadata=metadata, + confidence=args.confidence, + warmup=args.warmup, + iterations=args.iterations, + ) + print(f"{height}x{width},{detections},{mean_ms:.4f}", flush=True) + + +if __name__ == "__main__": + main() From fcd23f2a3542bde62b4ccce30f6152f8bd33bdbd Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 17:22:41 +0000 Subject: [PATCH 17/25] Use lru_cache for pure Triton postproc caches --- .../models/rfdetr/triton_fullpostproc.py | 62 ++++++++----------- 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index e98566f9ed..f30ad895e1 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -13,13 +13,15 @@ - resizes the selected 78x78 mask to the original image size and thresholds it - optionally bit-packs the dense mask sidecar for lower DtoH transfer volume -The resize uses cached 2-tap closed-form bilinear tables, so the per-inference -hot path remains a single Triton launch without any CUDA bootstrap probes. +The resize uses cached exact 2-tap bilinear tables extracted once from the +reference CUDA resize operator, so the steady-state hot path remains a single +Triton launch. Host-side work is limited to slicing the preallocated buffers once the kernel completes and wrapping them in ``InstanceDetections`` / RLE containers. """ +from functools import lru_cache from typing import Optional, Tuple import torch @@ -506,12 +508,8 @@ def rfdetr_fullpostproc_triton_kernel( tl.store(out_ptr, bits.to(tl.uint8), mask=tile_mask) -_THRESHOLD_CACHE: dict = {} -_EMPTY_INT32 = torch.empty((1,), dtype=torch.int32) -_EMPTY_INT32_DEVICE_CACHE: dict = {} _SCRATCH_CACHE: dict = {} _CLASS_MAPPING_INT32_CACHE: dict = {} -_AA_RESIZE_CACHE: dict = {} _RLE_SCRATCH_CACHE: dict = {} @@ -575,29 +573,24 @@ def _build_resize_axis_tables( ) +@lru_cache(maxsize=128) def _get_resize_tables( orig_h: int, orig_w: int, device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - key = (device, orig_h, orig_w) - cached = _AA_RESIZE_CACHE.get(key) - if cached is None: - y_indices, y_weights = _build_resize_axis_tables( - in_size=FASTPATH_MASK_H, - out_size=orig_h, - device=device, - horizontal=False, - ) - x_indices, x_weights = _build_resize_axis_tables( - in_size=FASTPATH_MASK_W, - out_size=orig_w, - device=device, - horizontal=True, - ) - cached = (y_indices, y_weights, x_indices, x_weights) - _AA_RESIZE_CACHE[key] = cached - y_indices, y_weights, x_indices, x_weights = cached + y_indices, y_weights = _build_resize_axis_tables( + in_size=FASTPATH_MASK_H, + out_size=orig_h, + device=device, + horizontal=False, + ) + x_indices, x_weights = _build_resize_axis_tables( + in_size=FASTPATH_MASK_W, + out_size=orig_w, + device=device, + horizontal=True, + ) return y_indices, y_weights, x_indices, x_weights @@ -662,6 +655,13 @@ def _get_class_mapping_int32( return cached +@lru_cache(maxsize=32) +def _get_scalar_threshold_tensor( + threshold_value: float, device: torch.device +) -> torch.Tensor: + return torch.tensor([threshold_value], dtype=torch.float32, device=device) + + def _prepare_threshold(threshold, device: torch.device, num_classes: int): if isinstance(threshold, torch.Tensor): tensor = threshold @@ -672,20 +672,12 @@ def _prepare_threshold(threshold, device: torch.device, num_classes: int): ): tensor = tensor.to(dtype=torch.float32, device=device).contiguous() return tensor, True - key = (float(threshold), device) - cached = _THRESHOLD_CACHE.get(key) - if cached is None: - cached = torch.tensor([float(threshold)], dtype=torch.float32, device=device) - _THRESHOLD_CACHE[key] = cached - return cached, False + return _get_scalar_threshold_tensor(float(threshold), device), False +@lru_cache(maxsize=None) def _get_empty_int32_on_device(device: torch.device) -> torch.Tensor: - cached = _EMPTY_INT32_DEVICE_CACHE.get(device) - if cached is None: - cached = torch.empty((1,), dtype=torch.int32, device=device) - _EMPTY_INT32_DEVICE_CACHE[device] = cached - return cached + return torch.empty((1,), dtype=torch.int32, device=device) def rfdetr_triton_postproc( From 79bd5bd42e84155683cb45d200df4bdd7f466943 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 22:00:02 +0000 Subject: [PATCH 18/25] reduce preconditions on kernel eligibility --- .../inference_models/models/rfdetr/common.py | 114 +++- .../models/rfdetr/triton_fullpostproc.py | 607 ++++++++++++------ 2 files changed, 477 insertions(+), 244 deletions(-) diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index a1814f642f..d22619d6a1 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -21,12 +21,9 @@ if RFDETR_TRITON_POSTPROC: try: from inference_models.models.rfdetr.triton_fullpostproc import ( - FASTPATH_MASK_H, - FASTPATH_MASK_W, - FASTPATH_NUM_CLASSES_TOTAL, - FASTPATH_NUM_QUERIES, TRITON_AVAILABLE as _TRITON_POSTPROC_AVAILABLE, rfdetr_triton_postproc, + rfdetr_triton_postproc_geometry_supported, ) _TRITON_POSTPROC_READY = ( @@ -35,9 +32,11 @@ except Exception: _TRITON_POSTPROC_READY = False rfdetr_triton_postproc = None + rfdetr_triton_postproc_geometry_supported = None else: _TRITON_POSTPROC_READY = False rfdetr_triton_postproc = None + rfdetr_triton_postproc_geometry_supported = None def post_triton_eligible( @@ -49,6 +48,8 @@ def post_triton_eligible( ) -> bool: if not _TRITON_POSTPROC_READY: return False + if bboxes.ndim != 3 or logits.ndim != 3 or masks.ndim != 4: + return False if not bboxes.is_cuda or not logits.is_cuda or not masks.is_cuda: return False if bboxes.device != logits.device or bboxes.device != masks.device: @@ -60,35 +61,45 @@ def post_triton_eligible( or len(pre_processing_meta) != 1 ): return False - if ( - bboxes.shape[1] != FASTPATH_NUM_QUERIES - or logits.shape[1] != FASTPATH_NUM_QUERIES - or logits.shape[2] != FASTPATH_NUM_CLASSES_TOTAL - or masks.shape[1] != FASTPATH_NUM_QUERIES - or masks.shape[2] != FASTPATH_MASK_H - or masks.shape[3] != FASTPATH_MASK_W - ): - return False - meta = pre_processing_meta[0] - if meta.nonsquare_intermediate_size is not None: - return False - if meta.static_crop_offset.offset_x != 0 or meta.static_crop_offset.offset_y != 0: - return False - if ( - meta.pad_left != 0 - or meta.pad_top != 0 - or meta.pad_right != 0 - or meta.pad_bottom != 0 - ): + if bboxes.shape[2] != 4: return False + num_queries = bboxes.shape[1] + num_classes_total = logits.shape[2] + mask_h = masks.shape[2] + mask_w = masks.shape[3] if ( - meta.size_after_pre_processing.height < FASTPATH_MASK_H - or meta.size_after_pre_processing.width < FASTPATH_MASK_W + num_queries <= 0 + or num_classes_total <= 0 + or mask_h <= 0 + or mask_w <= 0 + or logits.shape[1] != num_queries + or masks.shape[1] != num_queries ): return False - if classes_re_mapping is None: + meta = pre_processing_meta[0] + if rfdetr_triton_postproc_geometry_supported is None: return False - return True + denorm_size = meta.nonsquare_intermediate_size or meta.inference_size + return rfdetr_triton_postproc_geometry_supported( + denorm_size_wh=(denorm_size.width, denorm_size.height), + pad_ltrb=( + meta.pad_left, + meta.pad_top, + meta.pad_right, + meta.pad_bottom, + ), + scale_wh=(meta.scale_width, meta.scale_height), + orig_size_wh=(meta.original_size.width, meta.original_size.height), + size_after_pre_processing_wh=( + meta.size_after_pre_processing.width, + meta.size_after_pre_processing.height, + ), + static_crop_offset_xy=( + meta.static_crop_offset.offset_x, + meta.static_crop_offset.offset_y, + ), + mask_size_hw=(mask_h, mask_w), + ) def parse_model_type(config_path: str) -> str: @@ -207,6 +218,7 @@ def post_process_instance_segmentation_results( bboxes, logits, masks, pre_processing_meta, classes_re_mapping ): meta = pre_processing_meta[0] + denorm_size = meta.nonsquare_intermediate_size or meta.inference_size thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) combined, mask_bin, counter, done_event, _, _ = rfdetr_triton_postproc( bboxes=bboxes, @@ -214,10 +226,28 @@ def post_process_instance_segmentation_results( masks=masks, threshold=thr_arg, num_classes=num_classes, - class_mapping=classes_re_mapping.class_mapping, - inference_size_wh=(meta.inference_size.width, meta.inference_size.height), + class_mapping=( + classes_re_mapping.class_mapping + if classes_re_mapping is not None + else None + ), + denorm_size_wh=(denorm_size.width, denorm_size.height), + pad_ltrb=( + meta.pad_left, + meta.pad_top, + meta.pad_right, + meta.pad_bottom, + ), scale_wh=(meta.scale_width, meta.scale_height), orig_size_wh=(meta.original_size.width, meta.original_size.height), + size_after_pre_processing_wh=( + meta.size_after_pre_processing.width, + meta.size_after_pre_processing.height, + ), + static_crop_offset_xy=( + meta.static_crop_offset.offset_x, + meta.static_crop_offset.offset_y, + ), ) done_event.wait(torch.cuda.current_stream(bboxes.device)) n_survivors = int(counter.item()) @@ -335,6 +365,7 @@ def post_process_instance_segmentation_results_to_rle_masks( bboxes, logits, masks, pre_processing_meta, classes_re_mapping ): meta = pre_processing_meta[0] + denorm_size = meta.nonsquare_intermediate_size or meta.inference_size thr_arg = threshold if isinstance(threshold, torch.Tensor) else float(threshold) combined, mask_bin, counter, done_event, rle_counts, rle_lengths = ( rfdetr_triton_postproc( @@ -343,13 +374,28 @@ def post_process_instance_segmentation_results_to_rle_masks( masks=masks, threshold=thr_arg, num_classes=num_classes, - class_mapping=classes_re_mapping.class_mapping, - inference_size_wh=( - meta.inference_size.width, - meta.inference_size.height, + class_mapping=( + classes_re_mapping.class_mapping + if classes_re_mapping is not None + else None + ), + denorm_size_wh=(denorm_size.width, denorm_size.height), + pad_ltrb=( + meta.pad_left, + meta.pad_top, + meta.pad_right, + meta.pad_bottom, ), scale_wh=(meta.scale_width, meta.scale_height), orig_size_wh=(meta.original_size.width, meta.original_size.height), + size_after_pre_processing_wh=( + meta.size_after_pre_processing.width, + meta.size_after_pre_processing.height, + ), + static_crop_offset_xy=( + meta.static_crop_offset.offset_x, + meta.static_crop_offset.offset_y, + ), emit_rle=emit_in_kernel_rle, pack_dense_masks=defer_count_to_adapter and not emit_in_kernel_rle, ) diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index f30ad895e1..bbc247f36f 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -2,25 +2,27 @@ Fast path scope: - batch size == 1 -- RF-DETR seg TRT tensor shapes: Q=100, C=91, Mh=Mw=78 -- no static crop, no letterbox padding, no nonsquare intermediate size -- output mask resize is an upsample (the common benchmark / parity case) +- arbitrary single-image RF-DETR-style output shapes; the kernel specializes + per observed ``(Q, C, Mh, Mw)`` +- positive / negative padding and static-crop geometry are supported Within one Triton launch, each program owns one output detection rank and: - performs flat top-k over the (Q, C) sigmoid grid - applies class remap + confidence filtering - denormalizes / rescales / rounds the selected box -- resizes the selected 78x78 mask to the original image size and thresholds it +- crops any padded mask border, resizes to the pre-resize image geometry, + and pastes into the original image canvas if static crop was applied - optionally bit-packs the dense mask sidecar for lower DtoH transfer volume -The resize uses cached exact 2-tap bilinear tables extracted once from the -reference CUDA resize operator, so the steady-state hot path remains a single -Triton launch. +The resize uses cached exact separable tables extracted once from the reference +CUDA resize operator, so the steady-state hot path remains a single Triton +launch. Host-side work is limited to slicing the preallocated buffers once the kernel completes and wrapping them in ``InstanceDetections`` / RLE containers. """ +from dataclasses import dataclass from functools import lru_cache from typing import Optional, Tuple @@ -51,6 +53,123 @@ _MAX_U32 = 0xFFFFFFFF +@dataclass(frozen=True) +class RFDETRTritonPostprocGeometry: + denorm_w: int + denorm_h: int + orig_w: int + orig_h: int + pad_left: int + pad_top: int + inv_scale_w: float + inv_scale_h: float + output_offset_x: int + output_offset_y: int + output_w: int + output_h: int + mask_offset_x: int + mask_offset_y: int + mask_input_w: int + mask_input_h: int + + +def get_rfdetr_triton_postproc_geometry( + denorm_size_wh: Tuple[int, int], + pad_ltrb: Tuple[int, int, int, int], + scale_wh: Tuple[float, float], + orig_size_wh: Tuple[int, int], + size_after_pre_processing_wh: Tuple[int, int], + static_crop_offset_xy: Tuple[int, int], + mask_size_hw: Tuple[int, int] = (FASTPATH_MASK_H, FASTPATH_MASK_W), +) -> RFDETRTritonPostprocGeometry: + denorm_w, denorm_h = denorm_size_wh + pad_left, pad_top, pad_right, pad_bottom = pad_ltrb + scale_w, scale_h = scale_wh + orig_w, orig_h = orig_size_wh + output_w, output_h = size_after_pre_processing_wh + output_offset_x, output_offset_y = static_crop_offset_xy + mask_h, mask_w = mask_size_hw + + if denorm_w <= 0 or denorm_h <= 0: + raise ValueError("Denorm size must be positive for Triton fullpost.") + if orig_w <= 0 or orig_h <= 0: + raise ValueError("Original image size must be positive for Triton fullpost.") + if output_w <= 0 or output_h <= 0: + raise ValueError( + "Pre-resize image geometry must be positive for Triton fullpost." + ) + if mask_h <= 0 or mask_w <= 0: + raise ValueError("Mask size must be positive for Triton fullpost.") + if scale_w <= 0.0 or scale_h <= 0.0: + raise ValueError("Scale factors must be positive for Triton fullpost.") + if output_offset_x < 0 or output_offset_y < 0: + raise ValueError( + "Static-crop offsets must be non-negative for Triton fullpost." + ) + if output_offset_x + output_w > orig_w or output_offset_y + output_h > orig_h: + raise ValueError("Static-crop paste window must fit within the original image.") + + mask_pad_top = round(mask_h * pad_top / denorm_h) + mask_pad_bottom = round(mask_h * pad_bottom / denorm_h) + mask_pad_left = round(mask_w * pad_left / denorm_w) + mask_pad_right = round(mask_w * pad_right / denorm_w) + mask_input_h = mask_h - mask_pad_top - mask_pad_bottom + mask_input_w = mask_w - mask_pad_left - mask_pad_right + if mask_input_h <= 0 or mask_input_w <= 0: + raise ValueError( + "RF-DETR Triton fullpost mask window became empty after removing " + "padding." + ) + return RFDETRTritonPostprocGeometry( + denorm_w=int(denorm_w), + denorm_h=int(denorm_h), + orig_w=int(orig_w), + orig_h=int(orig_h), + pad_left=int(pad_left), + pad_top=int(pad_top), + inv_scale_w=float(1.0 / scale_w), + inv_scale_h=float(1.0 / scale_h), + output_offset_x=int(output_offset_x), + output_offset_y=int(output_offset_y), + output_w=int(output_w), + output_h=int(output_h), + mask_offset_x=int(mask_pad_left), + mask_offset_y=int(mask_pad_top), + mask_input_w=int(mask_input_w), + mask_input_h=int(mask_input_h), + ) + + +def rfdetr_triton_postproc_geometry_supported( + denorm_size_wh: Tuple[int, int], + pad_ltrb: Tuple[int, int, int, int], + scale_wh: Tuple[float, float], + orig_size_wh: Tuple[int, int], + size_after_pre_processing_wh: Tuple[int, int], + static_crop_offset_xy: Tuple[int, int], + mask_size_hw: Tuple[int, int] = (FASTPATH_MASK_H, FASTPATH_MASK_W), +) -> bool: + try: + get_rfdetr_triton_postproc_geometry( + denorm_size_wh=denorm_size_wh, + pad_ltrb=pad_ltrb, + scale_wh=scale_wh, + orig_size_wh=orig_size_wh, + size_after_pre_processing_wh=size_after_pre_processing_wh, + static_crop_offset_xy=static_crop_offset_xy, + mask_size_hw=mask_size_hw, + ) + except ValueError: + return False + return True + + +def _next_power_of_two(value: int) -> int: + if value <= 1: + return 1 + return 1 << (value - 1).bit_length() + + if TRITON_AVAILABLE: @triton.jit @@ -62,13 +181,28 @@ def rfdetr_fullpostproc_triton_kernel( class_map_ptr, y_indices_ptr, y_weights_ptr, + y_counts_ptr, x_indices_ptr, x_weights_ptr, + x_counts_ptr, rle_counts_ptr, rle_lengths_ptr, combined_out_ptr, mask_out_ptr, counter_ptr, + valid_class_limit, + denorm_w, + denorm_h, + pad_left, + pad_top, + inv_scale_w, + inv_scale_h, + output_offset_x, + output_offset_y, + output_w, + output_h, + mask_offset_x, + mask_offset_y, orig_w, orig_h, logits_stride_q, @@ -96,38 +230,46 @@ def rfdetr_fullpostproc_triton_kernel( RLE_TILE_H: tl.constexpr, RLE_TILE_W: tl.constexpr, RLE_MERGE_TILE: tl.constexpr, + MAX_Y_TAPS: tl.constexpr, + MAX_X_TAPS: tl.constexpr, ): pid_det = tl.program_id(0) - # Maintain the reference flat top-k exactly: top 100 scores over the - # 100x91 sigmoid grid, before class remap / thresholding. + # Maintain the reference flat top-k exactly: top Q scores over the + # full (Q, C) sigmoid grid, before class remap / thresholding. top_packed = tl.zeros((TOPK_PAD,), dtype=tl.int64) class_offsets = tl.arange(0, CLASS_BLOCK) rank_offsets = tl.arange(0, TOPK_PAD) - top_limit = tl.full((), NUM_QUERIES, tl.int32) num_classes_total = tl.full((), NUM_CLASSES_TOTAL, tl.int32) + valid_class_limit_i32 = valid_class_limit.to(tl.int32) for q in tl.range(0, NUM_QUERIES, num_stages=1): - valid_class = class_offsets < NUM_CLASSES_TOTAL - logit = tl.load( - logits_ptr + q * logits_stride_q + class_offsets, - mask=valid_class, - other=-float("inf"), - ) - abs_l = tl.abs(logit) - z = tl.exp(-abs_l) - sig_pos = 1.0 / (1.0 + z) - sig_neg = z / (1.0 + z) - conf = tl.where(logit >= 0.0, sig_pos, sig_neg) - conf_bits = conf.to(tl.float32, bitcast=False).to(tl.int32, bitcast=True) - flat_idx = q * NUM_CLASSES_TOTAL + class_offsets - packed = tl.where( - valid_class, - (conf_bits.to(tl.int64) << 32) | flat_idx.to(tl.int64), - tl.zeros((CLASS_BLOCK,), dtype=tl.int64), - ) - merged = tl.reshape(tl.join(top_packed, packed), (TOPK_PAD + CLASS_BLOCK,)) - top_packed = tl.topk(merged, k=TOPK_PAD) + for class_base in tl.range(0, NUM_CLASSES_TOTAL, CLASS_BLOCK, num_stages=1): + class_ids = class_base + class_offsets + valid_class = class_ids < NUM_CLASSES_TOTAL + logit = tl.load( + logits_ptr + q * logits_stride_q + class_ids, + mask=valid_class, + other=-float("inf"), + ) + abs_l = tl.abs(logit) + z = tl.exp(-abs_l) + sig_pos = 1.0 / (1.0 + z) + sig_neg = z / (1.0 + z) + conf = tl.where(logit >= 0.0, sig_pos, sig_neg) + conf_bits = conf.to(tl.float32, bitcast=False).to( + tl.int32, bitcast=True + ) + flat_idx = q * NUM_CLASSES_TOTAL + class_ids + packed = tl.where( + valid_class, + (conf_bits.to(tl.int64) << 32) | flat_idx.to(tl.int64), + tl.zeros((CLASS_BLOCK,), dtype=tl.int64), + ) + merged = tl.reshape( + tl.join(top_packed, packed), (TOPK_PAD + CLASS_BLOCK,) + ) + top_packed = tl.topk(merged, k=TOPK_PAD) selected_q = tl.full((), 0, tl.int32) selected_c = tl.full((), 0, tl.int32) @@ -154,7 +296,7 @@ def rfdetr_fullpostproc_triton_kernel( valid = mapped_class >= 0 else: mapped_class = raw_class - valid = raw_class < top_limit + valid = raw_class < valid_class_limit_i32 if PER_CLASS: safe_class = tl.where(valid, mapped_class, 0) @@ -193,12 +335,20 @@ def rfdetr_fullpostproc_triton_kernel( x2_pct = cx_pct + 0.5 * w_pct y2_pct = cy_pct + 0.5 * h_pct - orig_w_f = orig_w.to(tl.float32) - orig_h_f = orig_h.to(tl.float32) - x1 = x1_pct * orig_w_f - y1 = y1_pct * orig_h_f - x2 = x2_pct * orig_w_f - y2 = y2_pct * orig_h_f + denorm_w_f = denorm_w.to(tl.float32) + denorm_h_f = denorm_h.to(tl.float32) + x1 = x1_pct * denorm_w_f + y1 = y1_pct * denorm_h_f + x2 = x2_pct * denorm_w_f + y2 = y2_pct * denorm_h_f + x1 = x1 - pad_left + y1 = y1 - pad_top + x2 = x2 - pad_left + y2 = y2 - pad_top + x1 = x1 * inv_scale_w + output_offset_x + y1 = y1 * inv_scale_h + output_offset_y + x2 = x2 * inv_scale_w + output_offset_x + y2 = y2 * inv_scale_h + output_offset_y # Match torch.round(...).int() with half-to-even tie handling. x1_r = tl.floor(x1 + 0.5) @@ -235,19 +385,11 @@ def rfdetr_fullpostproc_triton_kernel( for out_x in tl.range(0, orig_w, RLE_TILE_W, num_stages=1): x = out_x + col_offsets x_mask = x < orig_w - x_table_offset = x * 2 - x_idx_a = tl.load( - x_indices_ptr + x_table_offset + 0, mask=x_mask, other=0 - ) - x_idx_b = tl.load( - x_indices_ptr + x_table_offset + 1, mask=x_mask, other=0 - ) - wx_a = tl.load( - x_weights_ptr + x_table_offset + 0, mask=x_mask, other=0.0 - ) - wx_b = tl.load( - x_weights_ptr + x_table_offset + 1, mask=x_mask, other=0.0 - ) + x_local = x - output_offset_x + x_active = x_mask & (x_local >= 0) & (x_local < output_w) + x_table_index = tl.where(x_active, x_local, 0) + x_table_base = x_table_index * MAX_X_TAPS + x_counts = tl.load(x_counts_ptr + x_table_index, mask=x_active, other=0) col_base = ( rle_counts_ptr + pid_det * rle_counts_stride_q @@ -260,55 +402,56 @@ def rfdetr_fullpostproc_triton_kernel( for out_y in tl.range(0, orig_h, RLE_TILE_H, num_stages=1): y = out_y + row_offsets y_mask = y < orig_h - tile_mask = y_mask[:, None] & x_mask[None, :] - y_table_offset = y * 2 - y_idx_a = tl.load( - y_indices_ptr + y_table_offset + 0, mask=y_mask, other=0 - ) - y_idx_b = tl.load( - y_indices_ptr + y_table_offset + 1, mask=y_mask, other=0 - ) - wy_a = tl.load( - y_weights_ptr + y_table_offset + 0, mask=y_mask, other=0.0 - ) - wy_b = tl.load( - y_weights_ptr + y_table_offset + 1, mask=y_mask, other=0.0 + y_local = y - output_offset_y + y_active = y_mask & (y_local >= 0) & (y_local < output_h) + tile_mask = y_active[:, None] & x_active[None, :] + y_table_index = tl.where(y_active, y_local, 0) + y_table_base = y_table_index * MAX_Y_TAPS + y_counts = tl.load( + y_counts_ptr + y_table_index, mask=y_active, other=0 ) - m00 = tl.load( - mask_base - + y_idx_a[:, None] * masks_stride_h - + x_idx_a[None, :] * masks_stride_w, - mask=tile_mask, - other=0.0, - ) - m01 = tl.load( - mask_base - + y_idx_a[:, None] * masks_stride_h - + x_idx_b[None, :] * masks_stride_w, - mask=tile_mask, - other=0.0, - ) - m10 = tl.load( - mask_base - + y_idx_b[:, None] * masks_stride_h - + x_idx_a[None, :] * masks_stride_w, - mask=tile_mask, - other=0.0, - ) - m11 = tl.load( - mask_base - + y_idx_b[:, None] * masks_stride_h - + x_idx_b[None, :] * masks_stride_w, - mask=tile_mask, - other=0.0, - ) - interp_top = wx_a[None, :] * m00 + wx_b[None, :] * m01 - interp_bottom = wx_a[None, :] * m10 + wx_b[None, :] * m11 - bits = ( - (wy_a[:, None] * interp_top + wy_b[:, None] * interp_bottom) - > 0.0 - ).to(tl.int32) + interp = tl.zeros((RLE_TILE_H, RLE_TILE_W), dtype=tl.float32) + for y_tap in tl.static_range(0, MAX_Y_TAPS): + y_tap_valid = y_active & (y_tap < y_counts) + src_y = mask_offset_y + tl.load( + y_indices_ptr + y_table_base + y_tap, + mask=y_tap_valid, + other=0, + ) + wy = tl.load( + y_weights_ptr + y_table_base + y_tap, + mask=y_tap_valid, + other=0.0, + ) + src_y_valid = y_tap_valid & (src_y >= 0) & (src_y < MASK_H) + + for x_tap in tl.static_range(0, MAX_X_TAPS): + x_tap_valid = x_active & (x_tap < x_counts) + src_x = mask_offset_x + tl.load( + x_indices_ptr + x_table_base + x_tap, + mask=x_tap_valid, + other=0, + ) + wx = tl.load( + x_weights_ptr + x_table_base + x_tap, + mask=x_tap_valid, + other=0.0, + ) + src_x_valid = x_tap_valid & (src_x >= 0) & (src_x < MASK_W) + + tap_values = tl.load( + mask_base + + src_y[:, None] * masks_stride_h + + src_x[None, :] * masks_stride_w, + mask=tile_mask + & src_y_valid[:, None] + & src_x_valid[None, :], + other=0.0, + ) + interp += wy[:, None] * wx[None, :] * tap_values + + bits = (interp > 0.0).to(tl.int32) for local_y in tl.static_range(0, RLE_TILE_H): valid_row = out_y + local_y < orig_h @@ -409,69 +552,64 @@ def rfdetr_fullpostproc_triton_kernel( for out_y in tl.range(0, orig_h, MASK_TILE_H, num_stages=1): y = out_y + row_offsets y_mask = y < orig_h - y_table_offset = y * 2 - y_idx_a = tl.load( - y_indices_ptr + y_table_offset + 0, mask=y_mask, other=0 - ) - y_idx_b = tl.load( - y_indices_ptr + y_table_offset + 1, mask=y_mask, other=0 - ) - wy_a = tl.load( - y_weights_ptr + y_table_offset + 0, mask=y_mask, other=0.0 - ) - wy_b = tl.load( - y_weights_ptr + y_table_offset + 1, mask=y_mask, other=0.0 - ) + y_local = y - output_offset_y + y_active = y_mask & (y_local >= 0) & (y_local < output_h) + y_table_index = tl.where(y_active, y_local, 0) + y_table_base = y_table_index * MAX_Y_TAPS + y_counts = tl.load(y_counts_ptr + y_table_index, mask=y_active, other=0) for out_x in tl.range(0, orig_w, MASK_TILE_W, num_stages=1): x = out_x + col_offsets x_mask = x < orig_w - tile_mask = y_mask[:, None] & x_mask[None, :] - x_table_offset = x * 2 - x_idx_a = tl.load( - x_indices_ptr + x_table_offset + 0, mask=x_mask, other=0 - ) - x_idx_b = tl.load( - x_indices_ptr + x_table_offset + 1, mask=x_mask, other=0 - ) - wx_a = tl.load( - x_weights_ptr + x_table_offset + 0, mask=x_mask, other=0.0 - ) - wx_b = tl.load( - x_weights_ptr + x_table_offset + 1, mask=x_mask, other=0.0 + x_local = x - output_offset_x + x_active = x_mask & (x_local >= 0) & (x_local < output_w) + tile_mask = y_active[:, None] & x_active[None, :] + x_table_index = tl.where(x_active, x_local, 0) + x_table_base = x_table_index * MAX_X_TAPS + x_counts = tl.load( + x_counts_ptr + x_table_index, mask=x_active, other=0 ) - m00 = tl.load( - mask_base - + y_idx_a[:, None] * masks_stride_h - + x_idx_a[None, :] * masks_stride_w, - mask=tile_mask, - other=0.0, - ) - m01 = tl.load( - mask_base - + y_idx_a[:, None] * masks_stride_h - + x_idx_b[None, :] * masks_stride_w, - mask=tile_mask, - other=0.0, - ) - m10 = tl.load( - mask_base - + y_idx_b[:, None] * masks_stride_h - + x_idx_a[None, :] * masks_stride_w, - mask=tile_mask, - other=0.0, - ) - m11 = tl.load( - mask_base - + y_idx_b[:, None] * masks_stride_h - + x_idx_b[None, :] * masks_stride_w, - mask=tile_mask, - other=0.0, - ) - interp_top = wx_a[None, :] * m00 + wx_b[None, :] * m01 - interp_bottom = wx_a[None, :] * m10 + wx_b[None, :] * m11 - interp = wy_a[:, None] * interp_top + wy_b[:, None] * interp_bottom + interp = tl.zeros((MASK_TILE_H, MASK_TILE_W), dtype=tl.float32) + for y_tap in tl.static_range(0, MAX_Y_TAPS): + y_tap_valid = y_active & (y_tap < y_counts) + src_y = mask_offset_y + tl.load( + y_indices_ptr + y_table_base + y_tap, + mask=y_tap_valid, + other=0, + ) + wy = tl.load( + y_weights_ptr + y_table_base + y_tap, + mask=y_tap_valid, + other=0.0, + ) + src_y_valid = y_tap_valid & (src_y >= 0) & (src_y < MASK_H) + + for x_tap in tl.static_range(0, MAX_X_TAPS): + x_tap_valid = x_active & (x_tap < x_counts) + src_x = mask_offset_x + tl.load( + x_indices_ptr + x_table_base + x_tap, + mask=x_tap_valid, + other=0, + ) + wx = tl.load( + x_weights_ptr + x_table_base + x_tap, + mask=x_tap_valid, + other=0.0, + ) + src_x_valid = x_tap_valid & (src_x >= 0) & (src_x < MASK_W) + + tap_values = tl.load( + mask_base + + src_y[:, None] * masks_stride_h + + src_x[None, :] * masks_stride_w, + mask=tile_mask + & src_y_valid[:, None] + & src_x_valid[None, :], + other=0.0, + ) + interp += wy[:, None] * wx[None, :] * tap_values + bits = (interp > 0.0).to(tl.int32) if PACK_DENSE_MASKS: packed = tl.sum( @@ -505,7 +643,11 @@ def rfdetr_fullpostproc_triton_kernel( + y[:, None] * mask_out_stride_h + x[None, :] * mask_out_stride_w ) - tl.store(out_ptr, bits.to(tl.uint8), mask=tile_mask) + tl.store( + out_ptr, + bits.to(tl.uint8), + mask=y_mask[:, None] & x_mask[None, :], + ) _SCRATCH_CACHE: dict = {} @@ -518,8 +660,8 @@ def _build_resize_axis_tables( out_size: int, device: torch.device, horizontal: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Returns exact 2-tap tables extracted from the reference CUDA resize op.""" +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns exact per-axis resize tables extracted from the reference op.""" basis = torch.eye(in_size, dtype=torch.float32, device=device) if horizontal: @@ -540,8 +682,8 @@ def _build_resize_axis_tables( )[:, 0, :, 0] resized_cpu = resized.cpu() - indices = torch.empty((out_size, 2), dtype=torch.int32) - weights = torch.zeros((out_size, 2), dtype=torch.float32) + supports = [] + max_taps = 0 for out_idx in range(out_size): support = torch.nonzero( resized_cpu[:, out_idx].abs() > 0, as_tuple=False @@ -551,47 +693,52 @@ def _build_resize_axis_tables( f"Reference bilinear AA resize produced no support for axis " f"{out_idx} of shape {in_size}->{out_size}." ) - if support.numel() > 2: - raise ValueError( - "RF-DETR Triton fullpost fast path only supports 2-tap " - "upsample resize tables, but the reference resize produced " - f"{support.numel()} taps for shape {in_size}->{out_size}." - ) - - idx_a = int(support[0]) - indices[out_idx, 0] = idx_a - weights[out_idx, 0] = resized_cpu[idx_a, out_idx] - if support.numel() == 1: - indices[out_idx, 1] = idx_a - weights[out_idx, 1] = 0.0 - else: - idx_b = int(support[1]) - indices[out_idx, 1] = idx_b - weights[out_idx, 1] = resized_cpu[idx_b, out_idx] - return indices.to(device=device, non_blocking=True), weights.to( - device=device, non_blocking=True + supports.append(support) + max_taps = max(max_taps, int(support.numel())) + + indices = torch.zeros((out_size, max_taps), dtype=torch.int32) + weights = torch.zeros((out_size, max_taps), dtype=torch.float32) + counts = torch.empty((out_size,), dtype=torch.int32) + for out_idx, support in enumerate(supports): + count = int(support.numel()) + counts[out_idx] = count + indices[out_idx, :count] = support.to(torch.int32) + weights[out_idx, :count] = resized_cpu[support, out_idx] + return ( + indices.to(device=device, non_blocking=True), + weights.to(device=device, non_blocking=True), + counts.to(device=device, non_blocking=True), ) @lru_cache(maxsize=128) def _get_resize_tables( - orig_h: int, - orig_w: int, + input_h: int, + input_w: int, + output_h: int, + output_w: int, device: torch.device, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - y_indices, y_weights = _build_resize_axis_tables( - in_size=FASTPATH_MASK_H, - out_size=orig_h, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + y_indices, y_weights, y_counts = _build_resize_axis_tables( + in_size=input_h, + out_size=output_h, device=device, horizontal=False, ) - x_indices, x_weights = _build_resize_axis_tables( - in_size=FASTPATH_MASK_W, - out_size=orig_w, + x_indices, x_weights, x_counts = _build_resize_axis_tables( + in_size=input_w, + out_size=output_w, device=device, horizontal=True, ) - return y_indices, y_weights, x_indices, x_weights + return y_indices, y_weights, y_counts, x_indices, x_weights, x_counts def _get_scratch_buffers( @@ -687,9 +834,12 @@ def rfdetr_triton_postproc( threshold: "torch.Tensor | float", num_classes: int, class_mapping: Optional[torch.Tensor], - inference_size_wh: Tuple[int, int], + denorm_size_wh: Tuple[int, int], + pad_ltrb: Tuple[int, int, int, int], scale_wh: Tuple[float, float], orig_size_wh: Tuple[int, int], + size_after_pre_processing_wh: Tuple[int, int], + static_crop_offset_xy: Tuple[int, int], emit_rle: bool = False, pack_dense_masks: bool = False, ) -> Tuple[ @@ -711,25 +861,45 @@ def rfdetr_triton_postproc( """ device = bboxes.device + if bboxes.ndim != 3 or logits.ndim != 3 or masks.ndim != 4: + raise ValueError( + "RF-DETR Triton fullpost expects bboxes/logits/masks shaped as " + "(1, Q, 4), (1, Q, C), and (1, Q, Mh, Mw)." + ) + if bboxes.shape[0] != 1 or logits.shape[0] != 1 or masks.shape[0] != 1: + raise ValueError("RF-DETR Triton fullpost supports batch size 1 only.") + if bboxes.shape[2] != 4: + raise ValueError( + f"RF-DETR Triton fullpost expects 4 bbox channels, got {bboxes.shape[2]}." + ) + if logits.shape[1] != bboxes.shape[1] or masks.shape[1] != bboxes.shape[1]: + raise ValueError( + "RF-DETR Triton fullpost expects matching query counts across " + "bboxes/logits/masks." + ) num_queries, num_classes_total = logits.shape[1], logits.shape[2] mask_h, mask_w = masks.shape[2], masks.shape[3] - if ( - num_queries != FASTPATH_NUM_QUERIES - or num_classes_total != FASTPATH_NUM_CLASSES_TOTAL - or mask_h != FASTPATH_MASK_H - or mask_w != FASTPATH_MASK_W - ): + if num_queries <= 0 or num_classes_total <= 0 or mask_h <= 0 or mask_w <= 0: raise ValueError( - "RF-DETR Triton fullpost fast path only supports the fixed TRT " - f"shape (Q={FASTPATH_NUM_QUERIES}, C={FASTPATH_NUM_CLASSES_TOTAL}, " - f"Mh={FASTPATH_MASK_H}, Mw={FASTPATH_MASK_W}), got " - f"{(num_queries, num_classes_total, mask_h, mask_w)}." + "RF-DETR Triton fullpost requires positive query/class/mask " + f"dimensions, got {(num_queries, num_classes_total, mask_h, mask_w)}." ) + selection_block = _next_power_of_two(num_queries) logits_2d = logits[0] if logits[0].is_contiguous() else logits[0].contiguous() bboxes_2d = bboxes[0] if bboxes[0].is_contiguous() else bboxes[0].contiguous() masks_3d = masks[0] if masks[0].is_contiguous() else masks[0].contiguous() + geometry = get_rfdetr_triton_postproc_geometry( + denorm_size_wh=denorm_size_wh, + pad_ltrb=pad_ltrb, + scale_wh=scale_wh, + orig_size_wh=orig_size_wh, + size_after_pre_processing_wh=size_after_pre_processing_wh, + static_crop_offset_xy=static_crop_offset_xy, + mask_size_hw=(mask_h, mask_w), + ) + orig_w, orig_h = orig_size_wh combined, mask_bin, counter = _get_scratch_buffers( num_queries=num_queries, @@ -738,9 +908,11 @@ def rfdetr_triton_postproc( device=device, pack_dense_masks=pack_dense_masks and not emit_rle, ) - y_indices, y_weights, x_indices, x_weights = _get_resize_tables( - orig_h=orig_h, - orig_w=orig_w, + y_indices, y_weights, y_counts, x_indices, x_weights, x_counts = _get_resize_tables( + input_h=geometry.mask_input_h, + input_w=geometry.mask_input_w, + output_h=geometry.output_h, + output_w=geometry.output_w, device=device, ) if emit_rle: @@ -762,8 +934,6 @@ def rfdetr_triton_postproc( has_remap = False cmap = _get_empty_int32_on_device(device) - _ = inference_size_wh - _ = scale_wh dummy_int32 = _get_empty_int32_on_device(device) rfdetr_fullpostproc_triton_kernel[(num_queries,)]( @@ -774,13 +944,28 @@ def rfdetr_triton_postproc( cmap, y_indices, y_weights, + y_counts, x_indices, x_weights, + x_counts, rle_counts if rle_counts is not None else dummy_int32, rle_lengths_scratch if rle_lengths_scratch is not None else dummy_int32, combined, mask_bin, counter, + int(num_classes), + int(geometry.denorm_w), + int(geometry.denorm_h), + int(geometry.pad_left), + int(geometry.pad_top), + float(geometry.inv_scale_w), + float(geometry.inv_scale_h), + int(geometry.output_offset_x), + int(geometry.output_offset_y), + int(geometry.output_w), + int(geometry.output_h), + int(geometry.mask_offset_x), + int(geometry.mask_offset_y), int(orig_w), int(orig_h), logits_2d.stride(0), @@ -797,17 +982,19 @@ def rfdetr_triton_postproc( HAS_REMAPPING=1 if has_remap else 0, EMIT_RLE=1 if emit_rle else 0, PACK_DENSE_MASKS=1 if (pack_dense_masks and not emit_rle) else 0, - NUM_QUERIES=FASTPATH_NUM_QUERIES, - NUM_CLASSES_TOTAL=FASTPATH_NUM_CLASSES_TOTAL, - MASK_H=FASTPATH_MASK_H, - MASK_W=FASTPATH_MASK_W, - TOPK_PAD=_TOPK_PAD, - CLASS_BLOCK=_CLASS_BLOCK, + NUM_QUERIES=num_queries, + NUM_CLASSES_TOTAL=num_classes_total, + MASK_H=mask_h, + MASK_W=mask_w, + TOPK_PAD=selection_block, + CLASS_BLOCK=selection_block, MASK_TILE_H=_MASK_TILE_H, MASK_TILE_W=_MASK_TILE_W, RLE_TILE_H=_RLE_TILE_H, RLE_TILE_W=_RLE_TILE_W, RLE_MERGE_TILE=_RLE_MERGE_TILE, + MAX_Y_TAPS=y_indices.shape[1], + MAX_X_TAPS=x_indices.shape[1], num_warps=4, num_stages=1, ) From cd078a26652c2e867694b48e4aefe407325d19fc Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 19 May 2026 23:17:25 +0000 Subject: [PATCH 19/25] fastpath --- .../models/rfdetr/triton_fullpostproc.py | 233 +++++++++++++----- 1 file changed, 178 insertions(+), 55 deletions(-) diff --git a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py index bbc247f36f..3a921aecdc 100644 --- a/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py +++ b/inference_models/inference_models/models/rfdetr/triton_fullpostproc.py @@ -170,6 +170,38 @@ def _next_power_of_two(value: int) -> int: return 1 << (value - 1).bit_length() +def _simple_mask_fastpath_supported( + geometry: RFDETRTritonPostprocGeometry, + *, + mask_h: int, + mask_w: int, + emit_rle: bool, +) -> bool: + """Returns true when we can use direct 2x2 bilinear mask upsampling. + + This path is intentionally narrow: it only handles the original stretch-style + regime where the low-res mask covers the full image canvas and can be + upsampled directly onto the original image without crop/paste geometry. + """ + + if emit_rle: + return False + if geometry.output_offset_x != 0 or geometry.output_offset_y != 0: + return False + if geometry.output_w != geometry.orig_w or geometry.output_h != geometry.orig_h: + return False + if geometry.mask_offset_x != 0 or geometry.mask_offset_y != 0: + return False + if geometry.mask_input_w != mask_w or geometry.mask_input_h != mask_h: + return False + if ( + geometry.output_w < geometry.mask_input_w + or geometry.output_h < geometry.mask_input_h + ): + return False + return True + + if TRITON_AVAILABLE: @triton.jit @@ -219,6 +251,7 @@ def rfdetr_fullpostproc_triton_kernel( HAS_REMAPPING: tl.constexpr, EMIT_RLE: tl.constexpr, PACK_DENSE_MASKS: tl.constexpr, + SIMPLE_MASK_FASTPATH: tl.constexpr, NUM_QUERIES: tl.constexpr, NUM_CLASSES_TOTAL: tl.constexpr, MASK_H: tl.constexpr, @@ -548,67 +581,126 @@ def rfdetr_fullpostproc_triton_kernel( if PACK_DENSE_MASKS: packed_col_offsets = tl.arange(0, MASK_TILE_W // 8) bit_weights = (1 << tl.arange(0, 8)).to(tl.int32) + if SIMPLE_MASK_FASTPATH: + mask_scale_y = tl.full((), MASK_H, tl.float32) / orig_h.to(tl.float32) + mask_scale_x = tl.full((), MASK_W, tl.float32) / orig_w.to(tl.float32) for out_y in tl.range(0, orig_h, MASK_TILE_H, num_stages=1): y = out_y + row_offsets y_mask = y < orig_h - y_local = y - output_offset_y - y_active = y_mask & (y_local >= 0) & (y_local < output_h) - y_table_index = tl.where(y_active, y_local, 0) - y_table_base = y_table_index * MAX_Y_TAPS - y_counts = tl.load(y_counts_ptr + y_table_index, mask=y_active, other=0) + if SIMPLE_MASK_FASTPATH: + src_y_f = (y.to(tl.float32) + 0.5) * mask_scale_y - 0.5 + src_y0_i = tl.floor(src_y_f).to(tl.int32) + src_y1_i = src_y0_i + 1 + dy = src_y_f - src_y0_i.to(tl.float32) + src_y0_clamped = tl.maximum(tl.minimum(src_y0_i, MASK_H - 1), 0) + src_y1_clamped = tl.maximum(tl.minimum(src_y1_i, MASK_H - 1), 0) + else: + y_local = y - output_offset_y + y_active = y_mask & (y_local >= 0) & (y_local < output_h) + y_table_index = tl.where(y_active, y_local, 0) + y_table_base = y_table_index * MAX_Y_TAPS + y_counts = tl.load( + y_counts_ptr + y_table_index, mask=y_active, other=0 + ) for out_x in tl.range(0, orig_w, MASK_TILE_W, num_stages=1): x = out_x + col_offsets x_mask = x < orig_w - x_local = x - output_offset_x - x_active = x_mask & (x_local >= 0) & (x_local < output_w) - tile_mask = y_active[:, None] & x_active[None, :] - x_table_index = tl.where(x_active, x_local, 0) - x_table_base = x_table_index * MAX_X_TAPS - x_counts = tl.load( - x_counts_ptr + x_table_index, mask=x_active, other=0 - ) - - interp = tl.zeros((MASK_TILE_H, MASK_TILE_W), dtype=tl.float32) - for y_tap in tl.static_range(0, MAX_Y_TAPS): - y_tap_valid = y_active & (y_tap < y_counts) - src_y = mask_offset_y + tl.load( - y_indices_ptr + y_table_base + y_tap, - mask=y_tap_valid, - other=0, + if SIMPLE_MASK_FASTPATH: + tile_mask = y_mask[:, None] & x_mask[None, :] + src_x_f = (x.to(tl.float32) + 0.5) * mask_scale_x - 0.5 + src_x0_i = tl.floor(src_x_f).to(tl.int32) + src_x1_i = src_x0_i + 1 + dx = src_x_f - src_x0_i.to(tl.float32) + src_x0_clamped = tl.maximum(tl.minimum(src_x0_i, MASK_W - 1), 0) + src_x1_clamped = tl.maximum(tl.minimum(src_x1_i, MASK_W - 1), 0) + + p00 = tl.load( + mask_base + + src_y0_clamped[:, None] * masks_stride_h + + src_x0_clamped[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, ) - wy = tl.load( - y_weights_ptr + y_table_base + y_tap, - mask=y_tap_valid, + p01 = tl.load( + mask_base + + src_y0_clamped[:, None] * masks_stride_h + + src_x1_clamped[None, :] * masks_stride_w, + mask=tile_mask, other=0.0, ) - src_y_valid = y_tap_valid & (src_y >= 0) & (src_y < MASK_H) + p10 = tl.load( + mask_base + + src_y1_clamped[:, None] * masks_stride_h + + src_x0_clamped[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + p11 = tl.load( + mask_base + + src_y1_clamped[:, None] * masks_stride_h + + src_x1_clamped[None, :] * masks_stride_w, + mask=tile_mask, + other=0.0, + ) + interp = ( + p00 * ((1.0 - dy)[:, None] * (1.0 - dx)[None, :]) + + p01 * ((1.0 - dy)[:, None] * dx[None, :]) + + p10 * (dy[:, None] * (1.0 - dx)[None, :]) + + p11 * (dy[:, None] * dx[None, :]) + ) + else: + x_local = x - output_offset_x + x_active = x_mask & (x_local >= 0) & (x_local < output_w) + tile_mask = y_active[:, None] & x_active[None, :] + x_table_index = tl.where(x_active, x_local, 0) + x_table_base = x_table_index * MAX_X_TAPS + x_counts = tl.load( + x_counts_ptr + x_table_index, mask=x_active, other=0 + ) - for x_tap in tl.static_range(0, MAX_X_TAPS): - x_tap_valid = x_active & (x_tap < x_counts) - src_x = mask_offset_x + tl.load( - x_indices_ptr + x_table_base + x_tap, - mask=x_tap_valid, + interp = tl.zeros((MASK_TILE_H, MASK_TILE_W), dtype=tl.float32) + for y_tap in tl.static_range(0, MAX_Y_TAPS): + y_tap_valid = y_active & (y_tap < y_counts) + src_y = mask_offset_y + tl.load( + y_indices_ptr + y_table_base + y_tap, + mask=y_tap_valid, other=0, ) - wx = tl.load( - x_weights_ptr + x_table_base + x_tap, - mask=x_tap_valid, + wy = tl.load( + y_weights_ptr + y_table_base + y_tap, + mask=y_tap_valid, other=0.0, ) - src_x_valid = x_tap_valid & (src_x >= 0) & (src_x < MASK_W) - - tap_values = tl.load( - mask_base - + src_y[:, None] * masks_stride_h - + src_x[None, :] * masks_stride_w, - mask=tile_mask - & src_y_valid[:, None] - & src_x_valid[None, :], - other=0.0, - ) - interp += wy[:, None] * wx[None, :] * tap_values + src_y_valid = y_tap_valid & (src_y >= 0) & (src_y < MASK_H) + + for x_tap in tl.static_range(0, MAX_X_TAPS): + x_tap_valid = x_active & (x_tap < x_counts) + src_x = mask_offset_x + tl.load( + x_indices_ptr + x_table_base + x_tap, + mask=x_tap_valid, + other=0, + ) + wx = tl.load( + x_weights_ptr + x_table_base + x_tap, + mask=x_tap_valid, + other=0.0, + ) + src_x_valid = ( + x_tap_valid & (src_x >= 0) & (src_x < MASK_W) + ) + + tap_values = tl.load( + mask_base + + src_y[:, None] * masks_stride_h + + src_x[None, :] * masks_stride_w, + mask=tile_mask + & src_y_valid[:, None] + & src_x_valid[None, :], + other=0.0, + ) + interp += wy[:, None] * wx[None, :] * tap_values bits = (interp > 0.0).to(tl.int32) if PACK_DENSE_MASKS: @@ -827,6 +919,11 @@ def _get_empty_int32_on_device(device: torch.device) -> torch.Tensor: return torch.empty((1,), dtype=torch.int32, device=device) +@lru_cache(maxsize=None) +def _get_empty_float32_on_device(device: torch.device) -> torch.Tensor: + return torch.empty((1,), dtype=torch.float32, device=device) + + def rfdetr_triton_postproc( bboxes: torch.Tensor, logits: torch.Tensor, @@ -908,13 +1005,40 @@ def rfdetr_triton_postproc( device=device, pack_dense_masks=pack_dense_masks and not emit_rle, ) - y_indices, y_weights, y_counts, x_indices, x_weights, x_counts = _get_resize_tables( - input_h=geometry.mask_input_h, - input_w=geometry.mask_input_w, - output_h=geometry.output_h, - output_w=geometry.output_w, - device=device, + simple_mask_fastpath = _simple_mask_fastpath_supported( + geometry, + mask_h=mask_h, + mask_w=mask_w, + emit_rle=emit_rle, ) + dummy_int32 = _get_empty_int32_on_device(device) + dummy_float32 = _get_empty_float32_on_device(device) + if simple_mask_fastpath: + y_indices = dummy_int32 + y_weights = dummy_float32 + y_counts = dummy_int32 + x_indices = dummy_int32 + x_weights = dummy_float32 + x_counts = dummy_int32 + max_y_taps = 1 + max_x_taps = 1 + else: + ( + y_indices, + y_weights, + y_counts, + x_indices, + x_weights, + x_counts, + ) = _get_resize_tables( + input_h=geometry.mask_input_h, + input_w=geometry.mask_input_w, + output_h=geometry.output_h, + output_w=geometry.output_w, + device=device, + ) + max_y_taps = y_indices.shape[1] + max_x_taps = x_indices.shape[1] if emit_rle: rle_counts, rle_lengths_scratch = _get_rle_buffers( num_queries=num_queries, @@ -934,8 +1058,6 @@ def rfdetr_triton_postproc( has_remap = False cmap = _get_empty_int32_on_device(device) - dummy_int32 = _get_empty_int32_on_device(device) - rfdetr_fullpostproc_triton_kernel[(num_queries,)]( logits_2d, bboxes_2d, @@ -982,6 +1104,7 @@ def rfdetr_triton_postproc( HAS_REMAPPING=1 if has_remap else 0, EMIT_RLE=1 if emit_rle else 0, PACK_DENSE_MASKS=1 if (pack_dense_masks and not emit_rle) else 0, + SIMPLE_MASK_FASTPATH=1 if simple_mask_fastpath else 0, NUM_QUERIES=num_queries, NUM_CLASSES_TOTAL=num_classes_total, MASK_H=mask_h, @@ -993,8 +1116,8 @@ def rfdetr_triton_postproc( RLE_TILE_H=_RLE_TILE_H, RLE_TILE_W=_RLE_TILE_W, RLE_MERGE_TILE=_RLE_MERGE_TILE, - MAX_Y_TAPS=y_indices.shape[1], - MAX_X_TAPS=x_indices.shape[1], + MAX_Y_TAPS=max_y_taps, + MAX_X_TAPS=max_x_taps, num_warps=4, num_stages=1, ) From 5ff12970f319e9710b2c6b3916751c7140907061 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 26 May 2026 15:06:28 -0700 Subject: [PATCH 20/25] add correctness and integration test --- inference_models/docs/changelog.md | 53 ++++- .../models/test_rfdetr_seg_postproc_parity.py | 206 ++++++++++++++++ .../common/roboflow/test_post_processing.py | 222 ++++++++++++++++++ .../unit_tests/core/utils/test_postprocess.py | 27 +++ 4 files changed, 507 insertions(+), 1 deletion(-) create mode 100644 inference_models/tests/integration_tests/models/test_rfdetr_seg_postproc_parity.py diff --git a/inference_models/docs/changelog.md b/inference_models/docs/changelog.md index add9210aaf..c882d09e24 100644 --- a/inference_models/docs/changelog.md +++ b/inference_models/docs/changelog.md @@ -1,10 +1,58 @@ # Changelog -## `0.27.3` +## `0.29.0` + +### Added + +- Optional Triton full post-processing fast path for RF-DETR instance + segmentation behind `RFDETR_TRITON_POSTPROC`, with parity coverage for + dense and RLE outputs on CUDA-backed RF-DETR backends. + +--- + +## `0.28.1` + +### Fixed + +- Detections at image edges are now clipped to the image dimensions. + +--- + +## `0.28.0` + +### Removed (BREAKING) + +- **MediaPipe is no longer supported.** The `mediapipe` extra and every + symbol coupled to it have been removed. Consumers comparing against + `BackendType.MEDIAPIPE` will hit `AttributeError`. Roboflow Universe + payloads of type `mediapipe-model-package-v1` are now silently filtered + by `MODEL_PACKAGE_PARSERS.get(...)`. Removed symbols: + - `inference_models.models.mediapipe_face_detection.MediaPipeFaceDetector` + - `inference_models.model_pipelines.face_and_gaze_detection.FaceAndGazeDetectionMPAndL2CS` + - `BackendType.MEDIAPIPE` + - `mediapipe_package_matches_runtime_environment` and its entry in + `MODEL_TO_RUNTIME_COMPATIBILITY_MATCHERS` + - Models registry entry for + `("mediapipe-face-detector", KEYPOINT_DETECTION_TASK, BackendType.MEDIAPIPE)` + - `BACKEND_PRIORITY[BackendType.MEDIAPIPE]` + - Pipelines registry's `face-and-gaze-detection` entry + + `mediapipe/face-detector` default parameter + - `MediapipeModelPackageV1`, `parse_mediapipe_model_package`, and the + `"mediapipe-model-package-v1"` entry in `MODEL_PACKAGE_PARSERS` + - `RuntimeXRayResult.mediapipe_available` and `is_mediapipe_available()` + - `INFERENCE_MODELS_MEDIAPIPE_FACE_DETECTOR_DEFAULT_CONFIDENCE` + - The `[project.optional-dependencies] mediapipe` extra in + `pyproject.toml` + + The standalone `L2CSNetOnnx` (under `inference_models.models.l2cs`) is + unaffected and remains supported. ### Fixed - RFDetr pre- and post-processing aligned with training transforms. Pre-processing replaced with a dedicated `PIL → F.resize → F.to_tensor → F.normalize` chain matching the training pipeline. For model packages with non-stretch `dataset_version_resize_dimensions`, the dataset-version resize (cv2 letterbox / center-crop) runs first, then the PIL stretch to `training_input_size`. Post-processing uses topk-flat across (queries × classes) via shared `select_topk_predictions`. Fixes a cross-backend divergence at low confidence thresholds. +- Fixed a bug where 'best' and 'default' confidence modes were not correctly handled by `RoboflowInstantHF` models. + +--- ## `0.27.2` @@ -13,6 +61,7 @@ - Temporarily disabled flash-attention in GLM-OCR for Jetsons, due to incompatibility detected before release. +--- ## `0.27.1` @@ -20,6 +69,8 @@ before release. - Improved logging for auto-negotiation of model packages. +--- + ## `0.27.0` ### Added diff --git a/inference_models/tests/integration_tests/models/test_rfdetr_seg_postproc_parity.py b/inference_models/tests/integration_tests/models/test_rfdetr_seg_postproc_parity.py new file mode 100644 index 0000000000..96a6f3530d --- /dev/null +++ b/inference_models/tests/integration_tests/models/test_rfdetr_seg_postproc_parity.py @@ -0,0 +1,206 @@ +import importlib +import os +from typing import Any, Dict, Tuple + +import numpy as np +import pytest +import torch + +from inference_models.models.common.rle_utils import coco_rle_masks_to_torch_mask + +pytest.importorskip("triton") + +if not torch.cuda.is_available(): # pragma: no cover - host-dependent + pytest.skip( + "CUDA required for RF-DETR Triton post-processing parity", + allow_module_level=True, + ) + + +BackendSpec = Dict[str, Any] + + +def _decode_masks(prediction) -> torch.Tensor: + if isinstance(prediction.mask, torch.Tensor): + return prediction.mask.detach().to(dtype=torch.bool).cpu() + return coco_rle_masks_to_torch_mask( + instances_masks=prediction.mask, + device=torch.device("cpu"), + ) + + +def _reload_backend_modules( + *, + backend_spec: BackendSpec, + triton_enabled: bool, +): + os.environ["RFDETR_TRITON_POSTPROC"] = "true" if triton_enabled else "false" + configuration_module = importlib.import_module("inference_models.configuration") + common_module = importlib.import_module("inference_models.models.rfdetr.common") + backend_module = importlib.import_module(backend_spec["module"]) + importlib.reload(configuration_module) + common_module = importlib.reload(common_module) + backend_module = importlib.reload(backend_module) + return common_module, backend_module + + +def _run_backend_once( + *, + backend_spec: BackendSpec, + package_path: str, + image: np.ndarray, + confidence: float, + mask_format: str, +) -> Tuple[Any, int]: + previous_env = os.environ.get("RFDETR_TRITON_POSTPROC") + common_module, backend_module = _reload_backend_modules( + backend_spec=backend_spec, + triton_enabled=backend_spec["triton_enabled"], + ) + call_count = {"value": 0} + original_postproc = getattr(common_module, "rfdetr_triton_postproc", None) + if original_postproc is not None: + def counting_postproc(*args, **kwargs): + call_count["value"] += 1 + return original_postproc(*args, **kwargs) + + common_module.rfdetr_triton_postproc = counting_postproc + + try: + model_class = getattr(backend_module, backend_spec["class_name"]) + model = model_class.from_pretrained(package_path, **backend_spec["init_kwargs"]) + prediction = model( + image, + confidence=confidence, + mask_format=mask_format, + )[0] + return prediction, call_count["value"] + finally: + if original_postproc is not None: + common_module.rfdetr_triton_postproc = original_postproc + if previous_env is None: + os.environ.pop("RFDETR_TRITON_POSTPROC", None) + else: + os.environ["RFDETR_TRITON_POSTPROC"] = previous_env + configuration_module = importlib.import_module("inference_models.configuration") + importlib.reload(configuration_module) + importlib.reload(common_module) + importlib.reload(backend_module) + + +def _assert_predictions_match(reference, candidate) -> None: + torch.testing.assert_close( + candidate.xyxy.cpu(), + reference.xyxy.cpu(), + atol=0, + rtol=0, + ) + torch.testing.assert_close( + candidate.confidence.cpu(), + reference.confidence.cpu(), + atol=1e-6, + rtol=0, + ) + torch.testing.assert_close( + candidate.class_id.cpu(), + reference.class_id.cpu(), + atol=0, + rtol=0, + ) + assert torch.equal(_decode_masks(candidate), _decode_masks(reference)) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "backend_spec", + [ + pytest.param( + { + "name": "torch", + "module": ( + "inference_models.models.rfdetr." + "rfdetr_instance_segmentation_pytorch" + ), + "class_name": "RFDetrForInstanceSegmentationTorch", + "package_fixture": "snakes_rfdetr_seg_torch_stretch_package", + "image_fixture": "snake_image_numpy", + "init_kwargs": {}, + "triton_enabled": True, + }, + id="torch", + marks=pytest.mark.torch_models, + ), + pytest.param( + { + "name": "onnx", + "module": ( + "inference_models.models.rfdetr." + "rfdetr_instance_segmentation_onnx" + ), + "class_name": "RFDetrForInstanceSegmentationOnnx", + "package_fixture": "snakes_rfdetr_seg_onnx_static_bs_stretch_package", + "image_fixture": "snake_image_numpy", + "init_kwargs": { + "onnx_execution_providers": [ + "CUDAExecutionProvider", + "CPUExecutionProvider", + ] + }, + "triton_enabled": True, + }, + id="onnx", + marks=pytest.mark.onnx_extras, + ), + pytest.param( + { + "name": "trt", + "module": ( + "inference_models.models.rfdetr." + "rfdetr_instance_segmentation_trt" + ), + "class_name": "RFDetrForInstanceSegmentationTRT", + "package_fixture": "rfdetr_seg_asl_trt_package", + "image_fixture": "asl_image_numpy", + "init_kwargs": {"engine_host_code_allowed": True}, + "triton_enabled": True, + }, + id="trt", + marks=pytest.mark.trt_extras, + ), + ], +) +@pytest.mark.parametrize("mask_format", ["dense", "rle"]) +def test_rfdetr_seg_predictions_match_when_triton_postproc_is_toggled( + backend_spec: BackendSpec, + mask_format: str, + request: pytest.FixtureRequest, +) -> None: + package_path = request.getfixturevalue(backend_spec["package_fixture"]) + image = request.getfixturevalue(backend_spec["image_fixture"]) + + enabled_spec = dict(backend_spec) + enabled_spec["triton_enabled"] = True + disabled_spec = dict(backend_spec) + disabled_spec["triton_enabled"] = False + + enabled_prediction, enabled_calls = _run_backend_once( + backend_spec=enabled_spec, + package_path=package_path, + image=image, + confidence=0.5, + mask_format=mask_format, + ) + disabled_prediction, disabled_calls = _run_backend_once( + backend_spec=disabled_spec, + package_path=package_path, + image=image, + confidence=0.5, + mask_format=mask_format, + ) + + assert enabled_calls == 1 + assert disabled_calls == 0 + _assert_predictions_match( + reference=disabled_prediction, + candidate=enabled_prediction, + ) diff --git a/inference_models/tests/unit_tests/models/common/roboflow/test_post_processing.py b/inference_models/tests/unit_tests/models/common/roboflow/test_post_processing.py index 4c04b736af..07a921ccb3 100644 --- a/inference_models/tests/unit_tests/models/common/roboflow/test_post_processing.py +++ b/inference_models/tests/unit_tests/models/common/roboflow/test_post_processing.py @@ -9,13 +9,25 @@ import torch from inference_models.configuration import INFERENCE_MODELS_DEFAULT_CONFIDENCE +from inference_models.entities import ImageDimensions +from inference_models.models.common.roboflow.model_packages import ( + PreProcessingMetadata, + StaticCropOffset, +) from inference_models.models.common.roboflow.post_processing import ( ConfidenceFilter, + align_instance_segmentation_results, post_process_nms_fused_model_output, + rescale_image_detections, + rescale_key_points_detections, run_nms_for_instance_segmentation, run_nms_for_key_points_detection, run_nms_for_object_detection, ) +from inference_models.models.rfdetr.triton_fullpostproc import ( + get_rfdetr_triton_postproc_geometry, + rfdetr_triton_postproc_geometry_supported, +) from inference_models.weights_providers.entities import RecommendedParameters @@ -99,6 +111,60 @@ def test_per_class_tensor_indexes_by_class_id(self) -> None: assert kept == [1, 2] +class TestRFDETRTritonPostProcessingGeometry: + def test_geometry_calculates_mask_window(self) -> None: + geometry = get_rfdetr_triton_postproc_geometry( + denorm_size_wh=(640, 480), + pad_ltrb=(32, 24, 32, 24), + scale_wh=(2.0, 2.0), + orig_size_wh=(320, 240), + size_after_pre_processing_wh=(320, 240), + static_crop_offset_xy=(0, 0), + mask_size_hw=(78, 78), + ) + + assert geometry.denorm_w == 640 + assert geometry.denorm_h == 480 + assert geometry.orig_w == 320 + assert geometry.orig_h == 240 + assert geometry.pad_left == 32 + assert geometry.pad_top == 24 + assert geometry.inv_scale_w == pytest.approx(0.5) + assert geometry.inv_scale_h == pytest.approx(0.5) + assert geometry.mask_offset_x == 4 + assert geometry.mask_offset_y == 4 + assert geometry.mask_input_w == 70 + assert geometry.mask_input_h == 70 + + def test_supported_rejects_empty_mask_window(self) -> None: + assert ( + rfdetr_triton_postproc_geometry_supported( + denorm_size_wh=(8, 8), + pad_ltrb=(4, 0, 4, 0), + scale_wh=(1.0, 1.0), + orig_size_wh=(8, 8), + size_after_pre_processing_wh=(8, 8), + static_crop_offset_xy=(0, 0), + mask_size_hw=(4, 4), + ) + is False + ) + + def test_supported_rejects_invalid_crop_window(self) -> None: + assert ( + rfdetr_triton_postproc_geometry_supported( + denorm_size_wh=(640, 640), + pad_ltrb=(0, 0, 0, 0), + scale_wh=(1.0, 1.0), + orig_size_wh=(100, 100), + size_after_pre_processing_wh=(90, 90), + static_crop_offset_xy=(20, 20), + mask_size_hw=(78, 78), + ) + is False + ) + + def _is_output(box_class_conf, num_mask_coeffs=32): num_anchors = len(box_class_conf) num_classes = max(c for _, c, _ in box_class_conf) + 1 @@ -279,3 +345,159 @@ def test_default_string_skips_recommended_parameters(self) -> None: default_confidence=0.25, ) assert cf.get_threshold(["cat", "dog"]) == pytest.approx(0.25) + + +class TestRescaleImageDetectionsClipping: + + @staticmethod + def _meta(orig_h=400, orig_w=600) -> PreProcessingMetadata: + return PreProcessingMetadata( + pad_left=0, + pad_top=0, + pad_right=0, + pad_bottom=0, + original_size=ImageDimensions(height=orig_h, width=orig_w), + size_after_pre_processing=ImageDimensions(height=orig_h, width=orig_w), + inference_size=ImageDimensions(height=640, width=640), + scale_width=1.0, + scale_height=1.0, + static_crop_offset=StaticCropOffset( + offset_x=0, + offset_y=0, + crop_width=orig_w, + crop_height=orig_h, + ), + ) + + def test_clips_negative_x1_y1_to_zero(self) -> None: + detections = torch.tensor( + [[-3.0, -5.0, 200.0, 200.0, 0.9, 0.0]], dtype=torch.float32 + ) + out = rescale_image_detections(detections, self._meta(orig_h=400, orig_w=600)) + assert out[0, 0].item() == pytest.approx(0.0) + assert out[0, 1].item() == pytest.approx(0.0) + assert out[0, 2].item() == pytest.approx(200.0) + assert out[0, 3].item() == pytest.approx(200.0) + + def test_clips_x2_y2_to_image_extent(self) -> None: + detections = torch.tensor( + [[10.0, 10.0, 700.0, 500.0, 0.9, 0.0]], dtype=torch.float32 + ) + out = rescale_image_detections(detections, self._meta(orig_h=400, orig_w=600)) + assert out[0, 0].item() == pytest.approx(10.0) + assert out[0, 1].item() == pytest.approx(10.0) + assert out[0, 2].item() == pytest.approx(600.0) + assert out[0, 3].item() == pytest.approx(400.0) + + def test_in_bounds_boxes_unchanged(self) -> None: + detections = torch.tensor( + [[50.0, 60.0, 400.0, 350.0, 0.8, 0.0]], dtype=torch.float32 + ) + out = rescale_image_detections(detections, self._meta(orig_h=400, orig_w=600)) + assert torch.allclose( + out[0, :4], + torch.tensor([50.0, 60.0, 400.0, 350.0]), + atol=1e-6, + ) + + def test_clipping_preserves_score_and_class_columns(self) -> None: + detections = torch.tensor( + [[-1.0, -2.0, 1000.0, 1000.0, 0.42, 7.0]], dtype=torch.float32 + ) + out = rescale_image_detections(detections, self._meta(orig_h=400, orig_w=600)) + assert out[0, 4].item() == pytest.approx(0.42) + assert out[0, 5].item() == pytest.approx(7.0) + + +class TestRescaleKeyPointsDetectionsClipping: + + @staticmethod + def _meta(orig_h=400, orig_w=600) -> PreProcessingMetadata: + return PreProcessingMetadata( + pad_left=0, + pad_top=0, + pad_right=0, + pad_bottom=0, + original_size=ImageDimensions(height=orig_h, width=orig_w), + size_after_pre_processing=ImageDimensions(height=orig_h, width=orig_w), + inference_size=ImageDimensions(height=640, width=640), + scale_width=1.0, + scale_height=1.0, + static_crop_offset=StaticCropOffset( + offset_x=0, + offset_y=0, + crop_width=orig_w, + crop_height=orig_h, + ), + ) + + def test_clips_box_coords_for_keypoint_detections(self) -> None: + # Row layout: [x1, y1, x2, y2, conf, cls_id, kp_x, kp_y, kp_conf] + detections = [ + torch.tensor( + [[-5.0, 10.0, 700.0, 350.0, 0.9, 0.0, 100.0, 100.0, 0.8]], + dtype=torch.float32, + ) + ] + rescale_key_points_detections( + detections, + [self._meta(orig_h=400, orig_w=600)], + num_classes=1, + key_points_slots_in_prediction=1, + ) + out = detections[0] + assert out[0, 0].item() == pytest.approx(0.0) + assert out[0, 1].item() == pytest.approx(10.0) + assert out[0, 2].item() == pytest.approx(600.0) + assert out[0, 3].item() == pytest.approx(350.0) + assert out[0, 4].item() == pytest.approx(0.9) + assert out[0, 5].item() == pytest.approx(0.0) + assert out[0, 6].item() == pytest.approx(100.0) + assert out[0, 7].item() == pytest.approx(100.0) + + +class TestAlignInstanceSegmentationResultsClipping: + + @staticmethod + def _meta(orig_h=400, orig_w=600) -> PreProcessingMetadata: + return PreProcessingMetadata( + pad_left=0, + pad_top=0, + pad_right=0, + pad_bottom=0, + original_size=ImageDimensions(height=orig_h, width=orig_w), + size_after_pre_processing=ImageDimensions(height=orig_h, width=orig_w), + inference_size=ImageDimensions(height=640, width=640), + scale_width=1.0, + scale_height=1.0, + static_crop_offset=StaticCropOffset( + offset_x=0, + offset_y=0, + crop_width=orig_w, + crop_height=orig_h, + ), + ) + + def test_clips_box_coords(self) -> None: + bboxes = torch.tensor( + [[10.0, 20.0, 700.0, 500.0, 0.9, 0.0]], dtype=torch.float32 + ) + masks = torch.zeros((1, 160, 160), dtype=torch.float32) + meta = self._meta(orig_h=400, orig_w=600) + out_bboxes, _ = align_instance_segmentation_results( + image_bboxes=bboxes, + masks=masks, + padding=(0, 0, 0, 0), + scale_width=1.0, + scale_height=1.0, + original_size=meta.original_size, + size_after_pre_processing=meta.size_after_pre_processing, + inference_size=meta.inference_size, + static_crop_offset=meta.static_crop_offset, + binarization_threshold=0.0, + ) + # box clamped to image bounds (400×600) + assert out_bboxes[0, 0].item() == pytest.approx(10.0) + assert out_bboxes[0, 1].item() == pytest.approx(20.0) + assert out_bboxes[0, 2].item() == pytest.approx(600.0) + assert out_bboxes[0, 3].item() == pytest.approx(400.0) diff --git a/tests/inference/unit_tests/core/utils/test_postprocess.py b/tests/inference/unit_tests/core/utils/test_postprocess.py index 25cbbb69d7..19d9346cdc 100644 --- a/tests/inference/unit_tests/core/utils/test_postprocess.py +++ b/tests/inference/unit_tests/core/utils/test_postprocess.py @@ -6,11 +6,13 @@ from inference.core.exceptions import PostProcessingError from inference.core.utils.postprocess import ( + bitpacked_masks2poly, clip_boxes_coordinates, clip_keypoints_coordinates, cosine_similarity, crop_mask, get_static_crop_dimensions, + masks2poly, post_process_bboxes, post_process_keypoints, post_process_polygons, @@ -96,6 +98,31 @@ def test_crop_mask() -> None: assert np.allclose(result, expected_result) +def test_bitpacked_masks2poly_matches_dense_masks2poly() -> None: + masks = np.zeros((2, 11, 13), dtype=np.uint8) + masks[0, 2:9, 3:10] = 1 + masks[1, 1:6, 1:5] = 1 + packed = np.packbits(masks, axis=-1, bitorder="little") + + dense_segments = masks2poly(masks) + packed_segments = bitpacked_masks2poly(packed, width=masks.shape[-1]) + + assert len(packed_segments) == len(dense_segments) + for packed_segment, dense_segment in zip(packed_segments, dense_segments): + np.testing.assert_array_equal(packed_segment, dense_segment) + + +def test_bitpacked_masks2poly_preserves_empty_masks() -> None: + masks = np.zeros((1, 5, 9), dtype=np.uint8) + packed = np.packbits(masks, axis=-1, bitorder="little") + + segments = bitpacked_masks2poly(packed, width=masks.shape[-1]) + + assert len(segments) == 1 + assert segments[0].shape == (0, 2) + assert segments[0].dtype == np.float32 + + def test_standardise_static_crop() -> None: # when result = standardise_static_crop( From 6008c6179740f5cd583d207955cb8d36ef954a31 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 26 May 2026 16:31:20 -0700 Subject: [PATCH 21/25] bound pinned memory with unit tests --- .../core/models/inference_models_adapters.py | 35 ++++++++++- .../models/common/rle_utils.py | 31 ++++++---- .../models/test_inference_models_adapters.py | 60 +++++++++++++++++++ 3 files changed, 111 insertions(+), 15 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index a09f2665a1..910d033dbb 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -1,5 +1,6 @@ import base64 import io +from math import prod from io import BytesIO from time import perf_counter from typing import Any, List, Optional, Tuple, Union @@ -89,14 +90,44 @@ # Pinned host buffers for async DtoH on the full-postproc Triton fast path. # Keyed by (name, dtype); reused across frames provided the cached buffer is -# at least as large as the requested shape in every dimension. +# at least as large as the requested shape in every dimension. If a buffer +# becomes much larger than the current request, replace it to avoid a permanent +# high-water mark in pinned host memory. PINNED_HOST_BUFFERS: dict = {} +PINNED_HOST_BUFFER_SHRINK_THRESHOLD = 8 + + +def clear_pinned_buffers(name: Optional[str] = None) -> None: + if name is None: + PINNED_HOST_BUFFERS.clear() + return + keys_to_remove = [key for key in PINNED_HOST_BUFFERS if key[0] == name] + for key in keys_to_remove: + PINNED_HOST_BUFFERS.pop(key, None) + + +def _buffer_can_reuse(buf: torch.Tensor, shape: Tuple[int, ...]) -> bool: + return len(buf.shape) == len(shape) and all( + buf.shape[i] >= shape[i] for i in range(len(shape)) + ) + + +def _should_shrink_pinned_buffer(buf: torch.Tensor, shape: Tuple[int, ...]) -> bool: + requested_numel = prod(shape) + if requested_numel == 0: + return False + return buf.numel() >= requested_numel * PINNED_HOST_BUFFER_SHRINK_THRESHOLD def get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor: + shape = tuple(int(dim) for dim in shape) key = (name, dtype) buf = PINNED_HOST_BUFFERS.get(key) - if buf is not None and all(buf.shape[i] >= shape[i] for i in range(len(shape))): + if buf is not None and _buffer_can_reuse(buf=buf, shape=shape): + if _should_shrink_pinned_buffer(buf=buf, shape=shape): + buf = torch.empty(shape, dtype=dtype, pin_memory=True) + PINNED_HOST_BUFFERS[key] = buf + return buf return buf[tuple(slice(0, s) for s in shape)] buf = torch.empty(shape, dtype=dtype, pin_memory=True) PINNED_HOST_BUFFERS[key] = buf diff --git a/inference_models/inference_models/models/common/rle_utils.py b/inference_models/inference_models/models/common/rle_utils.py index dc6970a720..42d6d4875c 100644 --- a/inference_models/inference_models/models/common/rle_utils.py +++ b/inference_models/inference_models/models/common/rle_utils.py @@ -7,11 +7,6 @@ from inference_models.models.base.types import InstancesRLEMasks -def counts_to_coco_rle(counts: list, image_size: tuple) -> dict: - h, w = image_size - return mask_utils.frPyObjects({"counts": counts, "size": [h, w]}, h, w) - - def torch_mask_to_coco_rle(mask: torch.Tensor) -> dict: # Convert to uncompressed run length encoding in GPU # coco tools expect fortran order (column-wise) @@ -22,14 +17,16 @@ def torch_mask_to_coco_rle(mask: torch.Tensor) -> dict: if values[0] == 1: counts.insert(0, 0) - return counts_to_coco_rle(counts=counts, image_size=tuple(mask.shape)) + h, w = mask.shape + return mask_utils.frPyObjects({"counts": counts, "size": [h, w]}, h, w) def numpy_mask_to_coco_rle(mask: np.ndarray) -> dict: mask_bool = np.asarray(mask, dtype=bool) mask_flat = np.ravel(mask_bool, order="F") if mask_flat.size == 0: - return counts_to_coco_rle(counts=[], image_size=tuple(mask_bool.shape)) + h, w = mask_bool.shape + return mask_utils.frPyObjects({"counts": [], "size": [h, w]}, h, w) transitions = np.flatnonzero(mask_flat[1:] != mask_flat[:-1]) + 1 counts = np.diff( np.concatenate( @@ -42,7 +39,8 @@ def numpy_mask_to_coco_rle(mask: np.ndarray) -> dict: ).tolist() if mask_flat[0]: counts.insert(0, 0) - return counts_to_coco_rle(counts=counts, image_size=tuple(mask_bool.shape)) + h, w = mask_bool.shape + return mask_utils.frPyObjects({"counts": counts, "size": [h, w]}, h, w) def unpack_bitpacked_masks_numpy(bitpacked_masks: np.ndarray, width: int) -> np.ndarray: @@ -153,12 +151,19 @@ def _ensure_materialized(self) -> None: return self._ensure_rle_cpu() if self._rle_counts_cpu is not None and self._rle_lengths_cpu is not None: + h, w = self.image_size self._masks = [ - counts_to_coco_rle( - counts=self._rle_counts_cpu[i, : int(self._rle_lengths_cpu[i])] - .astype(np.int64, copy=False) - .tolist(), - image_size=self.image_size, + mask_utils.frPyObjects( + { + "counts": self._rle_counts_cpu[ + i, : int(self._rle_lengths_cpu[i]) + ] + .astype(np.int64, copy=False) + .tolist(), + "size": [h, w], + }, + h, + w, )["counts"] for i in range(self._rle_lengths_cpu.shape[0]) ] diff --git a/tests/inference/unit_tests/core/models/test_inference_models_adapters.py b/tests/inference/unit_tests/core/models/test_inference_models_adapters.py index e5f9b209cc..58636a4364 100644 --- a/tests/inference/unit_tests/core/models/test_inference_models_adapters.py +++ b/tests/inference/unit_tests/core/models/test_inference_models_adapters.py @@ -3,8 +3,12 @@ import pytest import torch +import inference.core.models.inference_models_adapters as adapters from inference.core.exceptions import PostProcessingError from inference.core.models.inference_models_adapters import ( + PINNED_HOST_BUFFERS, + clear_pinned_buffers, + get_pinned_buffer, prepare_classification_response, prepare_multi_label_classification_response, ) @@ -14,6 +18,62 @@ ) +@pytest.fixture(autouse=True) +def clear_pinned_host_buffers() -> None: + clear_pinned_buffers() + yield + clear_pinned_buffers() + + +def _fake_pinned_empty(shape, dtype, pin_memory: bool = False) -> torch.Tensor: + assert pin_memory is True + return torch.zeros(shape, dtype=dtype) + + +def test_get_pinned_buffer_reuses_cached_storage_for_smaller_shape( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(adapters.torch, "empty", _fake_pinned_empty) + + first = get_pinned_buffer("mask", (16, 4), torch.float32) + second = get_pinned_buffer("mask", (8, 4), torch.float32) + + assert first.shape == (16, 4) + assert second.shape == (8, 4) + assert second.data_ptr() == first.data_ptr() + assert tuple(PINNED_HOST_BUFFERS[("mask", torch.float32)].shape) == (16, 4) + + +def test_get_pinned_buffer_shrinks_massively_oversized_cached_buffer( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(adapters.torch, "empty", _fake_pinned_empty) + + first = get_pinned_buffer("mask", (64, 4), torch.float32) + second = get_pinned_buffer("mask", (4, 4), torch.float32) + + assert first.shape == (64, 4) + assert second.shape == (4, 4) + assert second.data_ptr() != first.data_ptr() + assert tuple(PINNED_HOST_BUFFERS[("mask", torch.float32)].shape) == (4, 4) + + +def test_clear_pinned_buffers_clears_all_or_single_named_buffer( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(adapters.torch, "empty", _fake_pinned_empty) + + get_pinned_buffer("mask", (8, 8), torch.float32) + get_pinned_buffer("xyxy", (8, 4), torch.int32) + + clear_pinned_buffers(name="mask") + assert ("mask", torch.float32) not in PINNED_HOST_BUFFERS + assert ("xyxy", torch.int32) in PINNED_HOST_BUFFERS + + clear_pinned_buffers() + assert PINNED_HOST_BUFFERS == {} + + def test_prepare_multi_label_response_uses_class_ids_for_predicted_classes() -> None: """The model's `post_process` is the source of truth for which classes are "predicted" (it owns the priority chain user → per-class → global From d420ab5bdb5f45453056c917f9c523be732e057a Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 26 May 2026 16:45:21 -0700 Subject: [PATCH 22/25] revert import changes --- inference/core/models/inference_models_adapters.py | 4 +--- inference_models/inference_models/models/common/rle_utils.py | 4 +++- inference_models/inference_models/models/rfdetr/common.py | 5 ++++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 910d033dbb..282afd7d62 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -90,9 +90,7 @@ # Pinned host buffers for async DtoH on the full-postproc Triton fast path. # Keyed by (name, dtype); reused across frames provided the cached buffer is -# at least as large as the requested shape in every dimension. If a buffer -# becomes much larger than the current request, replace it to avoid a permanent -# high-water mark in pinned host memory. +# at least as large as the requested shape in every dimension. PINNED_HOST_BUFFERS: dict = {} PINNED_HOST_BUFFER_SHRINK_THRESHOLD = 8 diff --git a/inference_models/inference_models/models/common/rle_utils.py b/inference_models/inference_models/models/common/rle_utils.py index 42d6d4875c..d723fb1129 100644 --- a/inference_models/inference_models/models/common/rle_utils.py +++ b/inference_models/inference_models/models/common/rle_utils.py @@ -18,7 +18,9 @@ def torch_mask_to_coco_rle(mask: torch.Tensor) -> dict: counts.insert(0, 0) h, w = mask.shape - return mask_utils.frPyObjects({"counts": counts, "size": [h, w]}, h, w) + # compress + rle = mask_utils.frPyObjects({"counts": counts, "size": [h, w]}, h, w) + return rle def numpy_mask_to_coco_rle(mask: np.ndarray) -> dict: diff --git a/inference_models/inference_models/models/rfdetr/common.py b/inference_models/inference_models/models/rfdetr/common.py index d22619d6a1..0a37ab3322 100644 --- a/inference_models/inference_models/models/rfdetr/common.py +++ b/inference_models/inference_models/models/rfdetr/common.py @@ -1,11 +1,14 @@ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch +from torchvision.transforms import functional from inference_models import Detections, InstanceDetections, InstancesRLEMasks +from inference_models.entities import ImageDimensions from inference_models.errors import CorruptedModelPackageError from inference_models.models.common.roboflow.model_packages import ( PreProcessingMetadata, + StaticCropOffset, ) from inference_models.models.common.rle_utils import LazyInstancesRLEMasks from inference_models.models.common.roboflow.post_processing import ( From 3c2fa90257860fabf0c872d3f4387859496d1482 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 26 May 2026 16:47:39 -0700 Subject: [PATCH 23/25] make style make check_code_quality --- inference/core/models/inference_models_adapters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/core/models/inference_models_adapters.py b/inference/core/models/inference_models_adapters.py index 282afd7d62..132233a528 100644 --- a/inference/core/models/inference_models_adapters.py +++ b/inference/core/models/inference_models_adapters.py @@ -1,7 +1,7 @@ import base64 import io -from math import prod from io import BytesIO +from math import prod from time import perf_counter from typing import Any, List, Optional, Tuple, Union From 9cb403320a8d100de96fb6ad93c5a71ed6d0e2b2 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Wed, 27 May 2026 23:06:38 +0000 Subject: [PATCH 24/25] patch benchmark scripts to work on jetson --- .../rfdetr_nano_seg_trt_workflow.py | 122 +++++++++++++++++- temp/postproc_microbench.py | 75 +++++++++-- 2 files changed, 185 insertions(+), 12 deletions(-) diff --git a/development/stream_interface/rfdetr_nano_seg_trt_workflow.py b/development/stream_interface/rfdetr_nano_seg_trt_workflow.py index 9c213a8639..330cdfa70b 100644 --- a/development/stream_interface/rfdetr_nano_seg_trt_workflow.py +++ b/development/stream_interface/rfdetr_nano_seg_trt_workflow.py @@ -9,10 +9,17 @@ `DISABLED_INFERENCE_MODELS_BACKENDS` to every backend except the chosen one, so the benchmark numbers correspond unambiguously to a single execution path. +When `RFDETR_TRITON_POSTPROC=true`, the script also wires the local TRT package +layout used by the RF-DETR Triton post-processing integration path. + Defaults: rfdetr-seg-nano @ confidence 0.4 on the native TRT backend. """ import argparse +import importlib.util +import json import os +from pathlib import Path +import sys _ALL_BACKENDS = { "torch", @@ -24,6 +31,46 @@ "mediapipe", "custom", } +_DEFAULT_MODEL_ID = "rfdetr-seg-nano" +_PREFERRED_LOCAL_TRT_PACKAGE = "rfdetr-seg-nano-orin-trt-package" +_LOCAL_WORKFLOW_MODEL_ID = f"{_DEFAULT_MODEL_ID}/1" +_REPO_ROOT = Path(__file__).resolve().parents[2] +_INFERENCE_MODELS_ROOT = _REPO_ROOT / "inference_models" + + +def _str2bool(value: str | None) -> bool: + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +_RFDETR_TRITON_POSTPROC = _str2bool(os.getenv("RFDETR_TRITON_POSTPROC")) + + +def _is_local_trt_package(path: Path) -> bool: + if not path.is_dir(): + return False + required_files = ("engine.plan", "model_config.json", "inference_config.json") + if not all((path / file_name).is_file() for file_name in required_files): + return False + try: + model_config = json.loads((path / "model_config.json").read_text()) + except (OSError, json.JSONDecodeError): + return False + return model_config.get("backend_type") == "trt" + + +def _find_local_trt_package() -> str | None: + preferred = Path.cwd() / _PREFERRED_LOCAL_TRT_PACKAGE + if _is_local_trt_package(preferred): + return str(preferred.resolve()) + + candidates = sorted( + path.resolve() for path in Path.cwd().iterdir() if _is_local_trt_package(path) + ) + if len(candidates) == 1: + return str(candidates[0]) + return None + + def _select_backend_from_argv() -> str: pre = argparse.ArgumentParser(add_help=False) pre.add_argument("--backend", choices=("trt", "onnx", "torch"), default="trt") @@ -32,6 +79,13 @@ def _select_backend_from_argv() -> str: _BACKEND = _select_backend_from_argv() +_LOCAL_TRT_PACKAGE = None +if _BACKEND == "trt": + _LOCAL_TRT_PACKAGE = _find_local_trt_package() + if _LOCAL_TRT_PACKAGE is not None: + os.environ.setdefault( + "ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES", "True" + ) os.environ.setdefault( "ONNXRUNTIME_EXECUTION_PROVIDERS", "[TensorrtExecutionProvider,CUDAExecutionProvider,CPUExecutionProvider]", @@ -40,9 +94,64 @@ def _select_backend_from_argv() -> str: sorted(_ALL_BACKENDS - {_BACKEND}) ) +if _RFDETR_TRITON_POSTPROC: + for path in (str(_INFERENCE_MODELS_ROOT), str(_REPO_ROOT)): + if path not in sys.path: + sys.path.insert(0, path) + for module_name in list(sys.modules): + if module_name == "inference" or module_name.startswith("inference."): + del sys.modules[module_name] + if module_name == "inference_models" or module_name.startswith( + "inference_models." + ): + del sys.modules[module_name] + from time import perf_counter -from inference import InferencePipeline +if _RFDETR_TRITON_POSTPROC: + local_inference_spec = importlib.util.spec_from_file_location( + "inference", + _REPO_ROOT / "inference" / "__init__.py", + submodule_search_locations=[str(_REPO_ROOT / "inference")], + ) + if local_inference_spec is None or local_inference_spec.loader is None: + raise RuntimeError("Could not load local inference package") + local_inference_module = importlib.util.module_from_spec(local_inference_spec) + sys.modules["inference"] = local_inference_module + local_inference_spec.loader.exec_module(local_inference_module) + InferencePipeline = local_inference_module.InferencePipeline +else: + from inference import InferencePipeline + + +def _resolve_model_id(model_id: str, backend: str) -> str: + if ( + backend == "trt" + and model_id == _DEFAULT_MODEL_ID + and _LOCAL_TRT_PACKAGE + ): + return _LOCAL_WORKFLOW_MODEL_ID + return model_id + + +def _prepare_local_workflow_model_bundle(model_id: str) -> None: + if _LOCAL_TRT_PACKAGE is None or model_id != _LOCAL_WORKFLOW_MODEL_ID: + return + + model_dir = Path(model_id) + model_dir.parent.mkdir(parents=True, exist_ok=True) + target_dir = Path(_LOCAL_TRT_PACKAGE) + if not model_dir.exists(): + model_dir.symlink_to(target_dir, target_is_directory=True) + + model_cache_dir = Path(os.environ.get("MODEL_CACHE_DIR", "/tmp/cache")) / model_id + model_cache_dir.mkdir(parents=True, exist_ok=True) + model_type_path = model_cache_dir / "model_type.json" + model_metadata = { + "project_task_type": "instance-segmentation", + "model_type": "rfdetr-seg-nano", + } + model_type_path.write_text(json.dumps(model_metadata, indent=4)) def build_workflow(model_id: str, confidence: float) -> dict: @@ -89,7 +198,7 @@ def sink(predictions, _video_frames) -> None: def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--video_reference", required=True) - parser.add_argument("--model_id", default="rfdetr-seg-nano") + parser.add_argument("--model_id", default=_DEFAULT_MODEL_ID) parser.add_argument("--confidence", type=float, default=0.4) parser.add_argument( "--backend", @@ -98,10 +207,17 @@ def main() -> None: help="inference-models backend (consumed pre-import via env var).", ) args = parser.parse_args() + model_id = _resolve_model_id(args.model_id, args.backend) + _prepare_local_workflow_model_bundle(model_id) + if model_id != args.model_id: + print( + f"[model] using local TRT package via workflow model id: {model_id}", + flush=True, + ) pipeline = InferencePipeline.init_with_workflow( video_reference=args.video_reference, - workflow_specification=build_workflow(args.model_id, args.confidence), + workflow_specification=build_workflow(model_id, args.confidence), on_prediction=sink, ) pipeline.start() diff --git a/temp/postproc_microbench.py b/temp/postproc_microbench.py index 9b857214ad..9da6be55c9 100644 --- a/temp/postproc_microbench.py +++ b/temp/postproc_microbench.py @@ -6,11 +6,13 @@ """ import argparse +import json import os import time from pathlib import Path import cv2 +import numpy as np import torch os.environ.setdefault( @@ -18,11 +20,45 @@ "torch,torch-script,onnx,hugging-face,ultralytics,mediapipe,custom", ) +DEFAULT_MODEL_ID = "rfdetr-seg-nano" +PREFERRED_LOCAL_TRT_PACKAGE = "rfdetr-seg-nano-orin-trt-package" + + +def _is_local_trt_package(path: Path) -> bool: + if not path.is_dir(): + return False + required_files = ("engine.plan", "model_config.json", "inference_config.json") + if not all((path / file_name).is_file() for file_name in required_files): + return False + try: + model_config = json.loads((path / "model_config.json").read_text()) + except (OSError, json.JSONDecodeError): + return False + return model_config.get("backend_type") == "trt" + + +def _find_local_trt_package() -> str | None: + preferred = Path.cwd() / PREFERRED_LOCAL_TRT_PACKAGE + if _is_local_trt_package(preferred): + return str(preferred.resolve()) + + candidates = sorted( + path.resolve() for path in Path.cwd().iterdir() if _is_local_trt_package(path) + ) + if len(candidates) == 1: + return str(candidates[0]) + return None + + +LOCAL_TRT_PACKAGE = _find_local_trt_package() +if LOCAL_TRT_PACKAGE is not None: + os.environ.setdefault("ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES", "True") + from inference_models import AutoModel DEFAULT_SIZES = ("176x312", "720x1280", "1080x1920") -DEFAULT_VIDEO = Path("/home/ubuntu/inference/vehicles_312px.mp4") +DEFAULT_VIDEO = Path("vehicles_312px.mp4") def _parse_hw(spec: str) -> tuple[int, int]: @@ -66,25 +102,39 @@ def _benchmark_case( warmup: int, iterations: int, ): + latencies_ms = [] with torch.inference_mode(): for _ in range(warmup): det = model.post_process(outputs, metadata, confidence=confidence)[0] _sync_detection(det) - start = time.perf_counter() det_count = 0 for _ in range(iterations): + start = time.perf_counter() det = model.post_process(outputs, metadata, confidence=confidence)[0] _sync_detection(det) + latencies_ms.append((time.perf_counter() - start) * 1000.0) det_count += int(det.class_id.numel()) - elapsed = time.perf_counter() - start - return (elapsed * 1000.0) / iterations, det_count // max(1, iterations) + samples = np.array(latencies_ms, dtype=np.float64) + return { + "mean_ms": float(samples.mean()), + "p50_ms": float(np.percentile(samples, 50)), + "p90_ms": float(np.percentile(samples, 90)), + "p95_ms": float(np.percentile(samples, 95)), + "detections": det_count // max(1, iterations), + } + + +def _resolve_model_id(model_id: str) -> str: + if model_id == DEFAULT_MODEL_ID and LOCAL_TRT_PACKAGE is not None: + return LOCAL_TRT_PACKAGE + return model_id def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--video_reference", type=Path, default=DEFAULT_VIDEO) - parser.add_argument("--model_id", default="rfdetr-seg-nano") + parser.add_argument("--model_id", default=DEFAULT_MODEL_ID) parser.add_argument("--confidence", type=float, default=0.4) parser.add_argument("--warmup", type=int, default=25) parser.add_argument("--iterations", type=int, default=200) @@ -102,15 +152,18 @@ def main() -> None: ) frame = _read_seed_frame(args.video_reference) - model = AutoModel.from_pretrained(args.model_id) + model_id = _resolve_model_id(args.model_id) + if model_id != args.model_id: + print(f"[model] using local TRT package: {model_id}", flush=True) + model = AutoModel.from_pretrained(model_id) cases = [] for height, width in sizes: outputs, metadata = _prepare_case(model, frame, height, width) cases.append((height, width, outputs, metadata)) - print("size,detections,mean_ms", flush=True) + print("size,detections,mean_ms,p50_ms,p90_ms,p95_ms", flush=True) for height, width, outputs, metadata in cases: - mean_ms, detections = _benchmark_case( + stats = _benchmark_case( model=model, outputs=outputs, metadata=metadata, @@ -118,7 +171,11 @@ def main() -> None: warmup=args.warmup, iterations=args.iterations, ) - print(f"{height}x{width},{detections},{mean_ms:.4f}", flush=True) + print( + f"{height}x{width},{stats['detections']},{stats['mean_ms']:.4f}," + f"{stats['p50_ms']:.4f},{stats['p90_ms']:.4f},{stats['p95_ms']:.4f}", + flush=True, + ) if __name__ == "__main__": From 308c1d8ec5f541f622144b8ff9073f3d60310163 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 28 May 2026 00:18:53 +0000 Subject: [PATCH 25/25] bugfix --- .../rfdetr_nano_seg_trt_workflow.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/development/stream_interface/rfdetr_nano_seg_trt_workflow.py b/development/stream_interface/rfdetr_nano_seg_trt_workflow.py index 330cdfa70b..1c19d54cf7 100644 --- a/development/stream_interface/rfdetr_nano_seg_trt_workflow.py +++ b/development/stream_interface/rfdetr_nano_seg_trt_workflow.py @@ -9,8 +9,11 @@ `DISABLED_INFERENCE_MODELS_BACKENDS` to every backend except the chosen one, so the benchmark numbers correspond unambiguously to a single execution path. -When `RFDETR_TRITON_POSTPROC=true`, the script also wires the local TRT package -layout used by the RF-DETR Triton post-processing integration path. +The benchmark always imports the local repository copies of `inference` and +`inference_models` so `RFDETR_TRITON_POSTPROC=false/true` differs only by the +post-processing flag, not by package provenance. When a local TRT package is +present, the script also wires the local package layout used by the RF-DETR +Triton post-processing integration path. Defaults: rfdetr-seg-nano @ confidence 0.4 on the native TRT backend. """ @@ -94,34 +97,30 @@ def _select_backend_from_argv() -> str: sorted(_ALL_BACKENDS - {_BACKEND}) ) -if _RFDETR_TRITON_POSTPROC: - for path in (str(_INFERENCE_MODELS_ROOT), str(_REPO_ROOT)): - if path not in sys.path: - sys.path.insert(0, path) - for module_name in list(sys.modules): - if module_name == "inference" or module_name.startswith("inference."): - del sys.modules[module_name] - if module_name == "inference_models" or module_name.startswith( - "inference_models." - ): - del sys.modules[module_name] +for path in (str(_INFERENCE_MODELS_ROOT), str(_REPO_ROOT)): + if path not in sys.path: + sys.path.insert(0, path) +for module_name in list(sys.modules): + if module_name == "inference" or module_name.startswith("inference."): + del sys.modules[module_name] + if module_name == "inference_models" or module_name.startswith( + "inference_models." + ): + del sys.modules[module_name] from time import perf_counter -if _RFDETR_TRITON_POSTPROC: - local_inference_spec = importlib.util.spec_from_file_location( - "inference", - _REPO_ROOT / "inference" / "__init__.py", - submodule_search_locations=[str(_REPO_ROOT / "inference")], - ) - if local_inference_spec is None or local_inference_spec.loader is None: - raise RuntimeError("Could not load local inference package") - local_inference_module = importlib.util.module_from_spec(local_inference_spec) - sys.modules["inference"] = local_inference_module - local_inference_spec.loader.exec_module(local_inference_module) - InferencePipeline = local_inference_module.InferencePipeline -else: - from inference import InferencePipeline +local_inference_spec = importlib.util.spec_from_file_location( + "inference", + _REPO_ROOT / "inference" / "__init__.py", + submodule_search_locations=[str(_REPO_ROOT / "inference")], +) +if local_inference_spec is None or local_inference_spec.loader is None: + raise RuntimeError("Could not load local inference package") +local_inference_module = importlib.util.module_from_spec(local_inference_spec) +sys.modules["inference"] = local_inference_module +local_inference_spec.loader.exec_module(local_inference_module) +InferencePipeline = local_inference_module.InferencePipeline def _resolve_model_id(model_id: str, backend: str) -> str: