Skip to content

Commit 2d7c4ec

Browse files
authored
Merge pull request #1605 from ROCm/fix_dequantizelinear_bp
[BACKPORT] Fix dequantizelinear definition Backport a fix to the definition of dequantizelinear so that it can be used in MIGraphX 6.2.1
2 parents 2d2ac48 + 1324731 commit 2d7c4ec

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,11 @@ struct SoftmaxConverter final
837837
};
838838
} // namespace
839839

840-
// MIGraphX pseudo code:
841-
// output[i] = static_cast<T>(input[i] - zero_pts[i]) * scales[i];
840+
// MIGraphX implements:
841+
// Let T = scale element type
842+
// output[i] = (convert<T>(input[i]) - convert<T>(zero_pts[i])) * scales[i];
843+
// For f32, this matches ONNX reference, dequantizing to f16, if it's ever done
844+
// will be less precise than the reference but that's probably fine.
842845
LogicalResult DeQuantizeLinearConverter::matchAndRewrite(
843846
migraphx::DeQuantizeLinearOp op, OpAdaptor adaptor,
844847
ConversionPatternRewriter &rewriter) const {
@@ -847,34 +850,25 @@ LogicalResult DeQuantizeLinearConverter::matchAndRewrite(
847850
Value output = op.getOutput();
848851
Location loc = op->getLoc();
849852

850-
Value shifted = input;
853+
Type outputType = getShapedElementTy(output);
854+
Value upcastInput = createCastOp(rewriter, loc, outputType, input);
855+
856+
Value shifted = upcastInput;
851857
if (auto bias = adaptor.getBias()) {
852-
Type inElemTy = getShapedElementTy(input);
853-
Type biasElemTy = getShapedElementTy(bias);
854-
Type elementType =
855-
inElemTy.getIntOrFloatBitWidth() <= biasElemTy.getIntOrFloatBitWidth()
856-
? biasElemTy
857-
: inElemTy;
858-
if (inElemTy != elementType)
859-
input = createCastOp(rewriter, loc, elementType, shifted);
860-
if (biasElemTy != elementType)
861-
bias = createCastOp(rewriter, loc, elementType, bias);
862-
shifted =
863-
createOpAndInfer<tosa::SubOp>(rewriter, loc, elementType, input, bias);
858+
Value upcastBias = createCastOp(rewriter, loc, outputType, bias);
859+
shifted = createOpAndInfer<tosa::SubOp>(rewriter, loc, outputType,
860+
upcastInput, upcastBias);
864861
}
865862

866-
Type outputType = getShapedElementTy(output);
867-
Value upCast = createCastOp(rewriter, loc, outputType, shifted);
868-
869863
Value scaled = createOpAndInfer<tosa::MulOp>(rewriter, loc, outputType,
870-
upCast, scale, /*shift=*/0);
864+
shifted, scale, /*shift=*/0);
871865

872866
rewriter.replaceOp(op, scaled);
873867
return success();
874868
}
875869

876870
// MIGraphX pseudo code:
877-
// int64_t quantized = static_cast<int32>(
871+
// int32_t quantized = static_cast<int32>(
878872
// std::round(input[i] / scales[i])) + zero_pts[i];
879873
// output[i] = std::max(-128, std::min(127, quantized));
880874
LogicalResult QuantizeLinearConverter::matchAndRewrite(

mlir/test/Conversion/MIGraphXToTosa/mixr-to-tosa-ops.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,29 @@ module {
2828
}
2929

3030
// CHECK-LABEL: func @dequantize_scale_bias
31-
// CHECK: tosa.sub
3231
// CHECK: tosa.cast{{.*}}f32
32+
// CHECK: tosa.cast{{.*}}f32
33+
// CHECK: tosa.sub
3334
// CHECK: tosa.mul
3435
func.func @dequantize_scale_bias(%arg: !migraphx.shaped<1x112x112x64xi32, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi32, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
3536
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi32, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi32, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
3637
return %1 : !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1>
3738
}
3839

3940
// CHECK-LABEL: func @dequantize_wide_bias
40-
// CHECK: tosa.cast{{.*}}i32
41-
// CHECK: tosa.sub{{.*}}i32
4241
// CHECK: tosa.cast{{.*}}f32
42+
// CHECK: tosa.cast{{.*}}f32
43+
// CHECK: tosa.sub{{.*}}f32
4344
// CHECK: tosa.mul
4445
func.func @dequantize_wide_bias(%arg: !migraphx.shaped<1x112x112x64xi8, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi32, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
4546
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi8, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi32, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
4647
return %1 : !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1>
4748
}
4849

4950
// CHECK-LABEL: func @dequantize_wide_input
50-
// CHECK: tosa.cast{{.*}}i32
51-
// CHECK: tosa.sub{{.*}}i32
5251
// CHECK: tosa.cast{{.*}}f32
52+
// CHECK: tosa.cast{{.*}}f32
53+
// CHECK: tosa.sub{{.*}}f32
5354
// CHECK: tosa.mul
5455
func.func @dequantize_wide_input(%arg: !migraphx.shaped<1x112x112x64xi32, 802816x7168x64x1>, %scale: !migraphx.shaped<64xf32, 1>, %bias: !migraphx.shaped<64xi8, 1>) -> !migraphx.shaped<1x112x112x64xf32, 802816x7168x64x1> attributes {kernel = "mixr"} {
5556
%1 = migraphx.dequantizelinear %arg, %scale, %bias : <1x112x112x64xi32, 802816x7168x64x1>, <64xf32, 1>, !migraphx.shaped<64xi8, 1> -> <1x112x112x64xf32, 802816x7168x64x1>
@@ -142,8 +143,9 @@ module {
142143

143144
// CHECK-LABEL: func @conv_with_quant
144145
// CHECK: tosa.conv2d{{.*}} quantization_info
145-
// CHECK: tosa.sub
146146
// CHECK: tosa.cast
147+
// CHECK: tosa.cast
148+
// CHECK: tosa.sub
147149
// CHECK: tosa.mul
148150
// CHECK: tosa.reciprocal
149151
// CHECK: tosa.mul

0 commit comments

Comments
 (0)