Skip to content

[mlir][bufferization] Use Type instead of Value in unknown conversion #144658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ struct BufferizationOptions {
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
/// Parameters: tensor type, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a bonus: this is now aligned with FunctionArgTypeConverterFn

Value, Attribute memorySpace, const BufferizationOptions &)>;
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
Expand Down Expand Up @@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}

/// Return a MemRefType to which the type of the given value can be bufferized.
/// Return a MemRefType to which the TensorType can be bufferized.
///
/// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves.
Expand All @@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
/// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
BaseMemRefType getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = nullptr);

Expand Down
19 changes: 9 additions & 10 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
}
/// Default unknown type converter: Use a fully dynamic layout map.
BaseMemRefType
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(
llvm::cast<TensorType>(value.getType()), memorySpace);
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
}

} // namespace
Expand Down Expand Up @@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
if (!memSpace.has_value())
return op->emitError("could not infer memory space");

return getMemRefType(value, options, /*layout=*/{}, *memSpace);
return getMemRefType(cast<TensorType>(value.getType()), options,
/*layout=*/{}, *memSpace);
}

bool bufferization::hasTensorSemantics(Operation *op) {
Expand Down Expand Up @@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
// Bufferization-specific IRMapping support with debugging.
//===----------------------------------------------------------------------===//

BaseMemRefType bufferization::getMemRefType(Value value,
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
auto tensorType = llvm::cast<TensorType>(value.getType());

// Case 1: Unranked memref type.
if (auto unrankedTensorType =
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
Expand All @@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
memorySpace);
}

return options.unknownTypeConverterFn(value, memorySpace, options);
return options.unknownTypeConverterFn(tensorType, memorySpace, options);
}

BaseMemRefType
Expand Down Expand Up @@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
auto tensorType = cast<TensorType>(value.getType());

// No further analysis is possible for a block argument.
if (llvm::isa<BlockArgument>(value))
return bufferization::getMemRefType(value, options);
return bufferization::getMemRefType(tensorType, options);

// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
Expand All @@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");

return getMemRefType(value, options, /*layout=*/{}, *memSpace);
return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
}

bool bufferization::detail::defaultIsRepetitiveRegion(
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ struct OneShotBufferizePass
"'unknown-type-conversion'");
return signalPassFailure();
}
opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
opt.unknownTypeConverterFn = [=](TensorType tensorType,
Attribute memorySpace,
const BufferizationOptions &options) {
auto tensorType = cast<TensorType>(value.getType());
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
OneShotBufferizationOptions options;
options.bufferizeFunctionBoundaries = true;
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
options.unknownTypeConverterFn = [](TensorType tensorType,
Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
cast<TensorType>(value.getType()), memorySpace);
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
};
if (analysisOnly) {
options.testAnalysisOnly = true;
Expand Down
Loading