diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 0714523fb5..e9e455e26b 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -20,6 +20,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -484,9 +486,7 @@ namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Dialect/ONNX/Transforms/ONNXDecompose.inc" -#ifdef ONNX_MLIR_ENABLE_STABLEHLO - -RankedTensorType createResultType( +RankedTensorType createReducedType( Type outputType, int64_t axisValue, bool keepDims) { RankedTensorType outputShapeType = mlir::dyn_cast(outputType); @@ -507,6 +507,8 @@ RankedTensorType createResultType( return resultType; } +#ifdef ONNX_MLIR_ENABLE_STABLEHLO + struct SoftmaxPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -526,7 +528,7 @@ struct SoftmaxPattern : public OpRewritePattern { rewriter.getIntegerType(64, /*isSigned=*/true), 1); ArrayAttr axisAttr = rewriter.getI64ArrayAttr({axisValue}); RankedTensorType resultType = - createResultType(inputType, axisValue, /*keepDims=*/true); + createReducedType(inputType, axisValue, /*keepDims=*/true); Value maxInput = rewriter.create( odsLoc, resultType, input, axisAttr, keepDimsAttr); Value subValue = @@ -985,6 +987,117 @@ struct GroupNormIntoLayerNormPattern2 } }; +/// Decompose `onnx.SoftmaxCrossEntropyLoss` to the following sequence: +/// In the following we assume classes is in dim=1 of scores. +/// 1. one_hot_encoded = onnx.Castlike(onnx.OneHot(labels, dim=1), scores) +/// 2. log_softmax = onnx.Log(onnx.Softmax(scores, dim=1)) +/// 3. product = onnx.Mul(log_softmax, one_hot_encoded) +/// if `weights` arg is nont `none` then we additionally perform +/// product = onnx.Mul(product, op.Unsqueeze(weights)) +/// where unsqueezing makes the operation broadcastable. +/// 4. reduce_sum = onnx.ReduceSum(product, dim=1) +/// 5. loss = onnx.ReduceMean(reduce_sum) if reduciton == "mean" +/// onnx.ReduceSum(reduce_sum) if reduction == "sum" +/// onnx.Squeeze(reduce_sum) if reduciton == "none" +/// +struct SoftmaxCrossEntropyPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ONNXSoftmaxCrossEntropyLossOp sceOp, + PatternRewriter &rewriter) const final { + auto loc = sceOp.getLoc(); + onnx_mlir::OnnxBuilder create(rewriter, loc); + auto scores = sceOp.getScores(); + auto labels = sceOp.getLabels(); + auto weights = sceOp.getWeights(); + auto scoresTy = cast(scores.getType()); + auto labelsTy = cast(labels.getType()); + SmallVector newLabelsShape(labelsTy.getShape()); + newLabelsShape.insert(newLabelsShape.begin() + 1, scoresTy.getShape()[1]); + auto none = rewriter.create(loc); + auto numClasses = (scoresTy.isDynamicDim(1)) + ? create.dim(scores, 1) + : create.constantInt64({scoresTy.getShape()[1]}); + auto elemTy = scoresTy.getElementType(); + // Compute one hot encoded labels and cast to `scores` element type. + auto oneHotValsAttr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), + ArrayRef{0, 1}); + auto oneHotVals = create.constant(oneHotValsAttr); + auto oneHot = create.cast( + rewriter.create(loc, + RankedTensorType::get(newLabelsShape, labelsTy.getElementType()), + labels, numClasses, oneHotVals, /*axis=*/1), + /*saturate=*/ + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), 1), + TypeAttr::get(elemTy)); + // Compute logsoftmax of scores. + auto softmax = + rewriter.create(loc, scoresTy, scores, /*axis=*/1); + auto logSoftmax = rewriter.create(loc, scoresTy, softmax); + auto prod = rewriter.create(loc, logSoftmax, oneHot); + // Multiply by `weights` if not none. + if (auto weightTy = dyn_cast(weights.getType())) { + // Unsqueeze weight from [C] to [1 x C x 1 x ... x 1] to make it + // broadcast-compliant. + llvm::SmallVector unsqueezedShape(scoresTy.getRank(), 1); + unsqueezedShape[1] = scoresTy.getShape()[1]; + llvm::SmallVector axesList(scoresTy.getRank() - 1, 0); + std::iota(axesList.begin() + 1, axesList.end(), 2); + auto axes = create.constantInt64(axesList); + auto weightsUnsqueezed = create.unsqueeze( + RankedTensorType::get(unsqueezedShape, elemTy), weights, axes); + prod = rewriter.create(loc, prod, weightsUnsqueezed); + } + // Reduction across `class` (dim=1) axis. + auto axes = create.constant(onnx_mlir::createDenseArrayAttr( + rewriter, rewriter.getI64ArrayAttr({1}))); + auto reducedType = createReducedType(scoresTy, 1, /*keepdims=*/true); + Value loss = rewriter.create(loc, reducedType, prod, axes); + // ReduceMean/ReduceSum/Squeeze if reduction = mean/sum/none respectively. + // Set `axes=none` to indicate reducing all dims. + auto reduction = cast(sceOp.getReductionAttr()).getValue(); + if (reduction == "mean") { + if (isa(weights.getType())) { + loss = rewriter.create(loc, + RankedTensorType::get({}, elemTy), loss, none, + /*keepdims=*/0); + } else { + auto sumL = rewriter.create(loc, + RankedTensorType::get({}, elemTy), loss, none, + /*keepdims=*/0); + // Perform einsum(one_hot, weights) as a simple way of producing + // W[n][d1][d2]...[dk] = weights[labels[i][d1][d2]...[dk]] + auto scatteredWeights = rewriter.create(loc, + RankedTensorType::get(labelsTy.getShape(), elemTy), + ValueRange{oneHot, weights}, "ij...,j->i..."); + auto sumW = rewriter.create(loc, + RankedTensorType::get({}, elemTy), scatteredWeights, none, + /*keepdims=*/0); + loss = rewriter.create(loc, sumL, sumW); + } + } else if (reduction == "sum") { + loss = rewriter.create(loc, + RankedTensorType::get({}, elemTy), loss, none, + /*keepdims=*/0); + } else if (reduction == "none") { + loss = rewriter.create(loc, + createReducedType(reducedType, 1, /*keepdims=*/false), loss, axes); + } else { + llvm_unreachable("unexpected reduction type"); + } + // Negate. + loss = rewriter.create(loc, loss.getType(), loss); + // Second return value replacement depends if it is `none` or not. + if (isa(sceOp.getLogProb().getType())) + rewriter.replaceOp(sceOp, {loss, none}); + else + rewriter.replaceOp(sceOp, {loss, logSoftmax}); + return success(); + } +}; + /// Decompose `onnx.Sum` to a sequence of `onnx.Add` struct SumToAddPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1114,6 +1227,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -1190,6 +1304,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns( patterns.insert(context); patterns.insert(context); patterns.insert(context); + patterns.insert(context); patterns.insert(context); // TODO: consider whether to include SoftmaxPattern here diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 2fe2f9e374..c7c22bc08d 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -745,3 +745,186 @@ func.func @test_sum_single_input_to_unranked(%arg0: tensor<64x128x10xf32>) -> te // CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[ARG0]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32> // CHECK-NEXT: onnx.Return %[[CAST]] } + +// ----- + +func.func @sce_mean(%arg0: tensor<64x10xf32>, %arg1: tensor<64xi64>) -> tensor { + %0 = "onnx.NoValue"() {value, weight} : () -> none + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %0) {reduction = "mean"} : (tensor<64x10xf32>, tensor<64xi64>, none) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_mean + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1xf32> + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.ReduceMean"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_mean_return_log_prob(%arg0: tensor<64x10xf32>, %arg1: tensor<64xi64>) -> (tensor, tensor<64x10xf32>) { + %0 = "onnx.NoValue"() {value, weight} : () -> none + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %0) {reduction = "mean"} : (tensor<64x10xf32>, tensor<64xi64>, none) -> (tensor, tensor<64x10xf32>) + onnx.Return %output, %log_prob : tensor, tensor<64x10xf32> + // CHECK-LABEL: func @sce_mean_return_log_prob + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1xf32> + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.ReduceMean"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]], %[[LOG_SOFTMAX]] : tensor, tensor<64x10xf32> +} + +// ----- + +func.func @sce_mean_with_weight_NCD1D2(%arg0: tensor<64x10x2x3xf32>, %arg1: tensor<64x2x3xi64>, %arg2: tensor<10xf32>) -> tensor { + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %arg2) {reduction = "mean"} : (tensor<64x10x2x3xf32>, tensor<64x2x3xi64>, tensor<10xf32>) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_mean_with_weight_NCD1D2 + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10x2x3xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[UNSQUEEZE_AXES:.*]] = onnx.Constant dense<[0, 2, 3]> : tensor<3xi64> + // CHECK-NEXT: %[[UNSQUEEZE_WEIGHT:.*]] = "onnx.Unsqueeze"(%[[ARG2]], %[[UNSQUEEZE_AXES]]) : ({{.*}}) -> tensor<1x10x1x1xf32> + // CHECK-NEXT: %[[WEIGHT_PROD:.*]] = "onnx.Mul"(%[[PROD]], %[[UNSQUEEZE_WEIGHT]]) + // CHECK-NEXT: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[WEIGHT_PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1x2x3xf32> + // CHECK-NEXT: %[[SUML:.*]] = "onnx.ReduceSum"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + + // This block is an `onnx.EinSum` expanded by a different pattern rewrite + // CHECK-NEXT: %[[TRANSPOSE_ONE_HOT:.*]] = "onnx.Transpose"(%[[ONE_HOT_LABELS_F]]) {perm = [0, 2, 3, 1]} : ({{.*}}) -> tensor<64x2x3x10xf32> + // CHECK-NEXT: %[[COLLAPSED_SHAPE:.*]] = onnx.Constant dense<[384, 10]> : tensor<2xi64> + // CHECK-NEXT: %[[COLLAPSED_ONE_SHOT:.*]] = "onnx.Reshape"(%[[TRANSPOSE_ONE_HOT]], %[[COLLAPSED_SHAPE]]) {allowzero = 0 : si64} : ({{.*}}) -> tensor<384x10xf32> + // CHECK-NEXT: %[[EXPANDED_WEIGHT_SHAPE:.*]] = onnx.Constant dense<[10, 1]> : tensor<2xi64> + // CHECK-NEXT: %[[EXPANDED_WEIGHT:.*]] = "onnx.Reshape"(%[[ARG2]], %[[EXPANDED_WEIGHT_SHAPE]]) {allowzero = 0 : si64} : ({{.*}}) -> tensor<10x1xf32> + // CHECK-NEXT: %[[MATMUL:.*]] = "onnx.MatMul"(%[[COLLAPSED_ONE_SHOT]], %[[EXPANDED_WEIGHT]]) : ({{.*}}) -> tensor<384x1xf32> + // CHECK-NEXT: %[[W_SHAPE:.*]] = onnx.Constant dense<[64, 2, 3]> : tensor<3xi64> + // CHECK-NEXT: %[[W:.*]] = "onnx.Reshape"(%[[MATMUL]], %[[W_SHAPE]]) {allowzero = 0 : si64} : ({{.*}}) -> tensor<64x2x3xf32> + + // CHECK-NEXT: %[[SUMW:.*]] = "onnx.ReduceSum"(%[[W]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.Div"(%[[SUML]], %[[SUMW]]) + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_mean_with_weight_NCD1D2_dynamic_num_classes(%arg0: tensor<64x?x2x3xf32>, %arg1: tensor<64x2x3xi64>, %arg2: tensor) -> tensor { + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %arg2) {reduction = "mean"} : (tensor<64x?x2x3xf32>, tensor<64x2x3xi64>, tensor) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_mean_with_weight_NCD1D2 + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = "onnx.Dim"(%[[ARG0]]) {axis = 1 : si64} : (tensor<64x?x2x3xf32>) -> tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x?x2x3xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[UNSQUEEZE_AXES:.*]] = onnx.Constant dense<[0, 2, 3]> : tensor<3xi64> + // CHECK-NEXT: %[[UNSQUEEZE_WEIGHT:.*]] = "onnx.Unsqueeze"(%[[ARG2]], %[[UNSQUEEZE_AXES]]) : ({{.*}}) -> tensor<1x?x1x1xf32> + // CHECK-NEXT: %[[WEIGHT_PROD:.*]] = "onnx.Mul"(%[[PROD]], %[[UNSQUEEZE_WEIGHT]]) + // CHECK-NEXT: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[WEIGHT_PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1x2x3xf32> + // CHECK-NEXT: %[[SUML:.*]] = "onnx.ReduceSum"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[S_WEIGHTS:.*]] = "onnx.Einsum"(%[[ONE_HOT_LABELS_F]], %[[ARG2]]) {equation = "ij...,j->i..."} + // CHECK-NEXT: %[[SUMW:.*]] = "onnx.ReduceSum"(%[[S_WEIGHTS]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.Div"(%[[SUML]], %[[SUMW]]) + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_sum(%arg0: tensor<64x10xf32>, %arg1: tensor<64xi64>) -> tensor { + %0 = "onnx.NoValue"() {value, weight} : () -> none + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %0) {reduction = "sum"} : (tensor<64x10xf32>, tensor<64xi64>, none) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_sum + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1xf32> + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.ReduceSum"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_sum_with_weight(%arg0: tensor<64x10xf32>, %arg1: tensor<64xi64>, %arg2: tensor<10xf32>) -> tensor { + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %arg2) {reduction = "sum"} : (tensor<64x10xf32>, tensor<64xi64>, tensor<10xf32>) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_sum_with_weight + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[UNSQUEEZE_AXES:.*]] = onnx.Constant dense<0> : tensor<1xi64> + // CHECK-NEXT: %[[UNSQUEEZE_WEIGHT:.*]] = "onnx.Unsqueeze"(%[[ARG2]], %[[UNSQUEEZE_AXES]]) : ({{.*}}) -> tensor<1x10xf32> + // CHECK-NEXT: %[[WEIGHT_PROD:.*]] = "onnx.Mul"(%[[PROD]], %[[UNSQUEEZE_WEIGHT]]) + // CHECK-NEXT: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[WEIGHT_PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1xf32> + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.ReduceSum"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_none(%arg0: tensor<64x10x2x3xf32>, %arg1: tensor<64x2x3xi64>) -> tensor<64x2x3xf32> { + %0 = "onnx.NoValue"() {value, weight} : () -> none + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %0) {reduction = "none"} : (tensor<64x10x2x3xf32>, tensor<64x2x3xi64>, none) -> (tensor<64x2x3xf32>, none) + onnx.Return %output : tensor<64x2x3xf32> + // CHECK-LABEL: func @sce_none + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10x2x3xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1x2x3xf32> + // CHECK-NEXT: %[[SQUEEZE:.*]] = "onnx.Squeeze"(%[[SUM]], %[[REDUCE_AXIS]]) : ({{.*}}) -> tensor<64x2x3xf32> + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[SQUEEZE]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor<64x2x3xf32> +}