Skip to content

🐛 [Bug] Incorrect output when using grid-sample operator with use_explicit_typing since tensorrt 10.14.1.48 #4074

@TNTwise

Description

@TNTwise

Bug Description

After updating to tensorrt 10.14.1.48, Grid sample (forced to float32) building on a float16 model has an incorrect output when building a mixed precision engine. happens on any torch-tensorrt version. Just occurs when switching tensorrt versions

To Reproduce

Steps to reproduce the behavior:

  1. enable use_explicit_typing, if not already enabled.
  2. force grid_sample to run in fp32, while the rest of the network runs in fp16

my model build config

model_trt = torch_tensorrt.dynamo.compile(
                exported_program,
                tuple(inputs),
                device=device,
                enabled_precisions=(torch.float,),
                use_explicit_typing=True,
                num_avg_timing_iters=4,
                workspace_size=0,
                min_block_size=1,
)

This line of code causes the issue:
F.grid_sample( input=tenInput.to(torch.float), grid=g.to(torch.float), mode='bilinear', padding_mode=pd, align_corners=True, ).to(torch.float16)

Image

Expected behavior

(Inferencing video with RIFE) - Correct warping with the frame (tensorrt 13 and below)

Image

Logs when building engine:

2026-02-07 15:00:55,522 INFO torch_tensorrt.dynamo._compiler: Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=4, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=True, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=False, tiling_optimization_level='none', l2_limit_for_tiling=-1, use_distributed_mode_trace=False, offload_module_to_cpu=False, enable_autocast=False, autocast_low_precision_type=None, autocast_excluded_nodes=set(), autocast_excluded_ops=set(), autocast_max_output_threshold=512, autocast_max_depth_of_reduction=None, autocast_calibration_dataloader=None, enable_resource_partitioning=False, cpu_memory_budget=None)

2026-02-07 15:00:56,526 INFO torch_tensorrt.dynamo._compiler: Partitioning the graph via the fast partitioner
2026-02-07 15:00:56,676 INFO torch_tensorrt [TensorRT Conversion Context]: [MemUsageChange] Init CUDA: CPU -2, GPU +0, now: CPU 406, GPU 1933 (MiB)
2026-02-07 15:00:56,811 INFO torch_tensorrt.dynamo.conversion._TRTInterpreter: Converted node img0 [img0] (Inputs: () | Outputs: (img0: (1, 3, 1088, 1920)@torch.float16))
2026-02-07 15:00:56,812 INFO torch_tensorrt.dynamo.conversion._TRTInterpreter: Converted node /clamp [aten.clamp.default] (Inputs: (img0: (1, 3, 1088, 1920)@torch.float16, 0.0, 1.0) | Outputs: (clamp: (1, 3, 1088, 1920)@torch.float16))
2026-02-07 15:00:56,812 INFO torch_tensorrt.dynamo.conversion._TRTInterpreter: Converted node img1 [img1] (Inputs: () | Outputs: (img1: (1, 3, 1088, 1920)@torch.float16))
2026-02-07 15:00:56,813 INFO torch_tensorrt.dynamo.conversion._TRTInterpreter: Converted node /clamp_1 [aten.clamp.default] (Inputs: (img1: (1, 3, 1088, 1920)@torch.float16, 0.0, 1.0) | Outputs: (clamp_1: (1, 3, 1088, 1920)@torch.float16))
2026-02-07 15:00:56,813 INFO torch_tensorrt.dynamo.conversion._TRTInterpreter: Converted node f0 [f0] (Inputs: () | Outputs: (f0: (1, 4, 1088, 1920)@torch.float16))
2026-02-07 15:00:56,813 INFO torch_tensorrt.dynamo.conversion._TRTInterpreter: Converted node f1 [f1] (Inputs: () | Outputs: (f1: (1, 4, 1088, 1920)@torch.float16))
2026-02-07 15:00:56,813 INFO torch_tensorrt.dynamo.conversion._TRTInterpreter: Converted node timestep [timestep] (Inputs: () | Outputs: (timestep: (1, 1, 1088, 1920)@torch.float16))
...

...
Start building engine.
2026-02-07 15:01:01,186 INFO torch_tensorrt [TensorRT Conversion Context]: [MemUsageChange] Init builder kernel library: CPU -391, GPU +0, now: CPU 4026, GPU 2432 (MiB)
2026-02-07 15:01:01,187 INFO torch_tensorrt [TensorRT Conversion Context]: Global timing cache in use. Profiling results in this builder pass will be stored.
2026-02-07 15:01:01,188 INFO torch_tensorrt [TensorRT Conversion Context]: Compiler backend is used during engine build.
2026-02-07 15:01:01,802 INFO torch_tensorrt [TensorRT Conversion Context]: Detected 1 inputs and 1 output network tensors.
2026-02-07 15:01:01,808 INFO torch_tensorrt [TensorRT Conversion Context]: Total Host Persistent Memory: 80 bytes
2026-02-07 15:01:01,808 INFO torch_tensorrt [TensorRT Conversion Context]: Total Device Persistent Memory: 0 bytes
2026-02-07 15:01:01,808 INFO torch_tensorrt [TensorRT Conversion Context]: Max Scratch Memory: 50137088 bytes
2026-02-07 15:01:01,808 INFO torch_tensorrt [TensorRT Conversion Context]: [BlockAssignment] Started assigning block shifts. This will take 1 steps to complete.
2026-02-07 15:01:01,808 INFO torch_tensorrt [TensorRT Conversion Context]: [BlockAssignment] Algorithm ShiftNTopDown took 0.01567ms to assign 1 blocks to 1 nodes requiring 50137088 bytes.
2026-02-07 15:01:01,808 INFO torch_tensorrt [TensorRT Conversion Context]: Total Activation Memory: 50137088 bytes
2026-02-07 15:01:01,809 INFO torch_tensorrt [TensorRT Conversion Context]: Total Weights Memory: 12561536 bytes
2026-02-07 15:01:01,809 INFO torch_tensorrt [TensorRT Conversion Context]: Compiler backend is used during engine execution.
2026-02-07 15:01:01,809 INFO torch_tensorrt [TensorRT Conversion Context]: Engine generation completed in 0.622215 seconds.
2026-02-07 15:01:01,809 INFO torch_tensorrt [TensorRT Conversion Context]: [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 577 MiB
2026-02-07 15:01:01,809 INFO torch_tensorrt.dynamo.conversion._TRTInterpreter: Build TRT engine elapsed time: 0:00:00.807143
2026-02-07 15:01:01,927 INFO torch_tensorrt [TensorRT Conversion Context]: Serialized 6196 bytes of code generator cache.
2026-02-07 15:01:02,044 INFO torch_tensorrt [TensorRT Conversion Context]: Serialized 9769322 bytes of compilation cache.
2026-02-07 15:01:02,044 INFO torch_tensorrt [TensorRT Conversion Context]: Serialized 0 timing cache entries

Environment

  • Torch-TensorRT Version (e.g. 2.10.0):
  • PyTorch Version (e.g. 2.10.0):
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Arch Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12.9
  • CUDA version: 13.0
  • GPU models and configuration: RTX 5090
  • Any other relevant information:

Additional context

This issue doesnt happen when making enabled_precisions strictly float32 or float16. only when using explicit_typing.
No error happens, just the change in the output.

Thank you
I apologize if i am missing any information

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions