From 17fb201433bf6097b70972801a612f8e3defdd9e Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Sat, 16 Dec 2023 07:38:04 +0800 Subject: [PATCH] [CPU] SDPA supports multi-query and different input layout (#21513) --- .../kernels/scaled_attn/mha_single_token.cpp | 43 +-- src/plugins/intel_cpu/src/nodes/memory.cpp | 7 +- .../intel_cpu/src/nodes/scaled_attn.cpp | 52 ++- src/plugins/intel_cpu/src/nodes/scaled_attn.h | 1 + .../cpu_opset/common/op/sdpa.cpp | 36 ++- .../cpu_opset/common/op/sdpa.hpp | 12 +- .../pass/stateful_transpose_sdpa_fusion.cpp | 176 +++++++++++ .../pass/stateful_transpose_sdpa_fusion.hpp | 18 ++ .../transformation_pipeline.cpp | 2 + .../src/concat_multiple_query_sdp.cpp | 297 ++++++++++++++++++ .../src/concat_transpose_sdp_transpose.cpp | 296 +++++++++++++++++ 11 files changed, 898 insertions(+), 42 deletions(-) create mode 100644 src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.cpp create mode 100644 src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.hpp create mode 100644 src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp create mode 100644 src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_transpose_sdp_transpose.cpp diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index 8ac1aca6c1467e..7fa5abab0aefe2 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -139,7 +139,11 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, auto q_len = query.size(2); auto S = query.size(3); auto kv_len = present_key.size(2); - + auto h_group_num = present_key.size(1); + size_t h_each_group_len = 1; + if (h_group_num != H) { + h_each_group_len = H / h_group_num; + } if (d_scale == 0.0f) d_scale = 1.0f / sqrt(S); @@ -149,20 +153,21 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, bool is_abcd = present_key.stride(1) >= present_key.stride(2); size_t dim0 = is_abcd ? B : kv_len; - size_t dim1 = is_abcd ? H : B; - size_t dim2 = is_abcd ? kv_len : H; + size_t dim1 = is_abcd ? h_group_num : B; + size_t dim2 = is_abcd ? kv_len : h_group_num; parallel_for3d(dim0, dim1, dim2, [&](size_t d0, size_t d1, size_t d2) { size_t b = is_abcd ? d0 : d1; - size_t h = is_abcd ? d1 : d2; + size_t h_group = is_abcd ? d1 : d2; size_t pk = is_abcd ? d2 : d0; // which batch item should be used at postion pk? auto b_kv = beams ? beams.at({b, pk}) : b; for (size_t pq = 0; pq < q_len; pq++) { - buf_attn_w.at({b, h, pq, pk}) = dot_product(&query.at({b, h, pq, 0}), - &present_key.at({b_kv, h, pk, 0}, true), - S); + for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { + buf_attn_w.at({b, h, pq, pk}) = + dot_product(&query.at({b, h, pq, 0}), &present_key.at({b_kv, h_group, pk, 0}, true), S); + } } }); @@ -190,29 +195,31 @@ void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, // buf_attn_w {B, H, q_len, kv_len} parallel_nt_static(nthr, [&](const size_t ithr, const size_t nthr) { size_t start{0}, end{0}; - splitter(B * H * kv_len, nthr, ithr, start, end); + splitter(B * h_group_num * kv_len, nthr, ithr, start, end); memset(&buf_attn_score.at({ithr, 0, 0, 0, 0}), 0, buf_attn_score.stride(0) * sizeof(float)); - size_t b, h, pv; + size_t b, h_group, pv; if (start < end) { if (is_abcd) - parallel_it_init(start, b, B, h, H, pv, kv_len); + parallel_it_init(start, b, B, h_group, h_group_num, pv, kv_len); else - parallel_it_init(start, pv, kv_len, b, B, h, H); + parallel_it_init(start, pv, kv_len, b, B, h_group, h_group_num); for (size_t iwork = start; iwork < end; ++iwork) { auto b_kv = beams ? beams.at({b, pv}) : b; - auto* v = &present_value.at({b_kv, h, pv, 0}, true); + auto* v = &present_value.at({b_kv, h_group, pv, 0}, true); for (size_t pq = 0; pq < q_len; pq++) { - attn_acc_value(&buf_attn_score.at({ithr, b, pq, h, 0}), - buf_attn_w.at({b, h, pq, pv}), - v, - S); + for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { + attn_acc_value(&buf_attn_score.at({ithr, b, pq, h, 0}), + buf_attn_w.at({b, h, pq, pv}), + v, + S); + } } if (is_abcd) - parallel_it_step(b, B, h, H, pv, kv_len); + parallel_it_step(b, B, h_group, h_group_num, pv, kv_len); else - parallel_it_step(pv, kv_len, b, B, h, H); + parallel_it_step(pv, kv_len, b, B, h_group, h_group_num); } } }); diff --git a/src/plugins/intel_cpu/src/nodes/memory.cpp b/src/plugins/intel_cpu/src/nodes/memory.cpp index f915703f071e7b..bce278e7c5d430 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.cpp +++ b/src/plugins/intel_cpu/src/nodes/memory.cpp @@ -548,6 +548,7 @@ void MemoryInputSDPA::initSupportedPrimitiveDescriptors() { // Since this is a very specialized implementation, lets mimic SDPA precision and set cabd layout precision = SDPA->getOriginalInputPrecisionAtPort(childPort); + // Just used a place holder here, the actual layout is obtained at initOptimalPrimitiveDescriptor ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3}); PortConfig outPortConfig; @@ -573,7 +574,6 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() { "failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set"); const auto& childConfig = childPd->getConfig(); - auto childPrecision = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc()->getPrecision(); auto selectedPd = getSelectedPrimitiveDescriptor(); OPENVINO_ASSERT(selectedPd, @@ -582,8 +582,9 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() { " failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set"); auto config = selectedPd->getConfig(); - auto memDesc = config.outConfs.front().getMemDesc(); - auto newMemDesc = memDesc->cloneWithNewPrecision(childPrecision); + // The pyscial layout varies from models, e.g. [LBHS]chatglm, [BHLS]Llama + // The SDPA knows details, so should trust the layout config provided by SPDA + auto newMemDesc = childConfig.inConfs.back().getMemDesc(); config.outConfs.front().setMemDesc(newMemDesc); //bypass any checks, we enforce the child descriptor precision selectedPd->setConfig(config); diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index c2d1ef17143337..fad65cbbd7ee37 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -512,10 +512,16 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt v_input.assert_dims({B, 0, L1, S}, true); auto past_k_idx = inputs.size() - 2; auto past_k_mem = inputs[past_k_idx + 0]; - L0 = past_k_mem->getStaticDims()[2]; + const auto& permute_axes = config.config.permute_axes; + L0 = permute_axes.empty() ? past_k_mem->getStaticDims()[2] : past_k_mem->getStaticDims()[permute_axes[2]]; // [B, H, L0, S] past_k_output.reset(outputs[1]); past_v_output.reset(outputs[2]); + if (!permute_axes.empty()) { + // [L, B, H, S] -> [B, H, L, S] + past_k_output = past_k_output.permute(permute_axes); + past_v_output = past_v_output.permute(permute_axes); + } attn_memcpy(k_input, v_input, past_k_output.slice(2, L0, L0 + L1), past_v_output.slice(2, L0, L0 + L1)); if (!config.is_concat_inplaced) { PlainTensor past_k_input, past_v_input; @@ -560,12 +566,18 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt } // q: [B, H, L1, S] + const auto & permute_axes = config.config.permute_axes; + + PlainTensor present_key, present_value; + if (!permute_axes.empty()) { + q_input = q_input.permute(permute_axes); + k_input = k_input.permute(permute_axes); + v_input = v_input.permute(permute_axes); + } B = q_input.size(0); H = q_input.size(1); L1 = q_input.size(2); S = q_input.size(-1); - - PlainTensor present_key, present_value; concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value); ov::intel_cpu::PlainTensor output_emb(outputs[0]); @@ -634,9 +646,20 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptrpermute[1,2,0,3] ->[B,H,L,S] + // The actual index of B is permute[0], H is permute[1], L is permute[2], S is permute[3] + layoutDescCreator = ArbitraryOrderDescCreator({static_cast(permute_axes[2]), + static_cast(permute_axes[0]), + static_cast(permute_axes[1]), + static_cast(permute_axes[3])}); + } + config.inConfs[orginSDPInputNumber + 0].setMemDesc(layoutDescCreator.createSharedDesc( kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); - config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc( + config.inConfs[orginSDPInputNumber + 1].setMemDesc(layoutDescCreator.createSharedDesc( kvCachePrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); - config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc( + config.outConfs[1].setMemDesc(layoutDescCreator.createSharedDesc( kvCachePrecision, getOutputShapeAtPort(1))); config.outConfs[1].inPlace(orginSDPInputNumber + 0); - config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc( + config.outConfs[2].setMemDesc(layoutDescCreator.createSharedDesc( kvCachePrecision, getOutputShapeAtPort(2))); config.outConfs[2].inPlace(orginSDPInputNumber + 1); } @@ -712,7 +743,6 @@ void ScaledDotProductAttention::createPrimitive() { m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0; } - auto rtPrecision = getOriginalInputPrecisionAtPort(0); if (rtPrecision == ov::element::bf16) { m_executor = std::make_shared>(m_config); diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 78bc9d4231478f..7ce7d9f09ef8ff 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -54,6 +54,7 @@ class ScaledDotProductAttention : public Node { Config m_config; std::shared_ptr m_executor; template struct AttentionExecutor; + ov::element::Type rtPrecision; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp index 4dc5ba799dd4eb..f581d981797513 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.cpp @@ -29,18 +29,43 @@ void ov::intel_cpu::ScaledDotProductAttentionWithKVCache::validate_and_infer_typ // [B, H, L0, S] auto past_kv_ps = get_input_partial_shape(input_num - 1); + auto output_logits = q_ps; NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == false); NODE_VALIDATION_CHECK(this, q_ps.size() >= 3); + // permute_axes from original to [B, H, L, S] + const auto& permute_axes = this->m_config.permute_axes; if (past_kv_ps.rank().is_static()) { + const size_t length_index = permute_axes.empty() ? q_ps.size() - 2 : permute_axes[permute_axes.size() - 2]; + const size_t head_num_index = permute_axes.empty() ? q_ps.size() - 3 : permute_axes[permute_axes.size() - 3]; NODE_VALIDATION_CHECK(this, q_ps.size() == past_kv_ps.size()); for (size_t i = 0; i < q_ps.size(); i++) { - if (i == q_ps.size() - 2) + if (i == head_num_index) { + if (q_ps[i].is_static() && past_kv_ps[i].is_static()) { + NODE_VALIDATION_CHECK(this, + q_ps[i].get_length() % past_kv_ps[i].get_length() == 0, + "shape not compatiable at index ", + i); + } + } else if (i == length_index) { continue; - NODE_VALIDATION_CHECK(this, q_ps[i].compatible(past_kv_ps[i])); + } else { + NODE_VALIDATION_CHECK(this, + q_ps[i].compatible(past_kv_ps[i]), + "shape not compatiable at index ", + i); + } } - past_kv_ps[q_ps.size() - 2] += q_ps[q_ps.size() - 2]; + past_kv_ps[length_index] += q_ps[length_index]; } - set_output_type(0, get_input_element_type(0), q_ps); + if (!permute_axes.empty()) { + if (q_ps.rank().is_static()) { + // q_ps needs permute to BHLS + for (size_t i = 0; i < q_ps.size(); i++) { + output_logits[i] = q_ps[permute_axes[i]]; + } + } + } + set_output_type(0, get_input_element_type(0), output_logits); set_output_type(1, get_input_element_type(input_num - 1), past_kv_ps); set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps); } @@ -52,6 +77,7 @@ bool ov::intel_cpu::ScaledDotProductAttentionWithKVCache::visit_attributes(ov::A visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn); visitor.on_attribute("is_causal", m_config.is_causal); visitor.on_attribute("fuse_concat", m_config.fuse_concat); + visitor.on_attribute("permute_axes", m_config.permute_axes); visitor.finish_structure(); return true; -} +} \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp index 94406caeab016e..753de527dc73f3 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdpa.hpp @@ -21,11 +21,13 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op { ScaledDotProductAttentionWithKVCache() = default; struct Config { - bool output_BLHxS = false; // true implies that output is [B,L,H*S] + bool output_BLHxS = false; // true implies that output is [B,L,H*S] - bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask - bool is_causal = false; // apply causal mask internally - bool fuse_concat = false; // fuse (concat->sdp) ==> sdp + bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask + bool is_causal = false; // apply causal mask internally + bool fuse_concat = false; // fuse (concat->sdp) ==> sdp + std::vector permute_axes; // not empty means input has transpose. output of permutation is [B,H,L,S] + // e.g. [L,B,H,S] -> permute[1, 2, 0, 3] ->[B, H, L, S] }; ScaledDotProductAttentionWithKVCache(const OutputVector& args, const Config& cfg); @@ -47,4 +49,4 @@ class ScaledDotProductAttentionWithKVCache : public ov::op::Op { }; } // namespace intel_cpu -} // namespace ov +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.cpp new file mode 100644 index 00000000000000..e06bec71494095 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.cpp @@ -0,0 +1,176 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "stateful_transpose_sdpa_fusion.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "itt.hpp" +#include "ov_ops/type_relaxed.hpp" +#include "transformations/cpu_opset/common/op/sdpa.hpp" + +namespace ov { +namespace intel_cpu { + +StatefulTransposeSDPAFusion::StatefulTransposeSDPAFusion() { + MATCHER_SCOPE(StatefulTransposeSDPAFusion); + using namespace ov::pass::pattern; + + auto past_k = wrap_type(); + auto past_v = wrap_type(); + auto convert_past_k = wrap_type({past_k}); + auto convert_past_v = wrap_type({past_v}); + auto concat_input_k = std::make_shared(OutputVector{past_k, convert_past_k}); + auto concat_input_v = std::make_shared(OutputVector{past_v, convert_past_v}); + auto concat_k = wrap_type({concat_input_k, any_input()}); + auto concat_v = wrap_type({concat_input_v, any_input()}); + + // multi-query branch + auto reshape_k = wrap_type({concat_k, any_input()}); + auto reshape_v = wrap_type({concat_v, any_input()}); + auto constant_k = wrap_type(); + auto constant_v = wrap_type(); + auto multiply_k = wrap_type({reshape_k, constant_k}); + auto multiply_v = wrap_type({reshape_v, constant_v}); + auto reshape1_k = wrap_type({multiply_k, any_input()}); + auto reshape1_v = wrap_type({multiply_v, any_input()}); + + auto transpose_k_input = std::make_shared(OutputVector{reshape1_k, concat_k}); + auto transpose_v_input = std::make_shared(OutputVector{reshape1_v, concat_v}); + auto order_k = wrap_type(); + auto order_v = wrap_type(); + auto transpose_k = wrap_type({transpose_k_input, order_k}); + auto transpose_v = wrap_type({transpose_v_input, order_v}); + + auto order_q = wrap_type(); + auto q_input = any_input(); + auto transpose_q = wrap_type({q_input, order_q}); + auto sdp0 = wrap_type({transpose_q, transpose_k, transpose_v}); + auto sdp1 = wrap_type({transpose_q, transpose_k, transpose_v, any_input()}); + auto sdp2 = wrap_type({transpose_q, transpose_k, transpose_v, any_input(), any_input()}); + auto sdp = std::make_shared(OutputVector{sdp0, sdp1, sdp2}); + + ov::matcher_pass_callback callback = [=](Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + auto find_assign = [&](const ov::Output& out, opset6::Assign*& assign, opset1::Convert*& cvt) { + auto present_to = out.get_target_inputs(); + if (present_to.size() != 2) + return; + for (auto& to : present_to) { + auto to_node = to.get_node(); + if (auto convert = dynamic_cast(to_node)) { + auto cvt_targets = convert->get_output_target_inputs(0); + if (cvt_targets.size() == 1) { + to_node = cvt_targets.begin()->get_node(); + cvt = convert; + } + } + assign = dynamic_cast(to_node); + if (assign) + return; + } + }; + + std::shared_ptr read_cvt_k_node, read_cvt_v_node; + const auto sdp_node = ov::as_type_ptr(root); + const auto past_k_node = ov::as_type_ptr(pattern_map.at(past_k).get_node_shared_ptr()); + const auto past_v_node = ov::as_type_ptr(pattern_map.at(past_v).get_node_shared_ptr()); + const auto concat_k_node = ov::as_type_ptr(pattern_map.at(concat_k).get_node_shared_ptr()); + const auto concat_v_node = ov::as_type_ptr(pattern_map.at(concat_v).get_node_shared_ptr()); + if (pattern_map.count(convert_past_k)) { + read_cvt_k_node = ov::as_type_ptr(pattern_map.at(convert_past_k).get_node_shared_ptr()); + read_cvt_v_node = ov::as_type_ptr(pattern_map.at(convert_past_v).get_node_shared_ptr()); + } + + // check broadcast arg has all ones + auto check_bcst = [&](const std::shared_ptr& ptr) { + const auto constant_node = ov::as_type_ptr(ptr); + const auto& bcst_arg = constant_node->cast_vector(); + return std::all_of(bcst_arg.begin(), bcst_arg.end(), [](int i) { + return i == 1.0; + }); + }; + + if (pattern_map.count(constant_k)) { + if (!check_bcst(pattern_map.at(constant_k).get_node_shared_ptr())) + return false; + } + + if (pattern_map.count(constant_v)) { + if (!check_bcst(pattern_map.at(constant_v).get_node_shared_ptr())) + return false; + } + + opset6::Assign* assign_k_node = nullptr, *assign_v_node = nullptr; + opset1::Convert* assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr; + find_assign(concat_k_node, assign_k_node, assign_cvt_k_node); + if (!assign_k_node) + return false; + if (past_k_node->get_variable_id() != assign_k_node->get_variable_id()) + return false; + + find_assign(concat_v_node, assign_v_node, assign_cvt_v_node); + if (!assign_v_node) + return false; + if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) + return false; + auto args = sdp_node->input_values(); + args[0] = pattern_map.at(q_input).get_node_shared_ptr()->output(0); + args[1] = concat_k_node->input_value(1); + args[2] = concat_v_node->input_value(1); + args.push_back(read_cvt_k_node ? read_cvt_k_node->output(0) : past_k_node->output(0)); + args.push_back(read_cvt_v_node ? read_cvt_v_node->output(0) : past_v_node->output(0)); + ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config; + + const auto order_q_node = ov::as_type_ptr(pattern_map.at(order_q).get_node_shared_ptr()); + const auto order_k_node = ov::as_type_ptr(pattern_map.at(order_k).get_node_shared_ptr()); + const auto order_v_node = ov::as_type_ptr(pattern_map.at(order_v).get_node_shared_ptr()); + + const auto& permute_q = order_q_node->cast_vector(); + const auto& permute_k = order_k_node->cast_vector(); + const auto& permute_v = order_v_node->cast_vector(); + if (permute_q != permute_k || permute_q != permute_v) { + return false; + } + + config.is_causal = sdp_node->get_causal(); + config.fuse_concat = true; + + config.permute_axes.resize(permute_q.size()); + for (size_t i = 0; i < permute_q.size(); i++) { + config.permute_axes[i] = static_cast(permute_q[i]); + } + auto& old_node = sdp_node; + auto new_node = std::make_shared(args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::replace_node(old_node, {new_node->output(0)}); + if (assign_cvt_k_node) + assign_cvt_k_node->set_arguments({new_node->output(1)}); + else + assign_k_node->set_arguments({new_node->output(1)}); + + if (assign_cvt_v_node) + assign_cvt_v_node->set_arguments({new_node->output(2)}); + else + assign_v_node->set_arguments({new_node->output(2)}); + + return true; + }; + + auto m = std::make_shared(sdp, matcher_name); + this->register_matcher(m, callback); +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.hpp new file mode 100644 index 00000000000000..94c60c36d2b8cf --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.hpp @@ -0,0 +1,18 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace intel_cpu { +class StatefulTransposeSDPAFusion : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("StatefulTransposeSDPAFusion", "0"); + StatefulTransposeSDPAFusion(); +}; + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index cf961d7978c5d7..774b16e32afbe5 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -114,6 +114,7 @@ #include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp" #include "transformations/cpu_opset/common/pass/rope_fusion.hpp" #include "transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp" +#include "transformations/cpu_opset/common/pass/stateful_transpose_sdpa_fusion.hpp" // Snippets #include "snippets/pass/tokenization.hpp" @@ -661,6 +662,7 @@ void Transformations::PostLpt() { CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion); CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulSDPAFusion); + CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulTransposeSDPAFusion); postLPTPassManager.run_passes(model); } diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp new file mode 100644 index 00000000000000..3e4198aca479bf --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_multiple_query_sdp.cpp @@ -0,0 +1,297 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ov_models/builders.hpp" +#include "ov_models/utils/ov_helpers.hpp" +#include "shared_test_classes/base/layer_test_utils.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "test_utils/cpu_test_utils.hpp" +#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" + +using namespace ov::test; +using namespace ngraph; +using namespace CPUTestUtils; +using namespace InferenceEngine; + +namespace SubgraphTestsDefinitions { + +using InputShapeAndTransposeOrder = std::pair, std::vector>; +using ConcatMultiQuerySDPParams = std::tuple; +// Subgraph: +/* Parameter + * | + * Parameter ReadValue | ReadValue Parameter + * \ / | \ / + * \ / | \ / + * Concat Transpose Concat + * / \ | / \ + * / \ | / \ + * / MultiQuery | MultiQuery \ + * / \ | / \ + * / Transpose | Transpose \ + * / \ | / \ + * Assign ScaledDotProductAttention Assign + * | + * Tranpose + * | + * Reshape + * | + * Add + * | + * Result + */ + +class ConcatMultiQuerySDPTest : public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest, + public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + ElementType inType; + InputShapeAndTransposeOrder inputShapeAndOrders; + bool hasShapeof; + std::tie(inType, inputShapeAndOrders, hasShapeof) = obj.param; + std::ostringstream result; + std::vector& inputShapes = inputShapeAndOrders.first; + std::vector& transposeOrder = inputShapeAndOrders.second; + result << "IS="; + for (const auto& shape : inputShapes) { + result << ov::test::utils::partialShape2str({shape.first}) << "_"; + } + result << "TS="; + for (const auto& shape : inputShapes) { + result << "("; + if (!shape.second.empty()) { + for (const auto& itr : shape.second) { + result << ov::test::utils::vec2str(itr); + } + } + result << ")_"; + } + result << "Prc=" << inType << "_"; + result << "HasShapeOf=" << hasShapeof; + result << "TransposeOrder="; + result << "("; + for (const auto& itr : transposeOrder) { + result << itr << ","; + } + result << ")"; + + return result.str(); + } + + void SetUp() override { + InputShapeAndTransposeOrder inputShapeAndOrders; + bool hasShapeOf; + ElementType inType; + std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam(); + std::vector& inputShapes = inputShapeAndOrders.first; + std::vector& transposeOrder = inputShapeAndOrders.second; + targetDevice = ov::test::utils::DEVICE_CPU; + rel_threshold = 1e-2f; + configuration[ov::hint::inference_precision.name()] = ov::element::f32; + if (inType == ElementType::bf16) { + configuration[ov::hint::inference_precision.name()] = ov::element::bf16; + rel_threshold = 0.01f; + } + init_input_shapes(inputShapes); + ov::ParameterVector inputParams; + // q,k,v + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + inputParams[0]->set_friendly_name("q"); + inputParams[1]->set_friendly_name("k"); + inputParams[2]->set_friendly_name("v"); + // pastkv init_cost + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[2])); + auto var_k = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastk"}); + auto pastk = std::make_shared(inputParams[3], var_k); + pastk->set_friendly_name("pastk_r"); + auto var_v = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[2], inType, "pastv"}); + auto pastv = std::make_shared(inputParams[3], var_v); + pastv->set_friendly_name("pastv_r"); + std::shared_ptr pastk_shapeof, pastv_shapeof; + if (hasShapeOf) { + pastk_shapeof = std::make_shared(pastk); + pastv_shapeof = std::make_shared(pastv); + } + + // pre SDPA transpose + auto preOrder = op::v0::Constant::create(ov::element::i32, {4}, transposeOrder); + auto transposeQ = std::make_shared(inputParams[0], preOrder); + + auto concat_axis = transposeOrder[2]; + auto concatK = std::make_shared(OutputVector{pastk, inputParams[1]}, concat_axis); + auto concatV = std::make_shared(OutputVector{pastv, inputParams[2]}, concat_axis); + + auto unsquezeAxis = op::v0::Constant::create(ov::element::i32, {}, {-2}); + auto unsqueezeK = std::make_shared(concatK, unsquezeAxis); + auto unsqueezeV = std::make_shared(concatV, unsquezeAxis); + + auto targetShape = op::v0::Constant::create(inType, {1, 1, 1, 4, 1}, {1}); + auto broadcastK = std::make_shared(unsqueezeK, targetShape); + auto broadcastV = std::make_shared(unsqueezeV, targetShape); + + auto target4D = op::v0::Constant::create(ov::element::i32, {4}, {0, 0, 8, 64}); + + auto reshapeK = std::make_shared(broadcastK, target4D, true); + auto reshapeV = std::make_shared(broadcastV, target4D, true); + + auto transposeK = std::make_shared(reshapeK, preOrder); + auto transposeV = std::make_shared(reshapeV, preOrder); + + auto sdp = std::make_shared(transposeQ, transposeK, transposeV, false); + sdp->set_friendly_name("mha"); + + // post SDPA transpose + reshape + auto get_reshape_order = [](const ov::PartialShape& qkv_shape, + const std::vector& transposeOrder) -> std::vector { + assert(transposeOrder.size() == 4); + auto H = qkv_shape[transposeOrder[1]].get_length(); + auto S = qkv_shape[transposeOrder[3]].get_length(); + return std::vector{0, 0, static_cast(H * S)}; + }; + const auto reshapeOrder = get_reshape_order(inputDynamicShapes[0], transposeOrder); + + auto postOrder = + ov::op::v0::Constant::create(ov::element::i32, {4}, std::vector{0, 2, 1, 3}); // BHLS -> BLHS + auto transposeSDP = std::make_shared(sdp, postOrder); + + auto constReshape = ov::op::v0::Constant::create(ov::element::i32, {3}, reshapeOrder); + auto reshapeSDP = std::make_shared(transposeSDP, constReshape, true); // BLHS -> B,L,HxS + + auto add = std::make_shared(reshapeSDP, op::v0::Constant::create(inType, {1}, {1.0f})); + auto pastk_assign = std::make_shared(concatK, var_k); + auto pastv_assign = std::make_shared(concatV, var_v); + pastk_assign->set_friendly_name("pastk_w"); + pastv_assign->set_friendly_name("pastv_w"); + + ov::OutputVector results{add}; + if (hasShapeOf) { + results.push_back(pastk_shapeof); + results.push_back(pastv_shapeof); + } + SinkVector sinks{pastk_assign, pastv_assign}; + function = std::make_shared(results, sinks, inputParams, "ConcatTranposeSDP"); + targetDevice = ov::test::utils::DEVICE_CPU; + + functionRefs = function->clone(); + pass::Manager manager; + // decompose ScaledDotProductAttention + manager.register_pass(); + manager.run_passes(functionRefs); + } + void generate_inputs(const std::vector& targetInputStaticShapes) override { + std::vector shapes(4); + shapes[0] = targetInputStaticShapes[0]; + shapes[1] = targetInputStaticShapes[1]; + shapes[2] = targetInputStaticShapes[1]; + shapes[3] = targetInputStaticShapes[2]; + SubgraphBaseTest::generate_inputs(shapes); + } + template + void strided_iota(IT first, size_t n, T value, T stride) { + for (size_t i = 0; i < n; i++) { + *first++ = value; + value += stride; + } + } + void generate(int idx, const std::vector& targetInputStaticShapes) { + inputs.clear(); + auto create_input = [this](std::shared_ptr param, ov::Shape shape, float val) { + if (param->get_element_type() == element::f32) { + ov::Tensor t{ov::element::f32, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } else { + ov::Tensor t{ov::element::bf16, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } + }; + // q, k, v + create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f); + create_input(function->get_parameters()[1], targetInputStaticShapes[1], idx + 2.0f); + create_input(function->get_parameters()[2], targetInputStaticShapes[1], idx + 3.0f); + create_input(function->get_parameters()[3], targetInputStaticShapes[2], idx + 4.0f); + } + void prepare() { + compile_model(); + inferRequest = compiledModel.create_infer_request(); + ASSERT_TRUE(inferRequest); + } + void reset() { + for (auto&& state : inferRequest.query_state()) { + state.reset(); + } + } + std::vector run_test(std::shared_ptr model) { + function = model; + prepare(); + std::vector outputs; + int idx = 0; + for (auto&& shapes : targetStaticShapes) { + generate(idx++, shapes); + for (const auto& input : inputs) { + inferRequest.set_tensor(input.first, input.second); + } + inferRequest.infer(); + auto outputTensor = inferRequest.get_output_tensor(0); + ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()}; + outputTensor.copy_to(copy); + outputs.push_back(copy); + } + reset(); + + return outputs; + } +}; + +TEST_P(ConcatMultiQuerySDPTest, CompareWithRefs) { + auto actualOutputs = run_test(function); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1); + CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0); + if (configuration[ov::hint::inference_precision.name()] == ov::element::bf16) { + CheckNumberOfNodesWithType(compiledModel, "Reorder", 5); + } else { + CheckNumberOfNodesWithType(compiledModel, "Reorder", 0); + } + CheckNumberOfNodesWithType(compiledModel, "Transpose", 1); + auto expectedOutputs = run_test(functionRefs); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0); + for (size_t i = 0; i < actualOutputs.size(); i++) { + ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold); + } +} + +namespace { +const std::vector inputShapeAndReorders = {{ + {// inputShapes ChatGLM + { + // L1, B, H, S + {{-1, 1, 8, 64}, {{10, 1, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {20, 1, 8, 64}, {1, 1, 8, 64}}}, + {{-1, 1, 2, 64}, {{10, 1, 2, 64}, {1, 1, 2, 64}, {1, 1, 2, 64}, {20, 1, 2, 64}, {1, 1, 2, 64}}}, + // L0, B, H, S + {{-1, 1, 2, 64}, {{0, 1, 2, 64}, {10, 1, 2, 64}, {11, 1, 2, 64}, {12, 1, 2, 64}, {32, 1, 2, 64}}}, + }, + // transposeOrder + {1, 2, 0, 3}}, +}}; +// TODO: BF16 test is disabled due to CI machine limitation +INSTANTIATE_TEST_SUITE_P(smoke_ConcatMultiQuerySDPTest, + ConcatMultiQuerySDPTest, + ::testing::Combine(::testing::Values(ElementType::f32), + ::testing::ValuesIn(inputShapeAndReorders), + ::testing::Values(true, false)), + ConcatMultiQuerySDPTest::getTestCaseName); +} // namespace +} // namespace SubgraphTestsDefinitions diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_transpose_sdp_transpose.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_transpose_sdp_transpose.cpp new file mode 100644 index 00000000000000..ea6d37e9e6502e --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_transpose_sdp_transpose.cpp @@ -0,0 +1,296 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ov_models/builders.hpp" +#include "ov_models/utils/ov_helpers.hpp" +#include "shared_test_classes/base/layer_test_utils.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "test_utils/cpu_test_utils.hpp" +#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" + +using namespace ov::test; +using namespace ngraph; +using namespace CPUTestUtils; +using namespace InferenceEngine; + +namespace SubgraphTestsDefinitions { + +using InputShapeAndTransposeOrder = std::pair, std::vector>; +using ConcatSDPTransposeTestParams = std::tuple; +// Subgraph: +/* Parameter + * | + * Parameter ReadValue | ReadValue Parameter + * \ / | \ / + * \ / | \ / + * Concat Transpose Concat + * / \ | / \ + * / \ | / \ + * / Transpose | Transpose \ + * / \ | / \ + * Assign ScaledDotProductAttention Assign + * | + * Tranpose + * | + * Reshape + * | + * Add + * | + * Result + */ + +class ConcatSDPTransposeTest : public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest, + public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + ElementType inType; + InputShapeAndTransposeOrder inputShapeAndOrders; + bool hasShapeof; + std::tie(inType, inputShapeAndOrders, hasShapeof) = obj.param; + std::ostringstream result; + std::vector& inputShapes = inputShapeAndOrders.first; + std::vector& transposeOrder = inputShapeAndOrders.second; + result << "IS="; + for (const auto& shape : inputShapes) { + result << ov::test::utils::partialShape2str({shape.first}) << "_"; + } + result << "TS="; + for (const auto& shape : inputShapes) { + result << "("; + if (!shape.second.empty()) { + for (const auto& itr : shape.second) { + result << ov::test::utils::vec2str(itr); + } + } + result << ")_"; + } + result << "Prc=" << inType << "_"; + result << "HasShapeOf=" << hasShapeof; + result << "TransposeOrder="; + result << "("; + for (const auto& itr : transposeOrder) { + result << itr << ","; + } + result << ")"; + + return result.str(); + } + + void SetUp() override { + ElementType inType; + InputShapeAndTransposeOrder inputShapeAndOrders; + bool hasShapeOf; + std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam(); + std::vector& inputShapes = inputShapeAndOrders.first; + std::vector& transposeOrder = inputShapeAndOrders.second; + targetDevice = ov::test::utils::DEVICE_CPU; + rel_threshold = 1e-2f; + configuration[ov::hint::inference_precision.name()] = ov::element::f32; + if (inType == ElementType::bf16) { + configuration[ov::hint::inference_precision.name()] = ov::element::bf16; + rel_threshold = 0.01f; + } + init_input_shapes(inputShapes); + ov::ParameterVector inputParams; + // q,k,v + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams[0]->set_friendly_name("q"); + inputParams[1]->set_friendly_name("k"); + inputParams[2]->set_friendly_name("v"); + // pastkv init_cost + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + auto var_k = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastk"}); + auto pastk = std::make_shared(inputParams[3], var_k); + pastk->set_friendly_name("pastk_r"); + auto var_v = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastv"}); + auto pastv = std::make_shared(inputParams[3], var_v); + pastv->set_friendly_name("pastv_r"); + std::shared_ptr pastk_shapeof, pastv_shapeof; + if (hasShapeOf) { + pastk_shapeof = std::make_shared(pastk); + pastv_shapeof = std::make_shared(pastv); + } + + // pre SDPA transpose + auto preOrder = ov::op::v0::Constant::create(ov::element::i32, {4}, transposeOrder); + auto transposeQ = std::make_shared(inputParams[0], preOrder); + + auto concat_axis = transposeOrder[2]; + auto concatK = std::make_shared(OutputVector{pastk, inputParams[1]}, concat_axis); + auto concatV = std::make_shared(OutputVector{pastv, inputParams[2]}, concat_axis); + auto transposeK = std::make_shared(concatK, preOrder); + auto transposeV = std::make_shared(concatV, preOrder); + + auto sdp = std::make_shared(transposeQ, transposeK, transposeV, false); + sdp->set_friendly_name("mha"); + + // post SDPA transpose + reshape + auto get_reshape_order = [](const ov::PartialShape& qkv_shape, + const std::vector& transposeOrder) -> std::vector { + assert(transposeOrder.size() == 4); + auto H = qkv_shape[transposeOrder[1]].get_length(); + auto S = qkv_shape[transposeOrder[3]].get_length(); + return std::vector{0, 0, static_cast(H * S)}; + }; + const auto reshapeOrder = get_reshape_order(inputDynamicShapes[0], transposeOrder); + + auto postOrder = + ov::op::v0::Constant::create(ov::element::i32, {4}, std::vector{0, 2, 1, 3}); // BHLS -> BLHS + auto transposeSDP = std::make_shared(sdp, postOrder); + + auto constReshape = ov::op::v0::Constant::create(ov::element::i32, {3}, reshapeOrder); + auto reshapeSDP = std::make_shared(transposeSDP, constReshape, true); // BLHS -> B,L,HxS + + auto add = std::make_shared(reshapeSDP, op::v0::Constant::create(inType, {1}, {1.0f})); + auto pastk_assign = std::make_shared(concatK, var_k); + auto pastv_assign = std::make_shared(concatV, var_v); + pastk_assign->set_friendly_name("pastk_w"); + pastv_assign->set_friendly_name("pastv_w"); + + ov::OutputVector results{add}; + if (hasShapeOf) { + results.push_back(pastk_shapeof); + results.push_back(pastv_shapeof); + } + SinkVector sinks{pastk_assign, pastv_assign}; + function = std::make_shared(results, sinks, inputParams, "ConcatTranposeSDP"); + targetDevice = ov::test::utils::DEVICE_CPU; + + functionRefs = function->clone(); + pass::Manager manager; + // decompose ScaledDotProductAttention + manager.register_pass(); + manager.run_passes(functionRefs); + } + void generate_inputs(const std::vector& targetInputStaticShapes) override { + std::vector shapes(4); + shapes[0] = targetInputStaticShapes[0]; + shapes[1] = targetInputStaticShapes[0]; + shapes[2] = targetInputStaticShapes[0]; + shapes[3] = targetInputStaticShapes[1]; + SubgraphBaseTest::generate_inputs(shapes); + } + template + void strided_iota(IT first, size_t n, T value, T stride) { + for (size_t i = 0; i < n; i++) { + *first++ = value; + value += stride; + } + } + void generate(int idx, const std::vector& targetInputStaticShapes) { + inputs.clear(); + auto create_input = [this](std::shared_ptr param, ov::Shape shape, float val) { + if (param->get_element_type() == element::f32) { + ov::Tensor t{ov::element::f32, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } else { + ov::Tensor t{ov::element::bf16, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } + }; + // q, k, v + create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f); + create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f); + create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f); + create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f); + } + void prepare() { + compile_model(); + inferRequest = compiledModel.create_infer_request(); + ASSERT_TRUE(inferRequest); + } + void reset() { + for (auto&& state : inferRequest.query_state()) { + state.reset(); + } + } + std::vector run_test(std::shared_ptr model) { + function = model; + prepare(); + std::vector outputs; + int idx = 0; + for (auto&& shapes : targetStaticShapes) { + generate(idx++, shapes); + for (const auto& input : inputs) { + inferRequest.set_tensor(input.first, input.second); + } + inferRequest.infer(); + auto outputTensor = inferRequest.get_output_tensor(0); + ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()}; + outputTensor.copy_to(copy); + outputs.push_back(copy); + } + reset(); + + return outputs; + } +}; + +TEST_P(ConcatSDPTransposeTest, CompareWithRefs) { + auto actualOutputs = run_test(function); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1); + CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0); + CheckNumberOfNodesWithType(compiledModel, "Reorder", 0); + CheckNumberOfNodesWithType(compiledModel, "Transpose", 1); + auto expectedOutputs = run_test(functionRefs); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0); + for (size_t i = 0; i < actualOutputs.size(); i++) { + ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold); + } +} + +namespace { +const std::vector inputShapeAndReorders = { + { + // inputShapes LLama + { + // B, H, L1, S + {{1, 8, -1, 64}, {{1, 8, 10, 64}, {1, 8, 1, 64}, {1, 8, 1, 64}, {1, 8, 20, 64}, {1, 8, 1, 64}}}, + // B, H, L0, S + {{1, 8, -1, 64}, {{1, 8, 0, 64}, {1, 8, 10, 64}, {1, 8, 11, 64}, {1, 8, 12, 64}, {1, 8, 32, 64}}}, + }, + // transposeOrder + {0, 1, 2, 3}}, + {// inputShapes QWen + { + // B, L1, H, S + {{1, -1, 8, 64}, {{1, 10, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {1, 20, 8, 64}, {1, 1, 8, 64}}}, + // B, L0, H, S + {{1, -1, 8, 64}, {{1, 0, 8, 64}, {1, 10, 8, 64}, {1, 11, 8, 64}, {1, 12, 8, 64}, {1, 32, 8, 64}}}, + }, + // transposeOrder + {0, 2, 1, 3}}, + {// inputShapes ChatGLM + { + // L1, B, H, S + {{-1, 1, 8, 64}, {{10, 1, 8, 64}, {1, 1, 8, 64}, {1, 1, 8, 64}, {20, 1, 8, 64}, {1, 1, 8, 64}}}, + // L0, B, H, S + {{-1, 1, 8, 64}, {{0, 1, 8, 64}, {10, 1, 8, 64}, {11, 1, 8, 64}, {12, 1, 8, 64}, {32, 1, 8, 64}}}, + }, + // transposeOrder + {1, 2, 0, 3}}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTransposeTest, + ConcatSDPTransposeTest, + ::testing::Combine(::testing::Values(ElementType::f32), + ::testing::ValuesIn(inputShapeAndReorders), + ::testing::Values(true, false)), + ConcatSDPTransposeTest::getTestCaseName); + +} // namespace +} // namespace SubgraphTestsDefinitions