Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch invalid broadcasts in pointwise reduce fusion #3659

Merged
merged 8 commits into from
Nov 29, 2024
Merged

Conversation

shivadbhavsar
Copy link
Contributor

No description provided.

@shivadbhavsar shivadbhavsar added the bugfix Fixes a bug found in the code. label Nov 26, 2024
@shivadbhavsar shivadbhavsar self-assigned this Nov 26, 2024
auto bstrides = b->get_shape().strides();

return std::all_of(
reduce_axes.begin(), reduce_axes.end(), [&](auto a) { return bstrides.at(a) == 0; });
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to check if the non-reduce axes are not being broadcasted as well.

auto axes = reduce->get_operator().to_value().at("axes").to_vector<size_t>();
auto broadcast = r.instructions["broadcast"];
auto fbroadcast = r.instructions["final_broadcast"];
if(not(is_valid_broadcast(broadcast, axes) and is_valid_broadcast(fbroadcast, axes)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should just check "broadcast" and not "final_broadcast"(which is the broadcast after contiguous if there is one).

match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{});
match::find_matches(mpm, find_reduce_reduce{});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this moved out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because when it hits the broadcast condition, its not fusing the other fused_reduce in the input

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to add this to the matcher then.

auto broadcast = r.instructions["broadcast"];
if(not is_valid_broadcast(broadcast, axes))
return;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check needs to be applied to all matchers.

Copy link

codecov bot commented Nov 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.20%. Comparing base (162b008) to head (4433e9b).
Report is 5 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3659      +/-   ##
===========================================
+ Coverage    92.18%   92.20%   +0.02%     
===========================================
  Files          513      513              
  Lines        21596    21653      +57     
===========================================
+ Hits         19908    19965      +57     
  Misses        1688     1688              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 26, 2024

So we could create a matcher like this to check the axes:

template<class M>
static auto match_broadcast_axes(M m)
{
    return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
        optional<instruction_ref> result = m.match(ctx, ins);
        if(contains(ctx.instructions, "broadcast"))
        {
            auto axes      = ins->get_operator().to_value().at("axes").to_vector<size_t>();
            auto broadcast = r.instructions["broadcast"];
            if(not is_valid_broadcast(broadcast, axes))
                return nullopt;
        }
        return result;
    });
}

And then we could change match_broadcastable_input to use it like this:

static auto match_broadcastable_input(const std::string& op, const std::string& name)
{
    auto match_op                 = match::name(op)(used_once_except_broadcast()).bind(name);
    auto match_op_input           = any_input(match_op, match::used_once());
    auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
    return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input));
}

I dont know if there is a better way of doing this.

@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 27, 2024

Maybe a verify test could be added for this as well.

@shivadbhavsar
Copy link
Contributor Author

Maybe a verify test could be added for this as well.

how can I force the verify test to run with MIGRAPHX_DISABLE_LAYERNORM_FUSION=1? or will CI do that at some point?

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
4433e9
Rate old
162b00
Diff Compare
torchvision-resnet50 64 3,256.58 3,257.53 -0.03%
torchvision-resnet50_fp16 64 6,988.90 6,981.74 0.10%
torchvision-densenet121 32 2,434.33 2,435.44 -0.05%
torchvision-densenet121_fp16 32 4,056.93 4,088.61 -0.77%
torchvision-inceptionv3 32 1,629.29 1,629.88 -0.04%
torchvision-inceptionv3_fp16 32 2,742.90 2,745.06 -0.08%
cadene-inceptionv4 16 764.40 764.50 -0.01%
cadene-resnext64x4 16 810.75 810.97 -0.03%
slim-mobilenet 64 7,390.45 7,464.92 -1.00%
slim-nasnetalarge 64 208.49 208.49 0.00%
slim-resnet50v2 64 3,443.34 3,440.82 0.07%
bert-mrpc-onnx 8 1,149.61 1,145.86 0.33%
bert-mrpc-tf 1 468.61 468.54 0.01%
pytorch-examples-wlang-gru 1 417.22 418.01 -0.19%
pytorch-examples-wlang-lstm 1 408.92 408.51 0.10%
torchvision-resnet50_1 1 771.66 778.29 -0.85%
cadene-dpn92_1 1 396.33 396.81 -0.12%
cadene-resnext101_1 1 382.10 382.45 -0.09%
onnx-taau-downsample 1 345.54 345.96 -0.12%
dlrm-criteoterabyte 1 33.33 33.33 -0.01%
dlrm-criteoterabyte_fp16 1 52.71 52.75 -0.07%
agentmodel 1 8,127.97 8,309.93 -2.19%
unet_fp16 2 58.86 58.83 0.06%
resnet50v1_fp16 1 956.61 942.38 1.51%
resnet50v1_int8 1 1,014.63 1,025.97 -1.11%
bert_base_cased_fp16 64 1,170.12 1,170.17 -0.00%
bert_large_uncased_fp16 32 363.32 363.14 0.05%
bert_large_fp16 1 198.84 198.79 0.02%
distilgpt2_fp16 16 2,200.20 2,200.72 -0.02%
yolov5s 1 532.38 532.21 0.03%
tinyllama 1 43.43 43.63 -0.44%
vicuna-fastchat 1 173.03 173.33 -0.17%
whisper-tiny-encoder 1 417.77 417.76 0.00%
whisper-tiny-decoder 1 435.04 428.46 1.53%

This build is OK for merge ✅

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 27, 2024

how can I force the verify test to run with MIGRAPHX_DISABLE_LAYERNORM_FUSION=1? or will CI do that at some point?

Well I am hoping to remove layernorm fusion soon in #3465. For now the verify tests can run with the layernorm fusion, and when it gets removed in #3465, this test will check that it works correctly.

@causten causten merged commit 241e24e into develop Nov 29, 2024
43 of 44 checks passed
@causten causten deleted the fix_reduce_fusion branch November 29, 2024 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bugfix Fixes a bug found in the code.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants