Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions inference_models/inference_models/models/common/trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class TRTCudaGraphState:
input_buffer: torch.Tensor
output_buffers: List[torch.Tensor]
execution_context: trt.IExecutionContext
consumer_done_event: Optional["torch.cuda.Event"] = None


class TRTCudaGraphCache:
Expand Down Expand Up @@ -571,7 +572,10 @@ def infer_from_trt_engine(
outputs=outputs,
trt_cuda_graph_cache=trt_cuda_graph_cache,
)
stream.synchronize()
graph_state = getattr(results[0], "_trt_graph_state", None)
produce_event = torch.cuda.Event()
produce_event.record(graph_state.cuda_stream if graph_state is not None else stream)
results[0]._trt_produce_event = produce_event
return results


Expand Down Expand Up @@ -692,24 +696,35 @@ def _execute_trt_engine(
if cache_key not in trt_cuda_graph_cache:
LOGGER.debug("Capturing CUDA graph for shape %s", input_shape)

reuse_input = bool(
getattr(pre_processed_images, "_trt_reuse_as_input_buffer", False)
)
results, trt_cuda_graph = _capture_cuda_graph(
pre_processed_images=pre_processed_images,
engine=engine,
device=device,
input_name=input_name,
outputs=outputs,
use_pre_processed_images_as_input_buffer=reuse_input,
)
trt_cuda_graph_cache[cache_key] = trt_cuda_graph
return results

else:
trt_cuda_graph_state = trt_cuda_graph_cache[cache_key]
stream = trt_cuda_graph_state.cuda_stream
consumer_done = trt_cuda_graph_state.consumer_done_event
if consumer_done is not None:
consumer_done.wait(stream)
with torch.cuda.stream(stream):
trt_cuda_graph_state.input_buffer.copy_(pre_processed_images)
if (
trt_cuda_graph_state.input_buffer.data_ptr()
!= pre_processed_images.data_ptr()
):
trt_cuda_graph_state.input_buffer.copy_(pre_processed_images)
trt_cuda_graph_state.cuda_graph.replay()
results = [buf.clone() for buf in trt_cuda_graph_state.output_buffers]
stream.synchronize()
results = list(trt_cuda_graph_state.output_buffers)
results[0]._trt_graph_state = trt_cuda_graph_state
return results

else:
Expand Down Expand Up @@ -752,14 +767,18 @@ def _capture_cuda_graph(
device: torch.device,
input_name: str,
outputs: List[str],
use_pre_processed_images_as_input_buffer: bool = False,
) -> Tuple[List[torch.Tensor], TRTCudaGraphState]:
# Each CUDA graph needs its own execution context. Sharing a single context
# across graphs for different input shapes causes TRT to reallocate internal
# workspace buffers, invalidating GPU addresses baked into earlier graphs.
graph_context = engine.create_execution_context()

input_buffer = torch.empty_like(pre_processed_images, device=device)
input_buffer.copy_(pre_processed_images)
if use_pre_processed_images_as_input_buffer:
input_buffer = pre_processed_images
else:
input_buffer = torch.empty_like(pre_processed_images, device=device)
input_buffer.copy_(pre_processed_images)

status = graph_context.set_input_shape(
input_name, tuple(pre_processed_images.shape)
Expand Down
Loading