Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bb1abe6
perf(rfdetr-seg): fused Triton preproc kernel for TRT path
aseembits93 May 9, 2026
3e4892a
perf(rfdetr): widen fast path — single integration point + torch+batc…
aseembits93 May 9, 2026
4e08a51
perf(rfdetr): add static_crop support to Triton fast path
aseembits93 May 9, 2026
666f96f
perf(rfdetr): accept CHW-layout uint8 tensors in fast path
aseembits93 May 9, 2026
c48b0ef
perf(rfdetr): add kill switch for Triton preproc fast path
aseembits93 May 11, 2026
7ef57d0
test(rfdetr): repro scripts for preproc parity + bench
aseembits93 May 11, 2026
06a02b2
rename flag and move to env.py
aseembits93 May 12, 2026
338feeb
almost ready
aseembits93 May 12, 2026
63522cc
removing temp benchmark and sanity check files
aseembits93 May 12, 2026
ecc2345
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 12, 2026
442516e
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 12, 2026
8c10378
Update inference_models/inference_models/models/rfdetr/pre_processing.py
aseembits93 May 12, 2026
ca5c5e4
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 12, 2026
5653b96
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 12, 2026
978d80e
move env var to inference_models
aseembits93 May 13, 2026
f0e7717
add seed for reproducibility
aseembits93 May 13, 2026
6b4f745
remove static crop var
aseembits93 May 13, 2026
84459cf
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 13, 2026
03e3aa3
Use model errors in Triton preprocess
aseembits93 May 13, 2026
1a80fd9
Warn when Triton preprocessing is unavailable
aseembits93 May 13, 2026
8ec83eb
Bound RF-DETR resample table cache
aseembits93 May 13, 2026
30b4129
update changelog
aseembits93 May 13, 2026
f23789f
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 18, 2026
5761e0f
remove testing scripts, move changelog update to new version
aseembits93 May 19, 2026
424319c
remove minimal workflow script
aseembits93 May 19, 2026
7e4609f
make style make check_code_quality pass
aseembits93 May 19, 2026
b859a26
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 21, 2026
6e83901
add correctness and integration test
aseembits93 May 22, 2026
8b13056
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 22, 2026
643ad84
typo
aseembits93 May 22, 2026
3d7cc82
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 22, 2026
6136c92
tighter bounds on correctness
aseembits93 May 22, 2026
a11f1b5
Update requirements on inference_models (#2370)
grzegorz-roboflow May 26, 2026
ee72821
Merge branch 'main' into perf/rfdetr-seg-triton-widen-scope
aseembits93 May 26, 2026
e3806f5
default is opt-in for triton preproc
aseembits93 May 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests_inference_experimental.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ jobs:
working-directory: inference_models
run: |
source .venv/bin/activate
python -m pytest -n auto tests/unit_tests
python -m pytest -n auto -m "not gpu_only" tests/unit_tests
2 changes: 1 addition & 1 deletion inference/core/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.2.10"
__version__ = "1.2.11"


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions inference_models/docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## `0.29.0`

### Added

- Custom Triton Kernel for RFDETR pre-processing (Codeflash). Numerical parity with `PIL → F.resize → F.to_tensor → F.normalize` chain.

## `0.28.1`

### Fixed
Expand Down
5 changes: 5 additions & 0 deletions inference_models/inference_models/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,8 @@
"ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND"
)
DEFAULT_ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND = False

DEFAULT_USE_TRITON_FOR_PREPROCESSING = False

USE_TRITON_FOR_PREPROCESSING = get_boolean_from_env(
variable_name="USE_TRITON_FOR_PREPROCESSING", default=DEFAULT_USE_TRITON_FOR_PREPROCESSING)
236 changes: 236 additions & 0 deletions inference_models/inference_models/models/rfdetr/pre_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@

torch.Tensor inputs (advanced caller, float CHW [0, 1]):
tensor F.resize → F.normalize

Triton path: for the common case (single-stage resize, no contrast)
we invoke a single PIL-exact Triton kernel that writes the normalized
fp32 tensor directly to `target_device`.
"""

from functools import lru_cache
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image

from inference_models.configuration import USE_TRITON_FOR_PREPROCESSING
from inference_models import PreProcessingOverrides
from inference_models.entities import ColorFormat, ImageDimensions
from inference_models.errors import ModelRuntimeError
Expand All @@ -40,6 +47,43 @@
pre_process_numpy_image,
)

try:
from inference_models.models.rfdetr.triton_preprocess import (
TRITON_AVAILABLE as _TRITON_AVAILABLE,
build_resample_tables,
triton_preprocess_rfdetr_stretch, ResampleTables,
)
except ImportError:
_TRITON_AVAILABLE = False
build_resample_tables = None
triton_preprocess_rfdetr_stretch = None

if USE_TRITON_FOR_PREPROCESSING and not _TRITON_AVAILABLE:
warnings.warn(
"USE_TRITON_FOR_PREPROCESSING is enabled, but Triton is not available; "
"RF-DETR Triton preprocessing will be disabled.",
RuntimeWarning,
stacklevel=2,
)

@lru_cache(maxsize=50)
def _get_resample_tables_cached(
device_str: str, src_h: int, src_w: int, th: int, tw: int
) -> ResampleTables:
return build_resample_tables(
src_h=src_h,
src_w=src_w,
target_h=th,
target_w=tw,
device=torch.device(device_str),
)


def get_resample_tables(
device: torch.device, src_h: int, src_w: int, th: int, tw: int
) -> ResampleTables:
return _get_resample_tables_cached(str(device), src_h, src_w, th, tw)


def pre_process_network_input(
images: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]],
Expand Down Expand Up @@ -92,6 +136,23 @@ def pre_process_network_input(
else:
image_list = [images]

if triton_path_eligible(
image_list=image_list,
image_pre_processing=image_pre_processing,
network_input=network_input,
target_device=target_device,
pre_processing_overrides=pre_processing_overrides,
):
return triton_path_preprocess(
image_list=image_list,
image_pre_processing=image_pre_processing,
network_input=network_input,
target_size=target_size,
target_device=target_device,
input_color_mode=input_color_mode,
pre_processing_overrides=pre_processing_overrides,
)

tensors: List[torch.Tensor] = []
metadata: List[PreProcessingMetadata] = []
for img in image_list:
Expand Down Expand Up @@ -129,6 +190,181 @@ def pre_process_network_input(
return batch, metadata


def triton_path_eligible(
image_list: List[Union[np.ndarray, torch.Tensor]],
image_pre_processing: ImagePreProcessing,
network_input: NetworkInputDefinition,
target_device: torch.device,
pre_processing_overrides: Optional[PreProcessingOverrides],
) -> bool:
"""True when every item in `image_list` can go through the fused Triton
kernel. Predicate is intentionally conservative — any miss → PIL path."""
if not USE_TRITON_FOR_PREPROCESSING:
return False
if not _TRITON_AVAILABLE:
return False
# Kernel runs on CUDA.
if target_device.type != "cuda":
return False
# grayscale / contrast aren't implemented in the kernel; static_crop is
# (as a load-time offset + effective-dims substitution).
ipp = image_pre_processing
if (ipp.contrast is not None and ipp.contrast.enabled) or (
ipp.grayscale is not None and ipp.grayscale.enabled
):
return False
# Two-stage dataset-version resize isn't in the kernel.
ni = network_input
if _needs_two_step_resize(ni):
return False
if ni.input_channels != 3:
return False
if ni.scaling_factor not in (None, 255):
return False
if ni.normalization is None:
return False
if ni.resize_mode not in (
ResizeMode.STRETCH_TO,
ResizeMode.LETTERBOX,
ResizeMode.CENTER_CROP,
ResizeMode.LETTERBOX_REFLECT_EDGES,
):
return False
if not image_list:
return False
# `pre_process_network_input` already unbinds 4D inputs (batch tensors
# and 4D numpy arrays) into a list of 3D items before calling us, so
# the per-item check below only needs to handle 3D.
for img in image_list:
if isinstance(img, np.ndarray):
if img.dtype != np.uint8 or img.ndim != 3:
return False
if img.shape[2] != 3 and not looks_like_chw(img.shape):
return False
elif isinstance(img, torch.Tensor):
# Only uint8 3-channel images (HWC or CHW). Float tensors keep
# the existing tensor branch (F.resize bilinear *without* PIL's
# antialias — a caller-accepted divergence we don't silently
# change).
if img.dtype != torch.uint8 or img.ndim != 3:
return False
if img.shape[-1] != 3 and not looks_like_chw(img.shape):
return False
else:
return False
return True


def looks_like_chw(shape) -> bool:
"""Matches _tensor_to_hwc_uint8's CHW heuristic: first dim is 1/3/4 and
last dim is not. Catches torchvision.io.read_image's CHW output."""
return (
len(shape) == 3
and shape[0] in (1, 3, 4)
and shape[-1] not in (1, 3, 4)
)


def as_hwc_uint8_cuda(
img: Union[np.ndarray, torch.Tensor], device: torch.device
) -> torch.Tensor:
"""Return a contiguous (H, W, 3) uint8 CUDA tensor, copying if needed.
Accepts HWC or CHW 3D inputs (CHW is torchvision.io.read_image's layout)."""
if isinstance(img, torch.Tensor):
if looks_like_chw(img.shape):
img = img.permute(1, 2, 0)
if img.device != device:
img = img.to(device=device, non_blocking=True)
return img.contiguous()
if looks_like_chw(img.shape):
img = np.transpose(img, (1, 2, 0))
t = torch.from_numpy(np.ascontiguousarray(img))
return t.to(device=device, non_blocking=True)


def triton_path_preprocess(
image_list: List[Union[np.ndarray, torch.Tensor]],
image_pre_processing: ImagePreProcessing,
network_input: NetworkInputDefinition,
target_size: ImageDimensions,
target_device: torch.device,
input_color_mode: Optional[ColorMode],
pre_processing_overrides: Optional[PreProcessingOverrides],
) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
"""Per-image Triton launch + stack. Assumes `triton_path_eligible` passed."""
means, stds = network_input.normalization
means_t = (float(means[0]), float(means[1]), float(means[2]))
stds_t = (float(stds[0]), float(stds[1]), float(stds[2]))
th, tw = target_size.height, target_size.width

# Match _pre_process_numpy's swap semantics: the PIL path does
swap_rb = input_color_mode != network_input.color_mode

# Resolve whether static_crop is active for this call. Computed once
# because the config + override are call-level.
static_crop_overridden = (
pre_processing_overrides is not None
and pre_processing_overrides.disable_static_crop is True
)
crop_cfg = image_pre_processing.static_crop
crop_active = (
crop_cfg is not None and crop_cfg.enabled and not static_crop_overridden
)

outs: List[torch.Tensor] = []
metas: List[PreProcessingMetadata] = []
for img in image_list:
src_gpu = as_hwc_uint8_cuda(img, target_device)
sh, sw = int(src_gpu.shape[0]), int(src_gpu.shape[1])

if crop_active:
# Matches apply_static_crop_to_numpy_image: percentage-based.
x0 = int(crop_cfg.x_min / 100 * sw)
y0 = int(crop_cfg.y_min / 100 * sh)
x1 = int(crop_cfg.x_max / 100 * sw)
y1 = int(crop_cfg.y_max / 100 * sh)
crop_w = x1 - x0
crop_h = y1 - y0
else:
x0 = y0 = 0
crop_w, crop_h = sw, sh

tables = get_resample_tables(target_device, crop_h, crop_w, th, tw)
out = triton_preprocess_rfdetr_stretch(
src=src_gpu,
tables=tables,
target_h=th,
target_w=tw,
means=means_t,
stds=stds_t,
swap_rb=swap_rb,
crop_offset_y=y0,
crop_offset_x=x0,
crop_h=crop_h,
crop_w=crop_w,
)
outs.append(out[0]) # drop leading batch dim to match stack below
metas.append(
PreProcessingMetadata(
pad_left=0,
pad_top=0,
pad_right=0,
pad_bottom=0,
original_size=ImageDimensions(width=sw, height=sh),
size_after_pre_processing=ImageDimensions(width=crop_w, height=crop_h),
inference_size=ImageDimensions(width=tw, height=th),
scale_width=tw / crop_w,
scale_height=th / crop_h,
static_crop_offset=StaticCropOffset(
offset_x=x0, offset_y=y0, crop_width=crop_w, crop_height=crop_h
),
)
)

batch = torch.stack(outs).contiguous()
return batch, metas


def _pre_process_numpy(
image: np.ndarray,
image_pre_processing: ImagePreProcessing,
Expand Down
Loading
Loading