@@ -52,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
52
52
53
53
// Max pooling
54
54
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
55
- AtenMaxPool2dWithIndicesOp>(op)) {
55
+ AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) {
56
56
if (isa<mlir::FloatType>(elementTy)) {
57
57
auto constAttr = DenseElementsAttr::get (
58
58
constType,
@@ -73,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
73
73
return nullptr ;
74
74
}
75
75
76
+ // AtenMaxPool1dWithIndicesOp
77
+ template <>
78
+ LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
79
+ AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor,
80
+ ConversionPatternRewriter &rewriter) const {
81
+ Value input = adaptor.getSelf ();
82
+ auto inputTy = cast<RankedTensorType>(input.getType ());
83
+ auto inputElemTy = inputTy.getElementType ();
84
+ auto inputShape = inputTy.getShape ();
85
+ auto inputRank = inputTy.getRank ();
86
+
87
+ auto outValTy =
88
+ cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType (0 )));
89
+ auto outIdxTy =
90
+ cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType (1 )));
91
+
92
+ if (inputRank <= 1 ) {
93
+ return op.emitError (
94
+ " max_pooling1d only supports inputs with rank higher than 1" );
95
+ }
96
+
97
+ SmallVector<int64_t , 1 > padding, kernelSize, stride, dilation;
98
+ bool ceilMode = false ;
99
+
100
+ if (!(matchPattern (op.getKernelSize (),
101
+ m_TorchListOfConstantInts (kernelSize)))) {
102
+ return rewriter.notifyMatchFailure (
103
+ op, " non-const int kernel size unsupported!" );
104
+ }
105
+ if (!(matchPattern (op.getStride (), m_TorchListOfConstantInts (stride)))) {
106
+ return rewriter.notifyMatchFailure (op, " non-const int stride unsupported!" );
107
+ }
108
+ if (!(matchPattern (op.getPadding (), m_TorchListOfConstantInts (padding)))) {
109
+ return rewriter.notifyMatchFailure (op,
110
+ " non-const int padding unsupported!" );
111
+ }
112
+ if (!(matchPattern (op.getDilation (), m_TorchListOfConstantInts (dilation)))) {
113
+ return rewriter.notifyMatchFailure (op,
114
+ " non-const int dilation unsupported!" );
115
+ }
116
+ if (!(matchPattern (op.getCeilMode (), m_TorchConstantBool (&ceilMode)))) {
117
+ return rewriter.notifyMatchFailure (op,
118
+ " non-const bool ceil_mode unsupported!" );
119
+ }
120
+
121
+ SmallVector<int64_t > stablehloStride (inputRank, 1 );
122
+ SmallVector<int64_t > stablehloDilation (inputRank, 1 );
123
+ SmallVector<int64_t > stablehloKernelSize (inputRank, 1 );
124
+ SmallVector<int64_t > stablehloPadding (inputRank * 2 , 0 );
125
+
126
+ std::copy (stride.begin (), stride.end (),
127
+ stablehloStride.begin () + inputRank - 1 );
128
+ std::copy (dilation.begin (), dilation.end (),
129
+ stablehloDilation.begin () + inputRank - 1 );
130
+ std::copy (kernelSize.begin (), kernelSize.end (),
131
+ stablehloKernelSize.begin () + inputRank - 1 );
132
+ stablehloPadding[stablehloPadding.size () - 1 ] = padding[0 ];
133
+ stablehloPadding[stablehloPadding.size () - 2 ] = padding[0 ];
134
+
135
+ Value initVal = createInitialValueForAtenPoolingOp (op, inputElemTy, rewriter);
136
+
137
+ auto windowDimensions = rewriter.getDenseI64ArrayAttr (stablehloKernelSize);
138
+ auto windowStrides = rewriter.getDenseI64ArrayAttr (stablehloStride);
139
+ auto windowDilations = rewriter.getDenseI64ArrayAttr (stablehloDilation);
140
+ DenseIntElementsAttr pad = DenseIntElementsAttr::get (
141
+ RankedTensorType::get (
142
+ {static_cast <int64_t >(inputRank), static_cast <int64_t >(2 )},
143
+ rewriter.getI64Type ()),
144
+ stablehloPadding);
145
+ DenseI64ArrayAttr baseDilations;
146
+
147
+ auto inputShapeInfo = hlo::getDimIndexOfTensor (rewriter, op, input);
148
+ if (failed (inputShapeInfo)) {
149
+ return rewriter.notifyMatchFailure (
150
+ op, " failed to get dimension sizes of the input" );
151
+ }
152
+ auto inputShapeVec = *inputShapeInfo;
153
+ auto inputShapeTensor = rewriter.create <mlir::tensor::FromElementsOp>(
154
+ op->getLoc (), inputShapeVec);
155
+
156
+ // no need to reshape here for max_pool_1d. Need to make sure the iota
157
+ // dimension. dim=inputRank-2 or dim=inputRank-1?
158
+ auto indexTensor =
159
+ rewriter
160
+ .create <stablehlo::DynamicIotaOp>(
161
+ op->getLoc (),
162
+ RankedTensorType::get (inputShape, rewriter.getI64Type ()),
163
+ inputShapeTensor, static_cast <uint64_t >(inputRank - 1 ))
164
+ .getResult ();
165
+ Value initIdx = hlo::getConstTensor<int64_t >(rewriter, op, {0 }, {}).value ();
166
+
167
+ auto reduceWindowOp = rewriter.create <stablehlo::ReduceWindowOp>(
168
+ op->getLoc (), mlir::TypeRange{outValTy, outIdxTy},
169
+ mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
170
+ windowDimensions, windowStrides, baseDilations, windowDilations, pad);
171
+
172
+ // add block.
173
+ Block &block = reduceWindowOp.getBody ().emplaceBlock ();
174
+ auto blockValArgumentType = RankedTensorType::get ({}, inputElemTy);
175
+ auto blockIdxArgumentType = RankedTensorType::get ({}, rewriter.getI64Type ());
176
+ auto compareResultType = RankedTensorType::get ({}, rewriter.getI1Type ());
177
+ block.addArgument (blockValArgumentType, op->getLoc ());
178
+ block.addArgument (blockIdxArgumentType, op->getLoc ());
179
+ block.addArgument (blockValArgumentType, op->getLoc ());
180
+ block.addArgument (blockIdxArgumentType, op->getLoc ());
181
+ auto *firstValArg = block.args_begin ();
182
+ auto *firstIdxArg = std::next (firstValArg);
183
+ auto *secondValArg = std::next (firstIdxArg);
184
+ auto *secondIdxArg = std::next (secondValArg);
185
+
186
+ stablehlo::ComparisonTypeAttr compareTypeAttr;
187
+ if (isa<mlir::FloatType>(inputTy.getElementType ())) {
188
+ compareTypeAttr = stablehlo::ComparisonTypeAttr::get (
189
+ rewriter.getContext (), stablehlo::ComparisonType::FLOAT);
190
+ } else if (isa<mlir::IntegerType>(inputTy.getElementType ())) {
191
+ compareTypeAttr = stablehlo::ComparisonTypeAttr::get (
192
+ rewriter.getContext (), stablehlo::ComparisonType::SIGNED);
193
+ }
194
+
195
+ stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
196
+ stablehlo::ComparisonDirectionAttr::get (
197
+ rewriter.getContext (), stablehlo::ComparisonDirection::GE);
198
+ stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
199
+ stablehlo::ComparisonDirectionAttr::get (
200
+ rewriter.getContext (), stablehlo::ComparisonDirection::EQ);
201
+
202
+ {
203
+ OpBuilder::InsertionGuard guard (rewriter);
204
+ rewriter.setInsertionPointToStart (&block);
205
+
206
+ Value compareGeResult = rewriter.create <stablehlo::CompareOp>(
207
+ op->getLoc (), compareResultType, *firstValArg, *secondValArg,
208
+ compareGeDirectionAttr, compareTypeAttr);
209
+ Value retValResult = rewriter.create <stablehlo::SelectOp>(
210
+ op->getLoc (), compareGeResult, *firstValArg, *secondValArg);
211
+
212
+ // Get smaller index if compared values are equal.
213
+ Value compareEqResult = rewriter.create <stablehlo::CompareOp>(
214
+ op->getLoc (), compareResultType, *firstValArg, *secondValArg,
215
+ compareEqDirectionAttr, compareTypeAttr);
216
+ Value minIdx = rewriter.create <stablehlo::MinOp>(op->getLoc (), *firstIdxArg,
217
+ *secondIdxArg);
218
+ Value idxWithGeVal = rewriter.create <stablehlo::SelectOp>(
219
+ op->getLoc (), compareGeResult, *firstIdxArg, *secondIdxArg);
220
+ Value retIdxResult = rewriter.create <stablehlo::SelectOp>(
221
+ op->getLoc (), compareEqResult, minIdx, idxWithGeVal);
222
+
223
+ rewriter.create <stablehlo::ReturnOp>(
224
+ op->getLoc (), mlir::ValueRange{retValResult, retIdxResult});
225
+ }
226
+
227
+ rewriter.replaceOp (op, reduceWindowOp.getResults ());
228
+ return success ();
229
+ }
230
+
76
231
// AtenMaxPool2dWithIndicesOp
77
232
template <>
78
233
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
@@ -657,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
657
812
#define INSERT_ATEN_POOLING_PATTERN (AtenOp ) \
658
813
target.addIllegalOp <AtenOp>(); \
659
814
patterns.add <ConvertAtenOp<AtenOp>>(typeConverter, context, options)
815
+ INSERT_ATEN_POOLING_PATTERN (AtenMaxPool1dWithIndicesOp);
660
816
INSERT_ATEN_POOLING_PATTERN (AtenMaxPool2dWithIndicesOp);
661
817
INSERT_ATEN_POOLING_PATTERN (AtenCumsumOp);
662
818
#undef INSERT_ATEN_POOLING_PATTERN
0 commit comments