Skip to content

Commit

Permalink
[Snippets][CPU] Moved N_tail processing to the end in BrgemmCopyBKernel
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 24, 2025
1 parent 62e8e08 commit 9a47d46
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,13 @@ void BrgemmCopyBKernel::generate() {
size_t start_out = 0;
size_t start_comp = 0;

auto add_ptr_increments = [&](size_t current_N) {
for (size_t nb = 0; nb < div_up(N_blk, wei_N_blk); nb++) {
const auto current_N = N_blk - nb * wei_N_blk < wei_N_blk ? wei_N_tail : wei_N_blk;
emit_brgemm_copy_b_kernel_call(current_N, K, start_in, start_out, start_comp);

start_in += is_transpose ? K * current_N * wei_data_size : current_N * wei_data_size;
start_out += current_N * vnni_factor * wei_data_size;
start_comp += is_with_comp ? current_N * sizeof(int32_t) : 0;
};

// OneDNN requires tail handling before main iterations
if (wei_N_tail != 0) {
emit_brgemm_copy_b_kernel_call(wei_N_tail, K, start_in, start_out, start_comp);
add_ptr_increments(wei_N_tail);
}

for (auto nb = wei_N_tail; nb < N_blk; nb += wei_N_blk) {
emit_brgemm_copy_b_kernel_call(wei_N_blk, K, start_in, start_out, start_comp);
add_ptr_increments(wei_N_blk);
}

postamble();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ template <
typename T,
typename = typename std::enable_if<(std::is_same<T, size_t>::value || std::is_same<T, int64_t>::value), bool>::type>
T compute_LDB(T n_block, const ov::element::Type& precision) {
return snippets::utils::is_dynamic_value<T>(n_block)
? n_block
: std::max(n_block, static_cast<T>(compute_inner_n_block(precision)));
return ov::snippets::utils::rnd_up(n_block, static_cast<T>(compute_inner_n_block(precision)));
}
/**
* @brief Retrieves the expression pointer for the brgemm_copy_b expression corresponding to the given BrgemmCPU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void RepackedWeightsBufferExpression::init_allocation_size(

const auto& precision = get_node()->get_input_element_type(0);
// Repacking buffer shape is set in accordance to OneDNN requirements
const size_t N_dim = std::max(n_blk, compute_inner_n_block(precision));
const size_t N_dim = snippets::utils::rnd_up(n_blk, compute_inner_n_block(precision));
if (!in_layout.empty() && in_layout.back() != in_layout.size() - 1) {
// In case of transpose, K dimension must be rounded-up to number of elems in vector register
// For the details, please see 'transpose16x8' and 'fixup16x16' implementations and usage in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ VectorDims BrgemmExternalRepackingAdjuster::get_blk_shape(const VectorDims& plan
const auto K = *++planar_shape.rbegin();
const auto N = *planar_shape.rbegin();
const auto new_K = snippets::utils::div_up(K, vnni_factor);
const auto new_N = std::max(N, brgemm_utils::repacking::compute_inner_n_block(prc));
const auto new_N = snippets::utils::rnd_up(N, brgemm_utils::repacking::compute_inner_n_block(prc));
VectorDims blk_shape(planar_shape.begin(), planar_shape.end() - brgemm_kernel_rank);
blk_shape.insert(blk_shape.end(), {new_K, new_N, vnni_factor});
return blk_shape;
Expand Down

0 comments on commit 9a47d46

Please sign in to comment.