Skip to content

Commit 338518b

Browse files
committed
Minor fix for negative axes
Add negative axes lit tests for MVN Signed-off-by: Zahid Wakeel <[email protected]>
1 parent c9fe057 commit 338518b

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1641,7 +1641,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16411641
if (!Torch::isValidDim(dim, inputRank)) {
16421642
return failure();
16431643
}
1644-
reduced_shape[i] = 1;
1644+
reduced_shape[dim] = 1;
16451645
}
16461646
Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get(
16471647
resultType.getContext(), reduced_shape, resultType.getDtype());

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,30 @@ func.func @test_meanvarnorm_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch
16461646

16471647
// -----
16481648

1649+
// CHECK-LABEL: func.func @test_meanvarnorm_neg_axes(
1650+
func.func @test_meanvarnorm_neg_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1651+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1652+
// CHECK: %[[VAL_0:.*]] = torch.constant.bool true
1653+
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
1654+
// CHECK: %[[VAL_2:.*]] = torch.constant.none
1655+
// CHECK: %[[VAL_3:.*]] = torch.constant.int -1
1656+
// CHECK: %[[VAL_4:.*]] = torch.constant.int -3
1657+
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
1658+
// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_5]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2,1],f32>
1659+
// CHECK: %[[VAL_7:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_5]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,1,2,1],f32>
1660+
// CHECK: %[[VAL_8:.*]] = torch.constant.int 1
1661+
// CHECK: %[[VAL_9:.*]] = torch.constant.float 1.000000e-09
1662+
// CHECK: %[[VAL_10:.*]] = torch.aten.add.Scalar %[[VAL_7]], %[[VAL_9]], %[[VAL_8]] : !torch.vtensor<[3,1,2,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,1,2,1],f32>
1663+
// CHECK: %[[VAL_11:.*]] = torch.aten.sqrt %[[VAL_10]] : !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,1,2,1],f32>
1664+
// CHECK: %[[VAL_12:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_6]], %[[VAL_8]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32>
1665+
// CHECK: %[[VAL_13:.*]] = torch.aten.div.Tensor %[[VAL_12]], %[[VAL_11]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,5,2,2],f32>
1666+
// CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5,2,2],f32>
1667+
// CHECK: }
1668+
%0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [-1 : si64, -3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32>
1669+
return %0 : !torch.vtensor<[3,5,2,2],f32>
1670+
1671+
// -----
1672+
16491673
// CHECK-LABEL: func.func @test_not_2d
16501674
func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
16511675
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>

0 commit comments

Comments
 (0)