diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index adccbef754ec5..2fb795f16ae2c 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -265,9 +265,9 @@ struct BufferizationOptions { std::function; /// Tensor -> MemRef type converter. - /// Parameters: Value, memory space, bufferization options + /// Parameters: tensor type, memory space, bufferization options using UnknownTypeConverterFn = std::function; + TensorType, Attribute memorySpace, const BufferizationOptions &)>; // Produce a MemorySpace attribute from a tensor type using DefaultMemorySpaceFn = std::function(TensorType t)>; @@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op, return newOp; } -/// Return a MemRefType to which the type of the given value can be bufferized. +/// Return a MemRefType to which the TensorType can be bufferized. /// /// If possible, op bufferization implementations should not use this function /// and instead infer precise memref types for tensor results by themselves. @@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op, /// Note: Canonicalization patterns could clean up layout maps and infer more /// precise layout maps after bufferization. However, many possible /// canonicalizations are currently not implemented. -BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, +BaseMemRefType getMemRefType(TensorType tensorType, + const BufferizationOptions &options, MemRefLayoutAttrInterface layout = {}, Attribute memorySpace = nullptr); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 1d6e1bdaf80f5..dd43647682ea2 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, } /// Default unknown type converter: Use a fully dynamic layout map. BaseMemRefType -defaultUnknownTypeConverter(Value value, Attribute memorySpace, +defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout( - llvm::cast(value.getType()), memorySpace); + return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace); } } // namespace @@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, *memSpace); + return getMemRefType(cast(value.getType()), options, + /*layout=*/{}, *memSpace); } bool bufferization::hasTensorSemantics(Operation *op) { @@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, // Bufferization-specific IRMapping support with debugging. //===----------------------------------------------------------------------===// -BaseMemRefType bufferization::getMemRefType(Value value, +BaseMemRefType bufferization::getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout, Attribute memorySpace) { - auto tensorType = llvm::cast(value.getType()); - // Case 1: Unranked memref type. if (auto unrankedTensorType = llvm::dyn_cast(tensorType)) { @@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value, memorySpace); } - return options.unknownTypeConverterFn(value, memorySpace, options); + return options.unknownTypeConverterFn(tensorType, memorySpace, options); } BaseMemRefType @@ -955,10 +953,11 @@ FailureOr bufferization::detail::defaultGetBufferType( const BufferizationState &bufferizationState, SmallVector &invocationStack) { assert(llvm::isa(value.getType()) && "expected tensor type"); + auto tensorType = cast(value.getType()); // No further analysis is possible for a block argument. if (llvm::isa(value)) - return bufferization::getMemRefType(value, options); + return bufferization::getMemRefType(tensorType, options); // Value is an OpResult. Operation *op = getOwnerOfValue(value); @@ -981,7 +980,7 @@ FailureOr bufferization::detail::defaultGetBufferType( if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(value, options, /*layout=*/{}, *memSpace); + return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace); } bool bufferization::detail::defaultIsRepetitiveRegion( diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index c7681d309a4af..7e9b9119ce949 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -109,9 +109,9 @@ struct OneShotBufferizePass "'unknown-type-conversion'"); return signalPassFailure(); } - opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace, + opt.unknownTypeConverterFn = [=](TensorType tensorType, + Attribute memorySpace, const BufferizationOptions &options) { - auto tensorType = cast(value.getType()); if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap) return bufferization::getMemRefTypeWithStaticIdentityLayout( tensorType, memorySpace); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index a3ab53d818115..15e5102462ad7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) { OneShotBufferizationOptions options; options.bufferizeFunctionBoundaries = true; options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap); - options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, + options.unknownTypeConverterFn = [](TensorType tensorType, + Attribute memorySpace, const BufferizationOptions &options) { - return getMemRefTypeWithStaticIdentityLayout( - cast(value.getType()), memorySpace); + return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); }; if (analysisOnly) { options.testAnalysisOnly = true;