Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions development/stream_interface/rfdetr_nano_seg_trt_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Minimal benchmark: RF-DETR instance segmentation through inference-models,
run via InferencePipeline on a single video source.

Workflow has exactly one block — the segmentation model. No annotators, no
buffer strategies, no rate limiting.

The `--backend` flag (trt | onnx | torch) is parsed before importing
`inference` and pins the auto-loader by setting
`DISABLED_INFERENCE_MODELS_BACKENDS` to every backend except the chosen one,
so the benchmark numbers correspond unambiguously to a single execution path.

Defaults: rfdetr-seg-nano @ confidence 0.4 on the native TRT backend.
"""
import argparse
import os

_ALL_BACKENDS = {
"torch",
"torch-script",
"onnx",
"trt",
"hugging-face",
"ultralytics",
"mediapipe",
"custom",
}
def _select_backend_from_argv() -> str:
pre = argparse.ArgumentParser(add_help=False)
pre.add_argument("--backend", choices=("trt", "onnx", "torch"), default="trt")
args, _ = pre.parse_known_args()
return args.backend


_BACKEND = _select_backend_from_argv()
os.environ.setdefault(
"ONNXRUNTIME_EXECUTION_PROVIDERS",
"[TensorrtExecutionProvider,CUDAExecutionProvider,CPUExecutionProvider]",
)
os.environ["DISABLED_INFERENCE_MODELS_BACKENDS"] = ",".join(
sorted(_ALL_BACKENDS - {_BACKEND})
)

from time import perf_counter

from inference import InferencePipeline


def build_workflow(model_id: str, confidence: float) -> dict:
return {
"version": "1.0",
"inputs": [{"type": "WorkflowImage", "name": "image"}],
"steps": [
{
"type": "roboflow_core/roboflow_instance_segmentation_model@v3",
"name": "segmentation",
"images": "$inputs.image",
"model_id": model_id,
"confidence_mode": "custom",
"custom_confidence": confidence,
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.segmentation.predictions",
},
],
}

FRAME_COUNT = 0
START_TIME = None
PROGRESS_EVERY = 50


def sink(predictions, _video_frames) -> None:
global FRAME_COUNT, START_TIME
del _video_frames
if not isinstance(predictions, list):
predictions = [predictions]
FRAME_COUNT += sum(p is not None for p in predictions)
if START_TIME is None:
START_TIME = perf_counter()
if FRAME_COUNT % PROGRESS_EVERY == 0:
fps = FRAME_COUNT / (perf_counter() - START_TIME)
print(f"[progress] frames={FRAME_COUNT} fps={fps:.2f}", flush=True)


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--video_reference", required=True)
parser.add_argument("--model_id", default="rfdetr-seg-nano")
parser.add_argument("--confidence", type=float, default=0.4)
parser.add_argument(
"--backend",
choices=("trt", "onnx", "torch"),
default="trt",
help="inference-models backend (consumed pre-import via env var).",
)
args = parser.parse_args()

pipeline = InferencePipeline.init_with_workflow(
video_reference=args.video_reference,
workflow_specification=build_workflow(args.model_id, args.confidence),
on_prediction=sink,
)
pipeline.start()
pipeline.join()

elapsed = perf_counter() - START_TIME if START_TIME else 0.0
fps = FRAME_COUNT / elapsed if elapsed > 0 else 0.0
print(f"frames={FRAME_COUNT} elapsed={elapsed:.2f}s fps={fps:.2f}")


if __name__ == "__main__":
main()
63 changes: 59 additions & 4 deletions inference/core/models/inference_models_adapters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import io
import os
from io import BytesIO
from time import perf_counter
from typing import Any, List, Optional, Tuple, Union
Expand All @@ -8,6 +9,26 @@
import torch
from PIL import Image, ImageDraw, ImageFont


# Cache of pinned host buffers for async DtoH, keyed by (name, dtype).
# Pinned memory lets torch's copy_(non_blocking=True) actually run async.
# We grow on first use and reuse thereafter; buffer is sliced to the
# current n_survivors for each copy.
_PINNED_HOST_BUFFERS: dict = {}


def _get_pinned_buffer(name: str, shape, dtype: torch.dtype) -> torch.Tensor:
shape = tuple(int(s) for s in shape)
key = (name, dtype)
buf = _PINNED_HOST_BUFFERS.get(key)
if buf is not None:
# Reuse if the cached buffer is at least as large in every dim.
if all(buf.shape[i] >= shape[i] for i in range(len(shape))):
return buf[tuple(slice(0, s) for s in shape)]
buf = torch.empty(shape, dtype=dtype, pin_memory=True)
_PINNED_HOST_BUFFERS[key] = buf
return buf

from inference.core.entities.requests import (
ClassificationInferenceRequest,
InferenceRequest,
Expand Down Expand Up @@ -308,17 +329,51 @@ def postprocess(
detections_list = self._model.post_process(
predictions, preprocess_return_metadata, **mapped_kwargs
)
gpu_fastpath = os.getenv("RFDETR_GPU_POSTPROCESS", "true").lower() in ("true", "1")

responses: List[InstanceSegmentationInferenceResponse] = []
for preproc_metadata, det in zip(preprocess_return_metadata, detections_list):
H = preproc_metadata.original_size.height
W = preproc_metadata.original_size.width

xyxy = det.xyxy.detach().cpu().numpy()
confs = det.confidence.detach().cpu().numpy()
masks = det.mask.detach().cpu().numpy()
# Fast path: triton_rfdetr_fullpost returns padded buffers plus
# the atomic counter. One .cpu() for the combined (N, 6) int32
# scalar buffer, one .cpu() for the mask buffer, one .item() to
# get n_survivors (the counter DtoH doubles as the sync for the
# in-flight .cpu() calls because they're on the same stream).
# Fast path: single .cpu() of the combined (n, 6) int32 buffer
# plus one .cpu() for the mask buffer. Use pinned host buffers +
# non_blocking=True so both DtoH transfers pipeline on the copy
# engine, then sync once at the end.
combined_gpu = getattr(det, "_combined_gpu", None)
if combined_gpu is not None and det.mask.is_cuda:
mask_gpu = det.mask
combined_host = _get_pinned_buffer(
"combined", combined_gpu.shape, combined_gpu.dtype
)
mask_host = _get_pinned_buffer(
"mask", mask_gpu.shape, mask_gpu.dtype
)
combined_host.copy_(combined_gpu, non_blocking=True)
mask_host.copy_(mask_gpu, non_blocking=True)
torch.cuda.current_stream(combined_gpu.device).synchronize()
combined_cpu = combined_host.numpy()
xyxy = combined_cpu[:, :4]
confs = combined_cpu[:, 4].view(np.float32)
class_ids = combined_cpu[:, 5]
masks = mask_host.numpy()
elif combined_gpu is not None:
combined_cpu = combined_gpu.detach().cpu().numpy()
xyxy = combined_cpu[:, :4]
confs = combined_cpu[:, 4].view(np.float32)
class_ids = combined_cpu[:, 5]
masks = det.mask.detach().cpu().numpy()
else:
xyxy = det.xyxy.detach().cpu().numpy()
confs = det.confidence.detach().cpu().numpy()
class_ids = det.class_id.detach().cpu().numpy()
masks = det.mask.detach().cpu().numpy()
polys = masks2poly(masks)
class_ids = det.class_id.detach().cpu().numpy()

predictions: List[InstanceSegmentationPrediction] = []

Expand Down
111 changes: 100 additions & 11 deletions inference/models/rfdetr/rfdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,37 @@
run_session_via_iobinding,
)
from inference.core.utils.postprocess import mask2poly
from inference.core.utils.environment import str2bool
from inference.core.utils.preprocess import letterbox_image

if USE_PYTORCH_FOR_PREPROCESSING:
import torch

CUDA_IS_AVAILABLE = torch.cuda.is_available()

RFDETR_USE_TRITON_PREPROC = str2bool(os.getenv("RFDETR_USE_TRITON_PREPROC", False))

if RFDETR_USE_TRITON_PREPROC:
try:
import torch

from inference.models.rfdetr.triton_preprocess import (
TRITON_AVAILABLE,
triton_preprocess_rfdetr,
)

_TRITON_PREPROC_READY = TRITON_AVAILABLE and torch.cuda.is_available()
except Exception as _triton_err: # pragma: no cover
logger.warning(
"RFDETR_USE_TRITON_PREPROC=1 but triton preprocess unavailable: %s",
_triton_err,
)
_TRITON_PREPROC_READY = False
triton_preprocess_rfdetr = None
else:
_TRITON_PREPROC_READY = False
triton_preprocess_rfdetr = None

ROBOFLOW_BACKGROUND_CLASS = "background_class83422"


Expand All @@ -73,6 +97,25 @@ class RFDETRObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel):
preprocess_means = [0.485, 0.456, 0.406]
preprocess_stds = [0.229, 0.224, 0.225]

_preproc_cache_initialized: bool = False

def _ensure_preproc_cache(self, device) -> None:
if self._preproc_cache_initialized:
return
means = torch.tensor(self.preprocess_means, device=device, dtype=torch.float32)
stds = torch.tensor(self.preprocess_stds, device=device, dtype=torch.float32)
self._preproc_scale_gpu = (1.0 / (255.0 * stds)).view(1, 3, 1, 1).contiguous()
self._preproc_offset_gpu = (-means / stds).view(1, 3, 1, 1).contiguous()
self._input_buffer = torch.empty(
(1, 3, self.img_size_h, self.img_size_w),
dtype=torch.float32,
device=device,
)
self._preproc_bgr2rgb_idx = torch.tensor(
[2, 1, 0], device=device, dtype=torch.long
)
self._preproc_cache_initialized = True

@property
def weights_file(self) -> str:
"""Gets the weights file for the RFDETR model.
Expand All @@ -82,6 +125,38 @@ def weights_file(self) -> str:
"""
return "weights.onnx"

def _try_triton_preprocess(
self, image: Any
) -> Optional[Tuple["torch.Tensor", Tuple[int, int]]]:
if isinstance(image, np.ndarray):
src = image
elif isinstance(image, InferenceRequestImage):
try:
src, _ = load_image(image, disable_preproc_auto_orient=True)
except Exception:
return None
if not isinstance(src, np.ndarray):
return None
else:
return None

if src.dtype != np.uint8 or src.ndim != 3 or src.shape[2] != 3:
return None
if self._needs_nonsquare_preproc:
return None

orig_h, orig_w = src.shape[0], src.shape[1]
src_gpu = torch.from_numpy(np.ascontiguousarray(src)).cuda(non_blocking=True)
out = triton_preprocess_rfdetr(
src_gpu,
target_h=self.img_size_h,
target_w=self.img_size_w,
means=tuple(self.preprocess_means),
stds=tuple(self.preprocess_stds),
pad_color=114,
)
return out, (orig_h, orig_w)

def preproc_image(
self,
image: Union[Any, InferenceRequestImage],
Expand All @@ -103,6 +178,15 @@ def preproc_image(
Returns:
Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size.
"""
if (
_TRITON_PREPROC_READY
and USE_PYTORCH_FOR_PREPROCESSING
and self.resize_method in ("Fit (grey edges) in", "Stretch to")
):
triton_out = self._try_triton_preprocess(image)
if triton_out is not None:
return triton_out

if isinstance(image, Image.Image) and USE_PYTORCH_FOR_PREPROCESSING:
if CUDA_IS_AVAILABLE:
np_image = torch.from_numpy(np.asarray(image, copy=False)).cuda()
Expand Down Expand Up @@ -135,15 +219,12 @@ def preproc_image(
)
preprocessed_image = preprocessed_image.float()

preprocessed_image /= 255.0

means = torch.tensor(
self.preprocess_means, device=preprocessed_image.device
).view(3, 1, 1)
stds = torch.tensor(
self.preprocess_stds, device=preprocessed_image.device
).view(3, 1, 1)
preprocessed_image = (preprocessed_image - means) / stds
self._ensure_preproc_cache(preprocessed_image.device)
preprocessed_image = torch.addcmul(
self._preproc_offset_gpu,
preprocessed_image,
self._preproc_scale_gpu,
)
else:
preprocessed_image = preprocessed_image.astype(np.float32)
preprocessed_image /= 255.0
Expand Down Expand Up @@ -224,14 +305,22 @@ def preproc_image(
if isinstance(resized, np.ndarray):
resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
else:
resized = resized[:, [2, 1, 0], :, :]
resized = resized.index_select(1, self._preproc_bgr2rgb_idx)

if isinstance(resized, np.ndarray):
img_in = np.transpose(resized, (2, 0, 1))
img_in = img_in.astype(np.float32)
img_in = np.expand_dims(img_in, axis=0)
elif USE_PYTORCH_FOR_PREPROCESSING:
img_in = resized.float()
if (
not self.batching_enabled
and self._preproc_cache_initialized
and resized.shape == self._input_buffer.shape
):
self._input_buffer.copy_(resized, non_blocking=True)
img_in = self._input_buffer
else:
img_in = resized
else:
raise ValueError(
f"Received an image of unknown type, {type(resized)}; "
Expand Down
Loading