Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Improve designate/elemental indices match in opt-bufferization. #121371

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vzakhari
Copy link
Contributor

This pattern appears in tonto: rys1%w = rys1%w * ..., where
component w is a pointer. Due to the computations transforming
the elemental's one-based indices to the array indices,
the indices match check did not pass in opt-bufferization.
This patch recognizes this indices adjusting pattern,
and returns the one-based indices for the designator.

This pattern appears in `tonto`: `rys1%w = rys1%w * ...`, where
component `w` is a pointer. Due to the computations transforming
the elemental's one-based indices to the array indices,
the indices match check did not pass in opt-bufferization.
This patch recognizes this indices adjusting pattern,
and returns the one-based indices for the designator.
@vzakhari vzakhari requested review from tblah and jeanPerier December 31, 2024 05:09
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 31, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 31, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

This pattern appears in tonto: rys1%w = rys1%w * ..., where
component w is a pointer. Due to the computations transforming
the elemental's one-based indices to the array indices,
the indices match check did not pass in opt-bufferization.
This patch recognizes this indices adjusting pattern,
and returns the one-based indices for the designator.


Full diff: https://github.com/llvm/llvm-project/pull/121371.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+75-1)
  • (added) flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir (+69)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index bf3cf861e46f4a..bfaabed0136785 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -87,6 +87,13 @@ class ElementalAssignBufferization
   /// determines if the transformation can be applied to this elemental
   static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental);
 
+  /// Returns the array indices for the given hlfir.designate.
+  /// It recognizes the computations used to transform the one-based indices
+  /// into the array's lb-based indices, and returns the one-based indices
+  /// in these cases.
+  static llvm::SmallVector<mlir::Value>
+  getDesignatorIndices(hlfir::DesignateOp designate);
+
 public:
   using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
 
@@ -430,6 +437,73 @@ bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) {
   return false;
 }
 
+llvm::SmallVector<mlir::Value>
+ElementalAssignBufferization::getDesignatorIndices(
+    hlfir::DesignateOp designate) {
+  mlir::Value memref = designate.getMemref();
+
+  // If the object is a box, then the indices may be adjusted
+  // according to the box's lower bound(s). Scan through
+  // the computations to try to find the one-based indices.
+  if (mlir::isa<fir::BaseBoxType>(memref.getType())) {
+    // Look for the following pattern:
+    //   %13 = fir.load %12 : !fir.ref<!fir.box<...>
+    //   %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ...
+    //   %17 = arith.subi %14#0, %c1 : index
+    //   %18 = arith.addi %arg2, %17 : index
+    //   %19 = hlfir.designate %13 (%18)  : (!fir.box<...>, index) -> ...
+    //
+    // %arg2 is a one-based index.
+
+    auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) {
+      // Return true, if v and dim are such that:
+      //   %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ...
+      //   %17 = arith.subi %14#0, %c1 : index
+      //   %19 = hlfir.designate %13 (...)  : (!fir.box<...>, index) -> ...
+      if (auto subOp =
+              mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) {
+        auto cst = fir::getIntIfConstant(subOp.getRhs());
+        if (!cst || *cst != 1)
+          return false;
+        if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>(
+                subOp.getLhs().getDefiningOp())) {
+          if (memref != dimsOp.getVal() ||
+              dimsOp.getResult(0) != subOp.getLhs())
+            return false;
+          auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim());
+          return dimsOpDim && dimsOpDim == dim;
+        }
+      }
+      return false;
+    };
+
+    llvm::SmallVector<mlir::Value> newIndices;
+    for (auto index : llvm::enumerate(designate.getIndices())) {
+      if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>(
+              index.value().getDefiningOp())) {
+        for (unsigned opNum = 0; opNum < 2; ++opNum)
+          if (isNormalizedLb(addOp->getOperand(opNum), index.index())) {
+            newIndices.push_back(addOp->getOperand((opNum + 1) % 2));
+            break;
+          }
+
+        // If new one-based index was not added, exit early.
+        if (newIndices.size() <= index.index())
+          break;
+      }
+    }
+
+    // If any of the indices is not adjusted to the array's lb,
+    // then return the original designator indices.
+    if (newIndices.size() != designate.getIndices().size())
+      return designate.getIndices();
+
+    return newIndices;
+  }
+
+  return designate.getIndices();
+}
+
 std::optional<ElementalAssignBufferization::MatchInfo>
 ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
   mlir::Operation::user_range users = elemental->getUsers();
@@ -557,7 +631,7 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
                                   << " at " << elemental.getLoc() << "\n");
           return std::nullopt;
         }
-        auto indices = designate.getIndices();
+        auto indices = getDesignatorIndices(designate);
         auto elementalIndices = elemental.getIndices();
         if (indices.size() == elementalIndices.size() &&
             std::equal(indices.begin(), indices.end(), elementalIndices.begin(),
diff --git a/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir b/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
new file mode 100644
index 00000000000000..ae91930d44eb12
--- /dev/null
+++ b/flang/test/HLFIR/opt-bufferization-same-ptr-elemental.fir
@@ -0,0 +1,69 @@
+// RUN: fir-opt --opt-bufferization %s | FileCheck %s
+
+// Verify that the hlfir.assign of hlfir.elemental is optimized
+// into element-per-element assignment:
+// subroutine test1(p)
+//   real, pointer :: p(:)
+//   p = p + 1.0
+// end subroutine test1
+
+func.func @_QPtest1(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "p"}) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest1Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+  %2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+  %3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> (index, index, index)
+  %4 = fir.shape %3#1 : (index) -> !fir.shape<1>
+  %5 = hlfir.elemental %4 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+  ^bb0(%arg1: index):
+    %6 = arith.subi %3#0, %c1 : index
+    %7 = arith.addi %arg1, %6 : index
+    %8 = hlfir.designate %2 (%7)  : (!fir.box<!fir.ptr<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
+    %9 = fir.load %8 : !fir.ref<f32>
+    %10 = arith.addf %9, %cst fastmath<contract> : f32
+    hlfir.yield_element %10 : f32
+  }
+  hlfir.assign %5 to %2 : !hlfir.expr<?xf32>, !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  hlfir.destroy %5 : !hlfir.expr<?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @_QPtest1(
+// CHECK-NOT:         hlfir.assign
+// CHECK:             hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
+// CHECK-NOT:         hlfir.assign
+
+// subroutine test2(p)
+//   real, pointer :: p(:,:)
+//   p = p + 1.0
+// end subroutine test2
+func.func @_QPtest2(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>> {fir.bindc_name = "p"}) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg0 dummy_scope %0 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFtest2Ep"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>)
+  %2 = fir.load %1#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?x?xf32>>>>
+  %3:3 = fir.box_dims %2, %c0 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
+  %4:3 = fir.box_dims %2, %c1 : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index) -> (index, index, index)
+  %5 = fir.shape %3#1, %4#1 : (index, index) -> !fir.shape<2>
+  %6 = hlfir.elemental %5 unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
+  ^bb0(%arg1: index, %arg2: index):
+    %7 = arith.subi %3#0, %c1 : index
+    %8 = arith.addi %arg1, %7 : index
+    %9 = arith.subi %4#0, %c1 : index
+    %10 = arith.addi %arg2, %9 : index
+    %11 = hlfir.designate %2 (%8, %10)  : (!fir.box<!fir.ptr<!fir.array<?x?xf32>>>, index, index) -> !fir.ref<f32>
+    %12 = fir.load %11 : !fir.ref<f32>
+    %13 = arith.addf %12, %cst fastmath<contract> : f32
+    hlfir.yield_element %13 : f32
+  }
+  hlfir.assign %6 to %2 : !hlfir.expr<?x?xf32>, !fir.box<!fir.ptr<!fir.array<?x?xf32>>>
+  hlfir.destroy %6 : !hlfir.expr<?x?xf32>
+  return
+}
+// CHECK-LABEL:   func.func @_QPtest2(
+// CHECK-NOT:         hlfir.assign
+// CHECK:             hlfir.assign %{{.*}} to %{{.*}} : f32, !fir.ref<f32>
+// CHECK-NOT:         hlfir.assign

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants