Skip to content

Commit abb9282

Browse files
zezhangqingyunqu
authored andcommitted
Add canonicalize pattern for aten.mul.int and aten.floordiv.int (#3680)
This PR add `floordiv` to the `PY_BUILTIN_TO_TORCH_OP`. For `aten.mul.int` and `aten.floordiv.int` ops, we add new Canonicalization Patterns as follow: ``` %1 = torch.aten.mul.int %input, %const-5 %2 = torch.aten.mul.int %1, %const-6 ``` Will be replaced by `torch.aten.mul.int %input, %const-30` And ``` %1 = torch.aten.mul.int %input, %const-5 %2 = torch.aten.floordiv.int %1, %const-5 ``` Will directly return `%input` This PR also relaxes the `float` type constraint in TorchToTosa for the `AtenRsubScalarOp` conversion. To test: `cmake --build build --target check-torch-mlir-all`
1 parent 16bbcb0 commit abb9282

File tree

7 files changed

+114
-7
lines changed

7 files changed

+114
-7
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15492,6 +15492,7 @@ def Torch_AtenFloordivIntOp : Torch_Op<"aten.floordiv.int", [
1549215492
}
1549315493
}];
1549415494
let hasFolder = 1;
15495+
let hasCanonicalizer = 1;
1549515496
}
1549615497

1549715498
def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
@@ -15641,6 +15642,7 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
1564115642
}
1564215643
}];
1564315644
let hasFolder = 1;
15645+
let hasCanonicalizer = 1;
1564415646
}
1564515647

1564615648
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,10 +1823,6 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
18231823
return rewriter.notifyMatchFailure(
18241824
op, "Only ranked tensor types supported in TOSA Rsub");
18251825

1826-
if (!isa<mlir::FloatType>(selfTy.getElementType()))
1827-
return rewriter.notifyMatchFailure(
1828-
op, "Only floating-point datatype legalization supported");
1829-
18301826
Value otherTensor, alphaTensor;
18311827

18321828
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3508,6 +3508,44 @@ OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) {
35083508
[](int64_t a, int64_t b) { return std::floor(a / (double)b); });
35093509
}
35103510

3511+
void AtenFloordivIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3512+
MLIRContext *context) {
3513+
patterns.add(+[](AtenFloordivIntOp op, PatternRewriter &rewriter) {
3514+
int64_t lhs, rhs;
3515+
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
3516+
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
3517+
if (lConstant && rConstant)
3518+
return failure();
3519+
if (lConstant || rConstant) {
3520+
int64_t firstConstant = lConstant ? lhs : rhs;
3521+
Value firstOperand = lConstant ? op.getB() : op.getA();
3522+
if (firstOperand.getDefiningOp() &&
3523+
firstOperand.getDefiningOp<AtenMulIntOp>()) {
3524+
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
3525+
int64_t prevLhs, prevRhs;
3526+
bool prevLConstant =
3527+
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
3528+
bool prevRConstant =
3529+
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
3530+
if (prevLConstant && prevRConstant)
3531+
return failure();
3532+
if ((prevLConstant || prevRConstant) &&
3533+
prevMulIntOp->hasOneUse() == 1) {
3534+
int64_t secondConstant = prevLConstant ? prevLhs : prevRhs;
3535+
if (secondConstant == firstConstant) {
3536+
rewriter.replaceAllUsesWith(
3537+
op.getResult(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0));
3538+
rewriter.eraseOp(op);
3539+
rewriter.eraseOp(prevMulIntOp);
3540+
return success();
3541+
}
3542+
}
3543+
}
3544+
}
3545+
return failure();
3546+
});
3547+
}
3548+
35113549
//===----------------------------------------------------------------------===//
35123550
// AtenRemainderIntOp
35133551
//===----------------------------------------------------------------------===//
@@ -3799,6 +3837,45 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) {
37993837
return nullptr;
38003838
}
38013839

3840+
void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3841+
MLIRContext *context) {
3842+
patterns.add(+[](AtenMulIntOp op, PatternRewriter &rewriter) {
3843+
int64_t lhs, rhs;
3844+
bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs));
3845+
bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs));
3846+
if (lConstant && rConstant)
3847+
return failure();
3848+
if (lConstant || rConstant) {
3849+
int64_t firstConstant = lConstant ? lhs : rhs;
3850+
Value firstOperand = lConstant ? op.getB() : op.getA();
3851+
if (firstOperand.getDefiningOp() &&
3852+
firstOperand.getDefiningOp<AtenMulIntOp>()) {
3853+
auto prevMulIntOp = firstOperand.getDefiningOp<AtenMulIntOp>();
3854+
int64_t prevLhs, prevRhs;
3855+
bool prevLConstant =
3856+
matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs));
3857+
bool prevRConstant =
3858+
matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs));
3859+
if (prevLConstant && prevRConstant)
3860+
return failure();
3861+
if ((prevLConstant || prevRConstant) &&
3862+
prevMulIntOp->hasOneUse() == 1) {
3863+
auto newConstant = rewriter.create<Torch::ConstantIntOp>(
3864+
op.getLoc(), rewriter.getI64IntegerAttr(
3865+
prevLConstant ? prevLhs * firstConstant
3866+
: prevRhs * firstConstant));
3867+
rewriter.replaceOpWithNewOp<AtenMulIntOp>(
3868+
op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0),
3869+
newConstant);
3870+
rewriter.eraseOp(prevMulIntOp);
3871+
return success();
3872+
}
3873+
}
3874+
}
3875+
return failure();
3876+
});
3877+
}
3878+
38023879
//===----------------------------------------------------------------------===//
38033880
// AtenMulFloatOp
38043881
//===----------------------------------------------------------------------===//

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,6 +2003,7 @@
20032003
"RsubFloatModule_basic",
20042004
"RsubFloatModule_noalpha_basic",
20052005
"RsubInt0d_NumToTensor_Module_basic",
2006+
"RsubIntModule_basic",
20062007
"ScalarTensorDefaultDtypeModule_basic",
20072008
"ScalarTensorFloat32Module_basic",
20082009
"ScalarTensorInt32Module_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,13 +1086,21 @@ def emit_with_mutating_variants(key, **kwargs):
10861086
emit("aten::le.int : (int, int) -> (bool)", has_folder=True)
10871087
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
10881088
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
1089-
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
1089+
emit(
1090+
"aten::floordiv.int : (int, int) -> (int)",
1091+
has_folder=True,
1092+
has_canonicalizer=True,
1093+
)
10901094
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
10911095
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True)
10921096
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
10931097
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
10941098
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
1095-
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
1099+
emit(
1100+
"aten::mul.int : (int, int) -> (int)",
1101+
has_folder=True,
1102+
has_canonicalizer=True,
1103+
)
10961104
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
10971105
emit("aten::neg.int : (int) -> (int)", has_folder=True)
10981106
emit("aten::log.int : (int) -> (float)")

python/torch_mlir/extras/fx_importer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@
267267
"gt": torch.ops.aten.gt,
268268
"mod": torch.ops.aten.fmod,
269269
"eq": torch.ops.aten.eq,
270+
"floordiv": torch.ops.aten.floordiv,
270271
}
271272

272273
# torch with cuda has a __version__ that looks like "2.1.0+cu113",

test/Dialect/Torch/canonicalize.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,19 @@ func.func @torch.aten.mul.int() -> !torch.int {
11681168
return %ret : !torch.int
11691169
}
11701170

1171+
// CHECK-LABEL: func.func @torch.aten.mul.int$canonicalize(
1172+
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
1173+
// CHECK: %[[CST30:.*]] = torch.constant.int 30
1174+
// CHECK: %[[RET:.*]] = torch.aten.mul.int %[[ARG]], %[[CST30]] : !torch.int, !torch.int -> !torch.int
1175+
// CHECK: return %[[RET]] : !torch.int
1176+
func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
1177+
%cst6 = torch.constant.int 6
1178+
%cst5 = torch.constant.int 5
1179+
%1 = torch.aten.mul.int %arg0, %cst5: !torch.int, !torch.int -> !torch.int
1180+
%ret = torch.aten.mul.int %1, %cst6: !torch.int, !torch.int -> !torch.int
1181+
return %ret : !torch.int
1182+
}
1183+
11711184
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
11721185
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
11731186
// CHECK: return %[[CST30]] : !torch.float
@@ -1207,6 +1220,16 @@ func.func @torch.aten.floordiv.int() -> !torch.int {
12071220
return %ret : !torch.int
12081221
}
12091222

1223+
// CHECK-LABEL: func.func @torch.aten.floordiv.int$canonicalize(
1224+
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
1225+
// CHECK: return %[[ARG]] : !torch.int
1226+
func.func @torch.aten.floordiv.int$canonicalize(%arg0: !torch.int) -> !torch.int {
1227+
%cst6 = torch.constant.int 6
1228+
%1 = torch.aten.mul.int %arg0, %cst6: !torch.int, !torch.int -> !torch.int
1229+
%ret = torch.aten.floordiv.int %1, %cst6: !torch.int, !torch.int -> !torch.int
1230+
return %ret : !torch.int
1231+
}
1232+
12101233
// CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int {
12111234
// CHECK: %[[CST3:.*]] = torch.constant.int 3
12121235
// CHECK: return %[[CST3]] : !torch.int
@@ -3122,7 +3145,6 @@ func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!
31223145
return %1 : !torch.tensor
31233146
}
31243147

3125-
31263148
// -----
31273149

31283150
// CHECK-LABEL: @torch.symbolic_int$canonicalize(

0 commit comments

Comments
 (0)