Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhuoz004 committed Dec 11, 2024
1 parent 9f3b6c8 commit cd05778
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 14 deletions.
4 changes: 2 additions & 2 deletions tripy/tests/frontend/trace/ops/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_from_list(self):

def test_empty_list(self):
data = [[]]
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data, dtype=tp.float16)
assert storage.dtype == tp.float16
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data)
assert storage.dtype == tp.float32
assert storage.shape == (1, 0)
assert storage.device.kind == "gpu"

Expand Down
1 change: 0 additions & 1 deletion tripy/tripy/backend/mlir/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import mlir_tensorrt.runtime.api as runtime

from tripy.backend.mlir import utils as mlir_utils
from tripy.common import datatype
from tripy.common import device as tp_device
from tripy.utils import raise_error

Expand Down
3 changes: 1 addition & 2 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tripy import export, utils
from tripy.backend.mlir import memref
from tripy.common import datatype
from tripy.common import utils as common_utils
from tripy.common.exception import raise_error, str_from_stack_info
from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY
from tripy.frontend.trace.ops import Storage
Expand Down Expand Up @@ -151,7 +150,7 @@ def raw_init(
data = memref.create_memref_view(data)
Storage.build_internal([], [instance.trace_tensor], data)
else:
Storage.build_internal([], [instance.trace_tensor], data, None, device)
Storage.build_internal([], [instance.trace_tensor], data, device)
# TODO(#155): Remove this hack:
instance.trace_tensor.device = utils.default(device, instance.trace_tensor.device)

Expand Down
16 changes: 7 additions & 9 deletions tripy/tripy/frontend/trace/ops/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
inputs: List["Tensor"],
outputs: List["Tensor"],
data: Union[runtime.MemRefValue, Sequence[numbers.Number]],
dtype: datatype = None,
device: tp_device = None,
) -> None:
super().__init__(inputs, outputs)
Expand All @@ -56,19 +55,18 @@ def __init__(
self.device = tp_device.create_directly(
"gpu" if data.address_space == runtime.PointerType.device else "cpu", 0
)
elif common_utils.is_empty(data):
# special case: empty tensor
self.dtype = utils.default(dtype, datatype.float32)
self.shape = tuple(utils.get_shape(data))
self.data = memref.create_memref(shape=self.shape, dtype=self.dtype)
self.device = utils.default(device, tp_device.create_directly("gpu", 0))
else:
self.dtype = dtype if dtype else common_utils.get_element_type(data)
if common_utils.is_empty(data):
self.dtype = datatype.float32
data_array = None
else:
self.dtype = common_utils.get_element_type(data)
data_array = common_utils.convert_list_to_array(utils.flatten_list(data), dtype=self.dtype)
self.shape = tuple(utils.get_shape(data))
self.data = memref.create_memref(
shape=self.shape,
dtype=self.dtype,
array=common_utils.convert_list_to_array(utils.flatten_list(data), dtype=self.dtype),
array=data_array,
)
self.device = utils.default(device, tp_device.create_directly("gpu", 0))

Expand Down

0 comments on commit cd05778

Please sign in to comment.