Skip to content

[mlir][bufferization] Use Type instead of Value in unknown conversion #144658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

andrey-golubev
Copy link
Contributor

Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function.

Both changes are prerequisites to enable custom types support in one-shot bufferization.

@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2025

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir-sparse

Author: Andrei Golubev (andrey-golubev)

Changes

Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function.

Both changes are prerequisites to enable custom types support in one-shot bufferization.


Full diff: https://github.com/llvm/llvm-project/pull/144658.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+5-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-10)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+3-3)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..2fb795f16ae2c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -265,9 +265,9 @@ struct BufferizationOptions {
       std::function<BaseMemRefType(TensorType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
   /// Tensor -> MemRef type converter.
-  /// Parameters: Value, memory space, bufferization options
+  /// Parameters: tensor type, memory space, bufferization options
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
-      Value, Attribute memorySpace, const BufferizationOptions &)>;
+      TensorType, Attribute memorySpace, const BufferizationOptions &)>;
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
       std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
   return newOp;
 }
 
-/// Return a MemRefType to which the type of the given value can be bufferized.
+/// Return a MemRefType to which the TensorType can be bufferized.
 ///
 /// If possible, op bufferization implementations should not use this function
 /// and instead infer precise memref types for tensor results by themselves.
@@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
 /// Note: Canonicalization patterns could clean up layout maps and infer more
 /// precise layout maps after bufferization. However, many possible
 /// canonicalizations are currently not implemented.
-BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
+BaseMemRefType getMemRefType(TensorType tensorType,
+                             const BufferizationOptions &options,
                              MemRefLayoutAttrInterface layout = {},
                              Attribute memorySpace = nullptr);
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..dd43647682ea2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
 }
 /// Default unknown type converter: Use a fully dynamic layout map.
 BaseMemRefType
-defaultUnknownTypeConverter(Value value, Attribute memorySpace,
+defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
                             const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(
-      llvm::cast<TensorType>(value.getType()), memorySpace);
+  return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
 }
 
 } // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  return getMemRefType(cast<TensorType>(value.getType()), options,
+                       /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
 // Bufferization-specific IRMapping support with debugging.
 //===----------------------------------------------------------------------===//
 
-BaseMemRefType bufferization::getMemRefType(Value value,
+BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
                                             const BufferizationOptions &options,
                                             MemRefLayoutAttrInterface layout,
                                             Attribute memorySpace) {
-  auto tensorType = llvm::cast<TensorType>(value.getType());
-
   // Case 1: Unranked memref type.
   if (auto unrankedTensorType =
           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
                            memorySpace);
   }
 
-  return options.unknownTypeConverterFn(value, memorySpace, options);
+  return options.unknownTypeConverterFn(tensorType, memorySpace, options);
 }
 
 BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
+  auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
   if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(value, options);
+    return bufferization::getMemRefType(tensorType, options);
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7681d309a4af..7e9b9119ce949 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
                   "'unknown-type-conversion'");
         return signalPassFailure();
       }
-      opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
+      opt.unknownTypeConverterFn = [=](TensorType tensorType,
+                                       Attribute memorySpace,
                                        const BufferizationOptions &options) {
-        auto tensorType = cast<TensorType>(value.getType());
         if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
           return bufferization::getMemRefTypeWithStaticIdentityLayout(
               tensorType, memorySpace);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index a3ab53d818115..15e5102462ad7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
   OneShotBufferizationOptions options;
   options.bufferizeFunctionBoundaries = true;
   options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
-  options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
+  options.unknownTypeConverterFn = [](TensorType tensorType,
+                                      Attribute memorySpace,
                                       const BufferizationOptions &options) {
-    return getMemRefTypeWithStaticIdentityLayout(
-        cast<TensorType>(value.getType()), memorySpace);
+    return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
   };
   if (analysisOnly) {
     options.testAnalysisOnly = true;

@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2025

@llvm/pr-subscribers-mlir

Author: Andrei Golubev (andrey-golubev)

Changes

Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function.

Both changes are prerequisites to enable custom types support in one-shot bufferization.


Full diff: https://github.com/llvm/llvm-project/pull/144658.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+5-4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-10)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+3-3)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..2fb795f16ae2c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -265,9 +265,9 @@ struct BufferizationOptions {
       std::function<BaseMemRefType(TensorType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
   /// Tensor -> MemRef type converter.
-  /// Parameters: Value, memory space, bufferization options
+  /// Parameters: tensor type, memory space, bufferization options
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
-      Value, Attribute memorySpace, const BufferizationOptions &)>;
+      TensorType, Attribute memorySpace, const BufferizationOptions &)>;
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
       std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
   return newOp;
 }
 
-/// Return a MemRefType to which the type of the given value can be bufferized.
+/// Return a MemRefType to which the TensorType can be bufferized.
 ///
 /// If possible, op bufferization implementations should not use this function
 /// and instead infer precise memref types for tensor results by themselves.
@@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
 /// Note: Canonicalization patterns could clean up layout maps and infer more
 /// precise layout maps after bufferization. However, many possible
 /// canonicalizations are currently not implemented.
-BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
+BaseMemRefType getMemRefType(TensorType tensorType,
+                             const BufferizationOptions &options,
                              MemRefLayoutAttrInterface layout = {},
                              Attribute memorySpace = nullptr);
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..dd43647682ea2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
 }
 /// Default unknown type converter: Use a fully dynamic layout map.
 BaseMemRefType
-defaultUnknownTypeConverter(Value value, Attribute memorySpace,
+defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
                             const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(
-      llvm::cast<TensorType>(value.getType()), memorySpace);
+  return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
 }
 
 } // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  return getMemRefType(cast<TensorType>(value.getType()), options,
+                       /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
 // Bufferization-specific IRMapping support with debugging.
 //===----------------------------------------------------------------------===//
 
-BaseMemRefType bufferization::getMemRefType(Value value,
+BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
                                             const BufferizationOptions &options,
                                             MemRefLayoutAttrInterface layout,
                                             Attribute memorySpace) {
-  auto tensorType = llvm::cast<TensorType>(value.getType());
-
   // Case 1: Unranked memref type.
   if (auto unrankedTensorType =
           llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
                            memorySpace);
   }
 
-  return options.unknownTypeConverterFn(value, memorySpace, options);
+  return options.unknownTypeConverterFn(tensorType, memorySpace, options);
 }
 
 BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
   assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
+  auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
   if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(value, options);
+    return bufferization::getMemRefType(tensorType, options);
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7681d309a4af..7e9b9119ce949 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
                   "'unknown-type-conversion'");
         return signalPassFailure();
       }
-      opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
+      opt.unknownTypeConverterFn = [=](TensorType tensorType,
+                                       Attribute memorySpace,
                                        const BufferizationOptions &options) {
-        auto tensorType = cast<TensorType>(value.getType());
         if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
           return bufferization::getMemRefTypeWithStaticIdentityLayout(
               tensorType, memorySpace);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index a3ab53d818115..15e5102462ad7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
   OneShotBufferizationOptions options;
   options.bufferizeFunctionBoundaries = true;
   options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
-  options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
+  options.unknownTypeConverterFn = [](TensorType tensorType,
+                                      Attribute memorySpace,
                                       const BufferizationOptions &options) {
-    return getMemRefTypeWithStaticIdentityLayout(
-        cast<TensorType>(value.getType()), memorySpace);
+    return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
   };
   if (analysisOnly) {
     options.testAnalysisOnly = true;

@@ -265,9 +265,9 @@ struct BufferizationOptions {
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
/// Parameters: tensor type, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a bonus: this is now aligned with FunctionArgTypeConverterFn

Generally, bufferization should be able to create a memref from a tensor
without needing to know more than just a mlir::Type. Thus, change
BufferizationOptions::UnknownTypeConverterFn to accept just a type
(mlir::TensorType for now) instead of mlir::Value. Additionally, apply
the same rationale to getMemRefType() helper function.

Both changes are prerequisites to enable custom types support in
one-shot bufferization.
@andrey-golubev
Copy link
Contributor Author

@matthias-springer I guess if this looks reasonable, could you merge as well? (I don't have the rights).

@matthias-springer matthias-springer merged commit a1c2a71 into llvm:main Jun 18, 2025
11 checks passed
@andrey-golubev andrey-golubev deleted the bufferization_tensor_type branch June 18, 2025 12:45
fschlimb pushed a commit to fschlimb/llvm-project that referenced this pull request Jun 18, 2025
…llvm#144658)

Generally, bufferization should be able to create a memref from a tensor
without needing to know more than just a mlir::Type. Thus, change
BufferizationOptions::UnknownTypeConverterFn to accept just a type
(mlir::TensorType for now) instead of mlir::Value. Additionally, apply
the same rationale to getMemRefType() helper function.

Both changes are prerequisites to enable custom types support in
one-shot bufferization.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants