Skip to content

Commit 335cf5f

Browse files
authored
[stablehlo] support aten_adaptive_max_pool1d lowering (#3728)
1 parent aa7e77e commit 335cf5f

File tree

6 files changed

+286
-1
lines changed

6 files changed

+286
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7078,6 +7078,35 @@ def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [
70787078
}];
70797079
}
70807080

7081+
def Torch_AtenMaxPool1dWithIndicesOp : Torch_Op<"aten.max_pool1d_with_indices", [
7082+
AllowsTypeRefinement,
7083+
HasValueSemantics,
7084+
ReadOnly
7085+
]> {
7086+
let summary = "Generated op for `aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`";
7087+
let arguments = (ins
7088+
AnyTorchTensorType:$self,
7089+
AnyTorchListOfTorchIntType:$kernel_size,
7090+
AnyTorchListOfTorchIntType:$stride,
7091+
AnyTorchListOfTorchIntType:$padding,
7092+
AnyTorchListOfTorchIntType:$dilation,
7093+
Torch_BoolType:$ceil_mode
7094+
);
7095+
let results = (outs
7096+
AnyTorchOptionalTensorType:$result0,
7097+
AnyTorchOptionalTensorType:$result1
7098+
);
7099+
let hasCustomAssemblyFormat = 1;
7100+
let extraClassDefinition = [{
7101+
ParseResult AtenMaxPool1dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
7102+
return parseDefaultTorchOp(parser, result, 6, 2);
7103+
}
7104+
void AtenMaxPool1dWithIndicesOp::print(OpAsmPrinter &printer) {
7105+
printDefaultTorchOp(printer, *this, 6, 2);
7106+
}
7107+
}];
7108+
}
7109+
70817110
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
70827111
AllowsTypeRefinement,
70837112
HasValueSemantics,

lib/Conversion/TorchToStablehlo/Pooling.cpp

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
5252

5353
// Max pooling
5454
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
55-
AtenMaxPool2dWithIndicesOp>(op)) {
55+
AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) {
5656
if (isa<mlir::FloatType>(elementTy)) {
5757
auto constAttr = DenseElementsAttr::get(
5858
constType,
@@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
7373
return nullptr;
7474
}
7575

76+
// AtenMaxPool1dWithIndicesOp
77+
template <>
78+
LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
79+
AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor,
80+
ConversionPatternRewriter &rewriter) const {
81+
Value input = adaptor.getSelf();
82+
auto inputTy = cast<RankedTensorType>(input.getType());
83+
auto inputElemTy = inputTy.getElementType();
84+
auto inputShape = inputTy.getShape();
85+
auto inputRank = inputTy.getRank();
86+
87+
auto outValTy =
88+
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
89+
auto outIdxTy =
90+
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
91+
92+
if (inputRank <= 1) {
93+
return op.emitError(
94+
"max_pooling1d only supports inputs with rank higher than 1");
95+
}
96+
97+
SmallVector<int64_t, 1> padding, kernelSize, stride, dilation;
98+
bool ceilMode = false;
99+
100+
if (!(matchPattern(op.getKernelSize(),
101+
m_TorchListOfConstantInts(kernelSize)))) {
102+
return rewriter.notifyMatchFailure(
103+
op, "non-const int kernel size unsupported!");
104+
}
105+
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
106+
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
107+
}
108+
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
109+
return rewriter.notifyMatchFailure(op,
110+
"non-const int padding unsupported!");
111+
}
112+
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
113+
return rewriter.notifyMatchFailure(op,
114+
"non-const int dilation unsupported!");
115+
}
116+
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
117+
return rewriter.notifyMatchFailure(op,
118+
"non-const bool ceil_mode unsupported!");
119+
}
120+
121+
SmallVector<int64_t> stablehloStride(inputRank, 1);
122+
SmallVector<int64_t> stablehloDilation(inputRank, 1);
123+
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
124+
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
125+
126+
std::copy(stride.begin(), stride.end(),
127+
stablehloStride.begin() + inputRank - 1);
128+
std::copy(dilation.begin(), dilation.end(),
129+
stablehloDilation.begin() + inputRank - 1);
130+
std::copy(kernelSize.begin(), kernelSize.end(),
131+
stablehloKernelSize.begin() + inputRank - 1);
132+
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
133+
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
134+
135+
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
136+
137+
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
138+
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
139+
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
140+
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
141+
RankedTensorType::get(
142+
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
143+
rewriter.getI64Type()),
144+
stablehloPadding);
145+
DenseI64ArrayAttr baseDilations;
146+
147+
auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input);
148+
if (failed(inputShapeInfo)) {
149+
return rewriter.notifyMatchFailure(
150+
op, "failed to get dimension sizes of the input");
151+
}
152+
auto inputShapeVec = *inputShapeInfo;
153+
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
154+
op->getLoc(), inputShapeVec);
155+
156+
// no need to reshape here for max_pool_1d. Need to make sure the iota
157+
// dimension. dim=inputRank-2 or dim=inputRank-1?
158+
auto indexTensor =
159+
rewriter
160+
.create<stablehlo::DynamicIotaOp>(
161+
op->getLoc(),
162+
RankedTensorType::get(inputShape, rewriter.getI64Type()),
163+
inputShapeTensor, static_cast<uint64_t>(inputRank - 1))
164+
.getResult();
165+
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
166+
167+
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
168+
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
169+
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
170+
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
171+
172+
// add block.
173+
Block &block = reduceWindowOp.getBody().emplaceBlock();
174+
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
175+
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
176+
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
177+
block.addArgument(blockValArgumentType, op->getLoc());
178+
block.addArgument(blockIdxArgumentType, op->getLoc());
179+
block.addArgument(blockValArgumentType, op->getLoc());
180+
block.addArgument(blockIdxArgumentType, op->getLoc());
181+
auto *firstValArg = block.args_begin();
182+
auto *firstIdxArg = std::next(firstValArg);
183+
auto *secondValArg = std::next(firstIdxArg);
184+
auto *secondIdxArg = std::next(secondValArg);
185+
186+
stablehlo::ComparisonTypeAttr compareTypeAttr;
187+
if (isa<mlir::FloatType>(inputTy.getElementType())) {
188+
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
189+
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
190+
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
191+
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
192+
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
193+
}
194+
195+
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
196+
stablehlo::ComparisonDirectionAttr::get(
197+
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
198+
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
199+
stablehlo::ComparisonDirectionAttr::get(
200+
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
201+
202+
{
203+
OpBuilder::InsertionGuard guard(rewriter);
204+
rewriter.setInsertionPointToStart(&block);
205+
206+
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
207+
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
208+
compareGeDirectionAttr, compareTypeAttr);
209+
Value retValResult = rewriter.create<stablehlo::SelectOp>(
210+
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
211+
212+
// Get smaller index if compared values are equal.
213+
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
214+
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
215+
compareEqDirectionAttr, compareTypeAttr);
216+
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
217+
*secondIdxArg);
218+
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
219+
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
220+
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
221+
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
222+
223+
rewriter.create<stablehlo::ReturnOp>(
224+
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
225+
}
226+
227+
rewriter.replaceOp(op, reduceWindowOp.getResults());
228+
return success();
229+
}
230+
76231
// AtenMaxPool2dWithIndicesOp
77232
template <>
78233
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
@@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
657812
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
658813
target.addIllegalOp<AtenOp>(); \
659814
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
815+
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool1dWithIndicesOp);
660816
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
661817
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
662818
#undef INSERT_ATEN_POOLING_PATTERN

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7298,6 +7298,85 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
72987298
};
72997299
} // namespace
73007300

7301+
namespace {
7302+
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
7303+
// op.
7304+
class DecomposeAtenAdaptiveMaxPool1dOp
7305+
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
7306+
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
7307+
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
7308+
PatternRewriter &rewriter) const override {
7309+
Location loc = op->getLoc();
7310+
MLIRContext *context = op.getContext();
7311+
7312+
Value input = op.getSelf();
7313+
std::optional<unsigned> maybeRank = getTensorRank(input);
7314+
if (!maybeRank) {
7315+
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
7316+
}
7317+
unsigned rank = *maybeRank;
7318+
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
7319+
loc, rewriter.getI64IntegerAttr(rank - 1));
7320+
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
7321+
7322+
Value outputShape = op.getOutputSize();
7323+
SmallVector<Value> outputShapeSizesTorchInt;
7324+
getListConstructElements(outputShape, outputShapeSizesTorchInt);
7325+
Value outputSize = outputShapeSizesTorchInt[0];
7326+
7327+
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
7328+
loc, rewriter.getI64IntegerAttr(1));
7329+
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
7330+
loc, rewriter.getI64IntegerAttr(0));
7331+
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
7332+
7333+
int64_t outputSizeInt;
7334+
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
7335+
return rewriter.notifyMatchFailure(
7336+
op, "the output size of adaptive_max_pool1d must be a constant int");
7337+
}
7338+
7339+
SmallVector<Value, 1> kernelSize;
7340+
if (outputSizeInt == 1) {
7341+
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
7342+
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
7343+
kernelSize.push_back(
7344+
inputShape[rank - 1] == kUnknownSize
7345+
? inputSize
7346+
: rewriter.create<Torch::ConstantIntOp>(
7347+
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
7348+
} else {
7349+
if (!isAssumingStrictSymbolicShapes(rewriter)) {
7350+
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
7351+
rewriter.create<RuntimeAssertOp>(
7352+
loc, cond,
7353+
"unimplemented: only support cases where input and output size are "
7354+
"equal for non-unit output size");
7355+
}
7356+
kernelSize.push_back(constantOne);
7357+
}
7358+
7359+
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
7360+
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
7361+
Value strideList = rewriter.create<PrimListConstructOp>(
7362+
loc, Torch::ListType::get(Torch::IntType::get(context)),
7363+
ValueRange{constantOne});
7364+
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
7365+
loc, Torch::ListType::get(Torch::IntType::get(context)),
7366+
ValueRange{constantZero});
7367+
Value dialationList = rewriter.create<PrimListConstructOp>(
7368+
loc, Torch::ListType::get(Torch::IntType::get(context)),
7369+
ValueRange{constantOne});
7370+
7371+
rewriter.replaceOpWithNewOp<AtenMaxPool1dWithIndicesOp>(
7372+
op, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
7373+
paddingSizeList, dialationList,
7374+
/*ceil_mode=*/constantFalse);
7375+
return success();
7376+
}
7377+
};
7378+
} // namespace
7379+
73017380
namespace {
73027381
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
73037382

@@ -9801,6 +9880,7 @@ class DecomposeComplexOpsPass
98019880
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
98029881
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
98039882
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
9883+
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
98049884
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
98059885
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
98069886
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@
588588
"AdaptiveAvgPool3dDynamic_basic",
589589
"AdaptiveMaxPool1dDynamicNoBatch_basic",
590590
"AdaptiveMaxPool1dDynamic_basic",
591+
"AdaptiveMaxPool1dDimOneStatic_basic",
591592
"AdaptiveMaxPool1dStatic_basic",
592593
"AdaptiveMaxPool2dDynamicNoBatch_basic",
593594
"AdaptiveMaxPool2dDynamicWithIndices_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,9 @@ def emit_with_mutating_variants(key, **kwargs):
617617
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
618618
)
619619
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
620+
emit(
621+
"aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
622+
)
620623
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
621624
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
622625
emit(

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,22 @@ def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils):
17831783
module.forward(tu.rand(1, 512, 10))
17841784

17851785

1786+
class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module):
1787+
def __init__(self):
1788+
super().__init__()
1789+
self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(1), return_indices=False)
1790+
1791+
@export
1792+
@annotate_args([None, ([1, 512, 7], torch.float32, True)])
1793+
def forward(self, x):
1794+
return self.amp1d(x)
1795+
1796+
1797+
@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDimOneStatic())
1798+
def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils):
1799+
module.forward(tu.rand(1, 512, 7))
1800+
1801+
17861802
# AdaptiveMaxPool2d
17871803

17881804

0 commit comments

Comments
 (0)