diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index b4e85812afffcf..cc9dde3d256e0d 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -38,6 +38,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" #include "openvino/op/convert_like.hpp" #include "utils/split.hpp" @@ -159,12 +160,6 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { auto present_k = K; auto present_v = V; - std::shared_ptr alpha; - if (scale == 0.0f) { - alpha = std::make_shared(head_size_node); - } else { - alpha = v0::Constant::create(ov::element::f32, ov::Shape{}, {1.0f / scale}); - } const size_t kv_num_heads_factor = num_heads / kv_num_heads; if (kv_num_heads_factor > 1) { const auto kv_shape = std::make_shared(K); @@ -183,47 +178,43 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { K = std::make_shared(K, extended_kv_shape, false); V = std::make_shared(V, extended_kv_shape, false); } - // compute softmax((Q x K') / sqrt(head_size)) - std::shared_ptr softmax_input = std::make_shared(Q, K, false, true); - softmax_input = std::make_shared(softmax_input, alpha); // need to apply low-triangle mask to attention score. - auto past_seq_len_scalar = std::make_shared(past_sequence_length, one_without_shape, false); - auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); + // two steps, construct the total_sequence x total_sequence triangle, then slice the current length + auto seqlens_1d_scalar = std::make_shared(seqlens_1d, one_without_shape, false); // 12 or 13 std::shared_ptr mask_per_line_node = std::make_shared(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}), seqlens_1d_scalar, one_without_shape, - ov::element::i64); - mask_per_line_node = std::make_shared(mask_per_line_node, zero); - auto minus_inf = v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits::infinity()}); - auto mask_shape = std::make_shared(ov::NodeVector{current_seqlen_size, seqlens_1d}, 0); - auto compare_mask = std::make_shared(mask_per_line_node, mask_shape); - - std::shared_ptr vertical_range = - std::make_shared(past_seq_len_scalar, seqlens_1d_scalar, one_without_shape, ov::element::i64); - vertical_range = std::make_shared(vertical_range, one); - - auto triu = std::make_shared(compare_mask, vertical_range); - auto typed_zero = std::make_shared(zero, softmax_input); - auto typed_minus_inf = std::make_shared(minus_inf, softmax_input); - auto minus_inf_mask = std::make_shared(typed_minus_inf, mask_shape); - auto atten_mask = std::make_shared(triu, minus_inf_mask, typed_zero); - - std::shared_ptr softmax_input_added = std::make_shared(softmax_input, atten_mask); - // softmax((Q x K' + mask) / sqrt(head_size)) - const auto softmax = std::make_shared(softmax_input_added, 3); - - // softmax((Q x K' + mask) / sqrt(head_size)) x V - std::shared_ptr output = std::make_shared(softmax, V); + ov::element::i64); // [0,1,2,...,] + auto hori_range = std::make_shared(mask_per_line_node, zero); // 1x12 or 1x13 + auto vert_range = std::make_shared(mask_per_line_node, one); // 12x1 or 13x1 + auto triu = std::make_shared(hori_range, vert_range); // 12x12 or 13x13 + auto typed_zero = v0::Constant::create(ov::element::f32, ov::Shape{}, {0}); + auto minus_inf = v0::Constant::create(ov::element::f32, ov::Shape{}, {-std::numeric_limits::infinity()}); + auto atten_mask = std::make_shared(triu, minus_inf, typed_zero); // 12x12 or 13x13 + auto atten_mask_sliced = std::make_shared(atten_mask, + past_sequence_length, + seqlens_1d, + one, + zero); // slice to current query seqlen, 12x12 or 1x13 + + // compute softmax((Q x K') / sqrt(head_size)) x V + std::shared_ptr qga_output; + if (scale != 0.0f) { + auto scale_node = v0::Constant::create(ov::element::f32, Shape{}, {scale}); + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, scale_node, false); + } else { + qga_output = std::make_shared(Q, K, V, atten_mask_sliced, false); + } // transpose the result from (batch_size, num_heads, sequence_length, head_size) // to (batch_size, sequence_length, num_heads, head_size) - output = std::make_shared(output, perm); + auto qga_output_transposed = std::make_shared(qga_output, perm); auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}); // reshape the result from (batch_size, sequence_length, num_heads, head_size) // to (batch_size, sequence_length, num_heads * head_size) - output = std::make_shared(output, dim_merge_shape, true); + auto output = std::make_shared(qga_output_transposed, dim_merge_shape, true); return {output, present_k, present_v}; } @@ -254,75 +245,41 @@ std::shared_ptr rotaryEmbedding(ov::Output input, bool interleaved) { auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); - auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - auto cache_shape = std::make_shared(cos_cache); - auto cache_last_dim = get_dimensions(cache_shape, {-1}); - auto cache_1st_dim = get_dimensions(cache_shape, {0}); + auto slice_cache_dim_shape = seqlen_k; - // TODO: check the shape - auto input_shape = std::make_shared(input); + auto cos = std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); - auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); - // auto dim_head_size = get_dimensions(input_shape, {3}); - // half_last_dim is same as cos_cache - std::shared_ptr half_last_dim = cache_last_dim; + if (interleaved) { + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); - auto real_cache_shape = std::make_shared(ov::NodeVector{cache_1st_dim, dim_head_size}, 0); - auto slice_cache_dim_shape = seqlen_k; + auto cache_shape = std::make_shared(cos_cache); + auto cache_last_dim = get_dimensions(cos_cache, {-1}); - // auto end_lens = std::make_shared(half_last_dim, one); - // auto masks = std::make_shared(one, - // zero, - // end_lens, - // op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}), - // ov::op::PadMode::CONSTANT); - auto masks = std::make_shared(one, half_last_dim); + auto input_shape = std::make_shared(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + std::shared_ptr half_last_dim = cache_last_dim; - if (interleaved) { auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); auto split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, two}, 0); auto reshaped_input = std::make_shared(input, split_input_shape, false); - auto first_half = std::make_shared(reshaped_input, zero, one, one, negtive_one); - auto second_half = std::make_shared(reshaped_input, one, two, one, negtive_one); - - auto second_input = std::make_shared(ov::NodeVector{second_half, first_half}, -1); - - auto mask_shape = std::make_shared(ov::NodeVector{half_last_dim, one}, 0); - auto reshaped_mask = std::make_shared(masks, mask_shape, false); - auto negtive_mask = std::make_shared(reshaped_mask); - auto concat_mask = std::make_shared(ov::NodeVector{negtive_mask, reshaped_mask}, -1); - auto real_mask = std::make_shared(concat_mask, dim_head_size, false); - auto mask_f32 = std::make_shared(real_mask, ov::element::f32); - - auto real_input0 = std::make_shared(reshaped_input, input_shape, false); - auto real_input1 = std::make_shared(second_input, input_shape, false); - - auto new_cache_shape = std::make_shared(ov::NodeVector{cache_shape, two}, 0); - auto temp_cache_shape = std::make_shared(ov::NodeVector{cache_shape, one}, 0); - auto cos_cache_reshape = std::make_shared(cos_cache, temp_cache_shape, false); - auto sin_cache_reshape = std::make_shared(sin_cache, temp_cache_shape, false); - auto cos_cache_broadcasted = std::make_shared(cos_cache_reshape, new_cache_shape); - auto sin_cache_broadcasted = std::make_shared(sin_cache_reshape, new_cache_shape); - auto real_cos_input = std::make_shared(cos_cache_broadcasted, real_cache_shape, false); - auto real_sin_input = std::make_shared(sin_cache_broadcasted, real_cache_shape, false); - auto sliced_cos_input = - std::make_shared(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto sliced_sin_input = - std::make_shared(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero); - auto add_input0 = std::make_shared(real_input0, sliced_cos_input); - auto add_input1 = std::make_shared(real_input1, sliced_sin_input); - auto multi_input1 = std::make_shared(add_input1, mask_f32); - auto result = std::make_shared(add_input0, multi_input1); - return result; + + auto in_split = ov::op::util::make_split(reshaped_input, 2, -1); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); + + auto concat_ret = std::make_shared(ov::NodeVector{res_0, res_1}, -1); + return std::make_shared(concat_ret, input_shape, false); } else { - auto cos = - std::make_shared(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); - auto sin = - std::make_shared(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); auto in_split = ov::op::util::make_split(input, 2, -1); - auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), std::make_shared(in_split[1], sin)); - auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), std::make_shared(in_split[1], cos)); + auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), + std::make_shared(in_split[1], sin)); + auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), + std::make_shared(in_split[1], cos)); return std::make_shared(ov::NodeVector{res_0, res_1}, -1); }