diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp index 30d95ed6a2bf7a..8ee370619448b9 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp @@ -226,21 +226,13 @@ void BrgemmCopyBKernel::generate() { size_t start_out = 0; size_t start_comp = 0; - auto add_ptr_increments = [&](size_t current_N) { + for (size_t nb = 0; nb < div_up(N_blk, wei_N_blk); nb++) { + const auto current_N = N_blk - nb * wei_N_blk < wei_N_blk ? wei_N_tail : wei_N_blk; + emit_brgemm_copy_b_kernel_call(current_N, K, start_in, start_out, start_comp); + start_in += is_transpose ? K * current_N * wei_data_size : current_N * wei_data_size; start_out += current_N * vnni_factor * wei_data_size; start_comp += is_with_comp ? current_N * sizeof(int32_t) : 0; - }; - - // OneDNN requires tail handling before main iterations - if (wei_N_tail != 0) { - emit_brgemm_copy_b_kernel_call(wei_N_tail, K, start_in, start_out, start_comp); - add_ptr_increments(wei_N_tail); - } - - for (auto nb = wei_N_tail; nb < N_blk; nb += wei_N_blk) { - emit_brgemm_copy_b_kernel_call(wei_N_blk, K, start_in, start_out, start_comp); - add_ptr_increments(wei_N_blk); } postamble(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp index b5f470c1c695ba..1929d4517d05e9 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp @@ -69,9 +69,7 @@ template < typename T, typename = typename std::enable_if<(std::is_same::value || std::is_same::value), bool>::type> T compute_LDB(T n_block, const ov::element::Type& precision) { - return snippets::utils::is_dynamic_value(n_block) - ? n_block - : std::max(n_block, static_cast(compute_inner_n_block(precision))); + return ov::snippets::utils::rnd_up(n_block, static_cast(compute_inner_n_block(precision))); } /** * @brief Retrieves the expression pointer for the brgemm_copy_b expression corresponding to the given BrgemmCPU diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp index dac2ee94741aa1..6912a714f6ebfd 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/expressions/brgemm_copy_b_buffer_expressions.cpp @@ -45,7 +45,7 @@ void RepackedWeightsBufferExpression::init_allocation_size( const auto& precision = get_node()->get_input_element_type(0); // Repacking buffer shape is set in accordance to OneDNN requirements - const size_t N_dim = std::max(n_blk, compute_inner_n_block(precision)); + const size_t N_dim = snippets::utils::rnd_up(n_blk, compute_inner_n_block(precision)); if (!in_layout.empty() && in_layout.back() != in_layout.size() - 1) { // In case of transpose, K dimension must be rounded-up to number of elems in vector register // For the details, please see 'transpose16x8' and 'fixup16x16' implementations and usage in diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp index add7c66d3d7ffc..f8bdeeaedef047 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp @@ -57,7 +57,7 @@ VectorDims BrgemmExternalRepackingAdjuster::get_blk_shape(const VectorDims& plan const auto K = *++planar_shape.rbegin(); const auto N = *planar_shape.rbegin(); const auto new_K = snippets::utils::div_up(K, vnni_factor); - const auto new_N = std::max(N, brgemm_utils::repacking::compute_inner_n_block(prc)); + const auto new_N = snippets::utils::rnd_up(N, brgemm_utils::repacking::compute_inner_n_block(prc)); VectorDims blk_shape(planar_shape.begin(), planar_shape.end() - brgemm_kernel_rank); blk_shape.insert(blk_shape.end(), {new_K, new_N, vnni_factor}); return blk_shape;