Skip to content

[mlir][vector] Add support for vector extract/insert_strided_slice in vector distribution. #145421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

charithaintc
Copy link
Contributor

This PR adds initial support for vector.extract_strided_slice and vector.insert_strided_slice ops in vector distribution.

Initial support assumes that sinking both these ops do not require any cross lane comm. This requires,

  • Only a single dim is distributed.

For extract_strided_slice

  • extracted source vector must have rank greater than 1.
  • extraction happens in non distributed dims. (distributed dim is fully extracted).

For insert_strided_slice

  • Both inserted value (kD) and dest (nD) vectors are distributed (no broadcasting is involved). this require that distributed dimension is contained with the last k dims of the dest vector.

(Check code comments for more details)

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

Changes

This PR adds initial support for vector.extract_strided_slice and vector.insert_strided_slice ops in vector distribution.

Initial support assumes that sinking both these ops do not require any cross lane comm. This requires,

  • Only a single dim is distributed.

For extract_strided_slice

  • extracted source vector must have rank greater than 1.
  • extraction happens in non distributed dims. (distributed dim is fully extracted).

For insert_strided_slice

  • Both inserted value (kD) and dest (nD) vectors are distributed (no broadcasting is involved). this require that distributed dimension is contained with the last k dims of the dest vector.

(Check code comments for more details)


Full diff: https://github.com/llvm/llvm-project/pull/145421.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+210-10)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+80)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..297bb40cbb334 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,9 +15,12 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <utility>
 
@@ -52,6 +55,21 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
   return map;
 }
 
+static int getDistributedDim(VectorType origType, VectorType distributedType) {
+  assert(origType.getRank() == distributedType.getRank() &&
+         "sequential and distributed vector types must have the same rank");
+  int64_t distributedDim = -1;
+  for (int64_t i = 0; i < origType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != origType.getDimSize(i)) {
+      // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
+      // support distributing multiple dimensions in the future.
+      assert(distributedDim == -1 && "found multiple distributed dims");
+      distributedDim = i;
+    }
+  }
+  return distributedDim;
+}
+
 namespace {
 
 /// Helper struct to create the load / store operations that permit transit
@@ -1076,6 +1094,195 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   }
 };
 
+/// Sink out insert_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
+///     strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
+///   gpu.yield %insert : vector<8x16xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
+/// vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
+/// }
+/// %insert = vector.insert_strided_slice %0#0, %0#1,
+///   offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
+/// ```
+/// NOTE: Current support assume that both src and dest vectors are distributed
+/// to lanes and sinking the insert op does not require any cross lane
+/// communication.
+struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp =
+        operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          insertOp, "result vector type must be 2D or higher");
+    // Find the distributed dimension of the dest vector. There should be
+    // exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t destDistributedDim =
+        getDistributedDim(yieldedType, distributedType);
+    assert(destDistributedDim != -1 && "could not find distributed dimension");
+    (void)destDistributedDim;
+    VectorType srcType = insertOp.getSourceVectorType();
+    VectorType destType = insertOp.getDestVectorType();
+    // Currently we require that both source (kD) and dest (nD) vectors are
+    // distributed. This requires that distributedDim (d) is contained in the
+    // last k dims of the dest vector (d >= n - k).
+    // TODO: Add support for case where source vector is not distributed.
+    int64_t sourceDistributedDim =
+        destDistributedDim - (destType.getRank() - srcType.getRank());
+    if (sourceDistributedDim < 0)
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be in the last k dims");
+    // Distributed dimension must be fully inserted.
+    if (srcType.getDimSize(sourceDistributedDim) !=
+        destType.getDimSize(destDistributedDim))
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be fully inserted");
+    SmallVector<int64_t> newSourceDistShape(
+        insertOp.getSourceVectorType().getShape()),
+        newDestDistShape(insertOp.getDestVectorType().getShape());
+    newSourceDistShape[sourceDistributedDim] =
+        distributedType.getDimSize(destDistributedDim);
+    newDestDistShape[destDistributedDim] =
+        distributedType.getDimSize(destDistributedDim);
+    auto newSourceTy =
+        VectorType::get(newSourceDistShape, distributedType.getElementType());
+    auto newDestTy =
+        VectorType::get(newDestDistShape, distributedType.getElementType());
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+        {newSourceTy, newDestTy}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
+    auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
+    // Create a new insert strided slice op that inserts distributed source into
+    // distributed dest.
+    Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
+        insertOp.getLoc(), distributedDest.getType(), distributedSource,
+        distributedDest, insertOp.getOffsets(), insertOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
+    return success();
+  }
+};
+
+/// Sink out extract_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
+///   ...
+///   %src = ... : vector<32x16xf32>
+///   %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
+///     strides = [1] : vector<32x16xf32> to vector<16x16xf32>
+///   gpu.yield %extract : vector<16x16xf32>
+/// }
+/// ```
+/// To
+/// ````
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
+///   ...
+///   %src = ... : vector<32x16xf32>
+///   gpu.yield %src : vector<32x16xf32>
+/// }
+/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
+///   strides = [1] : vector<32x1xf32> to vector<16x1xf32>
+/// ```
+/// NOTE: Current support assumes that the extraction happens only on non
+/// distributed dimensions (does not require cross lane communication).
+struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto extractOp =
+        operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          extractOp, "result vector type must be 2D or higher");
+
+    // Find the distributed dimension. There should be exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
+    assert(distributedDim != -1 && "could not find distributed dimension");
+    (void)distributedDim;
+
+    // Distributed dimension must be fully extracted.
+    // TODO: Partial extraction from distributed dimension require cross lane
+    // communication.
+    if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
+      int64_t distributedDimOffset =
+          llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
+              .getInt();
+      int64_t distributedDimSize =
+          llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
+              .getInt();
+      if (distributedDimOffset != 0 ||
+          distributedDimSize != yieldedType.getDimSize(distributedDim))
+        return rewriter.notifyMatchFailure(
+            extractOp, "distributed dimension must be fully extracted");
+    }
+    SmallVector<int64_t> newDistributedShape(
+        extractOp.getSourceVectorType().getShape());
+    newDistributedShape[distributedDim] =
+        distributedType.getDimSize(distributedDim);
+    auto newDistributedType =
+        VectorType::get(newDistributedShape, distributedType.getElementType());
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
+        extractOp.getSizes(), [](Attribute attr) { return attr; });
+    // Update the distributed sizes to match the distributed type.
+    if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
+      distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+          distributedType.getDimSize(distributedDim));
+
+    // Create a new extract strided slice op that extracts from the
+    // distributed vector.
+    Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+    Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+        extractOp.getLoc(), distributedType, distributedVec,
+        extractOp.getOffsets(),
+        ArrayAttr::get(rewriter.getContext(), distributedSizes),
+        extractOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                newExtract);
+    return success();
+  }
+};
+
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtract : public WarpDistributionPattern {
@@ -1122,15 +1329,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
     auto distributedType =
         cast<VectorType>(warpOp.getResult(operandNumber).getType());
     auto yieldedType = cast<VectorType>(operand->get().getType());
-    int64_t distributedDim = -1;
-    for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
-      if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
-        // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
-        // support distributing multiple dimensions in the future.
-        assert(distributedDim == -1 && "found multiple distributed dims");
-        distributedDim = i;
-      }
-    }
+    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
     assert(distributedDim != -1 && "could not find distributed dimension");
     (void)distributedDim;
 
@@ -1764,7 +1963,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
   patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
                WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
                WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
-               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
+               WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
       patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..8c3060c91f0d1 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1296,6 +1296,86 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
   return %r : vector<4x96xf32>
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_outer(
+//  CHECK-RPOP-SAME: %[[LANEID:.*]]: index
+//       CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<64x1xf32>) {
+//       CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[VEC]] : vector<64x32xf32>
+//       CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
+//  CHECK-PROP-SAME: {offsets = [8], sizes = [24], strides = [1]} : vector<64x1xf32> to vector<24x1xf32>
+//       CHECK-PROP: return %[[EXTRACT]] : vector<24x1xf32>
+func.func @vector_extract_strided_slice_2d_distr_outer(%laneid: index) -> (vector<24x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<24x1xf32>) {
+    %0 = "some_def"() : () -> (vector<64x32xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [24], strides = [1]}
+      : vector<64x32xf32> to vector<24x32xf32>
+    gpu.yield %1 : vector<24x32xf32>
+  }
+  return %r : vector<24x1xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_inner(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index
+//       CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x64xf32>) {
+//       CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<32x64xf32>
+//       CHECK-PROP: gpu.yield %[[VEC]] : vector<32x64xf32>
+//       CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
+//  CHECK-PROP-SAME: {offsets = [0, 12], sizes = [1, 8], strides = [1, 1]} : vector<1x64xf32> to vector<1x8xf32>
+//       CHECK-PROP: return %[[EXTRACT]] : vector<1x8xf32>
+func.func @vector_extract_strided_slice_2d_distr_inner(%laneid: index) -> (vector<1x8xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x8xf32>) {
+    %0 = "some_def"() : () -> (vector<32x64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [0, 12], sizes = [32, 8], strides = [1, 1]}
+      : vector<32x64xf32> to vector<32x8xf32>
+    gpu.yield %1 : vector<32x8xf32>
+  }
+  return %r : vector<1x8xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_1d_to_2d(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index)
+//       CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}} -> (vector<1xf32>, vector<64x1xf32>) {
+//       CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<32xf32>
+//       CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<32xf32>, vector<64x32xf32>
+//       CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1
+//  CHECK-PROP-SAME: {offsets = [18, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
+//       CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
+func.func @vector_insert_strided_slice_1d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<32xf32>)
+    %1 = "some_def"() : () -> (vector<64x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [18, 0], strides = [1]}
+      : vector<32xf32> into vector<64x32xf32>
+    gpu.yield %2 : vector<64x32xf32>
+  }
+  return %r : vector<64x1xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_2d_to_2d(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index)
+//       CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<16x1xf32>, vector<64x1xf32>) {
+//       CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<16x32xf32>
+//       CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<16x32xf32>, vector<64x32xf32>
+//       CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 {offsets = [36, 0], strides = [1, 1]} :
+//  CHECK-PROP-SAME: vector<16x1xf32> into vector<64x1xf32>
+//       CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
+func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<16x32xf32>)
+    %1 = "some_def"() : () -> (vector<64x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [36, 0],  strides = [1, 1]}
+      : vector<16x32xf32> into vector<64x32xf32>
+    gpu.yield %2 : vector<64x32xf32>
+  }
+  return %r : vector<64x1xf32>
+}
+
 // -----
 
 // Make sure that all operands of the transfer_read op are properly propagated.

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-vector

Author: Charitha Saumya (charithaintc)

Changes

This PR adds initial support for vector.extract_strided_slice and vector.insert_strided_slice ops in vector distribution.

Initial support assumes that sinking both these ops do not require any cross lane comm. This requires,

  • Only a single dim is distributed.

For extract_strided_slice

  • extracted source vector must have rank greater than 1.
  • extraction happens in non distributed dims. (distributed dim is fully extracted).

For insert_strided_slice

  • Both inserted value (kD) and dest (nD) vectors are distributed (no broadcasting is involved). this require that distributed dimension is contained with the last k dims of the dest vector.

(Check code comments for more details)


Full diff: https://github.com/llvm/llvm-project/pull/145421.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+210-10)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+80)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 045c192787f10..297bb40cbb334 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,9 +15,12 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <utility>
 
@@ -52,6 +55,21 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
   return map;
 }
 
+static int getDistributedDim(VectorType origType, VectorType distributedType) {
+  assert(origType.getRank() == distributedType.getRank() &&
+         "sequential and distributed vector types must have the same rank");
+  int64_t distributedDim = -1;
+  for (int64_t i = 0; i < origType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != origType.getDimSize(i)) {
+      // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
+      // support distributing multiple dimensions in the future.
+      assert(distributedDim == -1 && "found multiple distributed dims");
+      distributedDim = i;
+    }
+  }
+  return distributedDim;
+}
+
 namespace {
 
 /// Helper struct to create the load / store operations that permit transit
@@ -1076,6 +1094,195 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   }
 };
 
+/// Sink out insert_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
+///     strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
+///   gpu.yield %insert : vector<8x16xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
+/// vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
+/// }
+/// %insert = vector.insert_strided_slice %0#0, %0#1,
+///   offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
+/// ```
+/// NOTE: Current support assume that both src and dest vectors are distributed
+/// to lanes and sinking the insert op does not require any cross lane
+/// communication.
+struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp =
+        operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          insertOp, "result vector type must be 2D or higher");
+    // Find the distributed dimension of the dest vector. There should be
+    // exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t destDistributedDim =
+        getDistributedDim(yieldedType, distributedType);
+    assert(destDistributedDim != -1 && "could not find distributed dimension");
+    (void)destDistributedDim;
+    VectorType srcType = insertOp.getSourceVectorType();
+    VectorType destType = insertOp.getDestVectorType();
+    // Currently we require that both source (kD) and dest (nD) vectors are
+    // distributed. This requires that distributedDim (d) is contained in the
+    // last k dims of the dest vector (d >= n - k).
+    // TODO: Add support for case where source vector is not distributed.
+    int64_t sourceDistributedDim =
+        destDistributedDim - (destType.getRank() - srcType.getRank());
+    if (sourceDistributedDim < 0)
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be in the last k dims");
+    // Distributed dimension must be fully inserted.
+    if (srcType.getDimSize(sourceDistributedDim) !=
+        destType.getDimSize(destDistributedDim))
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be fully inserted");
+    SmallVector<int64_t> newSourceDistShape(
+        insertOp.getSourceVectorType().getShape()),
+        newDestDistShape(insertOp.getDestVectorType().getShape());
+    newSourceDistShape[sourceDistributedDim] =
+        distributedType.getDimSize(destDistributedDim);
+    newDestDistShape[destDistributedDim] =
+        distributedType.getDimSize(destDistributedDim);
+    auto newSourceTy =
+        VectorType::get(newSourceDistShape, distributedType.getElementType());
+    auto newDestTy =
+        VectorType::get(newDestDistShape, distributedType.getElementType());
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+        {newSourceTy, newDestTy}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
+    auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
+    // Create a new insert strided slice op that inserts distributed source into
+    // distributed dest.
+    Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
+        insertOp.getLoc(), distributedDest.getType(), distributedSource,
+        distributedDest, insertOp.getOffsets(), insertOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
+    return success();
+  }
+};
+
+/// Sink out extract_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
+///   ...
+///   %src = ... : vector<32x16xf32>
+///   %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
+///     strides = [1] : vector<32x16xf32> to vector<16x16xf32>
+///   gpu.yield %extract : vector<16x16xf32>
+/// }
+/// ```
+/// To
+/// ````
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
+///   ...
+///   %src = ... : vector<32x16xf32>
+///   gpu.yield %src : vector<32x16xf32>
+/// }
+/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
+///   strides = [1] : vector<32x1xf32> to vector<16x1xf32>
+/// ```
+/// NOTE: Current support assumes that the extraction happens only on non
+/// distributed dimensions (does not require cross lane communication).
+struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto extractOp =
+        operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          extractOp, "result vector type must be 2D or higher");
+
+    // Find the distributed dimension. There should be exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
+    assert(distributedDim != -1 && "could not find distributed dimension");
+    (void)distributedDim;
+
+    // Distributed dimension must be fully extracted.
+    // TODO: Partial extraction from distributed dimension require cross lane
+    // communication.
+    if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
+      int64_t distributedDimOffset =
+          llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
+              .getInt();
+      int64_t distributedDimSize =
+          llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
+              .getInt();
+      if (distributedDimOffset != 0 ||
+          distributedDimSize != yieldedType.getDimSize(distributedDim))
+        return rewriter.notifyMatchFailure(
+            extractOp, "distributed dimension must be fully extracted");
+    }
+    SmallVector<int64_t> newDistributedShape(
+        extractOp.getSourceVectorType().getShape());
+    newDistributedShape[distributedDim] =
+        distributedType.getDimSize(distributedDim);
+    auto newDistributedType =
+        VectorType::get(newDistributedShape, distributedType.getElementType());
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
+        extractOp.getSizes(), [](Attribute attr) { return attr; });
+    // Update the distributed sizes to match the distributed type.
+    if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
+      distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+          distributedType.getDimSize(distributedDim));
+
+    // Create a new extract strided slice op that extracts from the
+    // distributed vector.
+    Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+    Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
+        extractOp.getLoc(), distributedType, distributedVec,
+        extractOp.getOffsets(),
+        ArrayAttr::get(rewriter.getContext(), distributedSizes),
+        extractOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
+                                newExtract);
+    return success();
+  }
+};
+
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtract : public WarpDistributionPattern {
@@ -1122,15 +1329,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
     auto distributedType =
         cast<VectorType>(warpOp.getResult(operandNumber).getType());
     auto yieldedType = cast<VectorType>(operand->get().getType());
-    int64_t distributedDim = -1;
-    for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
-      if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
-        // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
-        // support distributing multiple dimensions in the future.
-        assert(distributedDim == -1 && "found multiple distributed dims");
-        distributedDim = i;
-      }
-    }
+    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
     assert(distributedDim != -1 && "could not find distributed dimension");
     (void)distributedDim;
 
@@ -1764,7 +1963,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
   patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
                WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
                WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
-               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
+               WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
       patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 38771f2593449..8c3060c91f0d1 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1296,6 +1296,86 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
   return %r : vector<4x96xf32>
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_outer(
+//  CHECK-RPOP-SAME: %[[LANEID:.*]]: index
+//       CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<64x1xf32>) {
+//       CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[VEC]] : vector<64x32xf32>
+//       CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
+//  CHECK-PROP-SAME: {offsets = [8], sizes = [24], strides = [1]} : vector<64x1xf32> to vector<24x1xf32>
+//       CHECK-PROP: return %[[EXTRACT]] : vector<24x1xf32>
+func.func @vector_extract_strided_slice_2d_distr_outer(%laneid: index) -> (vector<24x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<24x1xf32>) {
+    %0 = "some_def"() : () -> (vector<64x32xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [8], sizes = [24], strides = [1]}
+      : vector<64x32xf32> to vector<24x32xf32>
+    gpu.yield %1 : vector<24x32xf32>
+  }
+  return %r : vector<24x1xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_extract_strided_slice_2d_distr_inner(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index
+//       CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x64xf32>) {
+//       CHECK-PROP: %[[VEC:.*]] = "some_def"() : () -> vector<32x64xf32>
+//       CHECK-PROP: gpu.yield %[[VEC]] : vector<32x64xf32>
+//       CHECK-PROP: %[[EXTRACT:.*]] = vector.extract_strided_slice %[[W]]
+//  CHECK-PROP-SAME: {offsets = [0, 12], sizes = [1, 8], strides = [1, 1]} : vector<1x64xf32> to vector<1x8xf32>
+//       CHECK-PROP: return %[[EXTRACT]] : vector<1x8xf32>
+func.func @vector_extract_strided_slice_2d_distr_inner(%laneid: index) -> (vector<1x8xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x8xf32>) {
+    %0 = "some_def"() : () -> (vector<32x64xf32>)
+    %1 = vector.extract_strided_slice %0 { offsets = [0, 12], sizes = [32, 8], strides = [1, 1]}
+      : vector<32x64xf32> to vector<32x8xf32>
+    gpu.yield %1 : vector<32x8xf32>
+  }
+  return %r : vector<1x8xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_1d_to_2d(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index)
+//       CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}} -> (vector<1xf32>, vector<64x1xf32>) {
+//       CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<32xf32>
+//       CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<32xf32>, vector<64x32xf32>
+//       CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1
+//  CHECK-PROP-SAME: {offsets = [18, 0], strides = [1]} : vector<1xf32> into vector<64x1xf32>
+//       CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
+func.func @vector_insert_strided_slice_1d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<32xf32>)
+    %1 = "some_def"() : () -> (vector<64x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [18, 0], strides = [1]}
+      : vector<32xf32> into vector<64x32xf32>
+    gpu.yield %2 : vector<64x32xf32>
+  }
+  return %r : vector<64x1xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_insert_strided_slice_2d_to_2d(
+//  CHECK-PROP-SAME: %[[LANEID:.*]]: index)
+//       CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<16x1xf32>, vector<64x1xf32>) {
+//       CHECK-PROP: %[[SRC:.*]] = "some_def"() : () -> vector<16x32xf32>
+//       CHECK-PROP: %[[DEST:.*]] = "some_def"() : () -> vector<64x32xf32>
+//       CHECK-PROP: gpu.yield %[[SRC]], %[[DEST]] : vector<16x32xf32>, vector<64x32xf32>
+//       CHECK-PROP: %[[INSERT:.*]] = vector.insert_strided_slice %[[W]]#0, %[[W]]#1 {offsets = [36, 0], strides = [1, 1]} :
+//  CHECK-PROP-SAME: vector<16x1xf32> into vector<64x1xf32>
+//       CHECK-PROP: return %[[INSERT]] : vector<64x1xf32>
+func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1xf32>) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<64x1xf32>) {
+    %0 = "some_def"() : () -> (vector<16x32xf32>)
+    %1 = "some_def"() : () -> (vector<64x32xf32>)
+    %2 = vector.insert_strided_slice %0, %1 { offsets = [36, 0],  strides = [1, 1]}
+      : vector<16x32xf32> into vector<64x32xf32>
+    gpu.yield %2 : vector<64x32xf32>
+  }
+  return %r : vector<64x1xf32>
+}
+
 // -----
 
 // Make sure that all operands of the transfer_read op are properly propagated.

@charithaintc
Copy link
Contributor Author

@Garra1980 please have a look.

@@ -52,6 +55,21 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
return map;
}

static int getDistributedDim(VectorType origType, VectorType distributedType) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add short description here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants