Skip to content

Commit

Permalink
[GPU][Transformations] Fixed up predicates (#28607)
Browse files Browse the repository at this point in the history
### Details:
 - *Fixed up predicates*

Signed-off-by: Evgeniia Nugmanova <[email protected]>
  • Loading branch information
jane-intel authored Jan 22, 2025
1 parent db41516 commit 8a7c974
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ConvertStridedSlicesToVariadicSplit::ConvertStridedSlicesToVariadicSplit() {
return false;
user_count++;
}
return (user_count == num_users_to_fuse) && consumers_count(num_users_to_fuse);
return (user_count == num_users_to_fuse) && consumers_count(num_users_to_fuse)(output);
};

auto data_m = any_input();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,16 @@ UnsqueezeBroadcastReshapeMatmulFusion::UnsqueezeBroadcastReshapeMatmulFusion() {
return ov::as_type_ptr<ov::op::v1::Reshape>(output.get_node_shared_ptr()) == nullptr;
};

auto unsqueeze_predicate = [](const ov::Output<ov::Node>& output) -> bool {
return rank_equals(5)(output) && consumers_count(1);
};
auto unsqueeze_predicate = rank_equals(5) && consumers_count(1);

auto broadcast_predicate = [](const ov::Output<ov::Node>& output) -> bool {
const auto broadcast = ov::as_type_ptr<ov::op::v3::Broadcast>(output.get_node_shared_ptr());
if (!broadcast || broadcast->get_broadcast_spec().m_type != ov::op::BroadcastType::BIDIRECTIONAL)
return false;
return rank_equals(5)(output) && consumers_count(1);
return rank_equals(5)(output) && consumers_count(1)(output);
};

auto reshape_predicate = [](const ov::Output<ov::Node>& output) -> bool {
return rank_equals(4)(output) && consumers_count(1);
};
auto reshape_predicate = rank_equals(4) && consumers_count(1);

auto input_a_m = any_input(not_reshape);
auto input_b_m = wrap_type<ov::intel_gpu::op::KVCache>({any_input(), any_input()});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,16 @@ using ov::pass::pattern::op::Or;
UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() {
using namespace ov::pass::pattern;

auto unsqueeze_predicate = [](const ov::Output<ov::Node>& output) -> bool {
return rank_equals(5)(output) && consumers_count(1);
};
auto unsqueeze_predicate = rank_equals(5) && consumers_count(1);

auto broadcast_predicate = [](const ov::Output<ov::Node>& output) -> bool {
const auto broadcast = ov::as_type_ptr<ov::op::v3::Broadcast>(output.get_node_shared_ptr());
if (!broadcast || broadcast->get_broadcast_spec().m_type != ov::op::BroadcastType::BIDIRECTIONAL)
return false;
return rank_equals(5)(output) && consumers_count(1);
return rank_equals(5)(output) && consumers_count(1)(output);
};

auto reshape_predicate = [](const ov::Output<ov::Node>& output) -> bool {
return rank_equals(4)(output) && consumers_count(1);
};
auto reshape_predicate = rank_equals(4) && consumers_count(1);

auto input_a_m = any_input();
auto input_attn_mask = any_input();
Expand Down

0 comments on commit 8a7c974

Please sign in to comment.