diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c6cb7dddfe3b..8d6b38814c37 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -311,6 +311,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) return failure(); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + if (binder.tensorOperandAtIndex(lhsZp, 2)) { lhsZp = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -323,9 +326,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); } - auto lhsTy = dyn_cast(lhs.getType()); - auto rhsTy = dyn_cast(rhs.getType()); - if (auto zpTy = dyn_cast(lhsZp.getType())) { for (auto dim : zpTy.getSizes()) if (dim != 1) @@ -366,8 +366,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rhs = rewriter.create( binder.getLoc(), rhsQTy, rhs, scale, rhsZp); - rewriter.replaceOpWithNewOp(binder.op, resultType, lhs, - rhs); + rewriter.replaceOpWithNewOp(binder.op, resultType, + lhs, rhs); return success(); }); patterns.onOp("Mul", 7, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index d89e15756b24..2c8e1653258a 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -36,6 +36,26 @@ static void getZeroPoint(Value value, Value &zeropoint) { } } +// for uint8 types, we shift down by 128 so that we can faithfully +// represent the quantization with signed i8 types. +static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, + Value &zp, bool isUnsignedType, int64_t numBits) { + if (!isUnsignedType) + return; + int64_t minSI = -(1 << (numBits - 1)); + Value minSIValue = rewriter.create(loc, minSI, 32); + zp = rewriter.create(loc, zp, minSIValue); + minSIValue = rewriter.create(loc, minSI, numBits); + arg = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, ValueRange{arg}, + arg.getType().cast().getElementType(), + [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value result = + rewriter.create(loc, payloadArgs[0], minSIValue); + b.create(loc, result); + }); +} + static Value transposeValue(Location loc, Value value, ArrayRef perms, PatternRewriter &rewriter) { auto valueTy = value.getType().cast(); @@ -154,32 +174,12 @@ class ConvertAtenMmOp : public OpConversionPattern { rhsZeroPoint = rewriter.create( loc, rewriter.getI32Type(), rhsZeroPoint); - // for uint8 types, we shift down by 128 so that we can faithfully - // represent the quantization with signed i8 types. - auto signShift = [&](Value &arg, Value &zp, bool isUnsignedType, - int64_t numBits) { - if (!isUnsignedType) - return; - int64_t minSI = -std::pow(2, numBits - 1); - Value minSIValue = - rewriter.create(loc, minSI, 32); - zp = rewriter.create(loc, zp, minSIValue); - minSIValue = rewriter.create(loc, minSI, numBits); - arg = torch_to_linalg::createElementwiseLinalgGeneric( - rewriter, loc, ValueRange{arg}, - arg.getType().cast().getElementType(), - [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { - Value result = rewriter.create(loc, payloadArgs[0], - minSIValue); - b.create(loc, result); - }); - }; - + // change uint8 quantization -> int8 quantization int64_t numBits = lhsType.getElementType().cast().getWidth(); - signShift(lhs, lhsZeroPoint, isUnsigned, numBits); + signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); numBits = rhsType.getElementType().cast().getWidth(); - signShift(rhs, rhsZeroPoint, isUnsignedR, numBits); + signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); matmul = rewriter @@ -302,10 +302,57 @@ class ConvertAtenMatmulOp : public OpConversionPattern { auto lhsType = lhs.getType().cast(); auto rhsType = rhs.getType().cast(); + auto lhsTorchType = cast(op.getSelf().getType()); + auto rhsTorchType = cast(op.getOther().getType()); + // Get the rank of both matrix. unsigned lhsRank = lhsType.getRank(); unsigned rhsRank = rhsType.getRank(); + Value lhsZeroPoint, rhsZeroPoint; + getZeroPoint(op.getSelf(), lhsZeroPoint); + getZeroPoint(op.getOther(), rhsZeroPoint); + + if (static_cast(lhsZeroPoint) != static_cast(rhsZeroPoint)) { + return rewriter.notifyMatchFailure( + op, "unsupported: aten.matmul with mixed quantization"); + } + + bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType); + bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType); + + if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) { + // Allows quantized types to mismatch + return rewriter.notifyMatchFailure( + op, "unsupported: aten.matmul with different input element types"); + } + + if (lhsZeroPoint) { + if (lhsRank < 2 || rhsRank < 2) { + return rewriter.notifyMatchFailure( + op, "unsupported: quantized aten.mm with vector"); + } + lhsZeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, + getTypeConverter()->convertType(lhsZeroPoint.getType()), + lhsZeroPoint); + rhsZeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, + getTypeConverter()->convertType(rhsZeroPoint.getType()), + rhsZeroPoint); + lhsZeroPoint = rewriter.create( + loc, rewriter.getI32Type(), lhsZeroPoint); + rhsZeroPoint = rewriter.create( + loc, rewriter.getI32Type(), rhsZeroPoint); + + // change uint8 quantization -> int8 quantization + int64_t numBits = + lhsType.getElementType().cast().getWidth(); + signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); + numBits = rhsType.getElementType().cast().getWidth(); + signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); + } + Type newResultType = getTypeConverter()->convertType(op.getType()); auto resultType = cast(newResultType); Type elementType = resultType.getElementType(); @@ -366,7 +413,7 @@ class ConvertAtenMatmulOp : public OpConversionPattern { return success(); } - // Fourth Case: Vec-Vec Multiplication. + // Fourth Case: Mat-Mat Multiplication. if (lhsRank == 2 && rhsRank == 2) { Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); @@ -376,11 +423,20 @@ class ConvertAtenMatmulOp : public OpConversionPattern { Value zeroTensor = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); - Value matmul = - rewriter - .create(loc, zeroTensor.getType(), - ValueRange{lhs, rhs}, zeroTensor) - .getResult(0); + Value matmul; + if (lhsZeroPoint) { + matmul = rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, + zeroTensor) + .getResult(0); + } else { + matmul = rewriter + .create(loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) + .getResult(0); + } rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } @@ -465,12 +521,24 @@ class ConvertAtenMatmulOp : public OpConversionPattern { rewriter, loc, ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1}, elementType); - Value matmul = - rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor) - .getResult(0); + Value matmul; + if (lhsZeroPoint) { + matmul = rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{broadcastedLhs, broadcastedRhs, + lhsZeroPoint, rhsZeroPoint}, + zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, + matmul); + return success(); + } + matmul = rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor) + .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } @@ -519,13 +587,24 @@ class ConvertAtenMatmulOp : public OpConversionPattern { loc, rewriter.getZeroAttr(elementType)); Value zeroTensor = rewriter.create(loc, c0, initTensor).getResult(0); - - Value batchMatMul = - rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{collapsedLhs, collapsedRhs}, zeroTensor) - .getResult(0); + Value batchMatMul; + + if (lhsZeroPoint) { + batchMatMul = rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{collapsedLhs, collapsedRhs, + lhsZeroPoint, rhsZeroPoint}, + zeroTensor) + .getResult(0); + } else { + batchMatMul = + rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{collapsedLhs, collapsedRhs}, zeroTensor) + .getResult(0); + } Value expandResult = rewriter.create( loc, resultType, batchMatMul, reassociation); rewriter.replaceOpWithNewOp(op, newResultType, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 30684e619cf2..6d93820b20de 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4049,8 +4049,8 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { if (sz == Torch::kUnknownSize || sz < 0) return nullptr; - ShapedType shapedty = - mlir::RankedTensorType::get(sizes, resultTensorType.getDtype()); + ShapedType shapedty = mlir::RankedTensorType::get( + resultTensorType.getSizes(), resultTensorType.getDtype()); auto elementType = shapedty.getElementType(); if (isa(elementType)) { diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 1dc0cc9a9d8d..bff463c4cee6 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -285,9 +285,11 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, QuantizeOperands, - QuantizeOperands, QuantizeTransposedOperands, - QuantizeAccumulator, QuantizeBias>( - context); + QuantizeOperands, + QuantizeTransposedOperands, + QuantizeAccumulator, QuantizeOperands, + QuantizeTransposedOperands, QuantizeAccumulator, + QuantizeBias>(context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 920fc388599d..0f9bca677431 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -150,6 +150,7 @@ 'AtenIntBoolOpModule_basic', 'QuantizedMLP_basic', 'QuantizedSingleLayer_basic', + 'QuantizedBatchedInputSingleLayer_basic', 'QuantizedNoLayer_basic', 'ScalarImplicitFloatModule_basic', 'ScalarImplicitIntModule_basic', @@ -317,6 +318,12 @@ "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "AtenMmQuint8_basic", + "AtenMmQint8_basic", + "AtenMmQMixedSigni8_basic", + "AtenMatmulQMixedSigni8Transpose_basic", + "AtenMatmulQMixedSigni8_basic", + "AtenMatmulQint8MV_basic", + "AtenMatmulQint8_basic", "Conv2dQInt8Module_basic", # Dynamo not supporting conv_tbc @@ -1802,6 +1809,7 @@ "NeIntModule_basic", "QuantizedMLP_basic", "QuantizedSingleLayer_basic", + "QuantizedBatchedInputSingleLayer_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", "SliceEndSleStartModule_basic", @@ -1952,7 +1960,13 @@ "AtenIntTensorCharDtypeModule_basic", "AtenItemFpOpModule_basic", "AtenItemIntOpModule_basic", + "AtenMmQint8_basic", "AtenMmQuint8_basic", + "AtenMmQMixedSigni8_basic", + "AtenMatmulQMixedSigni8Transpose_basic", + "AtenMatmulQMixedSigni8_basic", + "AtenMatmulQint8MV_basic", + "AtenMatmulQint8_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenSubFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 1dea4cbe2e02..24d76b76d769 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -13,19 +13,13 @@ "NativeGroupNormBackwardModule_basic", "QuantizedMLP_basic", "QuantizedSingleLayer_basic", + "QuantizedBatchedInputSingleLayer_basic", + "AtenMatmulQint8MV_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "ElementwiseToDtypeI64ToUI8Module_basic", } -# TODO: Delete once torch 2.1.0 is released -if torch_version_for_comparison() < version.parse("2.1.0.dev"): - COMMON_TORCH_MLIR_LOWERING_XFAILS.update({ - "ScaledDotProductAttentionDifferentModule_basic", - "ScaledDotProductAttentionSameModule_basic" - }) - - def register_all_tests(): """Registers all the built-in E2E tests that Torch-MLIR provides.""" # Side-effecting import statements. diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 80f02a7b5dc8..5ed5dc4dbdbe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -266,7 +266,7 @@ def AtenMmIntTypes_basic(module, tu: TestUtils): # ============================================================================== -class AtenMmQuint8(torch.nn.Module): +class AtenMmQint8(torch.nn.Module): def __init__(self): super().__init__() @@ -278,17 +278,173 @@ def __init__(self): ([4, 3], torch.int8, True), ]) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8) + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.1, 8) + qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) + qy = torch.dequantize(qy) + qz = torch.mm(qx, qy) + return qz + +@register_test_case(module_factory=lambda: AtenMmQint8()) +def AtenMmQint8_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, 3, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class AtenMmQuint8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.uint8, True), + ([4, 3], torch.uint8, True), + ]) + def forward(self, x, y): + qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65) + qx = torch.dequantize(qx) + qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160) qy = torch.dequantize(qy) qz = torch.mm(qx, qy) return qz @register_test_case(module_factory=lambda: AtenMmQuint8()) def AtenMmQuint8_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=0, high=255).to(torch.uint8), + tu.randint(4, 3, low=0, high=255).to(torch.uint8)) + +# ============================================================================== + +class AtenMmQMixedSigni8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int8, True), + ([4, 3], torch.uint8, True), + ]) + def forward(self, x, y): + qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) + qx = torch.dequantize(qx) + qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) + qy = torch.dequantize(qy) + qz = torch.mm(qx, qy) + return qz + +@register_test_case(module_factory=lambda: AtenMmQMixedSigni8()) +def AtenMmQMixedSigni8_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), - tu.randint(4, 3, low=-128, high=127).to(torch.int8)) + tu.randint(4, 3, low=0, high=255).to(torch.uint8)) + +# ============================================================================== + +class AtenMatmulQint8MV(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int8, True), + ([4], torch.int8, True), + ]) + def forward(self, x, y): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) + qx = torch.dequantize(qx) + qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) + qy = torch.dequantize(qy) + qz = torch.matmul(qx, qy) + return qz + +@register_test_case(module_factory=lambda: AtenMatmulQint8MV()) +def AtenMatmulQint8MV_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, low=-128, high=127).to(torch.int8)) + +# ============================================================================== +class AtenMatmulQint8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, -1, 3, 4], torch.int8, True), + ([-1, 4, 3], torch.int8, True), + ]) + def forward(self, x, y): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) + qx = torch.dequantize(qx) + qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) + qy = torch.dequantize(qy) + qz = torch.matmul(qx, qy) + return qz + +@register_test_case(module_factory=lambda: AtenMatmulQint8()) +def AtenMatmulQint8_basic(module, tu: TestUtils): + module.forward(tu.randint(4, 7, 3, 4, low=-128, high=127).to(torch.int8), + tu.randint(7, 4, 3, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class AtenMatmulQMixedSigni8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([7, -1, -1, -1], torch.int8, True), + ([-1, -1, -1], torch.uint8, True), + ]) + def forward(self, x, y): + qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) + qx = torch.dequantize(qx) + qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) + qy = torch.dequantize(qy) + qz = torch.matmul(qx, qy) + return qz + +@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8()) +def AtenMatmulQMixedSigni8_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8), + tu.randint(2, 4, 3, low=0, high=255).to(torch.uint8)) + +# ============================================================================== + +class AtenMatmulQMixedSigni8Transpose(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([7, -1, -1, -1], torch.int8, True), + ([-1, -1, -1], torch.uint8, True), + ]) + def forward(self, x, y): + qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) + qx = torch.dequantize(qx) + qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) + qy = torch.dequantize(qy) + qy = torch.transpose(qy, 1, 2) + qz = torch.matmul(qx, qy) + return qz + +@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose()) +def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 2, 3, 4, low=-128, high=127).to(torch.int8), + tu.randint(2, 6, 4, low=0, high=255).to(torch.uint8)) # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py index 9f217908a3fe..47e8adffdfd8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py @@ -15,6 +15,9 @@ def get_quant_model_input(): return 2 * torch.rand((1, 16)) - 1 +def get_batched_quant_model_input(): + return 2 * torch.rand((1, 2, 16)) - 1 + class QuantizedNoLayer(nn.Module): def __init__(self): super().__init__() @@ -83,6 +86,42 @@ def get_quantized_single_layer(): def QuantizedSingleLayer_basic(module, tu: TestUtils): module.forward(get_quant_model_input()) +class QuantizedBatchedInputSingleLayer(nn.Module): + def __init__(self): + super().__init__() + torch.random.manual_seed(0) + self.layers = nn.Sequential( + nn.Linear(16, 8), + ) + self.quantize = torch.quantization.QuantStub() + self.dequantize = torch.quantization.DeQuantStub() + + @export + @annotate_args([ + None, + ([1, 2, 16], torch.float32, True), + ]) + def forward(self, x): + x = self.quantize(x) + x = self.layers(x) + x = self.dequantize(x) + return x + +def get_batched_quantized_single_layer(): + model = QuantizedBatchedInputSingleLayer() + model.eval() + model.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(model, inplace=True) + torch.manual_seed(0) + for _ in range(32): + model(get_batched_quant_model_input()) + torch.quantization.convert(model, inplace=True) + return model + +@register_test_case(module_factory=get_batched_quantized_single_layer) +def QuantizedBatchedInputSingleLayer_basic(module, tu: TestUtils): + module.forward(get_batched_quant_model_input()) + class QuantizedMLP(nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index dc32918f6515..4aa3716dec12 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -345,13 +345,27 @@ func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vt // CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00 // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8> // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8> - // CHECK: %[[MM:.+]] = torch.aten.mm %[[LMAKE]], %[[RMAKE]] + // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]] // CHECK: return %[[MM]] return %0 : !torch.vtensor<[4,2],si32> } // ----- +// CHECK-LABEL: @test_matmulinteger_batched +func.func @test_matmulinteger_batched(%arg0: !torch.vtensor<[7,4,3],ui8>, %arg1: !torch.vtensor<[3,2],ui8>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[7,4,2],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[7,4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[7,4,2],si32> + // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2 + // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3 + // CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8> + // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8> + // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]] + // CHECK: return %[[MM]] + return %0 : !torch.vtensor<[7,4,2],si32> +} +// ----- + // CHECK-LABEL: func.func @test_mul func.func @test_mul(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>