Skip to content

Commit

Permalink
[torch] Add integer support for pooling operations (llvm#3610)
Browse files Browse the repository at this point in the history
If we pass an integer type to the pooling operation we incorrectly pad
with an integer value with causes downstream compilation failures.
  • Loading branch information
rsuderman authored Aug 8, 2024
1 parent 7f2a17e commit 4350672
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,18 +361,29 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {

Type elementType = cast<RankedTensorType>(self.getType()).getElementType();

TypedAttr smallestValueAttr;

if (auto fpty = dyn_cast<mlir::FloatType>(elementType)) {
smallestValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(fpty.getFloatSemantics(), /*Negative=*/true));
} else if (auto intTy = dyn_cast<mlir::IntegerType>(elementType)) {
int64_t bw = intTy.getIntOrFloatBitWidth();
smallestValueAttr = rewriter.getIntegerAttr(
elementType, intTy.isUnsigned() ? APInt::getMinValue(bw)
: APInt::getSignedMinValue(bw));
}

if (!smallestValueAttr)
return rewriter.notifyMatchFailure(op, "invalid element type");

if constexpr (Dim == 1) {
SmallVector<Value, 4> outTensorShape;
Value maxPool1d, paddedInput;
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNcwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
/*dimensionality=*/1, kernelSizeIntValues, strideInts,
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
paddedInput, maxPool1d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d");
Type newResultType = this->getTypeConverter()->convertType(op.getType());
Expand All @@ -382,15 +393,10 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
SmallVector<Value, 4> outTensorShape;
// `maxpool2d` contains the result of maxpool2d operation over the input.
Value maxPool2d, paddedInput;
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
elementType,
APFloat::getInf(
cast<mlir::FloatType>(elementType).getFloatSemantics(),
/*Negative=*/true));
if (failed(createPoolingOp<linalg::PoolingNchwMaxOp>(
op, rewriter, self, /*supportNonFPInput=*/true, ceilMode,
/*dimensionality=*/2, kernelSizeIntValues, strideInts,
paddingInts, dilationInts, smallestFPValueAttr, outTensorShape,
paddingInts, dilationInts, smallestValueAttr, outTensorShape,
paddedInput, maxPool2d)))
return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d");
Type newResultType = this->getTypeConverter()->convertType(op.getType());
Expand Down

0 comments on commit 4350672

Please sign in to comment.