diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index 2d060f3c2da64..f4694a30a8a12 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -245,7 +245,7 @@ def MemRefEraseDeadAllocAndStoresOp ]> { let description = [{ This applies memory optimization on memref. In particular it does store to - load forwarding, dead store elimination and dead alloc elimination. + load forwarding, dead store elimination and dead alloc/alloca elimination. #### Return modes diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 3f9fb071e0ba8..8735b10255ae3 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -156,13 +156,15 @@ static bool resultIsNotRead(Operation *op, std::vector &uses) { void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) { std::vector opToErase; - parentOp->walk([&](memref::AllocOp op) { + parentOp->walk([&](Operation *op) { std::vector candidates; - if (resultIsNotRead(op, candidates)) { + if (isa(op) && + resultIsNotRead(op, candidates)) { llvm::append_range(opToErase, candidates); - opToErase.push_back(op.getOperation()); + opToErase.push_back(op); } }); + for (Operation *op : opToErase) rewriter.eraseOp(op); } diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index acab37e482cfe..3b37c62fcb28e 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -327,6 +327,30 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +// CHECK-LABEL: func.func @dead_alloca +func.func @dead_alloca() { + // CHECK-NOT: %{{.+}} = memref.alloca + %0 = memref.alloca() : memref<8x64xf32, 3> + %1 = memref.subview %0[0, 0] [8, 4] [1, 1] : memref<8x64xf32, 3> to + memref<8x4xf32, affine_map<(d0, d1) -> (d0 * 64 + d1)>, 3> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<0.000000e+00> : vector<1x4xf32> + vector.transfer_write %cst_0, %1[%c0, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, memref<8x4xf32, affine_map<(d0, d1) -> (d0 * 64 + d1)>, 3> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () + transform.yield + } +} + + // ----- // CHECK-LABEL: @store_to_load