Skip to content

Commit

Permalink
Always construct memref value in storage op (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhuoz004 authored Dec 11, 2024
1 parent 498eb8a commit e2e0e2b
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 29 deletions.
7 changes: 2 additions & 5 deletions tripy/tests/frontend/trace/ops/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,21 @@ def test_from_memref(self, device):
module = np if device == "cpu" else cp
data = memref.create_memref_view(module.ones((2, 2), dtype=module.float32))
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data)
assert storage.has_memref is True
assert storage.dtype == tp.float32
assert storage.shape == (2, 2)
assert storage.device.kind == device

def test_from_list(self):
data = [[1.0, 2.0], [3.0, 4.0]]
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data)
assert storage.has_memref is False
assert storage.dtype == tp.float32
assert storage.shape == (2, 2)
assert storage.device.kind == "gpu"

def test_empty_list(self):
data = [[]]
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data, dtype=tp.float16)
assert storage.has_memref is True
assert storage.dtype == tp.float16
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data)
assert storage.dtype == tp.float32
assert storage.shape == (1, 0)
assert storage.device.kind == "gpu"

Expand Down
4 changes: 2 additions & 2 deletions tripy/tests/frontend/trace/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def test_str(self):
str(trace)
== dedent(
"""
a = storage(data=[0], shape=(1,), dtype=int32, device=gpu:0)
b = storage(data=[1], shape=(1,), dtype=int32, device=gpu:0)
a = storage(shape=(1,), dtype=int32, device=gpu:0)
b = storage(shape=(1,), dtype=int32, device=gpu:0)
c = a + b
outputs:
c: [shape=([-1]), dtype=(int32), loc=(gpu:0)]
Expand Down
1 change: 0 additions & 1 deletion tripy/tripy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,5 @@ def convert_list_to_array(values: List[Any], dtype: str) -> bytes:

return array.array(TYPE_TO_FORMAT[dtype], values)


def is_empty(data: Sequence) -> bool:
return isinstance(data, Sequence) and all(map(is_empty, data))
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def raw_init(
data = memref.create_memref_view(data)
Storage.build_internal([], [instance.trace_tensor], data)
else:
Storage.build_internal([], [instance.trace_tensor], data, dtype, device)
Storage.build_internal([], [instance.trace_tensor], data, device)
# TODO(#155): Remove this hack:
instance.trace_tensor.device = utils.default(device, instance.trace_tensor.device)

Expand Down Expand Up @@ -201,7 +201,7 @@ def device(self):
return self.trace_tensor.device

def eval(self) -> runtime.MemRefValue:
if isinstance(self.trace_tensor.producer, Storage) and self.trace_tensor.producer.has_memref:
if isinstance(self.trace_tensor.producer, Storage):
# Exit early if the tensor has already been evaluated.
# This happens before the imports below so we don't incur extra overhead.
return self.trace_tensor.producer.data
Expand Down
33 changes: 14 additions & 19 deletions tripy/tripy/frontend/trace/ops/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@dataclass(repr=False)
class Storage(BaseTraceOp):

data: Union[runtime.MemRefValue, Sequence[numbers.Number]]
data: runtime.MemRefValue
shape: Sequence[int]
dtype: type
device: tp_device
Expand All @@ -44,7 +44,6 @@ def __init__(
inputs: List["Tensor"],
outputs: List["Tensor"],
data: Union[runtime.MemRefValue, Sequence[numbers.Number]],
dtype: datatype = None,
device: tp_device = None,
) -> None:
super().__init__(inputs, outputs)
Expand All @@ -56,30 +55,26 @@ def __init__(
self.device = tp_device.create_directly(
"gpu" if data.address_space == runtime.PointerType.device else "cpu", 0
)
self.has_memref = True
elif common_utils.is_empty(data):
# special case: empty tensor
self.dtype = utils.default(dtype, datatype.float32)
self.shape = tuple(utils.get_shape(data))
self.data = memref.create_memref(shape=self.shape, dtype=self.dtype)
self.device = utils.default(device, tp_device.create_directly("gpu", 0))
self.has_memref = True
else:
# If the input was a sequence, we need to copy it so that we don't take changes made
# to the list after the Storage op was constructed.
self.data = copy.copy(data)
self.dtype = dtype if dtype else common_utils.get_element_type(data)
if common_utils.is_empty(data):
self.dtype = datatype.float32
data_array = None
else:
self.dtype = common_utils.get_element_type(data)
data_array = common_utils.convert_list_to_array(utils.flatten_list(data), dtype=self.dtype)
self.shape = tuple(utils.get_shape(data))
self.data = memref.create_memref(
shape=self.shape,
dtype=self.dtype,
array=data_array,
)
self.device = utils.default(device, tp_device.create_directly("gpu", 0))
self.has_memref = False

self.outputs[0].shape = list(self.shape)

def str_skip_fields(self) -> Set[str]:
# skip data if i) it is a MemRefValue or ii) its volume exceeds threshold
if not isinstance(self.data, Sequence) or utils.should_omit_constant_in_str(self.shape):
return {"data"}
return set()
# skip data since it is always a memref value
return {"data"}

def __eq__(self, other) -> bool:
return self.data == other.data if isinstance(other, Storage) else False
Expand Down

0 comments on commit e2e0e2b

Please sign in to comment.