Skip to content

Commit

Permalink
[torch-mlir][sparse] add ID-net example (llvm#3127)
Browse files Browse the repository at this point in the history
first sparse-in/sparse-out example, will be used
to make actual sparse output work!
  • Loading branch information
aartbik authored Apr 9, 2024
1 parent 8ff2852 commit 184d8c1
Showing 1 changed file with 44 additions and 5 deletions.
49 changes: 44 additions & 5 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ def sparse_jit(f, *args, **kwargs):
if a.layout is torch.sparse_coo:
# Construct the additional position array required by MLIR with data
# array([0, nnz]).
xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy())
xargs.append(torch.tensor([0, a._nnz()], dtype=a._indices().dtype).numpy())
# Transform a tensor<ndim x nnz> into [tensor<nnz> x ndim] to conform
# MLIR SoA COO representation.
for idx in a.indices():
for idx in a._indices():
xargs.append(idx.numpy())
xargs.append(a.values().numpy())
xargs.append(a._values().numpy())
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
xargs.append(a.crow_indices().numpy())
xargs.append(a.col_indices().numpy())
Expand All @@ -189,6 +189,46 @@ def run(f):
print()


@run
# CHECK-LABEL: test_sparse_id
# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20],f64,#[[$COO]]>) -> !torch.vtensor<[10,20],f64,#[[$COO]]> {
# CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]>
# CHECK: }
#
# CHECK: torch.sparse
# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9],
# CHECK: [ 0, 1, 10, 19]{{\]}}),
# CHECK: values=tensor([-1000., -1., 1., 1000.]),
# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo)
# CHECK: torch.mlir
#
def test_sparse_id():
class IdNet(torch.nn.Module):
def __init__(self):
super(IdNet, self).__init__()

def forward(self, x):
return x

net = IdNet()
idx = torch.tensor([[0, 1, 2, 9], [0, 1, 10, 19]])
val = torch.tensor([-1000.0, -1.0, 1.0, 1000.0], dtype=torch.float64)
sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20])
m = export_and_import(net, sparse_input)
print(m)

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
# TODO: make output work
res1 = net(sparse_input)
# res2 = sparse_jit(net, sparse_input)
print("torch.sparse")
print(res1)
print("torch.mlir")
# print(res2)


@run
# CHECK-LABEL: test_sparse_sum
# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
Expand Down Expand Up @@ -362,8 +402,7 @@ def forward(self, x):

# Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit.
#
# TODO: note several issues that need to be fixed
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
# TODO: propagate sparsity into elt-wise (instead of dense result)
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
res3 = sparse_jit(net, batch_input)
Expand Down

0 comments on commit 184d8c1

Please sign in to comment.