@@ -837,8 +837,11 @@ struct SoftmaxConverter final
837
837
};
838
838
} // namespace
839
839
840
- // MIGraphX pseudo code:
841
- // output[i] = static_cast<T>(input[i] - zero_pts[i]) * scales[i];
840
+ // MIGraphX implements:
841
+ // Let T = scale element type
842
+ // output[i] = (convert<T>(input[i]) - convert<T>(zero_pts[i])) * scales[i];
843
+ // For f32, this matches ONNX reference, dequantizing to f16, if it's ever done
844
+ // will be less precise than the reference but that's probably fine.
842
845
LogicalResult DeQuantizeLinearConverter::matchAndRewrite (
843
846
migraphx::DeQuantizeLinearOp op, OpAdaptor adaptor,
844
847
ConversionPatternRewriter &rewriter) const {
@@ -847,34 +850,25 @@ LogicalResult DeQuantizeLinearConverter::matchAndRewrite(
847
850
Value output = op.getOutput ();
848
851
Location loc = op->getLoc ();
849
852
850
- Value shifted = input;
853
+ Type outputType = getShapedElementTy (output);
854
+ Value upcastInput = createCastOp (rewriter, loc, outputType, input);
855
+
856
+ Value shifted = upcastInput;
851
857
if (auto bias = adaptor.getBias ()) {
852
- Type inElemTy = getShapedElementTy (input);
853
- Type biasElemTy = getShapedElementTy (bias);
854
- Type elementType =
855
- inElemTy.getIntOrFloatBitWidth () <= biasElemTy.getIntOrFloatBitWidth ()
856
- ? biasElemTy
857
- : inElemTy;
858
- if (inElemTy != elementType)
859
- input = createCastOp (rewriter, loc, elementType, shifted);
860
- if (biasElemTy != elementType)
861
- bias = createCastOp (rewriter, loc, elementType, bias);
862
- shifted =
863
- createOpAndInfer<tosa::SubOp>(rewriter, loc, elementType, input, bias);
858
+ Value upcastBias = createCastOp (rewriter, loc, outputType, bias);
859
+ shifted = createOpAndInfer<tosa::SubOp>(rewriter, loc, outputType,
860
+ upcastInput, upcastBias);
864
861
}
865
862
866
- Type outputType = getShapedElementTy (output);
867
- Value upCast = createCastOp (rewriter, loc, outputType, shifted);
868
-
869
863
Value scaled = createOpAndInfer<tosa::MulOp>(rewriter, loc, outputType,
870
- upCast , scale, /* shift=*/ 0 );
864
+ shifted , scale, /* shift=*/ 0 );
871
865
872
866
rewriter.replaceOp (op, scaled);
873
867
return success ();
874
868
}
875
869
876
870
// MIGraphX pseudo code:
877
- // int64_t quantized = static_cast<int32>(
871
+ // int32_t quantized = static_cast<int32>(
878
872
// std::round(input[i] / scales[i])) + zero_pts[i];
879
873
// output[i] = std::max(-128, std::min(127, quantized));
880
874
LogicalResult QuantizeLinearConverter::matchAndRewrite (
0 commit comments