Skip to content

Commit df8b6bd

Browse files
langc23pytorchmergebot
authored andcommitted
support fp16 shgemm under openblas (pytorch#169042)
# Purpose This PR is to support fp16 shgemm under openblas. We conducted tests using vLLM on the following platform. With using this patch, vLLM demonstrates faster inference speed under fp16. **Platform info:** Architecture: riscv64 Byte Order: Little Endian CPU(s): 64 On-line CPU(s) list: 0-63 Vendor ID: 0x5b7 BIOS Vendor ID: SOPHGO Model name: - BIOS Model name: SG2044 Not Set CPU @ 2.6GHz BIOS CPU family: 513 CPU family: 0x80000000090c0d00 Model: 0x2047000 Thread(s) per core: 1 Core(s) per socket: 64 Socket(s): 1 Frequency boost: disabled CPU(s) scaling MHz: 100% CPU max MHz: 2600.0000 CPU min MHz: 1000.0000 Caches (sum of all): L1d: 4 MiB (64 instances) L1i: 4 MiB (64 instances) L2: 32 MiB (16 instances) L3: 64 MiB (1 instance) Vulnerabilities: Gather data sampling: Not affected Itlb multihit: Not affected L1tf: Not affected Mds: Not affected Meltdown: Not affected Mmio stale data: Not affected Reg file data sampling: Not affected Retbleed: Not affected ISA: rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintntl_zihintpause_zihpm_zawrs_zfa_zfh_zfhmin_zca_zcb_zcd_zba_zbb_zbc_zbs_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_sscofpmf_sstc_svinval_svnapot_svpbmt **Branch** openblas: develop torch: develop vllm: main # Test Plan Base: without this PR Pytorch use OpenBLAS FP16 GEMM: use this PR **Base** export VLLM_CPU_OMP_THREADS_BIND=0-63 export VLLM_CPU_KVCACHE_SPACE=60 vllm bench latency \ --model /home/models/Qwen2.5-7B-Instruct \ --tensor-parallel-size 1\ --dtype float16 \ --input-len 16 \ --output-len 16 \ --enforce-eager \ --max-model-len 8192 \ --max-num-batched-tokens 8192 \ --batch-size 1 \ --n 1 \ --num-iters-warmup 5 \ --num-iters 8 \ --seed 42 \ --output-json ./latency_results_fp16_latency_base.json **Pytorch use OpenBLAS FP16 GEMM** export VLLM_CPU_OMP_THREADS_BIND=0-63 export VLLM_CPU_KVCACHE_SPACE=60 vllm bench latency \ --model /home/models/Qwen2.5-7B-Instruct \ --tensor-parallel-size 1\ --dtype float16 \ --input-len 16 \ --output-len 16 \ --enforce-eager \ --max-model-len 8192 \ --max-num-batched-tokens 8192 \ --batch-size 1 \ --n 1 \ --num-iters-warmup 5 \ --num-iters 8 \ --seed 42 \ --output-json ./latency_results_fp16_latency_with_openblas_support.json # Result **Base** { "avg_latency": 62.53946338250171, "latencies": [ 58.46783778001554, 58.230652199999895, 58.335780619992875, 59.77051957999356, 58.587668860011036, 59.31567866000114, 58.460076240007766, 89.14749311999185 ], "percentiles": { "10": 58.30424209399498, "25": 58.42900233500404, "50": 58.52775332001329, "75": 59.429388889999245, "90": 68.58361164199304, "99": 87.09110497219196 } } **Pytorch use OpenBLAS FP16 GEMM** { "avg_latency": 32.42863222499727, "latencies": [ 30.742418120033108, 33.67000828002347, 29.747197599965148, 32.11275753995869, 34.566938299976755, 30.849812360014766, 34.46360486000776, 33.27632073999848 ], "percentiles": { "10": 30.44385196401272, "25": 30.82296380001935, "50": 32.69453913997859, "75": 33.86840742501954, "90": 34.49460489199846, "99": 34.55970495917892 } } Pull Request resolved: pytorch#169042 Approved by: https://github.com/aditew01, https://github.com/albanD
1 parent 8121f2c commit df8b6bd

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

aten/src/ATen/native/CPUBlas.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k,
2828
float *beta,
2929
float *c, int *ldc);
3030
#endif // BLAS_HAS_SBGEMM
31+
#ifdef BLAS_HAS_SHGEMM
32+
extern "C" void shgemm_(char *transa, char *transb, int *m, int *n, int *k,
33+
float *alpha,
34+
const at::Half *a, int *lda,
35+
const at::Half *b, int *ldb,
36+
float *beta,
37+
float *c, int *ldc);
38+
#endif // BLAS_HAS_SHGEMM
3139
extern "C" void cswap_(int *n, const void *x, int *incx, void *y, int *incy);
3240
extern "C" void dcopy_(int *n, const double *x, int *incx, double *y, int *incy);
3341
extern "C" void scopy_(int *n, const float *x, int *incx, float *y, int *incy);
@@ -413,6 +421,34 @@ void gemm(
413421
mkldnn_fp16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
414422
return;
415423
}
424+
#endif
425+
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SHGEMM)
426+
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
427+
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
428+
char transa_ = to_blas(transa), transb_ = to_blas(transb);
429+
float alpha_ = alpha, beta_ = beta;
430+
int c_size = n_ * m_;
431+
// C matrix in OpenBLAS shgemm are of type "float" so we have to convert, copy and copy back.
432+
std::vector<float> float_v(c_size, 0.0f);
433+
for (const auto j : c10::irange(n)) {
434+
for (const auto i : c10::irange(m)) {
435+
float_v[j * m_ + i] = c10::convert<float>(c[j * ldc_ + i]);
436+
}
437+
}
438+
shgemm_(&transa_, &transb_,
439+
&m_, &n_, &k_,
440+
&alpha_,
441+
a, &lda_,
442+
b, &ldb_,
443+
&beta_,
444+
float_v.data(), &m_);
445+
for (const auto j : c10::irange(n)) {
446+
for (const auto i : c10::irange(m)) {
447+
c[j * ldc_ + i] = c10::convert<at::Half>(float_v[j * m_ + i]);
448+
}
449+
}
450+
return;
451+
}
416452
#endif
417453
gemm_stub(
418454
at::kCPU, at::kHalf,
@@ -471,6 +507,21 @@ void gemm(
471507
const float beta,
472508
float *c, int64_t ldc) {
473509
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
510+
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SHGEMM)
511+
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
512+
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
513+
char transa_ = to_blas(transa), transb_ = to_blas(transb);
514+
float alpha_ = alpha, beta_ = beta;
515+
shgemm_(&transa_, &transb_,
516+
&m_, &n_, &k_,
517+
&alpha_,
518+
a, &lda_,
519+
b, &ldb_,
520+
&beta_,
521+
c, &ldc_);
522+
return;
523+
}
524+
#endif
474525
#ifdef MKL_HAS_SHGEMM
475526
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
476527
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;

cmake/Modules/FindBLAS.cmake

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,14 @@ IF(BLAS_LIBRARIES)
346346
add_compile_options(-DBLAS_HAS_SBGEMM)
347347
ENDIF(BLAS_HAS_SBGEMM)
348348
ENDIF(BLAS_LIBRARIES)
349+
350+
# Blas has fp16 (half precision) support?
351+
IF(BLAS_LIBRARIES)
352+
INCLUDE(CheckFunctionExists)
353+
SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
354+
check_function_exists("shgemm_" BLAS_HAS_SHGEMM)
355+
set(CMAKE_REQUIRED_LIBRARIES)
356+
IF(BLAS_HAS_SHGEMM)
357+
add_compile_options(-DBLAS_HAS_SHGEMM)
358+
ENDIF(BLAS_HAS_SHGEMM)
359+
ENDIF(BLAS_LIBRARIES)

cmake/Summary.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ function(caffe2_print_configuration_summary)
6060
if(${USE_BLAS})
6161
message(STATUS " BLAS : ${BLAS_INFO}")
6262
message(STATUS " BLAS_HAS_SBGEMM : ${BLAS_HAS_SBGEMM}")
63+
message(STATUS " BLAS_HAS_SHGEMM : ${BLAS_HAS_SHGEMM}")
6364
endif()
6465
message(STATUS " USE_LAPACK : ${USE_LAPACK}")
6566
if(${USE_LAPACK})

0 commit comments

Comments
 (0)