Skip to content

Commit

Permalink
Fix GQA decomp interleave; Add onnx frontend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wine99 committed Jan 23, 2025
1 parent 91971ec commit f4770e0
Show file tree
Hide file tree
Showing 7 changed files with 1,412 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
auto cos_cache = node->input_value(7);
auto sin_cache = node->input_value(8);

auto T = Q.get_element_type();
// The length of all tokens (past + current) is `seqlens_k` + 1, current = Q.shape[2], past = `seqlens_k` + 1 - current

const auto T = Q.get_element_type();
const auto node_shape = std::make_shared<v3::ShapeOf>(Q);
const auto batch_size = get_dimensions(node_shape, {0});
const auto current_seqlen_size = get_dimensions(node_shape, {1});
Expand Down Expand Up @@ -262,12 +263,20 @@ std::shared_ptr<ov::Node> rotaryEmbedding(ov::Output<ov::Node> input,
auto reshaped_input = std::make_shared<v1::Reshape>(input, split_input_shape, false);

auto in_split = make_split(reshaped_input, 2, -1);
auto res_0 = std::make_shared<v1::Subtract>(std::make_shared<v1::Multiply>(in_split[0], cos),
std::make_shared<v1::Multiply>(in_split[1], sin));
auto res_1 = std::make_shared<v1::Add>(std::make_shared<v1::Multiply>(in_split[0], sin),
std::make_shared<v1::Multiply>(in_split[1], cos));
split_input_shape = std::make_shared<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim}, 0);
auto in_split_0 = std::make_shared<v1::Reshape>(in_split[0], split_input_shape, false);
auto in_split_1 = std::make_shared<v1::Reshape>(in_split[1], split_input_shape, false);

auto res_0 = std::make_shared<v1::Subtract>(std::make_shared<v1::Multiply>(in_split_0, cos),
std::make_shared<v1::Multiply>(in_split_1, sin));
auto res_1 = std::make_shared<v1::Add>(std::make_shared<v1::Multiply>(in_split_0, sin),
std::make_shared<v1::Multiply>(in_split_1, cos));

split_input_shape = std::make_shared<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim, one}, 0);
auto res_0_5d = std::make_shared<v1::Reshape>(res_0, split_input_shape, false);
auto res_1_5d = std::make_shared<v1::Reshape>(res_1, split_input_shape, false);

auto concat_ret = std::make_shared<v0::Concat>(ov::NodeVector{res_0, res_1}, -1);
auto concat_ret = std::make_shared<v0::Concat>(ov::NodeVector{res_0_5d, res_1_5d}, -1);
return std::make_shared<v1::Reshape>(concat_ret, input_shape, false);
} else {
auto in_split = make_split(input, 2, -1);
Expand Down
8 changes: 3 additions & 5 deletions src/core/src/op/group_query_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,11 @@ void GroupQueryAttention::validate_and_infer_types() {
}
Dimension output_kv_len;
PartialShape kv_past_shape = get_input_partial_shape(3);
// FIXME: Original GQA spec depends on the identical tensor set for input/output, but we cannot know it in advance,
// hence we base on sequence dimension static/dynamic
// https://github.com/openvinotoolkit/openvino/pull/27648
if (kv_past_shape[2].is_dynamic()) {
// FIXME: https://github.com/openvinotoolkit/openvino/pull/27648
if (kv_past_shape[2].is_static()) {
output_kv_len = kv_past_shape[2] + sequence_len;
} else {
output_kv_len = kv_past_shape[2];
output_kv_len = ov::Dimension();
}
auto element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
ir_version: 10
graph {
node {
input: "query"
input: ""
input: ""
input: "past_key"
input: "past_value"
input: "seqlens_k"
input: "total_sequence_length"
input: "cos_cache"
input: "sin_cache"
output: "output"
output: "present_key"
output: "present_value"
name: "GroupQueryAttention_0"
op_type: "GroupQueryAttention"
attribute {
name: "do_rotary"
i: 1
type: INT
}
attribute {
name: "kv_num_heads"
i: 1
type: INT
}
attribute {
name: "local_window_size"
i: -1
type: INT
}
attribute {
name: "num_heads"
i: 2
type: INT
}
attribute {
name: "rotary_interleaved"
i: 0
type: INT
}
attribute {
name: "smooth_softmax"
i: 0
type: INT
}
attribute {
name: "softcap"
f: 0
type: FLOAT
}
domain: "com.microsoft"
}
name: "GroupQueryAttention_Graph"
input {
name: "query"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 64
}
}
}
}
}
input {
name: "past_key"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 0
}
dim {
dim_value: 16
}
}
}
}
}
input {
name: "past_value"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 0
}
dim {
dim_value: 16
}
}
}
}
}
input {
name: "seqlens_k"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "total_sequence_length"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "cos_cache"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 8
}
}
}
}
}
input {
name: "sin_cache"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 8
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 32
}
}
}
}
}
output {
name: "present_key"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 16
}
}
}
}
}
output {
name: "present_value"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 16
}
}
}
}
}
}
opset_import {
version: 11
}
opset_import {
domain: "com.microsoft"
version: 1
}
Loading

0 comments on commit f4770e0

Please sign in to comment.