From 21005a7d63b9795148aa24ae042cd74136cd79c2 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Thu, 31 Oct 2024 12:03:31 +0800 Subject: [PATCH 1/2] [Torch] fix cross_entropy_loss decomposition --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1fefb59a4cac..fdb9e5fec475 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8925,10 +8925,19 @@ class DecomposeAtenCrossEntropyLossOp loc, rewriter.getI64IntegerAttr(1)); Value logSoftmax = rewriter.create( loc, self.getType(), self, dim, /*dtype=*/noneVal); + + Type secondType; + if (reductionInt == 0) { + secondType = target.getType(); + } else { + auto targetType = dyn_cast(target.getType()); + secondType = targetType.getWithSizesAndDtype({}, targetType.getDtype()); + } + Value nllLoss = rewriter .create( - loc, op.getType(), target.getType(), logSoftmax, target, + loc, op.getType(), secondType, logSoftmax, target, op.getWeight(), op.getReduction(), op.getIgnoreIndex()) ->getResult(0); rewriter.replaceOp(op, nllLoss); From 92e0098a00c08c504d593378669b58e1bc7fc6b8 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Thu, 31 Oct 2024 12:13:48 +0800 Subject: [PATCH 2/2] update --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index fdb9e5fec475..7a8718846255 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8896,6 +8896,12 @@ class DecomposeAtenCrossEntropyLossOp op, "Unimplemented: unranked target tensor"); unsigned targetRank = maybeRank.value(); + Value reduction = op.getReduction(); + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } // When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d // of the form [minibatch] the cross entropy loss decomposes to the // combination of softmax and nll loss as follows: