-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang] Improve designate/elemental indices match in opt-bufferization. #121371
base: main
Are you sure you want to change the base?
Conversation
This pattern appears in `tonto`: `rys1%w = rys1%w * ...`, where component `w` is a pointer. Due to the computations transforming the elemental's one-based indices to the array indices, the indices match check did not pass in opt-bufferization. This patch recognizes this indices adjusting pattern, and returns the one-based indices for the designator.
@llvm/pr-subscribers-flang-fir-hlfir Author: Slava Zakharin (vzakhari) ChangesThis pattern appears in Full diff: https://github.com/llvm/llvm-project/pull/121371.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index bf3cf861e46f4a..bfaabed0136785 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -87,6 +87,13 @@ class ElementalAssignBufferization
/// determines if the transformation can be applied to this elemental
static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
+ /// Returns the array indices for the given hlfir.designate.
+ /// It recognizes the computations used to transform the one-based indices
+ /// into the array's lb-based indices, and returns the one-based indices
+ /// in these cases.
+ static llvm::SmallVector<mlir::Value>
+ getDesignatorIndices(hlfir::DesignateOp designate);
+
public:
using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
@@ -430,6 +437,73 @@ bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
return false;
}
+llvm::SmallVector<mlir::Value>
+ElementalAssignBufferization::getDesignatorIndices(
+ hlfir::DesignateOp designate) {
+ mlir::Value memref = designate.getMemref();
+
+ // If the object is a box, then the indices may be adjusted
+ // according to the box's lower bound(s). Scan through
+ // the computations to try to find the one-based indices.
+ if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
+ // Look for the following pattern:
+ // %13 = fir.load %12 : !fir.ref<!fir.box<...>
+ // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
+ // %17 = arith.subi %14#0, %c1 : index
+ // %18 = arith.addi %arg2, %17 : index
+ // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ...
+ //
+ // %arg2 is a one-based index.
+
+ auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
+ // Return true, if v and dim are such that:
+ // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
+ // %17 = arith.subi %14#0, %c1 : index
+ // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ...
+ if (auto subOp =
+ mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
+ auto cst = fir::getIntIfConstant(subOp.getRhs());
+ if (!cst || *cst != 1)
+ return false;
+ if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
+ subOp.getLhs().getDefiningOp())) {
+ if (memref != dimsOp.getVal() ||
+ dimsOp.getResult(0) != subOp.getLhs())
+ return false;
+ auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
+ return dimsOpDim && dimsOpDim == dim;
+ }
+ }
+ return false;
+ };
+
+ llvm::SmallVector<mlir::Value> newIndices;
+ for (auto index : llvm::enumerate(designate.getIndices())) {
+ if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
+ index.value().getDefiningOp())) {
+ for (unsigned opNum = 0; opNum < 2; ++opNum)
+ if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
+ newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
+ break;
+ }
+
+ // If new one-based index was not added, exit early.
+ if (newIndices.size() <= index.index())
+ break;
+ }
+ }
+
+ // If any of the indices is not adjusted to the array's lb,
+ // then return the original designator indices.
+ if (newIndices.size() != designate.getIndices().size())
+ return designate.getIndices();
+
+ return newIndices;
+ }
+
+ return designate.getIndices();
+}
+
std::optional<ElementalAssignBufferization::MatchInfo>
ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
mlir::Operation::user_range users = elemental->getUsers();
@@ -557,7 +631,7 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
<< " at " << elemental.getLoc() << "\n");
return std::nullopt;
}
- auto indices = designate.getIndices();
+ auto indices = getDesignatorIndices(designate);
auto elementalIndices = elemental.getIndices();
if (indices.size() == elementalIndices.size() &&
std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
diff --git a/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir b/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
new file mode 100644
index 00000000000000..ae91930d44eb12
--- /dev/null
+++ b/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
@@ -0,0 +1,69 @@
+// RUN: fir-opt --opt-bufferization %s | FileCheck %s
+
+// Verify that the hlfir.assign of hlfir.elemental is optimized
+// into element-per-element assignment:
+// subroutine test1(p)
+// real, pointer :: p(:)
+// p = p + 1.0
+// end subroutine test1
+
+func.func @_QPtest1(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "p"}) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = fir.dummy_scope : !fir.dscope
+ %1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest1Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+ %2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> (index, index, index)
+ %4 = fir.shape %3#1 : (index) -> !fir.shape<1>
+ %5 = hlfir.elemental %4 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+ ^bb0(%arg1: index):
+ %6 = arith.subi %3#0, %c1 : index
+ %7 = arith.addi %arg1, %6 : index
+ %8 = hlfir.designate %2 (%7) : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
+ %9 = fir.load %8 : !fir.ref<f32>
+ %10 = arith.addf %9, %cst fastmath<contract> : f32
+ hlfir.yield_element %10 : f32
+ }
+ hlfir.assign %5 to %2 : !hlfir.expr<?xf32>, !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ hlfir.destroy %5 : !hlfir.expr<?xf32>
+ return
+}
+// CHECK-LABEL: func.func @_QPtest1(
+// CHECK-NOT: hlfir.assign
+// CHECK: hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
+// CHECK-NOT: hlfir.assign
+
+// subroutine test2(p)
+// real, pointer :: p(:,:)
+// p = p + 1.0
+// end subroutine test2
+func.func @_QPtest2(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> {fir.bindc_name = "p"}) {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = fir.dummy_scope : !fir.dscope
+ %1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest2Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>)
+ %2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
+ %3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
+ %4:3 = fir.box_dims %2, %c1 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
+ %5 = fir.shape %3#1, %4#1 : (index, index) -> !fir.shape<2>
+ %6 = hlfir.elemental %5 unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+ ^bb0(%arg1: index, %arg2: index):
+ %7 = arith.subi %3#0, %c1 : index
+ %8 = arith.addi %arg1, %7 : index
+ %9 = arith.subi %4#0, %c1 : index
+ %10 = arith.addi %arg2, %9 : index
+ %11 = hlfir.designate %2 (%8, %10) : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index, index) -> !fir.ref<f32>
+ %12 = fir.load %11 : !fir.ref<f32>
+ %13 = arith.addf %12, %cst fastmath<contract> : f32
+ hlfir.yield_element %13 : f32
+ }
+ hlfir.assign %6 to %2 : !hlfir.expr<?x?xf32>, !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
+ hlfir.destroy %6 : !hlfir.expr<?x?xf32>
+ return
+}
+// CHECK-LABEL: func.func @_QPtest2(
+// CHECK-NOT: hlfir.assign
+// CHECK: hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
+// CHECK-NOT: hlfir.assign
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This pattern appears in
tonto
:rys1%w = rys1%w * ...
, wherecomponent
w
is a pointer. Due to the computations transformingthe elemental's one-based indices to the array indices,
the indices match check did not pass in opt-bufferization.
This patch recognizes this indices adjusting pattern,
and returns the one-based indices for the designator.