Skip to content

[mlir][linalg] Constrain the parameters m, r in Winograd ops #144657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Hsiangkai
Copy link
Contributor

We only support fixed set of minimum filtering algorithm for Winograd Conv2D decomposition. Instead of letting users specify any integer, define a fixed set of enumeration values for the parameters of minimum filtering algorithm.

We only support fixed set of minimum filtering algorithm for Winograd
Conv2D decomposition. Instead of letting users specify any integer,
define a fixed set of enumeration values for the parameters of minimum
filtering algorithm.
@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Hsiangkai Wang (Hsiangkai)

Changes

We only support fixed set of minimum filtering algorithm for Winograd Conv2D decomposition. Instead of letting users specify any integer, define a fixed set of enumeration values for the parameters of minimum filtering algorithm.


Patch is 86.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144657.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+7)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+18)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+6-12)
  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5-4)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+44-13)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+64-79)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+15-15)
  • (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+12-12)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+12-12)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-winograd.mlir (+18-18)
  • (modified) mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+21-21)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+3-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..69e09f6b32c2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -100,6 +100,13 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
 
 #include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
 
+namespace mlir {
+namespace linalg {
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
+} // namespace linalg
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // Linalg Attributes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index ce68afe471fe8..8c98c0b8b8683 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -122,4 +122,22 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
   let cppNamespace = "::mlir::linalg";
 }
 
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr",
+    "Winograd Conv2D F(m, r)",
+    [
+      I32EnumAttrCase<"F_2_3", 0>,
+      I32EnumAttrCase<"F_4_3", 1>,
+      I32EnumAttrCase<"F_2_5", 2>,
+      I32EnumAttrCase<"Unknown", -1>,
+    ]>{
+  let cppNamespace = "mlir::linalg";
+}
+
 #endif // LINALG_ENUMS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1b48bf5fcb237..7ff44c2e1d2ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -183,15 +183,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
                        TensorRankOf<[AnyType], [4]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $filter `:` type($filter) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
@@ -254,15 +252,13 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
                        TensorRankOf<[AnyType], [6]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [6]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $input `:` type($input) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
@@ -343,15 +339,13 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
                        TensorRankOf<[AnyType], [4]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $value `:` type($value) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..1b035ec2a457e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -9,6 +9,7 @@
 #ifndef LINALG_TRANSFORM_OPS
 #define LINALG_TRANSFORM_OPS
 
+include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
 include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -2802,8 +2803,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$m,
-                       I64Attr:$r);
+                       WinogradConv2DFmr:$fmr);
   let results = (outs TransformHandleTypeInterface:$transformed);
 
   let assemblyFormat =
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e5ee7724cd32d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -36,6 +36,7 @@ class BufferizationState;
 namespace linalg {
 
 class LinalgOp;
+enum class WinogradConv2DFmr : uint32_t;
 
 //===----------------------------------------------------------------------===//
 // Utils.
@@ -1337,8 +1338,8 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
 /// F(m x m, r x r). m is the dimension size of output and r is the dimension
 /// size of filter.
 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r);
+                                      linalg::Conv2DNhwcFhwcOp op,
+                                      WinogradConv2DFmr fmr);
 
 /// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
 /// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
@@ -1879,8 +1880,8 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
 /// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r);
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
+                                    WinogradConv2DFmr fmr);
 
 /// Patterns to decompose Winograd operators.
 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..ac592ad808311 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2989,8 +2989,9 @@ LogicalResult WinogradFilterTransformOp::verify() {
   ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterH = filterShape[getFilterHDim()];
   int64_t filterW = filterShape[getFilterWDim()];
-  int64_t r = getR();
-  int64_t m = getM();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   if (filterH != r && filterH != 1)
     return emitOpError("expect filter height either equals to r or 1");
@@ -3046,8 +3047,9 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
   ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterH = filterShape[getFilterHDim()];
   int64_t filterW = filterShape[getFilterWDim()];
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t alpha = m + r - 1;
   int64_t alphaH = filterH != 1 ? alpha : 1;
   int64_t alphaW = filterW != 1 ? alpha : 1;
@@ -3124,8 +3126,9 @@ LogicalResult WinogradInputTransformOp::verify() {
   ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[getInputHDim()];
   int64_t inputW = inputShape[getInputWDim()];
-  int m = getM();
-  int r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t tileSize = m + r - 1;
 
   auto outputType = cast<ShapedType>(getOutput().getType());
@@ -3194,8 +3197,9 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
   int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
   int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
 
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t alpha = m + r - 1;
   int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
   int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
@@ -3224,8 +3228,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   ShapedType outputType = getOutputOperandType();
   ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3303,8 +3308,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
   int64_t valueW = valueShape[getValueAlphaWDim()];
   int64_t valueTileH = valueShape[getValueTileHDim()];
   int64_t valueTileW = valueShape[getValueTileWDim()];
-  int m = getM();
-  int r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   bool leftTransform = valueH != 1;
   bool rightTransform = valueW != 1;
 
@@ -3365,7 +3371,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  int64_t m = getM();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
@@ -3623,6 +3631,29 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
 namespace mlir {
 namespace linalg {
 
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r) {
+  if (m == 2 && r == 3)
+    return WinogradConv2DFmr::F_2_3;
+  if (m == 4 && r == 3)
+    return WinogradConv2DFmr::F_4_3;
+  if (m == 2 && r == 5)
+    return WinogradConv2DFmr::F_2_5;
+  return WinogradConv2DFmr::Unknown;
+}
+
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
+  switch (fmr) {
+  case WinogradConv2DFmr::F_2_3:
+    return {2, 3};
+  case WinogradConv2DFmr::F_4_3:
+    return {4, 3};
+  case WinogradConv2DFmr::F_2_5:
+    return {2, 5};
+  default:
+    return {-1, -1};
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // MatMulOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..528c561167824 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4030,7 +4030,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
   bool supported = TypeSwitch<Operation *, bool>(target)
                        .Case([&](linalg::Conv2DNhwcFhwcOp op) {
                          maybeTransformed =
-                             winogradConv2D(rewriter, op, getM(), getR());
+                             winogradConv2D(rewriter, op, getFmr());
                          return true;
                        })
                        .Default([&](Operation *op) { return false; });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index e4221d4748415..a4f835350fb52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -176,19 +177,6 @@ constexpr float A_2x2_5x5[] = {
 };
 // clang-format on
 
-using TransformMapKeyTy = std::pair<int, int>;
-
-/// We use F(m, r) to define the size of minimal filtering algorithms.
-/// m is the output dimension and r is the filter dimension. We can get
-/// the input dimension, alpha, from the formula, alpha = m + r - 1.
-///
-/// For example, when m = 2 and r = 3, we know its input size is 4.
-/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-/// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
 /// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
   TransformMatrix(const float *table, int64_t rows, int64_t cols,
@@ -344,22 +332,22 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
 ///     %ret = linalg.matmul %ret, GT
 ///     %inserted = insert %ret into filter<h x w x c x f>
 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
-                      Value retValue, int64_t m, int64_t r,
+                      Value retValue, WinogradConv2DFmr fmr,
                       bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to G transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       GMatrices = {
-          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
-          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
-          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
       };
 
   // Map from (m, r) to GT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       GTMatrices = {
-          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
-          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
-          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
       };
 
   auto filterType = cast<ShapedType>(filter.getType());
@@ -370,6 +358,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
   int64_t filterW = filterShape[2];
   int64_t filterC = filterShape[3];
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   if (filterH != r && filterH != 1)
     return Value();
   if (filterW != r && filterW != 1)
@@ -387,14 +377,13 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
                             zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
-    TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     Value matmulRetValue = extractFilter;
     Value zero = builder.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix G.
-      auto it = GMatrices.find(key);
+      auto it = GMatrices.find(fmr);
       if (it == GMatrices.end())
         return {};
       const TransformMatrix &GMatrix = it->second;
@@ -416,7 +405,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 
     if (rightTransform) {
       // Get constant transform matrix GT.
-      auto it = GTMatrices.find(key);
+      auto it = GTMatrices.find(fmr);
       if (it == GTMatrices.end())
         return {};
       const TransformMatrix &GTMatrix = it->second;
@@ -476,24 +465,26 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 ///                            %output<alphaH x alphaW x tileH x tileW x N x C>
 ///                            at [0, 0, %h, %w, %n, %c]
 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
-                     Value retValue, int64_t m, int64_t r,
+                     Value retValue, WinogradConv2DFmr fmr,
                      bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to BT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       BTMatrices = {
-          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
       };
 
   // Map from (m, r) to B transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       BMatrices = {
-          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
       };
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto inputType = cast<ShapedType>(input.getType());
   Type elementType = inputType.getElementType();
   auto inputShape = inputType.getShape(); // N, H, W, C
@@ -529,7 +520,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
                             widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
-    TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     int64_t retCols = 1;
     Value matmulRetValue = extractInput;
@@ -537,7 +527,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
         loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix BT.
-      auto it = BTMatrices.find(key);
+      auto it = BTMatrices.find(fmr);
       if (it == BTMatrices.end())
         return {};
       const TransformMatrix &BTMatrix = it->second;
@@ -560,7 +550,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
 
     if (rightTransform) {
       // Get constant transform matrix B.
-      auto it = BMatrices.find(key);
+      auto it = BMatrices.find(fmr);
       if (it == BMatrices.end())
         return {};
       const TransformMatrix &BMatrix = it->second;
@@ -696,24 +686,26 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
 ///                            output<N x H x W x F>
 ///                            at [%n, (%h x m), (%w x m), %f]
 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
-                      Value output, int64_t m, int64_t r,
+                      Value output, WinogradConv2DFmr fmr,
                       bool l...
[truncated]

@GeorgeARM
Copy link
Contributor

GeorgeARM commented Jun 18, 2025

Is the implementation unable to functionally work with different m, r parameters?
If yes constraining is fine; but if it can, then constraint semantics should be orthogonal. Otherwise we restrict experimentation and we might end up with a long list of configurations. Users using Winograd I presume need to be aware of what they are doing. If we want to ease use we could have a utility function that returns some "safe", "good" parameters for a given kernel config.
If we don't have transform matrices we can just create a side-band API to inject/extract transform matrices and work with this? Just brain-storming really here.
Thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants