diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8bbc0dc74686..f05ceff88a6a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9478,6 +9478,32 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [ }]; } +def Torch_AtenHingeEmbeddingLossOp : Torch_Op<"aten.hinge_embedding_loss", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::hinge_embedding_loss : (Tensor, Tensor, float, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + Torch_FloatType:$margin, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenHingeEmbeddingLossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenHingeEmbeddingLossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenPoissonNllLossOp : Torch_Op<"aten.poisson_nll_loss", [ AllowsTypeRefinement, HasValueSemantics, @@ -9496,7 +9522,7 @@ def Torch_AtenPoissonNllLossOp : Torch_Op<"aten.poisson_nll_loss", [ AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ + let extraClassDefinition = [{ ParseResult AtenPoissonNllLossOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 6, 1); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index e2a88128091f..8ec8f948296f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10700,6 +10700,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.hinge_embedding_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.int) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg3 : !torch.list, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mse_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" @@ -13398,6 +13412,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hinge_embedding_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 64cfadb7c89c..f3ce5207ab63 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -10554,6 +10554,88 @@ class DecomposeAtenNllLossForwardOp } // namespace namespace { +// Decompostion of aten.hinge_embedding_loss op +// Ref: +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L182 +// The Hinge Embedding Loss: +// | input, if target == 1 +// loss(x) = | +// | max(0, margin - input), if target == -1 +// target tensor may have values other than 1 and -1 +class DecomposeHingeEmbeddingLoss + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenHingeEmbeddingLossOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto input = op.getSelf(); + auto target = op.getTarget(); + + auto inputTy = dyn_cast(input.getType()); + if (!inputTy.hasDtype() || !inputTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "input must have dtype and size"); + + auto targetTy = dyn_cast(target.getType()); + if (!targetTy.hasDtype() || !targetTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "target must have dtype and size"); + + int64_t reduction; + if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } + auto resultTy = dyn_cast(op.getType()); + Value minusOne = getConstantWithGivenDtypeAndValue(rewriter, loc, -1, + targetTy.getDtype()); + Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1, + targetTy.getDtype()); + Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0, + targetTy.getDtype()); + Value alpha = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto boolType = targetTy.getWithSizesAndDtype(targetTy.getSizes(), + rewriter.getI1Type()); + // input - margin + auto inputMinusMargin = rewriter.create( + loc, inputTy, input, op.getMargin(), alpha); + // multiply by -1 to get margin - input + auto marginDiff = rewriter.create( + loc, inputTy, inputMinusMargin, minusOne); + // max(0, margin - input) => clamping the minimum value of margin - input at + // 0 + auto marginClamp = + rewriter.create(loc, inputTy, marginDiff, zero); + // Compute mask: target != 1 + auto targetNotOne = + rewriter.create(loc, boolType, target, one); + // If target != 1 use marginClamp otherwise 0. + auto outputMargin = rewriter.create( + loc, inputTy, targetNotOne, marginClamp, zero); + // Compute mask: target != -1 + auto targetNotMinusOne = + rewriter.create(loc, boolType, target, minusOne); + // If target != -1 use the original input. Otherwise 0. + auto outputSelf = rewriter.create( + loc, inputTy, targetNotMinusOne, input, zero); + // Add : outputMargin + outputSelf + auto output = rewriter.create(loc, inputTy, outputMargin, + outputSelf, /*alpha=*/alpha); + Value loss = output; + Value none = rewriter.create(loc); + // reduction: mean + if (reduction == 1) { + loss = rewriter.create(loc, resultTy, output, none); + } else if (reduction == 2) { + // reduction: sum + loss = rewriter.create(loc, resultTy, output, none); + } + rewriter.replaceOp(op, loss); + return success(); + } +}; +} // namespace + +namespace { class DecomposeAtenPoissonNllLossOp : public OpRewritePattern { public: @@ -12543,6 +12625,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f99537f8cd93..21839bcc7eed 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -538,6 +538,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a346d6da2cc7..2b6f5b18f011 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1795,6 +1795,9 @@ "L1LossMeanReductionModule_basic", "L1LossNoReductionModule_basic", "L1LossSumReductionModule_basic", + "HingeEmbeddingLossReductionMeanModule_basic", + "HingeEmbeddingLossReductionSumModule_basic", + "HingeEmbeddingLossReductionNoneModule_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", "RandIntLowModule_basic", @@ -2968,6 +2971,10 @@ "GtFloatIntModule_basic", "GtIntModule_basic", "HardtanhBackward_basic", + "HingeEmbeddingLossBasicModule_basic", + "HingeEmbeddingLossReductionMeanModule_basic", + "HingeEmbeddingLossReductionSumModule_basic", + "HingeEmbeddingLossReductionNoneModule_basic", "HstackBasicComplexModule_basic", "HstackBasicFloatModule_basic", "HstackBasicIntFloatModule_basic", @@ -3982,6 +3989,10 @@ "NllLossStaticModule_mean_basic", "NllLossStaticModule_sum_basic", "NllLossStaticModule_weight_basic", + "HingeEmbeddingLossBasicModule_basic", + "HingeEmbeddingLossReductionMeanModule_basic", + "HingeEmbeddingLossReductionSumModule_basic", + "HingeEmbeddingLossReductionNoneModule_basic", "Exp2StaticModule_basic", "ElementwiseRreluWithNoiseEvalModule_basic", "ElementwiseRreluWithNoiseEvalStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 4f59be02aef1..fb25b0c682b0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2186,6 +2186,11 @@ def aten〇nll_loss_forward〡shape(self: List[int], target: List[int], weight: def aten〇nll_loss_backward〡shape(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇hinge_embedding_loss〡shape(self: List[int], target: List[int], margin: float = 1., reduction: int = 1) -> List[int]: + if reduction in [1,2]: + return [] + return upstream_shape_functions.unary(self) + # TODO: upstream this def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]: if reduction == 0: @@ -3982,6 +3987,13 @@ def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], se return torch.int64 return result +def aten〇hinge_embedding_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], margin: float = 1., reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_tensors_with_the_same_dtype( None, [(2, 4, 7, 6), (2, 4, 6, 5)], None, None, [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) + diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 753beb972fd0..ff0fba139ea2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -761,6 +761,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)" ) + emit("aten::hinge_embedding_loss : (Tensor, Tensor, float, int) -> (Tensor)") emit( "aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 0eb0545e7f11..d9264db0657b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -2455,6 +2455,95 @@ def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class HingeEmbeddingLossBasicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, input, target): + return torch.ops.aten.hinge_embedding_loss( + input, target, margin=1.5, reduction=1 + ) + + +@register_test_case(module_factory=lambda: HingeEmbeddingLossBasicModule()) +def HingeEmbeddingLossBasicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 3), tu.rand(1, 2, 3)) + + +class HingeEmbeddingLossReductionMeanModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, input, target): + return torch.ops.aten.hinge_embedding_loss(input, target, reduction=1) + + +@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionMeanModule()) +def HingeEmbeddingLossReductionMeanModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 1), tu.rand(1, 1)) + + +class HingeEmbeddingLossReductionSumModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, input, target): + return torch.ops.aten.hinge_embedding_loss(input, target, reduction=2) + + +@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionSumModule()) +def HingeEmbeddingLossReductionSumModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5), tu.rand(1, 1)) + + +class HingeEmbeddingLossReductionNoneModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, input, target): + return torch.ops.aten.hinge_embedding_loss(input, target, margin=1.0) + + +@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionNoneModule()) +def HingeEmbeddingLossReductionNoneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 5), tu.rand(1)) + + +# ============================================================================== + + class TraceModule(torch.nn.Module): def __init__(self) -> None: super().__init__()