Skip to content

[TORCH] Add support for aten.hinge_embedding_loss Op #4227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9478,6 +9478,32 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
}];
}

def Torch_AtenHingeEmbeddingLossOp : Torch_Op<"aten.hinge_embedding_loss", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::hinge_embedding_loss : (Tensor, Tensor, float, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
Torch_FloatType:$margin,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenHingeEmbeddingLossOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenHingeEmbeddingLossOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenPoissonNllLossOp : Torch_Op<"aten.poisson_nll_loss", [
AllowsTypeRefinement,
HasValueSemantics,
Expand All @@ -9496,7 +9522,7 @@ def Torch_AtenPoissonNllLossOp : Torch_Op<"aten.poisson_nll_loss", [
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
let extraClassDefinition = [{
ParseResult AtenPoissonNllLossOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
Expand Down
22 changes: 22 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10700,6 +10700,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.hinge_embedding_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.int) -> !torch.list<int> {\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg3 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mse_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
Expand Down Expand Up @@ -13398,6 +13412,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.hinge_embedding_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
83 changes: 83 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10554,6 +10554,88 @@ class DecomposeAtenNllLossForwardOp
} // namespace

namespace {
// Decompostion of aten.hinge_embedding_loss op
// Ref:
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L182
// The Hinge Embedding Loss:
// | input, if target == 1
// loss(x) = |
// | max(0, margin - input), if target == -1
// target tensor may have values other than 1 and -1
class DecomposeHingeEmbeddingLoss
: public OpRewritePattern<AtenHingeEmbeddingLossOp> {
using OpRewritePattern<AtenHingeEmbeddingLossOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenHingeEmbeddingLossOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto input = op.getSelf();
auto target = op.getTarget();

auto inputTy = dyn_cast<ValueTensorType>(input.getType());
if (!inputTy.hasDtype() || !inputTy.hasSizes())
return rewriter.notifyMatchFailure(op, "input must have dtype and size");

auto targetTy = dyn_cast<ValueTensorType>(target.getType());
if (!targetTy.hasDtype() || !targetTy.hasSizes())
return rewriter.notifyMatchFailure(op, "target must have dtype and size");

int64_t reduction;
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) {
return rewriter.notifyMatchFailure(op,
"reduction should be a constant int!");
}
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
Value minusOne = getConstantWithGivenDtypeAndValue(rewriter, loc, -1,
targetTy.getDtype());
Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
targetTy.getDtype());
Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
targetTy.getDtype());
Value alpha =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto boolType = targetTy.getWithSizesAndDtype(targetTy.getSizes(),
rewriter.getI1Type());
// input - margin
auto inputMinusMargin = rewriter.create<AtenSubScalarOp>(
loc, inputTy, input, op.getMargin(), alpha);
// multiply by -1 to get margin - input
auto marginDiff = rewriter.create<AtenMulScalarOp>(
loc, inputTy, inputMinusMargin, minusOne);
// max(0, margin - input) => clamping the minimum value of margin - input at
// 0
auto marginClamp =
rewriter.create<AtenClampMinOp>(loc, inputTy, marginDiff, zero);
// Compute mask: target != 1
auto targetNotOne =
rewriter.create<AtenNeScalarOp>(loc, boolType, target, one);
// If target != 1 use marginClamp otherwise 0.
auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotOne, marginClamp, zero);
// Compute mask: target != -1
auto targetNotMinusOne =
rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne);
// If target != -1 use the original input. Otherwise 0.
auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>(
loc, inputTy, targetNotMinusOne, input, zero);
// Add : outputMargin + outputSelf
auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
outputSelf, /*alpha=*/alpha);
Comment on lines +10608 to +10622
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing all this, you can just do:

auto result = rewriter.create<AtenWhereScalarOtherOp>(
        loc, inputTy, targetNotOne, marginClamp, input);

Copy link
Author

@sharavak sharavak Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekkhandelwal1 Thanks for the suggestion, I did try the simplified version initially, but it caused numerical validation errors in some test cases. This happens because the target tensor can sometimes have values other than just -1 and 1.

To handle this properly and stay consistent with PyTorch's semantics, I decided to explicitly check for both target == 1 and target == -1. This way, the behavior stays correct even if target have values other than just -1 and 1.

Eg:

import torch
input=torch.randn(2,3)
target=torch.randn(2,3)
torch.hinge_embedding_loss(input,target)

Output:
tensor([[1.1361, 1.0000, 1.0000],
        [1.4880, 1.1624, 1.0000]])

Value loss = output;
Value none = rewriter.create<ConstantNoneOp>(loc);
// reduction: mean
if (reduction == 1) {
loss = rewriter.create<AtenMeanOp>(loc, resultTy, output, none);
} else if (reduction == 2) {
// reduction: sum
loss = rewriter.create<AtenSumOp>(loc, resultTy, output, none);
}
rewriter.replaceOp(op, loss);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenPoissonNllLossOp
: public OpRewritePattern<AtenPoissonNllLossOp> {
public:
Expand Down Expand Up @@ -12543,6 +12625,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeHingeEmbeddingLoss>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPoissonNllLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLerpTensorOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenL1LossOp>();
target.addIllegalOp<AtenHingeEmbeddingLossOp>();
target.addIllegalOp<AtenPoissonNllLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenRandintOp>();
Expand Down
11 changes: 11 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,9 @@
"L1LossMeanReductionModule_basic",
"L1LossNoReductionModule_basic",
"L1LossSumReductionModule_basic",
"HingeEmbeddingLossReductionMeanModule_basic",
"HingeEmbeddingLossReductionSumModule_basic",
"HingeEmbeddingLossReductionNoneModule_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"RandIntLowModule_basic",
Expand Down Expand Up @@ -2968,6 +2971,10 @@
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HardtanhBackward_basic",
"HingeEmbeddingLossBasicModule_basic",
"HingeEmbeddingLossReductionMeanModule_basic",
"HingeEmbeddingLossReductionSumModule_basic",
"HingeEmbeddingLossReductionNoneModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
Expand Down Expand Up @@ -3982,6 +3989,10 @@
"NllLossStaticModule_mean_basic",
"NllLossStaticModule_sum_basic",
"NllLossStaticModule_weight_basic",
"HingeEmbeddingLossBasicModule_basic",
"HingeEmbeddingLossReductionMeanModule_basic",
"HingeEmbeddingLossReductionSumModule_basic",
"HingeEmbeddingLossReductionNoneModule_basic",
"Exp2StaticModule_basic",
"ElementwiseRreluWithNoiseEvalModule_basic",
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2186,6 +2186,11 @@ def aten〇nll_loss_forward〡shape(self: List[int], target: List[int], weight:
def aten〇nll_loss_backward〡shape(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇hinge_embedding_loss〡shape(self: List[int], target: List[int], margin: float = 1., reduction: int = 1) -> List[int]:
if reduction in [1,2]:
return []
return upstream_shape_functions.unary(self)

# TODO: upstream this
def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]:
if reduction == 0:
Expand Down Expand Up @@ -3982,6 +3987,13 @@ def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], se
return torch.int64
return result

def aten〇hinge_embedding_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], margin: float = 1., reduction: int = 1) -> int:
self_rank, self_dtype = self_rank_dtype
target_rank, target_dtype = target_rank_dtype
ranks: List[Optional[int]] = [self_rank, target_rank]
dtypes = [self_dtype, target_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(_check_tensors_with_the_same_dtype(
None, [(2, 4, 7, 6), (2, 4, 6, 5)], None, None,
[2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
)
emit("aten::hinge_embedding_loss : (Tensor, Tensor, float, int) -> (Tensor)")
emit(
"aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)"
)
Expand Down
89 changes: 89 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2455,6 +2455,95 @@ def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils):
# ==============================================================================


class HingeEmbeddingLossBasicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.hinge_embedding_loss(
input, target, margin=1.5, reduction=1
)


@register_test_case(module_factory=lambda: HingeEmbeddingLossBasicModule())
def HingeEmbeddingLossBasicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 3), tu.rand(1, 2, 3))


class HingeEmbeddingLossReductionMeanModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.hinge_embedding_loss(input, target, reduction=1)


@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionMeanModule())
def HingeEmbeddingLossReductionMeanModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 1), tu.rand(1, 1))


class HingeEmbeddingLossReductionSumModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.hinge_embedding_loss(input, target, reduction=2)


@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionSumModule())
def HingeEmbeddingLossReductionSumModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5), tu.rand(1, 1))


class HingeEmbeddingLossReductionNoneModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
]
)
def forward(self, input, target):
return torch.ops.aten.hinge_embedding_loss(input, target, margin=1.0)


@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionNoneModule())
def HingeEmbeddingLossReductionNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 5), tu.rand(1))


# ==============================================================================


class TraceModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down