Skip to content

Commit

Permalink
Fix tripy tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Nov 4, 2024
1 parent a740e15 commit 15a15bf
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 35 deletions.
2 changes: 1 addition & 1 deletion tripy/tests/backend/api/test_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable):
assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
assert param.annotation == tp.Tensor

assert signature.return_annotation == tp.Tensor
assert signature.return_annotation == Sequence[tp.Tensor]

def test_signature_multiple_return_values(self, multiple_return_executable):
signature = inspect.signature(multiple_return_executable)
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/frontend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def test_no_explicit_cast(self):
"devices",
[
("cpu", "gpu"),
# TODO(#155)
# ("gpu", "cpu"),
("gpu", "cpu"),
],
)
def test_explicit_copy(self, devices):
Expand Down
15 changes: 8 additions & 7 deletions tripy/tests/integration/test_iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ def test_iota_like(self, dtype, shape, dim):

@pytest.mark.parametrize("dtype", DATA_TYPES.values())
def test_negative_no_casting(self, dtype):
from tripy.frontend.trace.ops.iota import Iota
with tp.logger.use_verbosity("ir"):
from tripy.frontend.trace.ops.iota import Iota

if dtype in [tp.float32, tp.int32, tp.int64]:
pytest.skip("tp.iota() supports float32, int32, and int64 without cast")
if dtype in [tp.float32, tp.int32, tp.int64]:
pytest.skip("tp.iota() supports float32, int32, and int64 without cast")

# TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint
a = tp.ones((2, 2))
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)
# TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint
a = tp.ones((2, 2))
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)

exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values"
exception_str = "InternalError: failed to run compilation on module with symbol name."
if dtype == tp.bool:
exception_str = "InternalError: failed to run compilation"
with helper.raises(
Expand Down
3 changes: 2 additions & 1 deletion tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,6 @@ def test_non_constant_scale(self):
input = tp.ones((4, 4))
scale = tp.ones((4,))
quantized = tp.quantize(input, scale, tp.int8, dim=0)
quantized_int32 = tp.cast(quantized, tp.int32)

assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32)))
1 change: 0 additions & 1 deletion tripy/tripy/backend/api/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,4 @@ def process_arg(name, arg):
return Executable(
executable,
compiled_arg_names,
output_devices=[out.device for out in trace.outputs],
)
50 changes: 29 additions & 21 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import base64
import inspect
from typing import Sequence, Union
from typing import Sequence, Union, Tuple, Callable

import mlir_tensorrt.runtime.api as runtime

Expand All @@ -37,13 +37,11 @@ class Executable:
"""

# The constructor is intentionally undocumented because it is not meant to be called by users.
# TODO(#155): output_devices is not needed after they can be queried from executable
def __init__(self, executable, arg_names, output_devices):
def __init__(self, executable, arg_names):
self._executable = executable
self._executor = Executor(self._executable)
self._arg_names = arg_names
self._num_expected_args = len(arg_names)
self._output_devices = output_devices
self._executable_signature = self._executable.get_signature("main")

# Build a signature so the executable works with `inspect.signature`
Expand Down Expand Up @@ -128,7 +126,7 @@ def add(a, b):
tensor.eval()

try:
executor_outputs = self._executor.execute(self._output_devices, input_tensors)
executor_outputs = self._executor.execute(input_tensors)
except runtime.MTRTException as err:
# TODO: Evaluate whether this should be moved into the executor
if "function expects a memref type with element type" in str(err):
Expand Down Expand Up @@ -170,15 +168,22 @@ def add(a, b):
output_tensors = output_tensors[0]
return output_tensors

def _get_arg_info(self, idx):
arg = self._executable_signature.get_arg(idx)
arg = runtime.MemRefType(arg)
arg_bound = self._executable_signature.get_arg_bound(idx)
shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max()))
if len(shape_bounds) == 0:
# For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape
shape_bounds = tuple((x, x) for x in arg.shape)
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo:
item = runtime.MemRefType(get_item(idx))
bound = get_bound(idx)
shape_bounds = tuple(zip(bound.min(), bound.max()))

if not shape_bounds:
# For static shape, fallback to item.shape
shape_bounds = tuple((x, x) for x in item.shape)

return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype))

def _get_arg_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound)

def _get_result_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound)

def get_input_info(self) -> Sequence[ArgInfo]:
"""
Expand Down Expand Up @@ -221,11 +226,16 @@ def add(a, b):
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
print(compiled_add.get_output_info())
"""
output_info = []
offset = self._executable_signature.get_num_input_args()
for idx in range(self._executable_signature.get_num_output_args()):
output_info.append(self._get_arg_info(idx + offset))
return output_info
num_input_args = self._executable_signature.get_num_input_args()
num_output_args = self._executable_signature.get_num_output_args()
num_results = self._executable_signature.get_num_results()

assert not (num_output_args and num_results), "Cannot have both output arguments and results"

if num_output_args:
return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)]
else:
return [self._get_result_info(idx) for idx in range(num_results)]

def save(self, path: str) -> None:
"""
Expand Down Expand Up @@ -289,7 +299,6 @@ def add(a, b):
def encode_executable(executable):
return {
"arg_names": executable._arg_names,
"output_devices": executable._output_devices,
"executable": base64.b64encode(executable._executable.serialize()).decode(),
}

Expand All @@ -300,5 +309,4 @@ def decode_executable(executable_dict):
return Executable(
runtime.Executable(executable_bytes),
executable_dict["arg_names"],
executable_dict["output_devices"],
)
2 changes: 1 addition & 1 deletion tripy/tripy/backend/mlir/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, executable: runtime.Executable) -> None:
self.signature = executable.get_signature("main")
self.stream = default_stream()

def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
in_args = []
for inp in inputs:
memref = inp.trace_tensor.producer.data
Expand Down
2 changes: 1 addition & 1 deletion tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def eval(self) -> runtime.MemRefValue:
self.executor = Executor(executable)
# Upon computing the value of this tensor, we switch it to have a `Storage`
# parameter so that it does not need to be computed again.
data = self.executor.execute([out.device for out in flat_ir.outputs])
data = self.executor.execute()
self.executor.stream.synchronize()
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
data = data[0]
Expand Down

0 comments on commit 15a15bf

Please sign in to comment.