Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] move LinalgToStandard to Linalg as ConvertToFunctionCalls #121392

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Dec 31, 2024

The remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing.

Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little.

The remnants of the ConvertLinalgToStandard pass were still present in the
codebase under this name, years after the Standard dialect was dismantled.
Practically, this pass / pattern set was only performing the rewrite of Linalg
operaitons to function calls. All this makes the existence of the pass highly
confusing.

Move the logic under Linalg/Transforms, similarly to other "lowerings" from
Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard
to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically,
one of them was called library-calls.mlir. Simplify the code a little.
@llvmbot
Copy link
Member

llvmbot commented Dec 31, 2024

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

The remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing.

Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little.


Full diff: https://github.com/llvm/llvm-project/pull/121392.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h (+2-25)
  • (modified) mlir/include/mlir/Conversion/Passes.td (-11)
  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+7)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+26)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (-1)
  • (removed) mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt (-23)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+2)
  • (renamed) mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp (+11-17)
  • (renamed) mlir/test/Dialect/Linalg/function-calls.mlir (+83-1)
  • (removed) mlir/test/Dialect/Linalg/standard.mlir (-81)
diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index eefa2c4724833b..346cf62cdb8e86 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -22,34 +22,11 @@ class OperationPass;
 
 namespace linalg {
 
-//===----------------------------------------------------------------------===//
-// Patterns to convert a LinalgOp to func.call @external library implementation.
-//===----------------------------------------------------------------------===//
-// These patterns are exposed individually because they are expected to be
-// typically used individually.
-
-// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
-// function. The implementation of the function can be either in the same module
-// or in an externally linked library.
-// This is a generic entry point for all LinalgOp, except for CopyOp, for which
-// more specialized patterns are provided.
-class LinalgOpToLibraryCallRewrite
-    : public OpInterfaceRewritePattern<LinalgOp> {
-public:
-  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
-  LogicalResult matchAndRewrite(LinalgOp op,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
-
-} // namespace linalg
-
 /// Create a pass to convert Linalg operations to the Standard dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
 
+} // namespace linalg
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 58ee87cf820396..7a3ffa97bd5212 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -713,17 +713,6 @@ def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// LinalgToStandard
-//===----------------------------------------------------------------------===//
-
-def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
-  let summary = "Convert the operations from the linalg dialect into the "
-                "Standard dialect";
-  let constructor = "mlir::createConvertLinalgToStandardPass()";
-  let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"];
-}
-
 //===----------------------------------------------------------------------===//
 // MathToLibm
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0a..99c6d1c14674a1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -58,6 +58,13 @@ def ConvertLinalgToParallelLoopsPass
   ];
 }
 
+def ConvertLinalgToFunctionCallsPass
+    : Pass<"convert-linalg-to-function-calls", "ModuleOp"> {
+  let summary = "Convert the operations from the Linalg dialect into "
+                "function calls";
+  let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect"];
+}
+
 def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
   let summary = "Remove unit-extent dimension in Linalg ops on tensors";
   let options = [
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..1ae27136512873 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1885,6 +1885,32 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
 /// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
+//===----------------------------------------------------------------------===//
+// Patterns to convert a LinalgOp to func.call @external library implementation.
+//
+// These patterns are exposed individually because they are expected to be
+// typically used individually.
+//===----------------------------------------------------------------------===//
+
+// Creates a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
+// function. The implementation of the function can be either in the same module
+// or in an externally linked library.
+// This is a generic entry point for all LinalgOp, except for CopyOp, for which
+// more specialized patterns are provided.
+class LinalgOpToLibraryCallRewrite
+    : public OpInterfaceRewritePattern<LinalgOp> {
+public:
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Populates the given list with patterns that convert from Linalg to library
+/// calls using the `func` dialect.
+void populateLinalgToFunctionCallsConversionPatterns(
+    RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 62461c0cea08af..1c7318bb584d45 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -31,7 +31,6 @@ add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
 add_subdirectory(IndexToLLVM)
 add_subdirectory(IndexToSPIRV)
-add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
diff --git a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
deleted file mode 100644
index 7fc4af54031855..00000000000000
--- a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
+++ /dev/null
@@ -1,23 +0,0 @@
-add_mlir_conversion_library(MLIRLinalgToStandard
-  LinalgToStandard.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRFuncDialect
-  MLIRIR
-  MLIRLinalgDialect
-  MLIRLinalgTransforms
-  MLIRLLVMDialect
-  MLIRMemRefDialect
-  MLIRPass
-  MLIRSCFDialect
-  MLIRTransforms
-  )
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..d6bdf1d52dd1da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   EliminateEmptyTensors.cpp
   EraseUnusedOperandsAndResults.cpp
   FoldAddIntoDest.cpp
+  FunctionCalls.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
   Generalization.cpp
@@ -68,6 +69,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRMeshTransforms
   MLIRLinalgDialect
   MLIRLinalgUtils
+  MLIRLLVMDialect
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRPass
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
similarity index 87%
rename from mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
rename to mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
index 4d1f35c767304d..a202dac0aa2326 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
@@ -1,4 +1,4 @@
-//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
+//===- LinalgToFunctionCalls.cpp - Linalg to function calls conversion ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,19 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
-
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Pass/Pass.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
-#include "mlir/Conversion/Passes.h.inc"
+#define GEN_PASS_DEF_CONVERTLINALGTOFUNCTIONCALLSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
 } // namespace mlir
 
 using namespace mlir;
@@ -123,8 +122,7 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
   return success();
 }
 
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void mlir::linalg::populateLinalgToStandardConversionPatterns(
+void mlir::linalg::populateLinalgToFunctionCallsConversionPatterns(
     RewritePatternSet &patterns) {
   // TODO: ConvOp conversion needs to export a descriptor with relevant
   // attribute values such as kernel striding and dilation.
@@ -132,13 +130,14 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
 }
 
 namespace {
-struct ConvertLinalgToStandardPass
-    : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
+struct ConvertLinalgToFunctionCallsPass
+    : public impl::ConvertLinalgToFunctionCallsPassBase<
+          ConvertLinalgToFunctionCallsPass> {
   void runOnOperation() override;
 };
 } // namespace
 
-void ConvertLinalgToStandardPass::runOnOperation() {
+void ConvertLinalgToFunctionCallsPass::runOnOperation() {
   auto module = getOperation();
   ConversionTarget target(getContext());
   target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
@@ -146,12 +145,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
                          scf::SCFDialect>();
   target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
   RewritePatternSet patterns(&getContext());
-  populateLinalgToStandardConversionPatterns(patterns);
+  populateLinalgToFunctionCallsConversionPatterns(patterns);
   if (failed(applyFullConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }
-
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertLinalgToStandardPass() {
-  return std::make_unique<ConvertLinalgToStandardPass>();
-}
diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/function-calls.mlir
similarity index 61%
rename from mlir/test/Dialect/Linalg/library-calls.mlir
rename to mlir/test/Dialect/Linalg/function-calls.mlir
index 1fa675d8b4b68a..103fcb16c51732 100644
--- a/mlir/test/Dialect/Linalg/library-calls.mlir
+++ b/mlir/test/Dialect/Linalg/function-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-linalg-to-std -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-function-calls -split-input-file --verify-diagnostics | FileCheck %s
 
 func.func private @printMemrefF32(memref<*xf32>)
 
@@ -99,3 +99,85 @@ func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8
                               ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>)
   return
 }
+
+// -----
+
+func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
+          %arg1: memref<?xf32, strided<[1], offset: ?>>,
+          %arg2: memref<f32>) {
+  linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
+                               memref<?xf32, strided<[1], offset: ?>>)
+             outs(%arg2: memref<f32>)
+  return
+}
+// CHECK-LABEL: func @dot(
+//  CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+//  CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+//  CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
+//       CHECK:   %[[o0:.*]] = memref.cast %[[arg0]] :
+//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   %[[o1:.*]] = memref.cast %[[arg1]] :
+//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   %[[o2:.*]] = memref.cast %[[arg2]] :
+//  CHECK-SAME:     memref<f32> to memref<f32, strided<[], offset: ?>>
+//       CHECK:   call @linalg_dot_viewsxf32_viewsxf32_viewf32(
+//  CHECK-SAME:     %[[o0]], %[[o1]], %[[o2]]) :
+//  CHECK-SAME:   memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
+
+// -----
+
+#matmul_accesses = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmul_trait = {
+  iterator_types = ["parallel", "parallel", "reduction"],
+  indexing_maps = #matmul_accesses,
+  library_call = "external_outerproduct_matmul"
+}
+
+!vector_type_A = vector<4xf32>
+!vector_type_B = vector<4xf32>
+!vector_type_C = vector<4x4xf32>
+
+!matrix_type_A = memref<?x?x!vector_type_A>
+!matrix_type_B = memref<?x?x!vector_type_B>
+!matrix_type_C = memref<?x?x!vector_type_C>
+
+func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
+  linalg.generic #matmul_trait
+      ins(%A, %B : !matrix_type_A, !matrix_type_B)
+     outs(%C : !matrix_type_C) {
+    ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
+      %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
+      linalg.yield %d: !vector_type_C
+  }
+  return
+}
+// CHECK-LABEL: func @matmul_vec_impl(
+// CHECK:  call @external_outerproduct_matmul(%{{.*}}) :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>)  {
+  // expected-error @below {{failed to legalize}}
+  %0 = linalg.generic {
+    indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
+  ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+  ^bb0(%in: f32, %out: f32): 
+    linalg.yield %in : f32
+  } -> tensor<?xf32>
+  return 
+}
+
+// -----
+
+func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error @below {{failed to legalize}}
+  %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
deleted file mode 100644
index f50016f9ea477f..00000000000000
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ /dev/null
@@ -1,81 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-std --split-input-file -verify-diagnostics | FileCheck %s
-
-func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
-          %arg1: memref<?xf32, strided<[1], offset: ?>>,
-          %arg2: memref<f32>) {
-  linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
-                               memref<?xf32, strided<[1], offset: ?>>)
-             outs(%arg2: memref<f32>)
-  return
-}
-// CHECK-LABEL: func @dot(
-//  CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-//  CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-//  CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
-//       CHECK:   %[[o0:.*]] = memref.cast %[[arg0]] :
-//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-//       CHECK:   %[[o1:.*]] = memref.cast %[[arg1]] :
-//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-//       CHECK:   %[[o2:.*]] = memref.cast %[[arg2]] :
-//  CHECK-SAME:     memref<f32> to memref<f32, strided<[], offset: ?>>
-//       CHECK:   call @linalg_dot_viewsxf32_viewsxf32_viewf32(
-//  CHECK-SAME:     %[[o0]], %[[o1]], %[[o2]]) :
-//  CHECK-SAME:   memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
-
-// -----
-
-#matmul_accesses = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmul_trait = {
-  iterator_types = ["parallel", "parallel", "reduction"],
-  indexing_maps = #matmul_accesses,
-  library_call = "external_outerproduct_matmul"
-}
-
-!vector_type_A = vector<4xf32>
-!vector_type_B = vector<4xf32>
-!vector_type_C = vector<4x4xf32>
-
-!matrix_type_A = memref<?x?x!vector_type_A>
-!matrix_type_B = memref<?x?x!vector_type_B>
-!matrix_type_C = memref<?x?x!vector_type_C>
-
-func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
-  linalg.generic #matmul_trait
-      ins(%A, %B : !matrix_type_A, !matrix_type_B)
-     outs(%C : !matrix_type_C) {
-    ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
-      %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
-      linalg.yield %d: !vector_type_C
-  }
-  return
-}
-// CHECK-LABEL: func @matmul_vec_impl(
-// CHECK:  call @external_outerproduct_matmul(%{{.*}}) :
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d0)>
-
-func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>)  {
-  // expected-error @below {{failed to legalize}}
-  %0 = linalg.generic {
-    indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
-  ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
-  ^bb0(%in: f32, %out: f32): 
-    linalg.yield %in : f32
-  } -> tensor<?xf32>
-  return 
-}
-
-// -----
-
-func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // expected-error @below {{failed to legalize}}
-  %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0 : tensor<4x8xf32>
-}

@llvmbot
Copy link
Member

llvmbot commented Dec 31, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

The remnants of the ConvertLinalgToStandard pass were still present in the codebase under this name, years after the Standard dialect was dismantled. Practically, this pass / pattern set was only performing the rewrite of Linalg operaitons to function calls. All this makes the existence of the pass highly confusing.

Move the logic under Linalg/Transforms, similarly to other "lowerings" from Linalg, e.g., the one to (affine or SCF) loops. Rename ConvertLinalgToStandard to ConvertLinalgToFunctionCalls. Merge the two relevant test files, ironically, one of them was called library-calls.mlir. Simplify the code a little.


Full diff: https://github.com/llvm/llvm-project/pull/121392.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h (+2-25)
  • (modified) mlir/include/mlir/Conversion/Passes.td (-11)
  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+7)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+26)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (-1)
  • (removed) mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt (-23)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+2)
  • (renamed) mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp (+11-17)
  • (renamed) mlir/test/Dialect/Linalg/function-calls.mlir (+83-1)
  • (removed) mlir/test/Dialect/Linalg/standard.mlir (-81)
diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index eefa2c4724833b..346cf62cdb8e86 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -22,34 +22,11 @@ class OperationPass;
 
 namespace linalg {
 
-//===----------------------------------------------------------------------===//
-// Patterns to convert a LinalgOp to func.call @external library implementation.
-//===----------------------------------------------------------------------===//
-// These patterns are exposed individually because they are expected to be
-// typically used individually.
-
-// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
-// function. The implementation of the function can be either in the same module
-// or in an externally linked library.
-// This is a generic entry point for all LinalgOp, except for CopyOp, for which
-// more specialized patterns are provided.
-class LinalgOpToLibraryCallRewrite
-    : public OpInterfaceRewritePattern<LinalgOp> {
-public:
-  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
-  LogicalResult matchAndRewrite(LinalgOp op,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
-
-} // namespace linalg
-
 /// Create a pass to convert Linalg operations to the Standard dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
 
+} // namespace linalg
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 58ee87cf820396..7a3ffa97bd5212 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -713,17 +713,6 @@ def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// LinalgToStandard
-//===----------------------------------------------------------------------===//
-
-def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
-  let summary = "Convert the operations from the linalg dialect into the "
-                "Standard dialect";
-  let constructor = "mlir::createConvertLinalgToStandardPass()";
-  let dependentDialects = ["func::FuncDialect", "memref::MemRefDialect"];
-}
-
 //===----------------------------------------------------------------------===//
 // MathToLibm
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d96ad919b65f0a..99c6d1c14674a1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -58,6 +58,13 @@ def ConvertLinalgToParallelLoopsPass
   ];
 }
 
+def ConvertLinalgToFunctionCallsPass
+    : Pass<"convert-linalg-to-function-calls", "ModuleOp"> {
+  let summary = "Convert the operations from the Linalg dialect into "
+                "function calls";
+  let dependentDialects = ["func::FuncDialect", "LLVM::LLVMDialect"];
+}
+
 def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
   let summary = "Remove unit-extent dimension in Linalg ops on tensors";
   let options = [
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..1ae27136512873 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1885,6 +1885,32 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
 /// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
+//===----------------------------------------------------------------------===//
+// Patterns to convert a LinalgOp to func.call @external library implementation.
+//
+// These patterns are exposed individually because they are expected to be
+// typically used individually.
+//===----------------------------------------------------------------------===//
+
+// Creates a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
+// function. The implementation of the function can be either in the same module
+// or in an externally linked library.
+// This is a generic entry point for all LinalgOp, except for CopyOp, for which
+// more specialized patterns are provided.
+class LinalgOpToLibraryCallRewrite
+    : public OpInterfaceRewritePattern<LinalgOp> {
+public:
+  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Populates the given list with patterns that convert from Linalg to library
+/// calls using the `func` dialect.
+void populateLinalgToFunctionCallsConversionPatterns(
+    RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 62461c0cea08af..1c7318bb584d45 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -31,7 +31,6 @@ add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
 add_subdirectory(IndexToLLVM)
 add_subdirectory(IndexToSPIRV)
-add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
diff --git a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt b/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
deleted file mode 100644
index 7fc4af54031855..00000000000000
--- a/mlir/lib/Conversion/LinalgToStandard/CMakeLists.txt
+++ /dev/null
@@ -1,23 +0,0 @@
-add_mlir_conversion_library(MLIRLinalgToStandard
-  LinalgToStandard.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/LinalgToStandard
-
-  DEPENDS
-  MLIRConversionPassIncGen
-
-  LINK_COMPONENTS
-  Core
-
-  LINK_LIBS PUBLIC
-  MLIRFuncDialect
-  MLIRIR
-  MLIRLinalgDialect
-  MLIRLinalgTransforms
-  MLIRLLVMDialect
-  MLIRMemRefDialect
-  MLIRPass
-  MLIRSCFDialect
-  MLIRTransforms
-  )
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..d6bdf1d52dd1da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   EliminateEmptyTensors.cpp
   EraseUnusedOperandsAndResults.cpp
   FoldAddIntoDest.cpp
+  FunctionCalls.cpp
   FusePadOpWithLinalgProducer.cpp
   Fusion.cpp
   Generalization.cpp
@@ -68,6 +69,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRMeshTransforms
   MLIRLinalgDialect
   MLIRLinalgUtils
+  MLIRLLVMDialect
   MLIRSCFDialect
   MLIRSCFTransforms
   MLIRPass
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
similarity index 87%
rename from mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
rename to mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
index 4d1f35c767304d..a202dac0aa2326 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FunctionCalls.cpp
@@ -1,4 +1,4 @@
-//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===//
+//===- LinalgToFunctionCalls.cpp - Linalg to function calls conversion ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,19 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
-
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Pass/Pass.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
-#include "mlir/Conversion/Passes.h.inc"
+#define GEN_PASS_DEF_CONVERTLINALGTOFUNCTIONCALLSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
 } // namespace mlir
 
 using namespace mlir;
@@ -123,8 +122,7 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
   return success();
 }
 
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void mlir::linalg::populateLinalgToStandardConversionPatterns(
+void mlir::linalg::populateLinalgToFunctionCallsConversionPatterns(
     RewritePatternSet &patterns) {
   // TODO: ConvOp conversion needs to export a descriptor with relevant
   // attribute values such as kernel striding and dilation.
@@ -132,13 +130,14 @@ void mlir::linalg::populateLinalgToStandardConversionPatterns(
 }
 
 namespace {
-struct ConvertLinalgToStandardPass
-    : public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
+struct ConvertLinalgToFunctionCallsPass
+    : public impl::ConvertLinalgToFunctionCallsPassBase<
+          ConvertLinalgToFunctionCallsPass> {
   void runOnOperation() override;
 };
 } // namespace
 
-void ConvertLinalgToStandardPass::runOnOperation() {
+void ConvertLinalgToFunctionCallsPass::runOnOperation() {
   auto module = getOperation();
   ConversionTarget target(getContext());
   target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
@@ -146,12 +145,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
                          scf::SCFDialect>();
   target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
   RewritePatternSet patterns(&getContext());
-  populateLinalgToStandardConversionPatterns(patterns);
+  populateLinalgToFunctionCallsConversionPatterns(patterns);
   if (failed(applyFullConversion(module, target, std::move(patterns))))
     signalPassFailure();
 }
-
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertLinalgToStandardPass() {
-  return std::make_unique<ConvertLinalgToStandardPass>();
-}
diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/function-calls.mlir
similarity index 61%
rename from mlir/test/Dialect/Linalg/library-calls.mlir
rename to mlir/test/Dialect/Linalg/function-calls.mlir
index 1fa675d8b4b68a..103fcb16c51732 100644
--- a/mlir/test/Dialect/Linalg/library-calls.mlir
+++ b/mlir/test/Dialect/Linalg/function-calls.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-linalg-to-std -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-linalg-to-function-calls -split-input-file --verify-diagnostics | FileCheck %s
 
 func.func private @printMemrefF32(memref<*xf32>)
 
@@ -99,3 +99,85 @@ func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8
                               ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>)
   return
 }
+
+// -----
+
+func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
+          %arg1: memref<?xf32, strided<[1], offset: ?>>,
+          %arg2: memref<f32>) {
+  linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
+                               memref<?xf32, strided<[1], offset: ?>>)
+             outs(%arg2: memref<f32>)
+  return
+}
+// CHECK-LABEL: func @dot(
+//  CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+//  CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
+//  CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
+//       CHECK:   %[[o0:.*]] = memref.cast %[[arg0]] :
+//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   %[[o1:.*]] = memref.cast %[[arg1]] :
+//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
+//       CHECK:   %[[o2:.*]] = memref.cast %[[arg2]] :
+//  CHECK-SAME:     memref<f32> to memref<f32, strided<[], offset: ?>>
+//       CHECK:   call @linalg_dot_viewsxf32_viewsxf32_viewf32(
+//  CHECK-SAME:     %[[o0]], %[[o1]], %[[o2]]) :
+//  CHECK-SAME:   memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
+
+// -----
+
+#matmul_accesses = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmul_trait = {
+  iterator_types = ["parallel", "parallel", "reduction"],
+  indexing_maps = #matmul_accesses,
+  library_call = "external_outerproduct_matmul"
+}
+
+!vector_type_A = vector<4xf32>
+!vector_type_B = vector<4xf32>
+!vector_type_C = vector<4x4xf32>
+
+!matrix_type_A = memref<?x?x!vector_type_A>
+!matrix_type_B = memref<?x?x!vector_type_B>
+!matrix_type_C = memref<?x?x!vector_type_C>
+
+func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
+  linalg.generic #matmul_trait
+      ins(%A, %B : !matrix_type_A, !matrix_type_B)
+     outs(%C : !matrix_type_C) {
+    ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
+      %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
+      linalg.yield %d: !vector_type_C
+  }
+  return
+}
+// CHECK-LABEL: func @matmul_vec_impl(
+// CHECK:  call @external_outerproduct_matmul(%{{.*}}) :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>)  {
+  // expected-error @below {{failed to legalize}}
+  %0 = linalg.generic {
+    indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
+  ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
+  ^bb0(%in: f32, %out: f32): 
+    linalg.yield %in : f32
+  } -> tensor<?xf32>
+  return 
+}
+
+// -----
+
+func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error @below {{failed to legalize}}
+  %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
deleted file mode 100644
index f50016f9ea477f..00000000000000
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ /dev/null
@@ -1,81 +0,0 @@
-// RUN: mlir-opt %s -convert-linalg-to-std --split-input-file -verify-diagnostics | FileCheck %s
-
-func.func @dot(%arg0: memref<?xf32, strided<[1], offset: ?>>,
-          %arg1: memref<?xf32, strided<[1], offset: ?>>,
-          %arg2: memref<f32>) {
-  linalg.dot ins(%arg0, %arg1: memref<?xf32, strided<[1], offset: ?>>,
-                               memref<?xf32, strided<[1], offset: ?>>)
-             outs(%arg2: memref<f32>)
-  return
-}
-// CHECK-LABEL: func @dot(
-//  CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-//  CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref<?xf32, strided<[1], offset: ?>>,
-//  CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref<f32>) {
-//       CHECK:   %[[o0:.*]] = memref.cast %[[arg0]] :
-//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-//       CHECK:   %[[o1:.*]] = memref.cast %[[arg1]] :
-//  CHECK-SAME:     memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
-//       CHECK:   %[[o2:.*]] = memref.cast %[[arg2]] :
-//  CHECK-SAME:     memref<f32> to memref<f32, strided<[], offset: ?>>
-//       CHECK:   call @linalg_dot_viewsxf32_viewsxf32_viewf32(
-//  CHECK-SAME:     %[[o0]], %[[o1]], %[[o2]]) :
-//  CHECK-SAME:   memref<?xf32, strided<[?], offset: ?>>, memref<?xf32, strided<[?], offset: ?>>, memref<f32, strided<[], offset: ?>>
-
-// -----
-
-#matmul_accesses = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmul_trait = {
-  iterator_types = ["parallel", "parallel", "reduction"],
-  indexing_maps = #matmul_accesses,
-  library_call = "external_outerproduct_matmul"
-}
-
-!vector_type_A = vector<4xf32>
-!vector_type_B = vector<4xf32>
-!vector_type_C = vector<4x4xf32>
-
-!matrix_type_A = memref<?x?x!vector_type_A>
-!matrix_type_B = memref<?x?x!vector_type_B>
-!matrix_type_C = memref<?x?x!vector_type_C>
-
-func.func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
-  linalg.generic #matmul_trait
-      ins(%A, %B : !matrix_type_A, !matrix_type_B)
-     outs(%C : !matrix_type_C) {
-    ^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
-      %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
-      linalg.yield %d: !vector_type_C
-  }
-  return
-}
-// CHECK-LABEL: func @matmul_vec_impl(
-// CHECK:  call @external_outerproduct_matmul(%{{.*}}) :
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-#map1 = affine_map<(d0, d1) -> (d0)>
-
-func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>)  {
-  // expected-error @below {{failed to legalize}}
-  %0 = linalg.generic {
-    indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]}
-  ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?xf32>) {
-  ^bb0(%in: f32, %out: f32): 
-    linalg.yield %in : f32
-  } -> tensor<?xf32>
-  return 
-}
-
-// -----
-
-func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // expected-error @below {{failed to legalize}}
-  %0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0 : tensor<4x8xf32>
-}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants