Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul backend function execution for improved performance and flexibility #270

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tripy/tests/frontend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def test_no_explicit_cast(self):
"devices",
[
("cpu", "gpu"),
# TODO(#155)
# ("gpu", "cpu"),
("gpu", "cpu"),
],
)
def test_explicit_copy(self, devices):
Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/integration/test_iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_negative_no_casting(self, dtype):
a = tp.ones((2, 2))
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)

exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values"
exception_str = "InternalError: failed to run compilation on module with symbol name."
if dtype == tp.bool:
exception_str = "InternalError: failed to run compilation"
with helper.raises(
Expand Down
3 changes: 2 additions & 1 deletion tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,6 @@ def test_non_constant_scale(self):
input = tp.ones((4, 4))
scale = tp.ones((4,))
quantized = tp.quantize(input, scale, tp.int8, dim=0)
quantized_int32 = tp.cast(quantized, tp.int32)

assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32)))
jhalakpatel marked this conversation as resolved.
Show resolved Hide resolved
1 change: 0 additions & 1 deletion tripy/tripy/backend/api/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,4 @@ def process_arg(name, arg):
return Executable(
executable,
compiled_arg_names,
output_devices=[out.device for out in trace.outputs],
)
46 changes: 24 additions & 22 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import base64
import inspect
from typing import Sequence, Union
from typing import Sequence, Union, Tuple, Callable

import mlir_tensorrt.runtime.api as runtime

Expand All @@ -37,21 +37,19 @@ class Executable:
"""

# The constructor is intentionally undocumented because it is not meant to be called by users.
# TODO(#155): output_devices is not needed after they can be queried from executable
def __init__(self, executable, arg_names, output_devices):
def __init__(self, executable, arg_names):
self._executable = executable
self._executor = Executor(self._executable)
self._arg_names = arg_names
self._num_expected_args = len(arg_names)
self._output_devices = output_devices
self._executable_signature = self._executable.get_signature("main")

# Build a signature so the executable works with `inspect.signature`
params = []
for name in self._arg_names:
params.append(inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Tensor))

return_annotation = Tensor if self._executable_signature.get_num_output_args() == 1 else Sequence[Tensor]
return_annotation = Tensor if self._executable_signature.get_num_results() == 1 else Sequence[Tensor]

self.__signature__ = inspect.Signature(params, return_annotation=return_annotation)

Expand Down Expand Up @@ -128,7 +126,7 @@ def add(a, b):
tensor.eval()

try:
executor_outputs = self._executor.execute(self._output_devices, input_tensors)
executor_outputs = self._executor.execute(input_tensors)
except runtime.MTRTException as err:
# TODO: Evaluate whether this should be moved into the executor
if "function expects a memref type with element type" in str(err):
Expand Down Expand Up @@ -170,15 +168,22 @@ def add(a, b):
output_tensors = output_tensors[0]
return output_tensors

def _get_arg_info(self, idx):
arg = self._executable_signature.get_arg(idx)
arg = runtime.MemRefType(arg)
arg_bound = self._executable_signature.get_arg_bound(idx)
shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max()))
if len(shape_bounds) == 0:
# For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape
shape_bounds = tuple((x, x) for x in arg.shape)
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo:
item = runtime.MemRefType(get_item(idx))
bound = get_bound(idx)
shape_bounds = tuple(zip(bound.min(), bound.max()))

if not shape_bounds:
# For static shape, fallback to item.shape
shape_bounds = tuple((x, x) for x in item.shape)

return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype))

def _get_arg_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound)

def _get_result_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound)

def get_input_info(self) -> Sequence[ArgInfo]:
"""
Expand Down Expand Up @@ -221,11 +226,10 @@ def add(a, b):
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
print(compiled_add.get_output_info())
"""
output_info = []
offset = self._executable_signature.get_num_input_args()
for idx in range(self._executable_signature.get_num_output_args()):
output_info.append(self._get_arg_info(idx + offset))
return output_info
num_input_args = self._executable_signature.get_num_input_args()
num_results = self._executable_signature.get_num_results()

return [self._get_result_info(idx) for idx in range(num_results)]

def save(self, path: str) -> None:
"""
Expand Down Expand Up @@ -289,7 +293,6 @@ def add(a, b):
def encode_executable(executable):
return {
"arg_names": executable._arg_names,
"output_devices": executable._output_devices,
"executable": base64.b64encode(executable._executable.serialize()).decode(),
}

Expand All @@ -300,5 +303,4 @@ def decode_executable(executable_dict):
return Executable(
runtime.Executable(executable_bytes),
executable_dict["arg_names"],
executable_dict["output_devices"],
)
1 change: 1 addition & 0 deletions tripy/tripy/backend/mlir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level):
f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}",
f"--tensorrt-builder-opt-level={trt_builder_opt_level}",
"--tensorrt-strongly-typed=True",
"--enable-non-dps-returns",
]
if config.enable_mlir_debug or config.enable_tensorrt_debug:
opts.append("--debug=true")
Expand Down
117 changes: 3 additions & 114 deletions tripy/tripy/backend/mlir/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,89 +31,14 @@

class Executor:
def __init__(self, executable: runtime.Executable) -> None:

self.runtime_client = MLIRRuntimeClient()
session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0)
self.session = runtime.RuntimeSession(session_options, executable)
self.device = self.runtime_client.get_devices()[0] # Assume a single device is available.
self.signature = executable.get_signature("main")
self.stream = default_stream()
self.num_input_args = self.signature.get_num_input_args()
self.num_output_args = self.signature.get_num_output_args()
self.output_args = [
self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args)
]
self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args]

def _create_shape_memref(self, shape):
shape = make_tuple(shape)
if len(shape) == 0:
return create_memref(
shape=(0,),
dtype=datatype.int64,
device=device("cpu"),
)
return create_memref(
array=convert_list_to_array(shape, datatype.int64),
shape=(len(shape),),
dtype=datatype.int64,
device=device("cpu"),
)

def _get_outputs_shape(self):
outputs_shape = []
all_outputs_known = True
for memref in self.output_memrefs:
outputs_shape.append(memref.shape)
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
return outputs_shape, all_outputs_known

def _get_inputs_runtime_shape(self, inputs):
inputs_shape = []
for input in inputs:
inputs_shape.append(input.trace_tensor.producer.data.shape)
return inputs_shape

def _execute_shape_inference(self, inputs_shape, outputs_shape):
inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape]
outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape]
self.session.execute_function(
name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref
)

outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref]
return outputs_runtime_shape

def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
outputs_tensor_info = []
for index in range(self.num_output_args):
memref = self.output_memrefs[index]
dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype)

output_device = output_devices[index]
if not output_device:
output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0))

runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])]
outputs_tensor_info.append(
TensorInfo(
len(runtime_shape),
tuple(runtime_shape),
dtype,
output_device,
)
)
return outputs_tensor_info

def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]):
outputs_shape, all_outputs_known = self._get_outputs_shape()
if not all_outputs_known:
inputs_shape = self._get_inputs_runtime_shape(inputs)
outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape)
output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices)
return output_tensor_info

def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
in_args = []
for inp in inputs:
memref = inp.trace_tensor.producer.data
Expand All @@ -131,45 +56,9 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->
)
in_args.append(memref)

# HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices)

# Allocate output memory and store buffer pointers.
outputs = [
create_memref(
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
)
for info in out_tensor_info
]

out_args = []
for out in outputs:
memref = out
# HACK (#155): MLIR-TensorRT requires inputs to be on device.
# Remove explicit copy to device once #155 is addressed.
if memref.address_space != runtime.PointerType.device:
memref = self.runtime_client.copy_to_device(
host_memref=memref,
device=self.runtime_client.get_devices()[0],
stream=self.stream._active_cuda_stream,
)
if not memref:
raise_error("Could not allocate output memref", details=memref.error_details)
out_args.append(memref)

# Execute and populate device pointers.
self.session.execute_function(
"main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream
outputs = self.session.execute_function(
"main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client
)

# For outputs that were on the host, do the copy back
# TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
for idx, out_info in enumerate(out_tensor_info):
if out_info.device.kind != "gpu":
self.runtime_client.copy_to_host(
device_memref=out_args[idx],
existing_host_memref=outputs[idx],
stream=self.stream._active_cuda_stream,
)

return outputs
9 changes: 9 additions & 0 deletions tripy/tripy/flat_ir/ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class CopyOp(BaseFlatIROp):

target: tripy.common.device

def set_memory_space_attr(self, tensor, mem_space_attr):
current_type = tensor.type
# Set the encoding attribute on the operation's result
new_type = ir.RankedTensorType.get(current_type.shape, current_type.element_type, encoding=mem_space_attr)
tensor.set_type(new_type)

def to_mlir(self, operands):
from mlir_tensorrt.compiler.dialects import bufferization, tensor, arith

Expand All @@ -46,7 +52,10 @@ def to_mlir(self, operands):
sliced_dims.append(dim)

alloc_tensor = bufferization.alloc_tensor(inp_type, sliced_dims, memory_space=mem_space_attr)
self.set_memory_space_attr(alloc_tensor, mem_space_attr)
result_tensor = bufferization.materialize_in_destination(inp_type, operands[0], alloc_tensor)
self.set_memory_space_attr(result_tensor, mem_space_attr)
cast_tensor = tensor.cast(self.outputs[0].to_mlir(), result_tensor)
self.set_memory_space_attr(cast_tensor, mem_space_attr)

return [cast_tensor]
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/ops/tensor_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
@constraints.dtypes(
constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"},
variables={
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"],
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"],
},
)
def arange(
Expand Down Expand Up @@ -346,7 +346,7 @@ def arange(
@constraints.dtypes(
constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"},
variables={
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"],
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"],
},
)
def arange(
Expand Down
3 changes: 2 additions & 1 deletion tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ def eval(self) -> runtime.MemRefValue:

compiler = Compiler(trt_builder_opt_level=0)
executable = compiler.compile(mlir, flat_ir=flat_ir)
# Ensure that session and client are available as long as tensor lives.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment.

executor = Executor(executable)
# Upon computing the value of this tensor, we switch it to have a `Storage`
# parameter so that it does not need to be computed again.
data = executor.execute([out.device for out in flat_ir.outputs])
data = executor.execute()
executor.stream.synchronize()
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
data = data[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also remove the hack a few lines below?

self.trace_tensor.device = flat_ir.outputs[0].device

Copy link
Collaborator Author

@jhalakpatel jhalakpatel Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still a few issues:

  1. MLIR-TensorRT still allocates inputs only on the device.
  2. This PR fixes the copy operation only.

There are still issues since the input get always allocated on device. There will be a device mismatch between inputs (which are always on device) vs output which could be now on (host as well as device).

    def infer_devices(self):
        """
        Infers devices for the operation and updates output tensor devices accordingly.
        """
        assert (
>           self.inputs and len(self.outputs) == 1 and all(inp.device == self.inputs[0].device for inp in self.inputs)
        ), "Default implementation cannot handle cases where there are no inputs, multiple outputs, or multiple inputs with different devices. Please override."
E       AssertionError: Default implementation cannot handle cases where there are no inputs, multiple outputs, or multiple inputs with different devices. Please override.

Expand Down
2 changes: 1 addition & 1 deletion tripy/tripy/frontend/trace/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def to_flat_ir(self, inputs, outputs):
@constraints.dtypes(
constraints={"input": "T1", "index": "T2", constraints.RETURN_VALUE: "T1"},
variables={
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"],
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
"T2": ["int32"],
},
)
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def iota_impl(shape: "tripy.Tensor", dim: int, dtype: datatype.dtype, output_ran
@constraints.dtypes(
constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"},
variables={
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "bool"],
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
},
)
def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor":
Expand Down Expand Up @@ -101,7 +101,7 @@ def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float3
constraints={"input": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"},
variables={
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
"T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "bool"],
"T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
},
)
def iota_like(input: "tripy.Tensor", dim: int = 0, dtype: Optional[datatype.dtype] = None) -> "tripy.Tensor":
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _arg_min_max_impl(tensor: "tripy.Tensor", kind: ArgMinMax.Kind, dim: Optiona
@export.public_api(document_under="operations/functions")
@constraints.dtypes(
constraints={"input": "T1", constraints.RETURN_VALUE: "T2"},
variables={"T1": ["float32", "float16", "bfloat16", "int32", "bool", "int8"], "T2": ["int32"]},
jhalakpatel marked this conversation as resolved.
Show resolved Hide resolved
variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]},
)
def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "tripy.Tensor":
"""
Expand Down Expand Up @@ -445,7 +445,7 @@ def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = Fal
@export.public_api(document_under="operations/functions")
@constraints.dtypes(
constraints={"input": "T1", constraints.RETURN_VALUE: "T2"},
variables={"T1": ["float32", "float16", "bfloat16", "int32", "bool", "int8"], "T2": ["int32"]},
variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]},
)
def argmin(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "tripy.Tensor":
"""
Expand Down