Skip to content

Commit

Permalink
Use scaled_dot_product_attention to improve the perfomance
Browse files Browse the repository at this point in the history
  • Loading branch information
sgbihu authored and wine99 committed Jan 26, 2025
1 parent 392d731 commit d5ed312
Showing 1 changed file with 50 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/convert_like.hpp"
#include "utils/split.hpp"

Expand Down Expand Up @@ -159,12 +160,6 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {
auto present_k = K;
auto present_v = V;

std::shared_ptr<ov::Node> alpha;
if (scale == 0.0f) {
alpha = std::make_shared<v0::Sqrt>(head_size_node);
} else {
alpha = v0::Constant::create(ov::element::f32, ov::Shape{}, {1.0f / scale});
}
const size_t kv_num_heads_factor = num_heads / kv_num_heads;
if (kv_num_heads_factor > 1) {
const auto kv_shape = std::make_shared<v3::ShapeOf>(K);
Expand All @@ -183,47 +178,43 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {
K = std::make_shared<v1::Reshape>(K, extended_kv_shape, false);
V = std::make_shared<v1::Reshape>(V, extended_kv_shape, false);
}
// compute softmax((Q x K') / sqrt(head_size))
std::shared_ptr<ov::Node> softmax_input = std::make_shared<v0::MatMul>(Q, K, false, true);
softmax_input = std::make_shared<v1::Divide>(softmax_input, alpha);

// need to apply low-triangle mask to attention score.
auto past_seq_len_scalar = std::make_shared<v1::Reshape>(past_sequence_length, one_without_shape, false);
auto seqlens_1d_scalar = std::make_shared<v1::Reshape>(seqlens_1d, one_without_shape, false);
// two steps, construct the total_sequence x total_sequence triangle, then slice the current length
auto seqlens_1d_scalar = std::make_shared<v1::Reshape>(seqlens_1d, one_without_shape, false); // 12 or 13
std::shared_ptr<ov::Node> mask_per_line_node =
std::make_shared<v4::Range>(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}),
seqlens_1d_scalar,
one_without_shape,
ov::element::i64);
mask_per_line_node = std::make_shared<v0::Unsqueeze>(mask_per_line_node, zero);
auto minus_inf = v0::Constant::create(element::f32, Shape{}, {-std::numeric_limits<float>::infinity()});
auto mask_shape = std::make_shared<v0::Concat>(ov::NodeVector{current_seqlen_size, seqlens_1d}, 0);
auto compare_mask = std::make_shared<v3::Broadcast>(mask_per_line_node, mask_shape);

std::shared_ptr<ov::Node> vertical_range =
std::make_shared<v4::Range>(past_seq_len_scalar, seqlens_1d_scalar, one_without_shape, ov::element::i64);
vertical_range = std::make_shared<v0::Unsqueeze>(vertical_range, one);

auto triu = std::make_shared<v1::Greater>(compare_mask, vertical_range);
auto typed_zero = std::make_shared<v1::ConvertLike>(zero, softmax_input);
auto typed_minus_inf = std::make_shared<v1::ConvertLike>(minus_inf, softmax_input);
auto minus_inf_mask = std::make_shared<v3::Broadcast>(typed_minus_inf, mask_shape);
auto atten_mask = std::make_shared<v1::Select>(triu, minus_inf_mask, typed_zero);

std::shared_ptr<ov::Node> softmax_input_added = std::make_shared<v1::Add>(softmax_input, atten_mask);
// softmax((Q x K' + mask) / sqrt(head_size))
const auto softmax = std::make_shared<v8::Softmax>(softmax_input_added, 3);

// softmax((Q x K' + mask) / sqrt(head_size)) x V
std::shared_ptr<ov::Node> output = std::make_shared<v0::MatMul>(softmax, V);
ov::element::i64); // [0,1,2,...,]
auto hori_range = std::make_shared<v0::Unsqueeze>(mask_per_line_node, zero); // 1x12 or 1x13
auto vert_range = std::make_shared<v0::Unsqueeze>(mask_per_line_node, one); // 12x1 or 13x1
auto triu = std::make_shared<v1::Greater>(hori_range, vert_range); // 12x12 or 13x13
auto typed_zero = v0::Constant::create(ov::element::f32, ov::Shape{}, {0});
auto minus_inf = v0::Constant::create(ov::element::f32, ov::Shape{}, {-std::numeric_limits<float>::infinity()});
auto atten_mask = std::make_shared<v1::Select>(triu, minus_inf, typed_zero); // 12x12 or 13x13
auto atten_mask_sliced = std::make_shared<v8::Slice>(atten_mask,
past_sequence_length,
seqlens_1d,
one,
zero); // slice to current query seqlen, 12x12 or 1x13

// compute softmax((Q x K') / sqrt(head_size)) x V
std::shared_ptr<ov::Node> qga_output;
if (scale != 0.0f) {
auto scale_node = v0::Constant::create(ov::element::f32, Shape{}, {scale});
qga_output = std::make_shared<v13::ScaledDotProductAttention>(Q, K, V, atten_mask_sliced, scale_node, false);
} else {
qga_output = std::make_shared<v13::ScaledDotProductAttention>(Q, K, V, atten_mask_sliced, false);
}

// transpose the result from (batch_size, num_heads, sequence_length, head_size)
// to (batch_size, sequence_length, num_heads, head_size)
output = std::make_shared<v1::Transpose>(output, perm);
auto qga_output_transposed = std::make_shared<v1::Transpose>(qga_output, perm);
auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1});
// reshape the result from (batch_size, sequence_length, num_heads, head_size)
// to (batch_size, sequence_length, num_heads * head_size)
output = std::make_shared<v1::Reshape>(output, dim_merge_shape, true);
auto output = std::make_shared<v1::Reshape>(qga_output_transposed, dim_merge_shape, true);

return {output, present_k, present_v};
}
Expand Down Expand Up @@ -254,75 +245,41 @@ std::shared_ptr<ov::Node> rotaryEmbedding(ov::Output<ov::Node> input,
bool interleaved) {
auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1});
auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2});

auto cache_shape = std::make_shared<v3::ShapeOf>(cos_cache);
auto cache_last_dim = get_dimensions(cache_shape, {-1});
auto cache_1st_dim = get_dimensions(cache_shape, {0});
auto slice_cache_dim_shape = seqlen_k;

// TODO: check the shape
auto input_shape = std::make_shared<v3::ShapeOf>(input);
auto cos = std::make_shared<v8::Slice>(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero);
auto sin = std::make_shared<v8::Slice>(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero);

auto dim_bns = get_dimensions(input_shape, {0, 1, 2});
// auto dim_head_size = get_dimensions(input_shape, {3});
// half_last_dim is same as cos_cache
std::shared_ptr<ov::Node> half_last_dim = cache_last_dim;
if (interleaved) {
auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2});

auto real_cache_shape = std::make_shared<v0::Concat>(ov::NodeVector{cache_1st_dim, dim_head_size}, 0);
auto slice_cache_dim_shape = seqlen_k;
auto cache_shape = std::make_shared<v3::ShapeOf>(cos_cache);
auto cache_last_dim = get_dimensions(cos_cache, {-1});

// auto end_lens = std::make_shared<v1::Subtract>(half_last_dim, one);
// auto masks = std::make_shared<v12::Pad>(one,
// zero,
// end_lens,
// op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}),
// ov::op::PadMode::CONSTANT);
auto masks = std::make_shared<v3::Broadcast>(one, half_last_dim);
auto input_shape = std::make_shared<v3::ShapeOf>(input);

auto dim_bns = get_dimensions(input_shape, {0, 1, 2});
std::shared_ptr<ov::Node> half_last_dim = cache_last_dim;

if (interleaved) {
auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1});
auto split_input_shape = std::make_shared<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim, two}, 0);
auto reshaped_input = std::make_shared<v1::Reshape>(input, split_input_shape, false);
auto first_half = std::make_shared<v8::Slice>(reshaped_input, zero, one, one, negtive_one);
auto second_half = std::make_shared<v8::Slice>(reshaped_input, one, two, one, negtive_one);

auto second_input = std::make_shared<v0::Concat>(ov::NodeVector{second_half, first_half}, -1);

auto mask_shape = std::make_shared<v0::Concat>(ov::NodeVector{half_last_dim, one}, 0);
auto reshaped_mask = std::make_shared<v1::Reshape>(masks, mask_shape, false);
auto negtive_mask = std::make_shared<v0::Negative>(reshaped_mask);
auto concat_mask = std::make_shared<v0::Concat>(ov::NodeVector{negtive_mask, reshaped_mask}, -1);
auto real_mask = std::make_shared<v1::Reshape>(concat_mask, dim_head_size, false);
auto mask_f32 = std::make_shared<v0::Convert>(real_mask, ov::element::f32);

auto real_input0 = std::make_shared<v1::Reshape>(reshaped_input, input_shape, false);
auto real_input1 = std::make_shared<v1::Reshape>(second_input, input_shape, false);

auto new_cache_shape = std::make_shared<v0::Concat>(ov::NodeVector{cache_shape, two}, 0);
auto temp_cache_shape = std::make_shared<v0::Concat>(ov::NodeVector{cache_shape, one}, 0);
auto cos_cache_reshape = std::make_shared<v1::Reshape>(cos_cache, temp_cache_shape, false);
auto sin_cache_reshape = std::make_shared<v1::Reshape>(sin_cache, temp_cache_shape, false);
auto cos_cache_broadcasted = std::make_shared<v3::Broadcast>(cos_cache_reshape, new_cache_shape);
auto sin_cache_broadcasted = std::make_shared<v3::Broadcast>(sin_cache_reshape, new_cache_shape);
auto real_cos_input = std::make_shared<v1::Reshape>(cos_cache_broadcasted, real_cache_shape, false);
auto real_sin_input = std::make_shared<v1::Reshape>(sin_cache_broadcasted, real_cache_shape, false);
auto sliced_cos_input =
std::make_shared<v8::Slice>(real_cos_input, past_seqlen, slice_cache_dim_shape, one, zero);
auto sliced_sin_input =
std::make_shared<v8::Slice>(real_sin_input, past_seqlen, slice_cache_dim_shape, one, zero);
auto add_input0 = std::make_shared<v1::Multiply>(real_input0, sliced_cos_input);
auto add_input1 = std::make_shared<v1::Multiply>(real_input1, sliced_sin_input);
auto multi_input1 = std::make_shared<v1::Multiply>(add_input1, mask_f32);
auto result = std::make_shared<v1::Add>(add_input0, multi_input1);
return result;

auto in_split = ov::op::util::make_split(reshaped_input, 2, -1);
auto res_0 = std::make_shared<v1::Subtract>(std::make_shared<v1::Multiply>(in_split[0], cos),
std::make_shared<v1::Multiply>(in_split[1], sin));
auto res_1 = std::make_shared<v1::Add>(std::make_shared<v1::Multiply>(in_split[0], sin),
std::make_shared<v1::Multiply>(in_split[1], cos));

auto concat_ret = std::make_shared<v0::Concat>(ov::NodeVector{res_0, res_1}, -1);
return std::make_shared<v1::Reshape>(concat_ret, input_shape, false);
} else {
auto cos =
std::make_shared<v8::Slice>(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero);
auto sin =
std::make_shared<v8::Slice>(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero);
auto in_split = ov::op::util::make_split(input, 2, -1);
auto res_0 = std::make_shared<v1::Subtract>(std::make_shared<v1::Multiply>(in_split[0], cos), std::make_shared<v1::Multiply>(in_split[1], sin));
auto res_1 = std::make_shared<v1::Add>(std::make_shared<v1::Multiply>(in_split[0], sin), std::make_shared<v1::Multiply>(in_split[1], cos));
auto res_0 = std::make_shared<v1::Subtract>(std::make_shared<v1::Multiply>(in_split[0], cos),
std::make_shared<v1::Multiply>(in_split[1], sin));
auto res_1 = std::make_shared<v1::Add>(std::make_shared<v1::Multiply>(in_split[0], sin),
std::make_shared<v1::Multiply>(in_split[1], cos));

return std::make_shared<v0::Concat>(ov::NodeVector{res_0, res_1}, -1);
}
Expand Down

0 comments on commit d5ed312

Please sign in to comment.