Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parser changes to handle MatMulIntegerToFloat #3445

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from

Conversation

TedThemistokleous
Copy link
Collaborator

@TedThemistokleous TedThemistokleous commented Sep 16, 2024

Changes to MatMul parser to handle the Microsoft Contrib operator MatMulintegarToFloat

Since we have the scale and zero points in our operands we can just perform a multiplied after int8 biases are added and then insert a regular dot on the scaled input values which should give the same output as the input data types.

Able to leverage the existing set of tests for matmul

Needs #3526 as there's a bug with dequantizelinear this has uncovered

@TedThemistokleous TedThemistokleous self-assigned this Sep 16, 2024
@TedThemistokleous
Copy link
Collaborator Author

TedThemistokleous commented Sep 16, 2024

TODO:

  • Add Parser tests for err cases
  • Add parser tests for base case
  • Add parser test for bias and zero point cases
  • Add verify tests for all of the above

@TedThemistokleous TedThemistokleous added onnxruntime PR changes interaction between MIGraphX and Onnxruntime Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase UAI labels Sep 16, 2024
Copy link

codecov bot commented Sep 16, 2024

Codecov Report

Attention: Patch coverage is 90.54054% with 7 lines in your changes missing coverage. Please review.

Project coverage is 92.17%. Comparing base (2803cb3) to head (c0c8120).
Report is 4 commits behind head on develop.

Files with missing lines Patch % Lines
src/onnx/parse_matmul.cpp 90.54% 7 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3445      +/-   ##
===========================================
- Coverage    92.17%   92.17%   -0.01%     
===========================================
  Files          512      512              
  Lines        21387    21459      +72     
===========================================
+ Hits         19714    19779      +65     
- Misses        1673     1680       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Updated parser to handle bias case as well as bad scale conditions

Initial float/half tests
bad scale tests
bad bias tests
avoid tidy screaming about complexity
TedThemistokleous and others added 2 commits October 11, 2024 17:45
Use dequantizelinear which elminates the need to add in shifts due to int8/uint8 mismatches

still needs parser tests

if(not(contains(supported_dq_types, scale_arg->get_shape().type())))
{
MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALDED: Scales must be float or half_type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SCALDED?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, would this message be proper for MatMul operator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't reach there as its gated by whether the operator contains the scaled inputs. This variant of MatMul also includes the dequantize to convert the quantized input types to float

{
has_valid_scale_bias = false;

if(args.size() > index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see an index defined in MatMulIntegerToFloat. Is this for some other operator. Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its argument index. We're doing the check here so its done for every arg

MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Bias have same dim as matrix B column");
}

has_valid_scale_bias = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As against invalid? ;-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If scale bias doesn't exist there isn't a bias at the end of the matmulintergertofloat added then.

instruction_ref& zp_a0,
bool no_zp)
{
if(no_zp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems only zp_a0 needs to be in the if clause..?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the second input as its bound by column of the input vector (which is 1-d always).

I broke this out instead of adding it inline to encapsulate logic.

return dequantized_op;
}

static instruction_ref handle_scaled_output(const onnx_parser::node_info& info,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many parameters. Ideally they should be handled by a struct parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're the same amount of a parameters gathered by the operator. These are all needed for dequantize steps and adding the proper unsqueeze->transpose paths. Order matters here with respect to matrix input A or B

@@ -173,12 +333,20 @@ struct parse_matmul : op_parser<parse_matmul>
}

auto is_quant_dot = opd.op_name == "quant_dot";
auto has_scales = opd.op_name == "quant_dot_scaled";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little confusing naming convention. Between quant_dots and Matmul**. And then there is has_scales: which is presumably also a quant_dot.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you suggest I name it? quant_dot_dequant? This operator essentially takes in quantized input and dequantizes the output.

@@ -200,23 +368,50 @@ struct parse_matmul : op_parser<parse_matmul>
auto s0_lens = a0->get_shape().lens();
auto s1_lens = a1->get_shape().lens();

if(not is_quant_dot and args.size() > 2)
if(not is_quant_dot and args.size() > 2 and not has_scales)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be simpler to just check if it is just a dot, instead of looking for quant_dot and quant_dot_scaled, as this clause seems to be doing? Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that can easily be swapped


// Only INT8 or UINT8 type currently supported
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::uint8_type,
migraphx::shape::int8_type};
const auto a0_type = a0->get_shape().type();
const auto a1_type = a1->get_shape().type();

if(is_quant_dot and
if((is_quant_dot or has_scales) and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it has_scales, then it is perhaps not a MATLMULINTEGER: as shown in the exception message.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, simple to just add op.name() here as part of the string. Both MatMulInteger and MatMulIntegerToFloat have the same error on this

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
c0c812
Rate old
275f85
Diff Compare
torchvision-resnet50 64 3,261.05 3,260.33 0.02%
torchvision-resnet50_fp16 64 6,995.46 6,996.93 -0.02%
torchvision-densenet121 32 2,436.48 2,436.13 0.01%
torchvision-densenet121_fp16 32 4,085.19 4,080.65 0.11%
torchvision-inceptionv3 32 1,639.24 1,639.01 0.01%
torchvision-inceptionv3_fp16 32 2,762.50 2,759.60 0.11%
cadene-inceptionv4 16 776.44 775.88 0.07%
cadene-resnext64x4 16 811.17 811.11 0.01%
slim-mobilenet 64 7,533.85 7,539.27 -0.07%
slim-nasnetalarge 64 211.54 211.54 0.00%
slim-resnet50v2 64 3,505.27 3,505.47 -0.01%
bert-mrpc-onnx 8 1,147.46 1,151.85 -0.38%
bert-mrpc-tf 1 462.50 463.50 -0.22%
pytorch-examples-wlang-gru 1 423.10 411.56 2.80%
pytorch-examples-wlang-lstm 1 408.41 388.27 5.19% 🔆
torchvision-resnet50_1 1 780.36 764.51 2.07%
cadene-dpn92_1 1 398.48 402.09 -0.90%
cadene-resnext101_1 1 384.08 381.69 0.62%
onnx-taau-downsample 1 342.65 342.54 0.03%
dlrm-criteoterabyte 1 33.35 33.35 0.01%
dlrm-criteoterabyte_fp16 1 52.79 52.74 0.08%
agentmodel 1 8,480.76 8,186.92 3.59% 🔆
unet_fp16 2 58.90 58.82 0.14%
resnet50v1_fp16 1 969.99 953.22 1.76%
resnet50v1_int8 1 1,016.65 1,003.11 1.35%
bert_base_cased_fp16 64 1,172.15 1,170.57 0.13%
bert_large_uncased_fp16 32 363.63 363.63 -0.00%
bert_large_fp16 1 199.98 201.70 -0.85%
distilgpt2_fp16 16 2,200.56 2,202.28 -0.08%
yolov5s 1 534.60 537.82 -0.60%
tinyllama 1 43.46 43.44 0.06%
vicuna-fastchat 1 170.74 173.47 -1.57%
whisper-tiny-encoder 1 418.63 418.26 0.09%
whisper-tiny-decoder 1 428.98 436.61 -1.75%

Check results before merge 🔆

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase onnxruntime PR changes interaction between MIGraphX and Onnxruntime UAI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants