-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Catch invalid broadcasts in pointwise reduce fusion #3659
Conversation
src/fuse_reduce.cpp
Outdated
auto bstrides = b->get_shape().strides(); | ||
|
||
return std::all_of( | ||
reduce_axes.begin(), reduce_axes.end(), [&](auto a) { return bstrides.at(a) == 0; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to check if the non-reduce axes are not being broadcasted as well.
src/fuse_reduce.cpp
Outdated
auto axes = reduce->get_operator().to_value().at("axes").to_vector<size_t>(); | ||
auto broadcast = r.instructions["broadcast"]; | ||
auto fbroadcast = r.instructions["final_broadcast"]; | ||
if(not(is_valid_broadcast(broadcast, axes) and is_valid_broadcast(fbroadcast, axes))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should just check "broadcast" and not "final_broadcast"(which is the broadcast after contiguous if there is one).
src/fuse_reduce.cpp
Outdated
match::find_matches( | ||
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{}); | ||
match::find_matches(mpm, find_reduce_pointwise{}, find_pointwise_reduce{}); | ||
match::find_matches(mpm, find_reduce_reduce{}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this moved out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because when it hits the broadcast condition, its not fusing the other fused_reduce in the input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to add this to the matcher then.
src/fuse_reduce.cpp
Outdated
auto broadcast = r.instructions["broadcast"]; | ||
if(not is_valid_broadcast(broadcast, axes)) | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check needs to be applied to all matchers.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #3659 +/- ##
===========================================
+ Coverage 92.18% 92.20% +0.02%
===========================================
Files 513 513
Lines 21596 21653 +57
===========================================
+ Hits 19908 19965 +57
Misses 1688 1688 ☔ View full report in Codecov by Sentry. |
So we could create a matcher like this to check the axes: template<class M>
static auto match_broadcast_axes(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
optional<instruction_ref> result = m.match(ctx, ins);
if(contains(ctx.instructions, "broadcast"))
{
auto axes = ins->get_operator().to_value().at("axes").to_vector<size_t>();
auto broadcast = r.instructions["broadcast"];
if(not is_valid_broadcast(broadcast, axes))
return nullopt;
}
return result;
});
} And then we could change static auto match_broadcastable_input(const std::string& op, const std::string& name)
{
auto match_op = match::name(op)(used_once_except_broadcast()).bind(name);
auto match_op_input = any_input(match_op, match::used_once());
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
return match::any_of(match_op_input, match_broadcast_axes(broadcast_match_op_input));
} I dont know if there is a better way of doing this. |
Maybe a verify test could be added for this as well. |
how can I force the verify test to run with |
This build is OK for merge ✅ |
🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output |
Well I am hoping to remove layernorm fusion soon in #3465. For now the verify tests can run with the layernorm fusion, and when it gets removed in #3465, this test will check that it works correctly. |
No description provided.