-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] move LinalgToStandard to Linalg as ConvertToFunctionCalls #121392
base: main
Are you sure you want to change the base?
Conversation
The remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing. Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little.
@llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesThe remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing. Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little. Full diff: https://github.com/llvm/llvm-project/pull/121392.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index eefa2c4724833b..346cf62cdb8e86 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -22,34 +22,11 @@ class OperationPass;
namespace linalg {
-//===----------------------------------------------------------------------===//
-// Patterns to convert a LinalgOp to func.call @external library implementation.
-//===----------------------------------------------------------------------===//
-// These patterns are exposed individually because they are expected to be
-// typically used individually.
-
-// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
-// function. The implementation of the function can be either in the same module
-// or in an externally linked library.
-// This is a generic entry point for all LinalgOp, except for CopyOp, for which
-// more specialized patterns are provided.
-class LinalgOpToLibraryCallRewrite
- : public OpInterfaceRewritePattern<LinalgOp> {
-public:
- using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
- LogicalResult matchAndRewrite(LinalgOp op,
- PatternRewriter &rewriter) const override;
-};
-
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
-
-} // namespace linalg
-
/// Create a pass to convert Linalg operations to the Standard dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
+} // namespace linalg
+
} // namespace mlir
#endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 58ee87cf820396..7a3ffa97bd5212 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -713,17 +713,6 @@ def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
];
}
-//===----------------------------------------------------------------------===//
-// LinalgToStandard
-//===----------------------------------------------------------------------===//
-
-def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
- let summary = "Convert the operations from the linalg dialect into the "
- "Standard dialect";
- let constructor = "mlir::createConvertLinalgToStandardPass()";
- let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"];
-}
-
//===----------------------------------------------------------------------===//
// MathToLibm
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0a..99c6d1c14674a1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -58,6 +58,13 @@ def ConvertLinalgToParallelLoopsPass
];
}
+def ConvertLinalgToFunctionCallsPass
+ : Pass<"convert-linalg-to-function-calls", "ModuleOp"> {
+ let summary = "Convert the operations from the Linalg dialect into "
+ "function calls";
+ let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect"];
+}
+
def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let options = [
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..1ae27136512873 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1885,6 +1885,32 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+//===----------------------------------------------------------------------===//
+// Patterns to convert a LinalgOp to func.call @external library implementation.
+//
+// These patterns are exposed individually because they are expected to be
+// typically used individually.
+//===----------------------------------------------------------------------===//
+
+// Creates a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
+// function. The implementation of the function can be either in the same module
+// or in an externally linked library.
+// This is a generic entry point for all LinalgOp, except for CopyOp, for which
+// more specialized patterns are provided.
+class LinalgOpToLibraryCallRewrite
+ : public OpInterfaceRewritePattern<LinalgOp> {
+public:
+ using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Populates the given list with patterns that convert from Linalg to library
+/// calls using the `func` dialect.
+void populateLinalgToFunctionCallsConversionPatterns(
+ RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 62461c0cea08af..1c7318bb584d45 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -31,7 +31,6 @@ add_subdirectory(GPUToSPIRV)
add_subdirectory(GPUToVulkan)
add_subdirectory(IndexToLLVM)
add_subdirectory(IndexToSPIRV)
-add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
diff --git a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
deleted file mode 100644
index 7fc4af54031855..00000000000000
--- a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
+++ /dev/null
@@ -1,23 +0,0 @@
-add_mlir_conversion_library(MLIRLinalgToStandard
- LinalgToStandard.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard
-
- DEPENDS
- MLIRConversionPassIncGen
-
- LINK_COMPONENTS
- Core
-
- LINK_LIBS PUBLIC
- MLIRFuncDialect
- MLIRIR
- MLIRLinalgDialect
- MLIRLinalgTransforms
- MLIRLLVMDialect
- MLIRMemRefDialect
- MLIRPass
- MLIRSCFDialect
- MLIRTransforms
- )
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..d6bdf1d52dd1da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
FoldAddIntoDest.cpp
+ FunctionCalls.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
Generalization.cpp
@@ -68,6 +69,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRMeshTransforms
MLIRLinalgDialect
MLIRLinalgUtils
+ MLIRLLVMDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRPass
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
similarity index 87%
rename from mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
rename to mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
index 4d1f35c767304d..a202dac0aa2326 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
@@ -1,4 +1,4 @@
-//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
+//===- LinalgToFunctionCalls.cpp - Linalg to function calls conversion ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,19 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Pass/Pass.h"
namespace mlir {
-#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
-#include "mlir/Conversion/Passes.h.inc"
+#define GEN_PASS_DEF_CONVERTLINALGTOFUNCTIONCALLSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
using namespace mlir;
@@ -123,8 +122,7 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
return success();
}
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void mlir::linalg::populateLinalgToStandardConversionPatterns(
+void mlir::linalg::populateLinalgToFunctionCallsConversionPatterns(
RewritePatternSet &patterns) {
// TODO: ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
@@ -132,13 +130,14 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
}
namespace {
-struct ConvertLinalgToStandardPass
- : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
+struct ConvertLinalgToFunctionCallsPass
+ : public impl::ConvertLinalgToFunctionCallsPassBase<
+ ConvertLinalgToFunctionCallsPass> {
void runOnOperation() override;
};
} // namespace
-void ConvertLinalgToStandardPass::runOnOperation() {
+void ConvertLinalgToFunctionCallsPass::runOnOperation() {
auto module = getOperation();
ConversionTarget target(getContext());
target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
@@ -146,12 +145,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
scf::SCFDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
RewritePatternSet patterns(&getContext());
- populateLinalgToStandardConversionPatterns(patterns);
+ populateLinalgToFunctionCallsConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
-
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertLinalgToStandardPass() {
- return std::make_unique<ConvertLinalgToStandardPass>();
-}
diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/function-calls.mlir
similarity index 61%
rename from mlir/test/Dialect/Linalg/library-calls.mlir
rename to mlir/test/Dialect/Linalg/function-calls.mlir
index 1fa675d8b4b68a..103fcb16c51732 100644
--- a/mlir/test/Dialect/Linalg/library-calls.mlir
+++ b/mlir/test/Dialect/Linalg/function-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-linalg-to-std -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-function-calls -split-input-file --verify-diagnostics | FileCheck %s
func.func private @printMemrefF32(memref<*xf32>)
@@ -99,3 +99,85 @@ func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8
ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>)
return
}
+
+// -----
+
+func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
+ %arg1: memref<?xf32, strided<[1], offset: ?>>,
+ %arg2: memref<f32>) {
+ linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
+ memref<?xf32, strided<[1], offset: ?>>)
+ outs(%arg2: memref<f32>)
+ return
+}
+// CHECK-LABEL: func @dot(
+// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
+// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] :
+// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] :
+// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] :
+// CHECK-SAME: memref<f32> to memref<f32, strided<[], offset: ?>>
+// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32(
+// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) :
+// CHECK-SAME: memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
+
+// -----
+
+#matmul_accesses = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmul_trait = {
+ iterator_types = ["parallel", "parallel", "reduction"],
+ indexing_maps = #matmul_accesses,
+ library_call = "external_outerproduct_matmul"
+}
+
+!vector_type_A = vector<4xf32>
+!vector_type_B = vector<4xf32>
+!vector_type_C = vector<4x4xf32>
+
+!matrix_type_A = memref<?x?x!vector_type_A>
+!matrix_type_B = memref<?x?x!vector_type_B>
+!matrix_type_C = memref<?x?x!vector_type_C>
+
+func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
+ linalg.generic #matmul_trait
+ ins(%A, %B : !matrix_type_A, !matrix_type_B)
+ outs(%C : !matrix_type_C) {
+ ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
+ %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
+ linalg.yield %d: !vector_type_C
+ }
+ return
+}
+// CHECK-LABEL: func @matmul_vec_impl(
+// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) {
+ // expected-error @below {{failed to legalize}}
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // expected-error @below {{failed to legalize}}
+ %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
deleted file mode 100644
index f50016f9ea477f..00000000000000
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ /dev/null
@@ -1,81 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-std --split-input-file -verify-diagnostics | FileCheck %s
-
-func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
- %arg1: memref<?xf32, strided<[1], offset: ?>>,
- %arg2: memref<f32>) {
- linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
- memref<?xf32, strided<[1], offset: ?>>)
- outs(%arg2: memref<f32>)
- return
-}
-// CHECK-LABEL: func @dot(
-// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
-// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] :
-// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] :
-// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] :
-// CHECK-SAME: memref<f32> to memref<f32, strided<[], offset: ?>>
-// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32(
-// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) :
-// CHECK-SAME: memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
-
-// -----
-
-#matmul_accesses = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>
-]
-#matmul_trait = {
- iterator_types = ["parallel", "parallel", "reduction"],
- indexing_maps = #matmul_accesses,
- library_call = "external_outerproduct_matmul"
-}
-
-!vector_type_A = vector<4xf32>
-!vector_type_B = vector<4xf32>
-!vector_type_C = vector<4x4xf32>
-
-!matrix_type_A = memref<?x?x!vector_type_A>
-!matrix_type_B = memref<?x?x!vector_type_B>
-!matrix_type_C = memref<?x?x!vector_type_C>
-
-func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
- linalg.generic #matmul_trait
- ins(%A, %B : !matrix_type_A, !matrix_type_B)
- outs(%C : !matrix_type_C) {
- ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
- %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
- linalg.yield %d: !vector_type_C
- }
- return
-}
-// CHECK-LABEL: func @matmul_vec_impl(
-// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d0)>
-
-func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) {
- // expected-error @below {{failed to legalize}}
- %0 = linalg.generic {
- indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
- ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<?xf32>
- return
-}
-
-// -----
-
-func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error @below {{failed to legalize}}
- %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
|
@llvm/pr-subscribers-mlir-linalg Author: Oleksandr "Alex" Zinenko (ftynse) ChangesThe remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing. Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little. Full diff: https://github.com/llvm/llvm-project/pull/121392.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index eefa2c4724833b..346cf62cdb8e86 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -22,34 +22,11 @@ class OperationPass;
namespace linalg {
-//===----------------------------------------------------------------------===//
-// Patterns to convert a LinalgOp to func.call @external library implementation.
-//===----------------------------------------------------------------------===//
-// These patterns are exposed individually because they are expected to be
-// typically used individually.
-
-// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
-// function. The implementation of the function can be either in the same module
-// or in an externally linked library.
-// This is a generic entry point for all LinalgOp, except for CopyOp, for which
-// more specialized patterns are provided.
-class LinalgOpToLibraryCallRewrite
- : public OpInterfaceRewritePattern<LinalgOp> {
-public:
- using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
- LogicalResult matchAndRewrite(LinalgOp op,
- PatternRewriter &rewriter) const override;
-};
-
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
-
-} // namespace linalg
-
/// Create a pass to convert Linalg operations to the Standard dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
+} // namespace linalg
+
} // namespace mlir
#endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 58ee87cf820396..7a3ffa97bd5212 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -713,17 +713,6 @@ def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
];
}
-//===----------------------------------------------------------------------===//
-// LinalgToStandard
-//===----------------------------------------------------------------------===//
-
-def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
- let summary = "Convert the operations from the linalg dialect into the "
- "Standard dialect";
- let constructor = "mlir::createConvertLinalgToStandardPass()";
- let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"];
-}
-
//===----------------------------------------------------------------------===//
// MathToLibm
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0a..99c6d1c14674a1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -58,6 +58,13 @@ def ConvertLinalgToParallelLoopsPass
];
}
+def ConvertLinalgToFunctionCallsPass
+ : Pass<"convert-linalg-to-function-calls", "ModuleOp"> {
+ let summary = "Convert the operations from the Linalg dialect into "
+ "function calls";
+ let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect"];
+}
+
def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let options = [
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..1ae27136512873 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1885,6 +1885,32 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+//===----------------------------------------------------------------------===//
+// Patterns to convert a LinalgOp to func.call @external library implementation.
+//
+// These patterns are exposed individually because they are expected to be
+// typically used individually.
+//===----------------------------------------------------------------------===//
+
+// Creates a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
+// function. The implementation of the function can be either in the same module
+// or in an externally linked library.
+// This is a generic entry point for all LinalgOp, except for CopyOp, for which
+// more specialized patterns are provided.
+class LinalgOpToLibraryCallRewrite
+ : public OpInterfaceRewritePattern<LinalgOp> {
+public:
+ using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Populates the given list with patterns that convert from Linalg to library
+/// calls using the `func` dialect.
+void populateLinalgToFunctionCallsConversionPatterns(
+ RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 62461c0cea08af..1c7318bb584d45 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -31,7 +31,6 @@ add_subdirectory(GPUToSPIRV)
add_subdirectory(GPUToVulkan)
add_subdirectory(IndexToLLVM)
add_subdirectory(IndexToSPIRV)
-add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
diff --git a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
deleted file mode 100644
index 7fc4af54031855..00000000000000
--- a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
+++ /dev/null
@@ -1,23 +0,0 @@
-add_mlir_conversion_library(MLIRLinalgToStandard
- LinalgToStandard.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard
-
- DEPENDS
- MLIRConversionPassIncGen
-
- LINK_COMPONENTS
- Core
-
- LINK_LIBS PUBLIC
- MLIRFuncDialect
- MLIRIR
- MLIRLinalgDialect
- MLIRLinalgTransforms
- MLIRLLVMDialect
- MLIRMemRefDialect
- MLIRPass
- MLIRSCFDialect
- MLIRTransforms
- )
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..d6bdf1d52dd1da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
EliminateEmptyTensors.cpp
EraseUnusedOperandsAndResults.cpp
FoldAddIntoDest.cpp
+ FunctionCalls.cpp
FusePadOpWithLinalgProducer.cpp
Fusion.cpp
Generalization.cpp
@@ -68,6 +69,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRMeshTransforms
MLIRLinalgDialect
MLIRLinalgUtils
+ MLIRLLVMDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRPass
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
similarity index 87%
rename from mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
rename to mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
index 4d1f35c767304d..a202dac0aa2326 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
@@ -1,4 +1,4 @@
-//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
+//===- LinalgToFunctionCalls.cpp - Linalg to function calls conversion ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,19 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Pass/Pass.h"
namespace mlir {
-#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
-#include "mlir/Conversion/Passes.h.inc"
+#define GEN_PASS_DEF_CONVERTLINALGTOFUNCTIONCALLSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
using namespace mlir;
@@ -123,8 +122,7 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
return success();
}
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void mlir::linalg::populateLinalgToStandardConversionPatterns(
+void mlir::linalg::populateLinalgToFunctionCallsConversionPatterns(
RewritePatternSet &patterns) {
// TODO: ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
@@ -132,13 +130,14 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
}
namespace {
-struct ConvertLinalgToStandardPass
- : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
+struct ConvertLinalgToFunctionCallsPass
+ : public impl::ConvertLinalgToFunctionCallsPassBase<
+ ConvertLinalgToFunctionCallsPass> {
void runOnOperation() override;
};
} // namespace
-void ConvertLinalgToStandardPass::runOnOperation() {
+void ConvertLinalgToFunctionCallsPass::runOnOperation() {
auto module = getOperation();
ConversionTarget target(getContext());
target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
@@ -146,12 +145,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
scf::SCFDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
RewritePatternSet patterns(&getContext());
- populateLinalgToStandardConversionPatterns(patterns);
+ populateLinalgToFunctionCallsConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
-
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertLinalgToStandardPass() {
- return std::make_unique<ConvertLinalgToStandardPass>();
-}
diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/function-calls.mlir
similarity index 61%
rename from mlir/test/Dialect/Linalg/library-calls.mlir
rename to mlir/test/Dialect/Linalg/function-calls.mlir
index 1fa675d8b4b68a..103fcb16c51732 100644
--- a/mlir/test/Dialect/Linalg/library-calls.mlir
+++ b/mlir/test/Dialect/Linalg/function-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-linalg-to-std -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-function-calls -split-input-file --verify-diagnostics | FileCheck %s
func.func private @printMemrefF32(memref<*xf32>)
@@ -99,3 +99,85 @@ func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8
ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>)
return
}
+
+// -----
+
+func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
+ %arg1: memref<?xf32, strided<[1], offset: ?>>,
+ %arg2: memref<f32>) {
+ linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
+ memref<?xf32, strided<[1], offset: ?>>)
+ outs(%arg2: memref<f32>)
+ return
+}
+// CHECK-LABEL: func @dot(
+// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
+// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] :
+// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] :
+// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] :
+// CHECK-SAME: memref<f32> to memref<f32, strided<[], offset: ?>>
+// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32(
+// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) :
+// CHECK-SAME: memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
+
+// -----
+
+#matmul_accesses = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmul_trait = {
+ iterator_types = ["parallel", "parallel", "reduction"],
+ indexing_maps = #matmul_accesses,
+ library_call = "external_outerproduct_matmul"
+}
+
+!vector_type_A = vector<4xf32>
+!vector_type_B = vector<4xf32>
+!vector_type_C = vector<4x4xf32>
+
+!matrix_type_A = memref<?x?x!vector_type_A>
+!matrix_type_B = memref<?x?x!vector_type_B>
+!matrix_type_C = memref<?x?x!vector_type_C>
+
+func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
+ linalg.generic #matmul_trait
+ ins(%A, %B : !matrix_type_A, !matrix_type_B)
+ outs(%C : !matrix_type_C) {
+ ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
+ %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
+ linalg.yield %d: !vector_type_C
+ }
+ return
+}
+// CHECK-LABEL: func @matmul_vec_impl(
+// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) {
+ // expected-error @below {{failed to legalize}}
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // expected-error @below {{failed to legalize}}
+ %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
deleted file mode 100644
index f50016f9ea477f..00000000000000
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ /dev/null
@@ -1,81 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-std --split-input-file -verify-diagnostics | FileCheck %s
-
-func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
- %arg1: memref<?xf32, strided<[1], offset: ?>>,
- %arg2: memref<f32>) {
- linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
- memref<?xf32, strided<[1], offset: ?>>)
- outs(%arg2: memref<f32>)
- return
-}
-// CHECK-LABEL: func @dot(
-// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-// CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
-// CHECK: %[[o0:.*]] = memref.cast %[[arg0]] :
-// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-// CHECK: %[[o1:.*]] = memref.cast %[[arg1]] :
-// CHECK-SAME: memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-// CHECK: %[[o2:.*]] = memref.cast %[[arg2]] :
-// CHECK-SAME: memref<f32> to memref<f32, strided<[], offset: ?>>
-// CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32(
-// CHECK-SAME: %[[o0]], %[[o1]], %[[o2]]) :
-// CHECK-SAME: memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
-
-// -----
-
-#matmul_accesses = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>
-]
-#matmul_trait = {
- iterator_types = ["parallel", "parallel", "reduction"],
- indexing_maps = #matmul_accesses,
- library_call = "external_outerproduct_matmul"
-}
-
-!vector_type_A = vector<4xf32>
-!vector_type_B = vector<4xf32>
-!vector_type_C = vector<4x4xf32>
-
-!matrix_type_A = memref<?x?x!vector_type_A>
-!matrix_type_B = memref<?x?x!vector_type_B>
-!matrix_type_C = memref<?x?x!vector_type_C>
-
-func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
- linalg.generic #matmul_trait
- ins(%A, %B : !matrix_type_A, !matrix_type_B)
- outs(%C : !matrix_type_C) {
- ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
- %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
- linalg.yield %d: !vector_type_C
- }
- return
-}
-// CHECK-LABEL: func @matmul_vec_impl(
-// CHECK: call @external_outerproduct_matmul(%{{.*}}) :
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d0)>
-
-func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) {
- // expected-error @below {{failed to legalize}}
- %0 = linalg.generic {
- indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
- ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- } -> tensor<?xf32>
- return
-}
-
-// -----
-
-func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error @below {{failed to legalize}}
- %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0 : tensor<4x8xf32>
-}
|
The remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing.
Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little.