Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM CPU] hgemm optimized for gqa #23107

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -394,6 +395,7 @@ else()
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -406,6 +408,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ class GQAAttentionBase {
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
output, static_cast<int>(present_buffer_sequence_length), nullptr);
} else if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) {
MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size,
q, static_cast<int>(head_size), k, static_cast<int>(head_size), output,
static_cast<int>(present_buffer_sequence_length), alpha, 0.0f /*beta*/, nullptr);
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
auto q_k_fp32 = allocator->Alloc(bytes);
Expand Down
102 changes: 101 additions & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,107 @@
T* output
);

/**
/**
* @brief Supply matrices data information to half precision gemm functions
*/
struct MLAS_HGEMM_DATA_PARAMS {
const MLAS_FP16* A = nullptr; /**< Supplies the address of matrix A */
size_t lda = 0; /**< Supplies the first dimension of matrix A. */
const MLAS_FP16* B = nullptr; /**< Supplies the address of matrix B */
size_t ldb = 0; /**< Supplies the first dimension of matrix B. */
MLAS_FP16* C = nullptr; /**< Supplies the address of matrix C */
size_t ldc = 0; /**< Supplies the first dimension of matrix C. */
MLAS_FP16 alpha = MLAS_FP16(1.0f); /**< Supplies the scalar alpha multiplier (see GEMM definition) */

Check failure on line 1471 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Vcpkg

field has incomplete type 'MLAS_FP16' (aka 'onnxruntime::MLFloat16')

Check failure on line 1471 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Vcpkg

field has incomplete type 'MLAS_FP16' (aka 'onnxruntime::MLFloat16')
MLAS_FP16 beta = MLAS_FP16(0.0f); /**< Supplies the scalar beta multiplier (see GEMM definition) */

Check failure on line 1472 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Vcpkg

field has incomplete type 'MLAS_FP16' (aka 'onnxruntime::MLFloat16')

Check failure on line 1472 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Vcpkg

field has incomplete type 'MLAS_FP16' (aka 'onnxruntime::MLFloat16')
};

/**
* @brief Check whether current CPU supports half precision gemm.
*/
bool
MLASCALL
MlasHGemmSupported(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB
);

/**
* @brief Batched half precision matrix/matrix multiply operation (HGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
* @param Data A array of matrices data parameters
* @param BatchSize Supplies number of multiplications in this batch
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_HGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);

/**
* @brief half precision matrix/matrix multiply operation (HGEMM)
* C = alpha * op(A) * op(B) + beta * C
*
* @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans.
* @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param alpha Supplies the scalar alpha multiplier (see GEMM definition)
* @param beta Supplies the scalar beta multiplier (see GEMM definition)
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support
* should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_FP16* A,
size_t lda,
const MLAS_FP16* B,
size_t ldb,
MLAS_FP16* C,
size_t ldc,
MLAS_FP16 alpha,

Check failure on line 1545 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Vcpkg

variable has incomplete type 'MLAS_FP16' (aka 'onnxruntime::MLFloat16')

Check failure on line 1545 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Vcpkg

variable has incomplete type 'MLAS_FP16' (aka 'onnxruntime::MLFloat16')
MLAS_FP16 beta,

Check failure on line 1546 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Vcpkg

variable has incomplete type 'MLAS_FP16' (aka 'onnxruntime::MLFloat16')

Check failure on line 1546 in onnxruntime/core/mlas/inc/mlas.h

View workflow job for this annotation

GitHub Actions / Vcpkg

variable has incomplete type 'MLAS_FP16' (aka 'onnxruntime::MLFloat16')
MLAS_THREADPOOL* ThreadPool
) {
MLAS_HGEMM_DATA_PARAMS Data;
Data.alpha = alpha;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.beta = beta;
Data.C = C;
Data.ldc = ldc;
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}

/**
* @brief Whether current CPU supports FP16 acceleration.
*/
bool MLASCALL
Expand Down
100 changes: 100 additions & 0 deletions onnxruntime/core/mlas/lib/fp16_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Module Name:

#pragma once

#include <arm_neon.h>
#include "mlas_float16.h"
#include "mlasi.h"

Expand Down Expand Up @@ -349,4 +350,103 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT
return vbsl_f16(select, ones, zeros);
}

MLAS_FORCEINLINE
void
Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3,
MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7)
{
// |v00|v01|v02|v03|v04|v05|v06|v07|
// |v10|v11|v12|v13|v14|v15|v16|v17|
// |v20|v21|v22|v23|v24|v25|v26|v27|
// |v30|v31|v32|v33|v34|v35|v36|v37|
// |v40|v41|v42|v43|v44|v45|v46|v47|
// |v50|v51|v52|v53|v54|v55|v56|v57|
// |v60|v61|v62|v63|v64|v65|v66|v67|
// |v70|v71|v72|v73|v74|v75|v76|v77|
float16x8x2_t t01 = vtrnq_f16(v0, v1);
float16x8x2_t t23 = vtrnq_f16(v2, v3);
float16x8x2_t t45 = vtrnq_f16(v4, v5);
float16x8x2_t t67 = vtrnq_f16(v6, v7);
// |v00|v10|v02|v12|v04|v14|v06|v16|
// |v01|v11|v03|v13|v05|v15|v07|v17|
// |v20|v30|v22|v32|v24|v34|v26|v36|
// |v21|v31|v23|v33|v25|v35|v27|v37|
// |v40|v50|v42|v52|v44|v54|v46|v56|
// |v41|v51|v43|v53|v45|v55|v47|v57|
// |v60|v70|v62|v72|v64|v74|v66|v76|
// |v61|v71|v63|v73|v65|v75|v67|v77|
float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]));
float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]));
float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0]));
float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1]));
// |v00|v10|v20|v30|v04|v14|v24|v34|
// |v01|v11|v21|v31|v05|v15|v25|v35|
// |v02|v12|v22|v32|v06|v16|v26|v36|
// |v03|v13|v23|v33|v07|v17|v27|v37|
// |v40|v50|v60|v70|v44|v54|v64|v74|
// |v41|v51|v61|v71|v45|v55|v65|v75|
// |v42|v52|v62|v72|v46|v56|v66|v76|
// |v43|v53|v63|v73|v47|v57|v67|v77|
v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
// |v00|v10|v20|v30|v40|v50|v60|v70|
// |v01|v11|v21|v31|v41|v51|v61|v71|
// |v02|v12|v22|v32|v42|v52|v62|v72|
// |v03|v13|v23|v33|v43|v53|v63|v73|
// |v04|v14|v24|v34|v44|v54|v64|v74|
// |v05|v15|v25|v35|v45|v55|v65|v75|
// |v06|v16|v26|v36|v46|v56|v66|v76|
// |v07|v17|v27|v37|v47|v57|v67|v77|
}

MLAS_FORCEINLINE
void
Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3)
{
// |v00|v01|v02|v03|v04|v05|v06|v07|
// |v10|v11|v12|v13|v14|v15|v16|v17|
// |v20|v21|v22|v23|v24|v25|v26|v27|
// |v30|v31|v32|v33|v34|v35|v36|v37|
// =>
// |v00|v10|v20|v30|v04|v14|v24|v34|
// |v01|v11|v21|v31|v05|v15|v25|v35|
// |v02|v12|v22|v32|v06|v16|v26|v36|
// |v03|v13|v23|v33|v07|v17|v27|v37|
float16x8x2_t t01 = vtrnq_f16(v0, v1);
float16x8x2_t t23 = vtrnq_f16(v2, v3);

v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
}

MLAS_FORCEINLINE
void
Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3)
{
// |v00|v01|v02|v03|
// |v10|v11|v12|v13|
// |v20|v21|v22|v23|
// |v30|v31|v32|v33|
// =>
// |v00|v10|v20|v30|
// |v01|v11|v21|v31|
// |v02|v12|v22|v32|
// |v03|v13|v23|v33|
float16x4x2_t t01 = vtrn_f16(v0, v1);
float16x4x2_t t23 = vtrn_f16(v2, v3);

v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
}

#endif // fp16 vector intrinsic supported
Loading
Loading