Skip to content

Commit

Permalink
Recompose QLinearMatMul and remove Quantize-Dequantize pairs (#2875)
Browse files Browse the repository at this point in the history
* Recompose QLinearMatMul and remove Quantize-Dequantize pairs

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Jul 16, 2024
1 parent 7879d17 commit 4a241ef
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
return layerNormOp.getY();
}

Value OnnxBuilder::qlinearMatMul(Type outputType, Value a, Value aScale,
Value aZeroPoint, Value b, Value bScale, Value bZeroPoint, Value yScale,
Value yZeroPoint) const {
return createOpAndInferShapes<ONNXQLinearMatMulOp>(toTensor(outputType),
toTensor(a), toTensor(aScale), toTensor(aZeroPoint), toTensor(b),
toTensor(bScale), toTensor(bZeroPoint), toTensor(yScale),
toTensor(yZeroPoint));
}

Value OnnxBuilder::RMSLayerNorm(Type outputType, Value input, Value scale,
Value bias, int64_t axis, FloatAttr epsilon) const {
IntegerAttr axisAttr = getSignedInt64Attr(axis);
Expand Down
6 changes: 6 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ struct OnnxBuilder : DialectBuilder {
mlir::Value scale, mlir::Value bias, int64_t axis,
mlir::FloatAttr epsilon) const;

// ONNXQLinearMatMulOp
mlir::Value qlinearMatMul(mlir::Type outputType, mlir::Value a,
mlir::Value aScale, mlir::Value aZeroPoint, mlir::Value b,
mlir::Value bScale, mlir::Value bZeroPoint, mlir::Value yScale,
mlir::Value yZeroPoint) const;

// ONNXRMSLayerNormalizationOp, version with one output only (Y).
mlir::Value RMSLayerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,7 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace",

def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear",
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX DequantizeLinear operation";
let description = [{
The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the full precision tensor.
Expand Down
6 changes: 6 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1858,3 +1858,9 @@ void ONNXWhereOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<AlwaysFalseWherePattern>(context);
}

// on the ONNXDequantizeLinearOp.
void ONNXDequantizeLinearOp::getCanonicalizationPatterns(
RewritePatternSet &result, MLIRContext *context) {
result.insert<QuantizeDequantizePattern>(context);
}
11 changes: 11 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.td
Original file line number Diff line number Diff line change
Expand Up @@ -1055,4 +1055,15 @@ def AlwaysFalseWherePattern : Pat<
[(IsNegativeSplatConstant:$negative_constant), (AreAllDimSizes:$dims)]
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXDequantizeLinear
//===----------------------------------------------------------------------===//

// Convert QuantizeLinear+DequantizeLinear to Identity.
def QuantizeDequantizePattern: Pat<
(ONNXDequantizeLinearOp (ONNXQuantizeLinearOp $x, $x_scale, $x_zeropoint, $x_axis, $x_saturate),
$y_scale, $y_zeropoint, $y_axis),
(replaceWithValue $x)
>;

#endif // ONNX_REWRITE
71 changes: 71 additions & 0 deletions src/Dialect/ONNX/Transforms/Recompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,65 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
}
};

struct RecomposeQLinearMatMulFromQuantizeLinearPattern
: public OpRewritePattern<ONNXQuantizeLinearOp> {
using OpRewritePattern<ONNXQuantizeLinearOp>::OpRewritePattern;

LogicalResult matchAndRewrite(
ONNXQuantizeLinearOp qlOp, PatternRewriter &rewriter) const final {
using namespace onnx_mlir;
Location loc = qlOp.getLoc();
// Match
Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint;
if (!matchQLinearMatMulPattern(qlOp, a, aScale, aZeroPoint, b, bScale,
bZeroPoint, outScale, outZeroPoint))
return failure();

// Replace
MultiDialectBuilder<OnnxBuilder> create(rewriter, loc);
Value res = create.onnx.qlinearMatMul(qlOp.getY().getType(), a, aScale,
aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint);

rewriter.replaceOp(qlOp, res);
return success();
}

// Recompose QLinearMatMul, starting from QuantizeLinear.
// Pattern: DequanizeLinear + MatMul + QuantizeLinear.
static bool matchQLinearMatMulPattern(ONNXQuantizeLinearOp op, Value &a,
Value &aScale, Value &aZeroPoint, Value &b, Value &bScale,
Value &bZeroPoint, Value &outScale, Value &outZeroPoint) {
Operation *quantizeOp = op.getOperation();
outScale = op.getYScale();
outZeroPoint = op.getYZeroPoint();
// Matching MatMul.
Value qlX, matA, matB;
Operation *matmulOp;
bool matchMatMul = onnx_mlir::operandOfOpDefinedBy<ONNXMatMulOp>(
matmulOp, quantizeOp, qlX, 0);
if (!matchMatMul)
return false;
matA = cast<ONNXMatMulOp>(matmulOp).getA();
matB = cast<ONNXMatMulOp>(matmulOp).getB();
// Matching input A of MatMul.
auto dlOpA = matA.getDefiningOp<ONNXDequantizeLinearOp>();
if (!dlOpA)
return false;
a = dlOpA.getX();
aScale = dlOpA.getXScale();
aZeroPoint = dlOpA.getXZeroPoint();
// Matching input B of MatMul.
auto dlOpB = matB.getDefiningOp<ONNXDequantizeLinearOp>();
if (!dlOpB)
return false;
b = dlOpB.getX();
bScale = dlOpB.getXScale();
bZeroPoint = dlOpB.getXZeroPoint();
// Matched the pattern.
return true;
}
};

struct RecomposeONNXToONNXPass
: public PassWrapper<RecomposeONNXToONNXPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RecomposeONNXToONNXPass)
Expand Down Expand Up @@ -387,6 +446,17 @@ void RecomposeONNXToONNXPass::runOnOperation() {
op, x, scale, axis, epsilon, isRMSLayerNorm);
});

// Recompose QLinearMatMul, starting from QuantizeLinear.
// Pattern: DequanizeLinear + MatMul + QuantizeLinear.
target.addDynamicallyLegalOp<ONNXQuantizeLinearOp>(
[](ONNXQuantizeLinearOp op) {
Value a, aScale, aZeroPoint, b, bScale, bZeroPoint, outScale,
outZeroPoint;
return !RecomposeQLinearMatMulFromQuantizeLinearPattern::
matchQLinearMatMulPattern(op, a, aScale, aZeroPoint, b, bScale,
bZeroPoint, outScale, outZeroPoint);
});

RewritePatternSet patterns(context);
onnx_mlir::getRecomposeONNXToONNXPatterns(patterns);

Expand All @@ -400,6 +470,7 @@ void onnx_mlir::getRecomposeONNXToONNXPatterns(
mlir::RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.insert<RecomposeLayerNormFromMulPattern>(context);
patterns.insert<RecomposeQLinearMatMulFromQuantizeLinearPattern>(context);
}

/*!
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/driver/compile_phases.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir %s | FileCheck %s
// RUN: onnx-mlir %s -o %t| FileCheck %s && rm %t.so

// CHECK: [1/5] {{.*}} Importing ONNX Model to MLIR Module
// CHECK: [2/5] {{.*}} Compiling and Optimizing MLIR Module
Expand Down
24 changes: 24 additions & 0 deletions test/mlir/driver/static_quantization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: onnx-mlir --printIR --EmitONNXIR %s | FileCheck %s

// COM: Check that Dequantize-MatMul-Quantize is always recomposed to QLinearMatMul before the removal of Quantize-Dequantize is applied.
// COM: Otherwise, the recomposition of QLinearMatMul failed due to pattern mismatched (lack of DequantizeLinear).
module {
func.func @qlinear_matmul(%arg0: tensor<?x?x768xf32>, %arg1: tensor<f32>, %arg2: tensor<i8>, %arg3: tensor<768x768xi8>, %arg4: tensor<f32>, %arg5: tensor<i8>, %arg6: tensor<f32>, %arg7: tensor<i8>) -> (tensor<?x?x768xi8>) {
%0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
%1 = "onnx.DequantizeLinear"(%0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>
%2 = "onnx.DequantizeLinear"(%arg3, %arg4, %arg5) {axis = 1 : si64} : (tensor<768x768xi8>, tensor<f32>, tensor<i8>) -> tensor<768x768xf32>
%3 = "onnx.MatMul"(%1, %2) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
%4 = "onnx.QuantizeLinear"(%3, %arg6, %arg7) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
return %4: tensor<?x?x768xi8>

}
"onnx.EntryPoint"() {func = @main_graph} : () -> ()

// CHECK-LABEL: func.func @qlinear_matmul
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>, [[PARAM_3_:%.+]]: tensor<768x768xi8>, [[PARAM_4_:%.+]]: tensor<f32>, [[PARAM_5_:%.+]]: tensor<i8>, [[PARAM_6_:%.+]]: tensor<f32>, [[PARAM_7_:%.+]]: tensor<i8>) -> tensor<?x?x768xi8> {
// CHECK: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 1 : si64, onnx_node_name = "onnx.QuantizeLinear_0", saturate = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// CHECK: [[VAR_1_:%.+]] = "onnx.QLinearMatMul"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]]) {onnx_node_name = "onnx.QLinearMatMul_1"} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>, tensor<768x768xi8>, tensor<f32>, tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// CHECK: return [[VAR_1_]] : tensor<?x?x768xi8>
// CHECK: }
// CHECK: "onnx.EntryPoint"() {func = @main_graph} : () -> ()
}
15 changes: 15 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1825,3 +1825,18 @@ func.func @test_where_with_always_false_3(%arg0: tensor<?x?xi64>) -> tensor<2xi6
// CHECK: onnx.Return [[VAR_6_]] : tensor<2xi64>
// CHECK: }
}

// -----

func.func @test_dequantize_linear(%arg0: tensor<?x?x768xf32>, %arg1: tensor<f32>, %arg2: tensor<i8>) -> (tensor<?x?x768xf32>) {
%0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
%1 = "onnx.DequantizeLinear"(%0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>
return %1: tensor<?x?x768xf32>

// CHECK-LABEL: func.func @test_dequantize_linear
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xf32>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>) -> tensor<?x?x768xf32> {
// CHECK-NOT: "onnx.QuantizeLinear"
// CHECK-NOT: "onnx.DequantizeLinear"
// CHECK: return [[PARAM_0_]] : tensor<?x?x768xf32>
// CHECK: }
}
16 changes: 16 additions & 0 deletions test/mlir/onnx/onnx_recompose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,19 @@ func.func @rms_layer_norm_v2(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>,
// CHECK: }
}

// -----

// COM: QLinearMatMul
func.func @qlinear_matmul(%arg0: tensor<?x?x768xi8>, %arg1: tensor<f32>, %arg2: tensor<i8>, %arg3: tensor<768x768xi8>, %arg4: tensor<f32>, %arg5: tensor<i8>, %arg6: tensor<f32>, %arg7: tensor<i8>) -> (tensor<?x?x768xi8>) {
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>
%1 = "onnx.DequantizeLinear"(%arg3, %arg4, %arg5) {axis = 1 : si64} : (tensor<768x768xi8>, tensor<f32>, tensor<i8>) -> tensor<768x768xf32>
%2 = "onnx.MatMul"(%0, %1) : (tensor<?x?x768xf32>, tensor<768x768xf32>) -> tensor<?x?x768xf32>
%3 = "onnx.QuantizeLinear"(%2, %arg6, %arg7) {axis = 1 : si64} : (tensor<?x?x768xf32>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
return %3: tensor<?x?x768xi8>

// CHECK-LABEL: func.func @qlinear_matmul
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x768xi8>, [[PARAM_1_:%.+]]: tensor<f32>, [[PARAM_2_:%.+]]: tensor<i8>, [[PARAM_3_:%.+]]: tensor<768x768xi8>, [[PARAM_4_:%.+]]: tensor<f32>, [[PARAM_5_:%.+]]: tensor<i8>, [[PARAM_6_:%.+]]: tensor<f32>, [[PARAM_7_:%.+]]: tensor<i8>) -> tensor<?x?x768xi8> {
// CHECK: [[VAR_0_:%.+]] = "onnx.QLinearMatMul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]]) : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>, tensor<768x768xi8>, tensor<f32>, tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xi8>
// CHECK: return [[VAR_0_]] : tensor<?x?x768xi8>
// CHECK: }
}
1 change: 1 addition & 0 deletions utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@
"Cast",
"Constant",
"DepthToSpace",
"DequantizeLinear",
"Div",
"Dropout",
"Equal",
Expand Down

0 comments on commit 4a241ef

Please sign in to comment.