diff --git a/tripy/tests/backend/api/test_executable.py b/tripy/tests/backend/api/test_executable.py index 3b588b463..40ac8d01a 100644 --- a/tripy/tests/backend/api/test_executable.py +++ b/tripy/tests/backend/api/test_executable.py @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable): assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD assert param.annotation == tp.Tensor - assert signature.return_annotation == tp.Tensor + assert signature.return_annotation == Sequence[tp.Tensor] def test_signature_multiple_return_values(self, multiple_return_executable): signature = inspect.signature(multiple_return_executable) diff --git a/tripy/tests/frontend/test_tensor.py b/tripy/tests/frontend/test_tensor.py index f0a21ded2..837ec8af0 100644 --- a/tripy/tests/frontend/test_tensor.py +++ b/tripy/tests/frontend/test_tensor.py @@ -226,8 +226,7 @@ def test_no_explicit_cast(self): "devices", [ ("cpu", "gpu"), - # TODO(#155) - # ("gpu", "cpu"), + ("gpu", "cpu"), ], ) def test_explicit_copy(self, devices): diff --git a/tripy/tests/integration/test_iota.py b/tripy/tests/integration/test_iota.py index 39df35787..48094c882 100644 --- a/tripy/tests/integration/test_iota.py +++ b/tripy/tests/integration/test_iota.py @@ -82,16 +82,17 @@ def test_iota_like(self, dtype, shape, dim): @pytest.mark.parametrize("dtype", DATA_TYPES.values()) def test_negative_no_casting(self, dtype): - from tripy.frontend.trace.ops.iota import Iota + with tp.logger.use_verbosity("ir"): + from tripy.frontend.trace.ops.iota import Iota - if dtype in [tp.float32, tp.int32, tp.int64]: - pytest.skip("tp.iota() supports float32, int32, and int64 without cast") + if dtype in [tp.float32, tp.int32, tp.int64]: + pytest.skip("tp.iota() supports float32, int32, and int64 without cast") - # TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint - a = tp.ones((2, 2)) - out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype) + # TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint + a = tp.ones((2, 2)) + out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype) - exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values" + exception_str = "InternalError: failed to run compilation on module with symbol name." if dtype == tp.bool: exception_str = "InternalError: failed to run compilation" with helper.raises( diff --git a/tripy/tests/integration/test_quantize.py b/tripy/tests/integration/test_quantize.py index b50293869..9bf96be98 100644 --- a/tripy/tests/integration/test_quantize.py +++ b/tripy/tests/integration/test_quantize.py @@ -117,5 +117,6 @@ def test_non_constant_scale(self): input = tp.ones((4, 4)) scale = tp.ones((4,)) quantized = tp.quantize(input, scale, tp.int8, dim=0) + quantized_int32 = tp.cast(quantized, tp.int32) - assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8))) + assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32))) diff --git a/tripy/tripy/backend/api/compile.py b/tripy/tripy/backend/api/compile.py index 1f6315805..0491f8618 100644 --- a/tripy/tripy/backend/api/compile.py +++ b/tripy/tripy/backend/api/compile.py @@ -196,5 +196,4 @@ def process_arg(name, arg): return Executable( executable, compiled_arg_names, - output_devices=[out.device for out in trace.outputs], ) diff --git a/tripy/tripy/backend/api/executable.py b/tripy/tripy/backend/api/executable.py index 33347b314..8c9771531 100644 --- a/tripy/tripy/backend/api/executable.py +++ b/tripy/tripy/backend/api/executable.py @@ -14,7 +14,7 @@ # limitations under the License. import base64 import inspect -from typing import Sequence, Union +from typing import Sequence, Union, Tuple, Callable import mlir_tensorrt.runtime.api as runtime @@ -37,13 +37,11 @@ class Executable: """ # The constructor is intentionally undocumented because it is not meant to be called by users. - # TODO(#155): output_devices is not needed after they can be queried from executable - def __init__(self, executable, arg_names, output_devices): + def __init__(self, executable, arg_names): self._executable = executable self._executor = Executor(self._executable) self._arg_names = arg_names self._num_expected_args = len(arg_names) - self._output_devices = output_devices self._executable_signature = self._executable.get_signature("main") # Build a signature so the executable works with `inspect.signature` @@ -128,7 +126,7 @@ def add(a, b): tensor.eval() try: - executor_outputs = self._executor.execute(self._output_devices, input_tensors) + executor_outputs = self._executor.execute(input_tensors) except runtime.MTRTException as err: # TODO: Evaluate whether this should be moved into the executor if "function expects a memref type with element type" in str(err): @@ -170,15 +168,22 @@ def add(a, b): output_tensors = output_tensors[0] return output_tensors - def _get_arg_info(self, idx): - arg = self._executable_signature.get_arg(idx) - arg = runtime.MemRefType(arg) - arg_bound = self._executable_signature.get_arg_bound(idx) - shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max())) - if len(shape_bounds) == 0: - # For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape - shape_bounds = tuple((x, x) for x in arg.shape) - return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype)) + def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo: + item = runtime.MemRefType(get_item(idx)) + bound = get_bound(idx) + shape_bounds = tuple(zip(bound.min(), bound.max())) + + if not shape_bounds: + # For static shape, fallback to item.shape + shape_bounds = tuple((x, x) for x in item.shape) + + return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype)) + + def _get_arg_info(self, idx: int) -> ArgInfo: + return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound) + + def _get_result_info(self, idx: int) -> ArgInfo: + return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound) def get_input_info(self) -> Sequence[ArgInfo]: """ @@ -221,11 +226,16 @@ def add(a, b): compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)]) print(compiled_add.get_output_info()) """ - output_info = [] - offset = self._executable_signature.get_num_input_args() - for idx in range(self._executable_signature.get_num_output_args()): - output_info.append(self._get_arg_info(idx + offset)) - return output_info + num_input_args = self._executable_signature.get_num_input_args() + num_output_args = self._executable_signature.get_num_output_args() + num_results = self._executable_signature.get_num_results() + + assert not (num_output_args and num_results), "Cannot have both output arguments and results" + + if num_output_args: + return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)] + else: + return [self._get_result_info(idx) for idx in range(num_results)] def save(self, path: str) -> None: """ @@ -289,7 +299,6 @@ def add(a, b): def encode_executable(executable): return { "arg_names": executable._arg_names, - "output_devices": executable._output_devices, "executable": base64.b64encode(executable._executable.serialize()).decode(), } @@ -300,5 +309,4 @@ def decode_executable(executable_dict): return Executable( runtime.Executable(executable_bytes), executable_dict["arg_names"], - executable_dict["output_devices"], ) diff --git a/tripy/tripy/backend/mlir/executor.py b/tripy/tripy/backend/mlir/executor.py index 1ec0e7fa5..447ad2cfa 100644 --- a/tripy/tripy/backend/mlir/executor.py +++ b/tripy/tripy/backend/mlir/executor.py @@ -41,7 +41,7 @@ def __init__(self, executable: runtime.Executable) -> None: self.signature = executable.get_signature("main") self.stream = default_stream() - def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]: + def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]: in_args = [] for inp in inputs: memref = inp.trace_tensor.producer.data diff --git a/tripy/tripy/frontend/tensor.py b/tripy/tripy/frontend/tensor.py index 87d84d388..5c5e2983c 100644 --- a/tripy/tripy/frontend/tensor.py +++ b/tripy/tripy/frontend/tensor.py @@ -188,7 +188,7 @@ def eval(self) -> runtime.MemRefValue: self.executor = Executor(executable) # Upon computing the value of this tensor, we switch it to have a `Storage` # parameter so that it does not need to be computed again. - data = self.executor.execute([out.device for out in flat_ir.outputs]) + data = self.executor.execute() self.executor.stream.synchronize() assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor" data = data[0]