-
Notifications
You must be signed in to change notification settings - Fork 14k
[mlir][bufferization] Use Type instead of Value in unknown conversion #144658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Use Type instead of Value in unknown conversion #144658
Conversation
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir-sparse Author: Andrei Golubev (andrey-golubev) ChangesGenerally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function. Both changes are prerequisites to enable custom types support in one-shot bufferization. Full diff: https://github.com/llvm/llvm-project/pull/144658.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..2fb795f16ae2c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -265,9 +265,9 @@ struct BufferizationOptions {
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
- /// Parameters: Value, memory space, bufferization options
+ /// Parameters: tensor type, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
- Value, Attribute memorySpace, const BufferizationOptions &)>;
+ TensorType, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}
-/// Return a MemRefType to which the type of the given value can be bufferized.
+/// Return a MemRefType to which the TensorType can be bufferized.
///
/// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves.
@@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
/// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented.
-BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
+BaseMemRefType getMemRefType(TensorType tensorType,
+ const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = nullptr);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..dd43647682ea2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
}
/// Default unknown type converter: Use a fully dynamic layout map.
BaseMemRefType
-defaultUnknownTypeConverter(Value value, Attribute memorySpace,
+defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(
- llvm::cast<TensorType>(value.getType()), memorySpace);
+ return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
}
} // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ return getMemRefType(cast<TensorType>(value.getType()), options,
+ /*layout=*/{}, *memSpace);
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
// Bufferization-specific IRMapping support with debugging.
//===----------------------------------------------------------------------===//
-BaseMemRefType bufferization::getMemRefType(Value value,
+BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
- auto tensorType = llvm::cast<TensorType>(value.getType());
-
// Case 1: Unranked memref type.
if (auto unrankedTensorType =
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
memorySpace);
}
- return options.unknownTypeConverterFn(value, memorySpace, options);
+ return options.unknownTypeConverterFn(tensorType, memorySpace, options);
}
BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
+ auto tensorType = cast<TensorType>(value.getType());
// No further analysis is possible for a block argument.
if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(value, options);
+ return bufferization::getMemRefType(tensorType, options);
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7681d309a4af..7e9b9119ce949 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
"'unknown-type-conversion'");
return signalPassFailure();
}
- opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
+ opt.unknownTypeConverterFn = [=](TensorType tensorType,
+ Attribute memorySpace,
const BufferizationOptions &options) {
- auto tensorType = cast<TensorType>(value.getType());
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index a3ab53d818115..15e5102462ad7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
OneShotBufferizationOptions options;
options.bufferizeFunctionBoundaries = true;
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
- options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
+ options.unknownTypeConverterFn = [](TensorType tensorType,
+ Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithStaticIdentityLayout(
- cast<TensorType>(value.getType()), memorySpace);
+ return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
};
if (analysisOnly) {
options.testAnalysisOnly = true;
|
@llvm/pr-subscribers-mlir Author: Andrei Golubev (andrey-golubev) ChangesGenerally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function. Both changes are prerequisites to enable custom types support in one-shot bufferization. Full diff: https://github.com/llvm/llvm-project/pull/144658.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..2fb795f16ae2c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -265,9 +265,9 @@ struct BufferizationOptions {
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
- /// Parameters: Value, memory space, bufferization options
+ /// Parameters: tensor type, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
- Value, Attribute memorySpace, const BufferizationOptions &)>;
+ TensorType, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}
-/// Return a MemRefType to which the type of the given value can be bufferized.
+/// Return a MemRefType to which the TensorType can be bufferized.
///
/// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves.
@@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
/// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented.
-BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
+BaseMemRefType getMemRefType(TensorType tensorType,
+ const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = nullptr);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..dd43647682ea2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
}
/// Default unknown type converter: Use a fully dynamic layout map.
BaseMemRefType
-defaultUnknownTypeConverter(Value value, Attribute memorySpace,
+defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(
- llvm::cast<TensorType>(value.getType()), memorySpace);
+ return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
}
} // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ return getMemRefType(cast<TensorType>(value.getType()), options,
+ /*layout=*/{}, *memSpace);
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
// Bufferization-specific IRMapping support with debugging.
//===----------------------------------------------------------------------===//
-BaseMemRefType bufferization::getMemRefType(Value value,
+BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
- auto tensorType = llvm::cast<TensorType>(value.getType());
-
// Case 1: Unranked memref type.
if (auto unrankedTensorType =
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
memorySpace);
}
- return options.unknownTypeConverterFn(value, memorySpace, options);
+ return options.unknownTypeConverterFn(tensorType, memorySpace, options);
}
BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
+ auto tensorType = cast<TensorType>(value.getType());
// No further analysis is possible for a block argument.
if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(value, options);
+ return bufferization::getMemRefType(tensorType, options);
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7681d309a4af..7e9b9119ce949 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
"'unknown-type-conversion'");
return signalPassFailure();
}
- opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
+ opt.unknownTypeConverterFn = [=](TensorType tensorType,
+ Attribute memorySpace,
const BufferizationOptions &options) {
- auto tensorType = cast<TensorType>(value.getType());
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index a3ab53d818115..15e5102462ad7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
OneShotBufferizationOptions options;
options.bufferizeFunctionBoundaries = true;
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
- options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
+ options.unknownTypeConverterFn = [](TensorType tensorType,
+ Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithStaticIdentityLayout(
- cast<TensorType>(value.getType()), memorySpace);
+ return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
};
if (analysisOnly) {
options.testAnalysisOnly = true;
|
@@ -265,9 +265,9 @@ struct BufferizationOptions { | |||
std::function<BaseMemRefType(TensorType, Attribute memorySpace, | |||
func::FuncOp, const BufferizationOptions &)>; | |||
/// Tensor -> MemRef type converter. | |||
/// Parameters: Value, memory space, bufferization options | |||
/// Parameters: tensor type, memory space, bufferization options | |||
using UnknownTypeConverterFn = std::function<BaseMemRefType( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a bonus: this is now aligned with FunctionArgTypeConverterFn
Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function. Both changes are prerequisites to enable custom types support in one-shot bufferization.
@matthias-springer I guess if this looks reasonable, could you merge as well? (I don't have the rights). |
…llvm#144658) Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function. Both changes are prerequisites to enable custom types support in one-shot bufferization.
Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function.
Both changes are prerequisites to enable custom types support in one-shot bufferization.