diff --git a/inference_models/inference_models/models/common/trt.py b/inference_models/inference_models/models/common/trt.py index 2d62408930..7e78c907be 100644 --- a/inference_models/inference_models/models/common/trt.py +++ b/inference_models/inference_models/models/common/trt.py @@ -78,6 +78,7 @@ class TRTCudaGraphState: input_buffer: torch.Tensor output_buffers: List[torch.Tensor] execution_context: trt.IExecutionContext + consumer_done_event: Optional["torch.cuda.Event"] = None class TRTCudaGraphCache: @@ -571,7 +572,10 @@ def infer_from_trt_engine( outputs=outputs, trt_cuda_graph_cache=trt_cuda_graph_cache, ) - stream.synchronize() + graph_state = getattr(results[0], "_trt_graph_state", None) + produce_event = torch.cuda.Event() + produce_event.record(graph_state.cuda_stream if graph_state is not None else stream) + results[0]._trt_produce_event = produce_event return results @@ -692,12 +696,16 @@ def _execute_trt_engine( if cache_key not in trt_cuda_graph_cache: LOGGER.debug("Capturing CUDA graph for shape %s", input_shape) + reuse_input = bool( + getattr(pre_processed_images, "_trt_reuse_as_input_buffer", False) + ) results, trt_cuda_graph = _capture_cuda_graph( pre_processed_images=pre_processed_images, engine=engine, device=device, input_name=input_name, outputs=outputs, + use_pre_processed_images_as_input_buffer=reuse_input, ) trt_cuda_graph_cache[cache_key] = trt_cuda_graph return results @@ -705,11 +713,18 @@ def _execute_trt_engine( else: trt_cuda_graph_state = trt_cuda_graph_cache[cache_key] stream = trt_cuda_graph_state.cuda_stream + consumer_done = trt_cuda_graph_state.consumer_done_event + if consumer_done is not None: + consumer_done.wait(stream) with torch.cuda.stream(stream): - trt_cuda_graph_state.input_buffer.copy_(pre_processed_images) + if ( + trt_cuda_graph_state.input_buffer.data_ptr() + != pre_processed_images.data_ptr() + ): + trt_cuda_graph_state.input_buffer.copy_(pre_processed_images) trt_cuda_graph_state.cuda_graph.replay() - results = [buf.clone() for buf in trt_cuda_graph_state.output_buffers] - stream.synchronize() + results = list(trt_cuda_graph_state.output_buffers) + results[0]._trt_graph_state = trt_cuda_graph_state return results else: @@ -752,14 +767,18 @@ def _capture_cuda_graph( device: torch.device, input_name: str, outputs: List[str], + use_pre_processed_images_as_input_buffer: bool = False, ) -> Tuple[List[torch.Tensor], TRTCudaGraphState]: # Each CUDA graph needs its own execution context. Sharing a single context # across graphs for different input shapes causes TRT to reallocate internal # workspace buffers, invalidating GPU addresses baked into earlier graphs. graph_context = engine.create_execution_context() - input_buffer = torch.empty_like(pre_processed_images, device=device) - input_buffer.copy_(pre_processed_images) + if use_pre_processed_images_as_input_buffer: + input_buffer = pre_processed_images + else: + input_buffer = torch.empty_like(pre_processed_images, device=device) + input_buffer.copy_(pre_processed_images) status = graph_context.set_input_shape( input_name, tuple(pre_processed_images.shape)