Skip to content

Commit

Permalink
[Torch Dialect] Add torch.aten.mul.int_float (required to simplify …
Browse files Browse the repository at this point in the history
…shape calculation of `upsample_nearest2d`) (#3764)

As per title. See also
[PR](#3750) for
`torch.aten.mul.float_int`.

---------

Co-authored-by: zjgarvey <[email protected]>
  • Loading branch information
2 people authored and qingyunqu committed Nov 21, 2024
1 parent abb9282 commit ac4cb97
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 6 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15645,6 +15645,31 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
let hasCanonicalizer = 1;
}

def Torch_AtenMulIntFloatOp : Torch_Op<"aten.mul.int_float", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::mul.int_float : (int, float) -> (float)`";
let arguments = (ins
Torch_IntType:$a,
Torch_FloatType:$b
);
let results = (outs
Torch_FloatType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMulIntFloatOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenMulIntFloatOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
8 changes: 7 additions & 1 deletion lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
Value b = adaptor.getB();
if (llvm::is_one_of<AtenOp, AtenAddFloatIntOp>::value)
b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType());
if (llvm::is_one_of<AtenOp, AtenMulIntFloatOp>::value)
a = convertScalarToDtype(rewriter, op.getLoc(), a, b.getType());
rewriter.template replaceOpWithNewOp<BinOp>(op, a, b);
return success();
}
Expand Down Expand Up @@ -467,15 +469,19 @@ class ConvertTorchToArith
patterns.add<ConvertAtenAddOp>(typeConverter, context);

target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
AtenMulIntOp>();
AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenAddFloatIntOp, arith::AddFOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntFloatOp, arith::MulFOp>>(
typeConverter, context);
target.addIllegalOp<AtenSubFloatOp, AtenMulFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context);
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3932,6 +3932,19 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
[](double a, double b) -> double { return a * b; });
}

//===----------------------------------------------------------------------===//
// AtenMulIntFloatOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenMulIntFloatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(),
[](double a, double b) -> double { return a * b; });
}

//===----------------------------------------------------------------------===//
// AtenSubOp
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.aten.append.t %1, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n"
" %22 = torch.aten.mul.int_float %20, %21 : !torch.int, !torch.float -> !torch.float\n"
" %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n"
" %24 = torch.aten.append.t %1, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
Expand Down Expand Up @@ -10931,7 +10931,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %19 : !torch.list<int>\n"
Expand Down Expand Up @@ -11011,11 +11011,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
" %21 = torch.operator \"aten.mul.int_float\"(%19, %20) : (!torch.int, !torch.float) -> !torch.float \n"
" %21 = torch.aten.mul.int_float %19, %20 : !torch.int, !torch.float -> !torch.float\n"
" %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n"
" %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %23 : !torch.list<int>\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,7 @@ def emit_with_mutating_variants(key, **kwargs):
has_folder=True,
has_canonicalizer=True,
)
emit("aten::mul.int_float : (int, float) -> (float)", has_folder=True)
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)")
Expand Down
14 changes: 14 additions & 0 deletions test/Conversion/TorchToArith/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,20 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in
return %0 : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.mul.int_float(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
// CHECK: %[[LHS_F64:.*]] = arith.sitofp %[[LHS_I64]] : i64 to f64
// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]]
// CHECK: return %[[OUT]] : !torch.float
func.func @torch.aten.mul.int_float(%arg0: !torch.int, %arg1: !torch.float) -> !torch.float {
%0 = torch.aten.mul.int_float %arg0, %arg1 : !torch.int, !torch.float -> !torch.float
return %0 : !torch.float
}

// CHECK-LABEL: func.func @torch.aten.div.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,16 @@ func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
return %ret : !torch.int
}

// CHECK-LABEL: func.func @torch.aten.mul.int_float() -> !torch.float {
// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00
// CHECK: return %[[CST6]] : !torch.float
func.func @torch.aten.mul.int_float() -> !torch.float {
%cst2 = torch.constant.int 2
%cst3 = torch.constant.float 3.0
%ret = torch.aten.mul.int_float %cst2, %cst3: !torch.int, !torch.float -> !torch.float
return %ret : !torch.float
}

// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
// CHECK: return %[[CST30]] : !torch.float
Expand Down

0 comments on commit ac4cb97

Please sign in to comment.