diff --git a/tripy/tests/frontend/trace/ops/test_storage.py b/tripy/tests/frontend/trace/ops/test_storage.py index a2daf9937..991934482 100644 --- a/tripy/tests/frontend/trace/ops/test_storage.py +++ b/tripy/tests/frontend/trace/ops/test_storage.py @@ -34,7 +34,6 @@ def test_from_memref(self, device): module = np if device == "cpu" else cp data = memref.create_memref_view(module.ones((2, 2), dtype=module.float32)) storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) - assert storage.has_memref is True assert storage.dtype == tp.float32 assert storage.shape == (2, 2) assert storage.device.kind == device @@ -42,16 +41,14 @@ def test_from_memref(self, device): def test_from_list(self): data = [[1.0, 2.0], [3.0, 4.0]] storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) - assert storage.has_memref is False assert storage.dtype == tp.float32 assert storage.shape == (2, 2) assert storage.device.kind == "gpu" def test_empty_list(self): data = [[]] - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data, dtype=tp.float16) - assert storage.has_memref is True - assert storage.dtype == tp.float16 + storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + assert storage.dtype == tp.float32 assert storage.shape == (1, 0) assert storage.device.kind == "gpu" diff --git a/tripy/tests/frontend/trace/test_trace.py b/tripy/tests/frontend/trace/test_trace.py index d1696fa1d..700aa84ec 100644 --- a/tripy/tests/frontend/trace/test_trace.py +++ b/tripy/tests/frontend/trace/test_trace.py @@ -95,8 +95,8 @@ def test_str(self): str(trace) == dedent( """ - a = storage(data=[0], shape=(1,), dtype=int32, device=gpu:0) - b = storage(data=[1], shape=(1,), dtype=int32, device=gpu:0) + a = storage(shape=(1,), dtype=int32, device=gpu:0) + b = storage(shape=(1,), dtype=int32, device=gpu:0) c = a + b outputs: c: [shape=([-1]), dtype=(int32), loc=(gpu:0)] diff --git a/tripy/tripy/common/utils.py b/tripy/tripy/common/utils.py index af1d70c28..db296b1db 100644 --- a/tripy/tripy/common/utils.py +++ b/tripy/tripy/common/utils.py @@ -68,6 +68,5 @@ def convert_list_to_array(values: List[Any], dtype: str) -> bytes: return array.array(TYPE_TO_FORMAT[dtype], values) - def is_empty(data: Sequence) -> bool: return isinstance(data, Sequence) and all(map(is_empty, data)) diff --git a/tripy/tripy/frontend/tensor.py b/tripy/tripy/frontend/tensor.py index ec25f1d09..88d0b488d 100644 --- a/tripy/tripy/frontend/tensor.py +++ b/tripy/tripy/frontend/tensor.py @@ -150,7 +150,7 @@ def raw_init( data = memref.create_memref_view(data) Storage.build_internal([], [instance.trace_tensor], data) else: - Storage.build_internal([], [instance.trace_tensor], data, dtype, device) + Storage.build_internal([], [instance.trace_tensor], data, device) # TODO(#155): Remove this hack: instance.trace_tensor.device = utils.default(device, instance.trace_tensor.device) @@ -201,7 +201,7 @@ def device(self): return self.trace_tensor.device def eval(self) -> runtime.MemRefValue: - if isinstance(self.trace_tensor.producer, Storage) and self.trace_tensor.producer.has_memref: + if isinstance(self.trace_tensor.producer, Storage): # Exit early if the tensor has already been evaluated. # This happens before the imports below so we don't incur extra overhead. return self.trace_tensor.producer.data diff --git a/tripy/tripy/frontend/trace/ops/storage.py b/tripy/tripy/frontend/trace/ops/storage.py index 2caa321e6..9823e89d8 100644 --- a/tripy/tripy/frontend/trace/ops/storage.py +++ b/tripy/tripy/frontend/trace/ops/storage.py @@ -34,7 +34,7 @@ @dataclass(repr=False) class Storage(BaseTraceOp): - data: Union[runtime.MemRefValue, Sequence[numbers.Number]] + data: runtime.MemRefValue shape: Sequence[int] dtype: type device: tp_device @@ -44,7 +44,6 @@ def __init__( inputs: List["Tensor"], outputs: List["Tensor"], data: Union[runtime.MemRefValue, Sequence[numbers.Number]], - dtype: datatype = None, device: tp_device = None, ) -> None: super().__init__(inputs, outputs) @@ -56,30 +55,26 @@ def __init__( self.device = tp_device.create_directly( "gpu" if data.address_space == runtime.PointerType.device else "cpu", 0 ) - self.has_memref = True - elif common_utils.is_empty(data): - # special case: empty tensor - self.dtype = utils.default(dtype, datatype.float32) - self.shape = tuple(utils.get_shape(data)) - self.data = memref.create_memref(shape=self.shape, dtype=self.dtype) - self.device = utils.default(device, tp_device.create_directly("gpu", 0)) - self.has_memref = True else: - # If the input was a sequence, we need to copy it so that we don't take changes made - # to the list after the Storage op was constructed. - self.data = copy.copy(data) - self.dtype = dtype if dtype else common_utils.get_element_type(data) + if common_utils.is_empty(data): + self.dtype = datatype.float32 + data_array = None + else: + self.dtype = common_utils.get_element_type(data) + data_array = common_utils.convert_list_to_array(utils.flatten_list(data), dtype=self.dtype) self.shape = tuple(utils.get_shape(data)) + self.data = memref.create_memref( + shape=self.shape, + dtype=self.dtype, + array=data_array, + ) self.device = utils.default(device, tp_device.create_directly("gpu", 0)) - self.has_memref = False self.outputs[0].shape = list(self.shape) def str_skip_fields(self) -> Set[str]: - # skip data if i) it is a MemRefValue or ii) its volume exceeds threshold - if not isinstance(self.data, Sequence) or utils.should_omit_constant_in_str(self.shape): - return {"data"} - return set() + # skip data since it is always a memref value + return {"data"} def __eq__(self, other) -> bool: return self.data == other.data if isinstance(other, Storage) else False