From 62e8e0873bf327584791152c6d741518d2886bef Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Fri, 24 Jan 2025 09:59:39 +0400 Subject: [PATCH] [Snippets][CPU] Disable MHA tokenization in LLM (#28601) ### Details: - *The second inference in LLM is usually single token inference. It means that `M` dimension of MatMuls in SDPA pattern will have the value `1` (during compilation model this dimension is dynamic (unknown)). Snippets cannot provide efficient execution for single token inference. So we decided to disable MHA tokenization by Snippets in CPU Plugin on LLMs'. We consider the presence of `ScaledDotProductAttentionWithKVCache` op in the model as a sign that this model is LLM.* ### Tickets: - *160634* - *160978* ### TODO: - [x] Performance validation on LLMs (the results are in the ticket CVS-160978) --- .../transformation_pipeline.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index da61917a146db0..880cdd54c42812 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -1031,18 +1031,27 @@ void Transformations::MainSnippets(void) { } CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config); - // - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM - const bool isMHASupported = -#if defined(OPENVINO_ARCH_ARM64) - false; -#else +#if defined(OPENVINO_ARCH_X86_64) + // Currently, Snippets don't provide efficient execution for single token inference in LLM case. + // To avoid performance degradations, we disable MHA tokenization into Subgraphs in LLMs'. + // We consider the presence of `ScaledDotProductAttentionWithKVCache` and `PagedAttentionExtension` ops + // in the model as a sign that this model is LLM. + const auto is_LLM = ov::op::util::has_op_with_type(model) || + ov::op::util::has_op_with_type(model); + + // CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM + const auto is_infer_prc_supported_by_MHA = (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && one_of(config.inferencePrecision, ov::element::f32, element::undefined)) || (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) || (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) && one_of(config.inferencePrecision, ov::element::f16)); + const bool isMHASupported = !is_LLM && is_infer_prc_supported_by_MHA; +#else + const bool isMHASupported = false; #endif + if (!isMHASupported) { CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets); CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::ExtractReshapesFromMHA);