diff --git a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.cpp index f22b32b23ea407..7661673e764949 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.cpp @@ -24,6 +24,7 @@ #include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/visualize_tree.hpp" #include "transformations/utils/utils.hpp" +#include "openvino/opsets/opset8.hpp" namespace ov { namespace intel_gpu { @@ -42,7 +43,8 @@ KVCacheFusionMatcher::KVCacheFusionMatcher() { auto gather_input = std::make_shared(OutputVector{past, convert_past}); auto beam_idx = wrap_type(); auto gather_past = wrap_type({gather_input, beam_idx, wrap_type()}); - auto concat_past_input = std::make_shared(OutputVector{past, convert_past, gather_past}); + auto gather_convert = wrap_type({gather_past}); + auto concat_past_input = std::make_shared(OutputVector{past, convert_past, gather_past, gather_convert}); auto concat = wrap_type({concat_past_input, any_input()}); auto convert_present = wrap_type({concat}); auto present_input = std::make_shared(OutputVector{concat, convert_present}); @@ -63,8 +65,10 @@ KVCacheFusionMatcher::KVCacheFusionMatcher() { return false; // TODO: Support conversion internally - if (!concat_node || concat_node->get_output_element_type(0) != past_node->get_output_element_type(0)) - return false; + if (ov::is_type(concat_past_input)) { + if (!concat_node || concat_node->get_output_element_type(0) != past_node->get_output_element_type(0)) + return false; + } auto variable = past_node->get_variable(); auto concat_axis = concat_node->get_axis(); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp index 2b0d2ed5eaf145..3f4480eaef0cbb 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp @@ -52,8 +52,13 @@ UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() { auto reshape_b_m = wrap_type({broadcast_b_m, any_input()}, reshape_predicate); auto reshape_c_m = wrap_type({broadcast_c_m, any_input()}, reshape_predicate); + auto convert_reshape_b_m = wrap_type({reshape_b_m}); + auto reshape_b_m_input = std::make_shared(OutputVector{reshape_b_m, convert_reshape_b_m}); + auto convert_reshape_c_m = wrap_type({reshape_c_m}); + auto reshape_c_m_input = std::make_shared(OutputVector{reshape_c_m, convert_reshape_c_m}); + auto sdpa_without_attn_mask_m = wrap_type({ input_a_m, reshape_b_m, reshape_c_m }); - auto sdpa_with_attn_mask_m = wrap_type({ input_a_m, reshape_b_m, reshape_c_m, input_attn_mask }); + auto sdpa_with_attn_mask_m = wrap_type({ input_a_m, reshape_b_m_input, reshape_c_m_input, input_attn_mask }); auto sdpa_with_attn_mask_and_scale_m = wrap_type({ input_a_m, reshape_b_m, reshape_c_m, input_attn_mask, input_scale }); auto sdpa_m = std::make_shared(OutputVector{sdpa_without_attn_mask_m, sdpa_with_attn_mask_m, sdpa_with_attn_mask_and_scale_m});