Skip to content

Commit

Permalink
Update MLIR-TRT execution function to use non-DPS style calling conve…
Browse files Browse the repository at this point in the history
…ntion
  • Loading branch information
jhalakpatel committed Oct 15, 2024
1 parent 60ccc43 commit ba2dd98
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 118 deletions.
1 change: 1 addition & 0 deletions tripy/tripy/backend/mlir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level):
f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}",
f"--tensorrt-builder-opt-level={trt_builder_opt_level}",
"--tensorrt-strongly-typed=True",
"--use-non-dps-call-conv",
]
if config.enable_mlir_debug or config.enable_tensorrt_debug:
opts.append("--debug=true")
Expand Down
118 changes: 3 additions & 115 deletions tripy/tripy/backend/mlir/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,79 +37,6 @@ def __init__(self, executable: runtime.Executable) -> None:
self.device = self.runtime_client.get_devices()[0] # Assume a single device is available.
self.signature = executable.get_signature("main")
self.stream = default_stream()
self.num_input_args = self.signature.get_num_input_args()
self.num_output_args = self.signature.get_num_output_args()
self.output_args = [
self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args)
]
self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args]

def _create_shape_memref(self, shape):
shape = make_tuple(shape)
if len(shape) == 0:
# create an empty memref
return self.runtime_client.create_memref(
shape=(0,), dtype=runtime.runtime.ScalarTypeCode.i64, stream=self.stream._active_cuda_stream
)
return self.runtime_client.create_memref(
convert_list_to_array(shape, datatype.int64),
shape=(len(shape),),
dtype=runtime.ScalarTypeCode.i64,
stream=self.stream._active_cuda_stream,
)

def _get_outputs_shape(self):
outputs_shape = []
all_outputs_known = True
for memref in self.output_memrefs:
outputs_shape.append(memref.shape)
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
return outputs_shape, all_outputs_known

def _get_inputs_runtime_shape(self, inputs):
inputs_shape = []
for input in inputs:
inputs_shape.append(input.trace_tensor.producer.data.shape)
return inputs_shape

def _execute_shape_inference(self, inputs_shape, outputs_shape):
inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape]
outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape]
self.session.execute_function(
name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref
)

outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref]
return outputs_runtime_shape

def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
outputs_tensor_info = []
for index in range(self.num_output_args):
memref = self.output_memrefs[index]
dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype)

output_device = output_devices[index]
if not output_device:
output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0))

runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])]
outputs_tensor_info.append(
TensorInfo(
len(runtime_shape),
tuple(runtime_shape),
dtype,
output_device,
)
)
return outputs_tensor_info

def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]):
outputs_shape, all_outputs_known = self._get_outputs_shape()
if not all_outputs_known:
inputs_shape = self._get_inputs_runtime_shape(inputs)
outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape)
output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices)
return output_tensor_info

def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
in_args = []
Expand All @@ -129,49 +56,10 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->
)
in_args.append(memref)

# HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices)

# Allocate output memory and store buffer pointers.
outputs = [
create_empty_memref(
shape=info.shape,
dtype=info.dtype,
device=info.device,
stream=self.stream._active_cuda_stream,
use_cache=False,
)
for info in out_tensor_info
]

out_args = []
for out in outputs:
memref = out
# HACK (#155): MLIR-TensorRT requires inputs to be on device.
# Remove explicit copy to device once #155 is addressed.
if memref.address_space != runtime.PointerType.device:
memref = self.runtime_client.copy_to_device(
host_memref=memref,
device=self.runtime_client.get_devices()[0],
stream=self.stream._active_cuda_stream,
)
if not memref:
raise_error("Could not allocate output memref", details=memref.error_details)
out_args.append(memref)

# Execute and populate device pointers.
self.session.execute_function(
"main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream
outputs = self.session.execute_function(
"main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client
)

# For outputs that were on the host, do the copy back
# TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
for idx, out_info in enumerate(out_tensor_info):
if out_info.device.kind != "gpu":
self.runtime_client.copy_to_host(
device_memref=out_args[idx],
existing_host_memref=outputs[idx],
stream=self.stream._active_cuda_stream,
)

# For now return results on GPU.
return outputs
6 changes: 3 additions & 3 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,11 @@ def eval(self) -> runtime.MemRefValue:

compiler = Compiler(trt_builder_opt_level=0)
executable = compiler.compile(mlir, flat_ir=flat_ir)
executor = Executor(executable)
self.executor = Executor(executable)
# Upon computing the value of this tensor, we switch it to have a `Storage`
# parameter so that it does not need to be computed again.
data = executor.execute([out.device for out in flat_ir.outputs])
executor.stream.synchronize()
data = self.executor.execute([out.device for out in flat_ir.outputs])
self.executor.stream.synchronize()
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
data = data[0]
# Data is present now. Assign the underlying device type.
Expand Down

0 comments on commit ba2dd98

Please sign in to comment.