diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index 7033d813b7e671..1441e3e67a633d 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -86,8 +86,9 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( auto cos_cache = node->input_value(7); auto sin_cache = node->input_value(8); - auto T = Q.get_element_type(); + // The length of all tokens (past + current) is `seqlens_k` + 1, current = Q.shape[2], past = `seqlens_k` + 1 - current + const auto T = Q.get_element_type(); const auto node_shape = std::make_shared(Q); const auto batch_size = get_dimensions(node_shape, {0}); const auto current_seqlen_size = get_dimensions(node_shape, {1}); @@ -262,12 +263,20 @@ std::shared_ptr rotaryEmbedding(ov::Output input, auto reshaped_input = std::make_shared(input, split_input_shape, false); auto in_split = make_split(reshaped_input, 2, -1); - auto res_0 = std::make_shared(std::make_shared(in_split[0], cos), - std::make_shared(in_split[1], sin)); - auto res_1 = std::make_shared(std::make_shared(in_split[0], sin), - std::make_shared(in_split[1], cos)); + split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim}, 0); + auto in_split_0 = std::make_shared(in_split[0], split_input_shape, false); + auto in_split_1 = std::make_shared(in_split[1], split_input_shape, false); + + auto res_0 = std::make_shared(std::make_shared(in_split_0, cos), + std::make_shared(in_split_1, sin)); + auto res_1 = std::make_shared(std::make_shared(in_split_0, sin), + std::make_shared(in_split_1, cos)); + + split_input_shape = std::make_shared(ov::NodeVector{dim_bns, half_last_dim, one}, 0); + auto res_0_5d = std::make_shared(res_0, split_input_shape, false); + auto res_1_5d = std::make_shared(res_1, split_input_shape, false); - auto concat_ret = std::make_shared(ov::NodeVector{res_0, res_1}, -1); + auto concat_ret = std::make_shared(ov::NodeVector{res_0_5d, res_1_5d}, -1); return std::make_shared(concat_ret, input_shape, false); } else { auto in_split = make_split(input, 2, -1); diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp index b9e9accb27bd64..475110e66bf5e3 100644 --- a/src/core/src/op/group_query_attention.cpp +++ b/src/core/src/op/group_query_attention.cpp @@ -52,13 +52,11 @@ void GroupQueryAttention::validate_and_infer_types() { } Dimension output_kv_len; PartialShape kv_past_shape = get_input_partial_shape(3); - // FIXME: Original GQA spec depends on the identical tensor set for input/output, but we cannot know it in advance, - // hence we base on sequence dimension static/dynamic - // https://github.com/openvinotoolkit/openvino/pull/27648 - if (kv_past_shape[2].is_dynamic()) { + // FIXME: https://github.com/openvinotoolkit/openvino/pull/27648 + if (kv_past_shape[2].is_static()) { output_kv_len = kv_past_shape[2] + sequence_len; } else { - output_kv_len = kv_past_shape[2]; + output_kv_len = ov::Dimension(); } auto element_type = get_input_element_type(0); NODE_VALIDATION_CHECK(this, diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt new file mode 100644 index 00000000000000..a7dacf0dc94ebc --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary.prototxt @@ -0,0 +1,247 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} \ No newline at end of file diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..d1400ad344e717 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_0_input_1_rotary_interleaved.prototxt @@ -0,0 +1,247 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 0 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} \ No newline at end of file diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt new file mode 100644 index 00000000000000..f5ec39c9b0bd8e --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary.prototxt @@ -0,0 +1,244 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + domain: "" + version: 21 +} diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..b61cf39552efc7 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_past_1_input_1_rotary_interleaved.prototxt @@ -0,0 +1,244 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "total_sequence_length" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 64 + } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 16 + } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "total_sequence_length" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 8 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 32 + } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 16 + } + } + } + } + } +} +opset_import { + domain: "" + version: 21 +} diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index 47a336f1749417..37cd5449484543 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1682,3 +1682,415 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_qlinear_mul) { test_case.add_expected_output(Shape{2, 2}, expected_output); test_case.run(); } + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector total_sequence_length = {1}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + }; + + std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + std::vector expected_present_key = {1.2561098, + 1.0199738, + -0.05948371, + -0.16574995, + 2.5059946, + -1.738188, + -0.03158256, + -0.35975295, + 1.0918287, + -0.90313876, + -0.4790303, + 0.67029977, + -0.87039495, + 0.7783688, + -0.81333745, + 0.89886224}; + + std::vector expected_present_value = {-0.2188, + -2.4351, + -0.0729, + -0.034, + 0.9625, + 0.3492, + -0.9215, + -0.0562, + -0.6227, + -0.4637, + 1.9218, + -0.4025, + 0.1239, + 1.1648, + 0.9234, + 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_past_0_input_1_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector total_sequence_length = {1}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + }; + + std::vector expected_output = {-0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + std::vector expected_present_key = {2.118801, + -0.2640816, + -0.5926066, + -0.19455537, + 0.9903903, + 2.954185, + -0.35343042, + -0.07457897, + -0.25603274, + -0.03627284, + 0.56591415, + 0.02181074, + -0.1586003, + 0.96567893, + -0.8591481, + 0.85514885}; + + std::vector expected_present_value = {-0.2188, + -2.4351, + -0.0729, + -0.034, + 0.9625, + 0.3492, + -0.9215, + -0.0562, + -0.6227, + -0.4637, + 1.9218, + -0.4025, + 0.1239, + 1.1648, + 0.9234, + 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = { + -0.6136, + 0.0316, + -0.4927, + 0.2484, + 0.4397, + 0.1124, + 0.6408, + 0.4412, + -0.1023, + 0.7924, + -0.2897, + 0.0525, + 0.5229, + 2.3022, + -1.4689, + -1.5867, + }; + std::vector past_value = { + -0.5692, + 0.9200, + 1.1108, + 1.2899, + -1.4782, + 2.5672, + -0.4731, + 0.3356, + -1.6293, + -0.5497, + -0.4798, + -0.4997, + -1.0670, + 1.1149, + -0.1407, + 0.8058, + }; + std::vector seqlens_k = {1}; + std::vector total_sequence_length = {2}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + -0.9635, + -0.8046, + 0.4139, + 0.9863, + 0.4117, + 0.9874, + -0.9743, + 0.9494, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + 0.2677, + -0.5938, + -0.9103, + -0.1650, + -0.9113, + -0.1583, + 0.2253, + 0.3140, + }; + + std::vector expected_output = { + -0.53934956, 0.6341806, 1.0099611, 1.1771176, -1.270278, 2.3782496, -0.511299, 0.30222273, + -1.5435482, -0.5423737, -0.27520883, -0.4914196, -0.96554786, 1.1191509, -0.05004983, 0.85533774, + -0.49356747, 0.19581467, 0.8553029, 1.0041412, -0.9513843, 2.088453, -0.5698854, 0.25103146, + -1.4120293, -0.5311372, 0.03857604, -0.47871974, -0.8099488, 1.1256707, 0.08898184, 0.93131447}; + + std::vector expected_present_key = { + -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, + -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, + -1.6519198, 1.1400802, 0.45031136, 0.5877534, -0.65952265, -1.8121169, 0.04630837, 0.5568472, + 0.20271924, 0.7458131, -0.17379119, 0.3623912, 2.5696063, -0.58594, -0.8126341, -0.7919839}; + + std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, + -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_past_1_input_1_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152, 0.3223, -1.2633, 0.3500, + 0.3081, 0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959, 0.5667, 0.7935, 0.5988, -1.5551, + -0.3414, 1.8530, 0.7502, -0.5855, -0.1734, 0.1835, 1.3894, 1.5863, 0.9463, -0.8437, 1.6459, + -1.3602, 0.3446, 0.5199, -2.6133, -1.6965, -0.2282, 0.2800, 0.2469, 0.0769, 0.3380, 0.4544, + 0.4569, -0.8654, 0.7813, -0.9268, -0.2188, -2.4351, -0.0729, -0.0340, 0.9625, 0.3492, -0.9215, + -0.0562, -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873, + }; + std::vector past_key = { + -0.6136, + 0.0316, + -0.4927, + 0.2484, + 0.4397, + 0.1124, + 0.6408, + 0.4412, + -0.1023, + 0.7924, + -0.2897, + 0.0525, + 0.5229, + 2.3022, + -1.4689, + -1.5867, + }; + std::vector past_value = { + -0.5692, + 0.9200, + 1.1108, + 1.2899, + -1.4782, + 2.5672, + -0.4731, + 0.3356, + -1.6293, + -0.5497, + -0.4798, + -0.4997, + -1.0670, + 1.1149, + -0.1407, + 0.8058, + }; + std::vector seqlens_k = {1}; + std::vector total_sequence_length = {2}; + std::vector cos_cache = { + 0.8437, + -0.7849, + -0.7829, + 0.4581, + -0.9870, + 0.6273, + -0.9483, + -0.9962, + -0.9635, + -0.8046, + 0.4139, + 0.9863, + 0.4117, + 0.9874, + -0.9743, + 0.9494, + }; + std::vector sin_cache = { + 0.5368, + 0.6196, + -0.6222, + 0.8889, + 0.1605, + -0.7788, + 0.3174, + -0.0872, + 0.2677, + -0.5938, + -0.9103, + -0.1650, + -0.9113, + -0.1583, + 0.2253, + 0.3140, + }; + + std::vector expected_output = { + -0.33396345, -1.332403, 0.31613833, 0.40111685, 0.16033238, 1.0781744, -0.7741276, 0.07257013, + -0.9535321, -0.491965, 1.1324831, -0.43444604, -0.2675047, 1.1483997, 0.57366973, 1.1961825, + -0.24709277, -2.164195, 0.02267693, 0.07289726, 0.7654276, 0.5282906, -0.8852943, -0.02456442, + -0.7039771, -0.47064403, 1.7278847, -0.41034833, 0.02774171, 1.1607709, 0.83748007, 1.3403473}; + + std::vector expected_present_key = { + -0.6136, 0.0316, -0.4927, 0.2484, 0.4397, 0.1124, 0.6408, 0.4412, + -0.1023, 0.7924, -0.2897, 0.0525, 0.5229, 2.3022, -1.4689, -1.5867, + -1.2216992, 1.7511603, 0.03145146, -0.62293506, -2.625969, 1.6767058, -0.17887366, 0.313817, + 0.1717277, -0.19334024, 0.4056727, 0.39516917, -0.25018305, 0.9460988, 1.0327814, -0.6345757}; + + std::vector expected_present_value = {-0.5692, 0.92, 1.1108, 1.2899, -1.4782, 2.5672, -0.4731, 0.3356, + -1.6293, -0.5497, -0.4798, -0.4997, -1.067, 1.1149, -0.1407, 0.8058, + -0.2188, -2.4351, -0.0729, -0.034, 0.9625, 0.3492, -0.9215, -0.0562, + -0.6227, -0.4637, 1.9218, -0.4025, 0.1239, 1.1648, 0.9234, 1.3873}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(query); + test_case.add_input(past_key); + test_case.add_input(past_value); + test_case.add_input(seqlens_k); + test_case.add_input(total_sequence_length); + test_case.add_input(cos_cache); + test_case.add_input(sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +}