Skip to content

Commit

Permalink
[FORK][WA]Fall back avx512 gemm to avx2 gemm when __BUILD_GEMM_AVX512…
Browse files Browse the repository at this point in the history
… is false.

[FORK][FEATURE] cpu: remove gemm legacy on avx512.
  • Loading branch information
luweizhou2016 committed Nov 15, 2023
1 parent a9e6db9 commit 8b178e8
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 129 deletions.
4 changes: 2 additions & 2 deletions src/cpu/gemm/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ dnnl_status_t gemm_s8x8s32(const char *transa, const char *transb,
if (*M == 0 || *N == 0 || *K == 0) return dnnl_success;

#if DNNL_X64 && !__BUILD_GEMM_NONE
bool use_jit = mayiuse(avx512_core);
bool use_jit = avx512_gemm_available();
bool use_s8u8 = true
&& utils::everyone_is(0, *ao, *bo) // so far a requirement
&& IMPLICATION(USE_MKL_IGEMM == 0, mayiuse(sse41));
Expand Down Expand Up @@ -299,7 +299,7 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb,
bfloat16_t *dummy_bo = nullptr;
float *dummy_co = nullptr;

if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
auto status = gemm_driver(transa, transb, dummyOffsetC, M, N, K, alpha,
(const bfloat16_t *)A, lda, dummy_ao, (const bfloat16_t *)B,
ldb, dummy_bo, beta, (float *)C, ldc, dummy_co, false);
Expand Down
19 changes: 17 additions & 2 deletions src/cpu/gemm/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@
#define __BUILD_GEMM_AVX2 __BUILD_GEMM_AVX512 || BUILD_GEMM_AVX2
#define __BUILD_GEMM_SSE41 __BUILD_GEMM_AVX2 || BUILD_GEMM_SSE41
#define __BUILD_GEMM_NONE BUILD_GEMM_KERNELS_NONE

#if __BUILD_GEMM_AVX512
#define avx512_gemm_available() mayiuse(avx512_core)
#define avx512_amx_gemm_available() mayiuse(avx512_core_amx)
#define avx512_bf16_gemm_available() mayiuse(avx512_core_bf16)
#define avx512_vnni_gemm_available() mayiuse(avx512_core_vnni)
#define avx512_bf16_ymm_gemm_available() mayiuse(avx512_core_bf16_ymm)
#else
#define avx512_gemm_available() false
#define avx512_amx_gemm_available() false
#define avx512_bf16_gemm_available() false
#define avx512_vnni_gemm_available() false
#define avx512_bf16_ymm_gemm_available() false
#endif

#else
#define __BUILD_GEMM_AMX 0
#define __BUILD_GEMM_AVX512 0
Expand Down Expand Up @@ -91,9 +106,9 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb,
#if !defined(USE_MKL_IGEMM) && defined(DNNL_X64)
#define IGEMM_S8U8S32_ISA_STR \
JIT_IMPL_NAME_HELPER(IGEMM_S8U8S32_IMPL_STR ":", \
mayiuse(avx512_core_vnni) \
avx512_vnni_gemm_available() \
? avx512_core_vnni \
: (mayiuse(avx512_core) ? avx512_core : isa_undef), \
: (avx512_gemm_available() ? avx512_core : isa_undef), \
"")
#else
#define IGEMM_S8U8S32_ISA_STR IGEMM_S8U8S32_IMPL_STR
Expand Down
36 changes: 19 additions & 17 deletions src/cpu/x64/gemm/f32/jit_avx2_kernel_sgemm_kern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ namespace impl {
namespace cpu {
namespace x64 {

#define avx512_gemm_available() false

int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const {
while (!(((idx / unroll_n_) < std::max(1, um / nelt_per_vecreg_))
|| ((idx % unroll_n_) < un)))
Expand All @@ -36,7 +38,7 @@ int jit_avx2_kernel_sgemm_kern::next_acc(int idx, int um, int un) const {

void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload(
int um, int un, int k_idx, int n_idx) {
if (!mayiuse(avx512_core)) {
if (!avx512_gemm_available()) {
if ((n_idx == 0) && (k_idx == 0) && (un == unroll_n_) && (um != 16)) {
prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]);
offb_ += 16;
Expand All @@ -46,7 +48,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeBload(

void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA(
int um, int un, int k_idx, int n_idx, int m_idx) {
if (!mayiuse(avx512_core)) {
if (!avx512_gemm_available()) {
if ((um == 16) || (un < unroll_n_)) {
if ((k_idx + m_idx + n_idx) == 0) {
prefetcht0(ptr[BO_ + elt_size_ * (PREFETCHSIZEB_ + offb_)]);
Expand All @@ -63,7 +65,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_beforeFMA(

void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA(
int um, int un, int k_idx, int n_idx, int m_idx) {
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
if ((um < unroll_m_) && (m_idx == 0)) {
if (((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 0) && (n_idx % 6 == 0))
|| ((k_idx % (nb_zmm_a_ / unroll_m_reg_) == 1)
Expand All @@ -87,7 +89,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterFMA(

void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload(
int um, int un, int k_idx, int n_idx) {
if (!mayiuse(avx512_core)) {
if (!avx512_gemm_available()) {
if ((um == unroll_m_) && (un == 2)) {
if (k_idx % 3 == 0) {
if (n_idx == 1) {
Expand All @@ -111,7 +113,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_afterBload(

void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA(
int k_idx, int n_idx, int m_idx) {
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
if (((m_idx + (k_idx % (nb_zmm_a_ / unroll_m_reg_)) * unroll_m_reg_)
== 0)
&& (n_idx == 1)) {
Expand All @@ -126,7 +128,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchB_afterFMA(

void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA(
int um, int un, int k_idx, int n_idx, int m_idx) {
if (!mayiuse(avx512_core)) {
if (!avx512_gemm_available()) {
if ((um == unroll_m_) && (un == unroll_n_)) {
if (((k_idx == 0) && (n_idx % 2 == 1) && (m_idx == 0))
|| ((k_idx == 1) && (n_idx == 2) && (m_idx == 0))
Expand Down Expand Up @@ -160,7 +162,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchA_beforeFMA(

void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload(
int um, int un, int k_idx, int n_idx) {
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
if (um == unroll_m_) {
if (n_idx == std::min(1, un - 1)) {
if (k_idx == unroll_k_ - 1)
Expand All @@ -173,7 +175,7 @@ void jit_avx2_kernel_sgemm_kern::prefetchC_afterBload(
}

void jit_avx2_kernel_sgemm_kern::prefetchC_beforeKloop(int um) {
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
if (um < unroll_m_) {
prefetchw(ptr[CO2_ + elt_size_ * 0]);
prefetchw(ptr[CO2_ + elt_size_ * 8]);
Expand Down Expand Up @@ -228,7 +230,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
mov(C_, ptr[rsp + get_size_of_abi_save_regs() + C_off]);
mov(LDC_, ptr[rsp + get_size_of_abi_save_regs() + LDC_off]);

if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
for (i = zmm_acc_idx_; i < unroll_m_reg_ * unroll_n_ + zmm_acc_idx_;
i++)
vpxorq(Xbyak::Zmm(i), Xbyak::Zmm(i), Xbyak::Zmm(i));
Expand Down Expand Up @@ -267,7 +269,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
add(AA_, A_);
mov(CO1_, C_);

if ((unroll_x == unroll_m_) || (!mayiuse(avx512_core)))
if ((unroll_x == unroll_m_) || (!avx512_gemm_available()))
lea(CO2_, ptr[C_ + LDC_ * 2]);

add(C_, unroll_x * elt_size_);
Expand All @@ -292,12 +294,12 @@ void jit_avx2_kernel_sgemm_kern::generate() {
T_NEAR);
}

if (!mayiuse(avx512_core))
if (!avx512_gemm_available())
prefetcht2(ptr[AA_ - addr_off_ * elt_size_]);

switch (unroll_x) {
case 8:
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
loop<Xbyak::Zmm, Xbyak::Zmm, Xbyak::Address, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vbroadcastf64x4,
Expand All @@ -319,7 +321,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {

break;
case 4:
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
loop<Xbyak::Zmm, Xbyak::Ymm, Xbyak::Address, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vbroadcastf32x4,
Expand All @@ -340,7 +342,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {

break;
case 2:
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
loop<Xbyak::Zmm, Xbyak::Ymm, Xbyak::Operand, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vbroadcastsd,
Expand All @@ -357,7 +359,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
&Xbyak::CodeGenerator::vmovsd);
break;
case 1:
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
loop<Xbyak::Zmm, Xbyak::Xmm, Xbyak::Operand, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vbroadcastss,
Expand All @@ -377,7 +379,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {

break;
default:
if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
loop<Xbyak::Zmm, Xbyak::Xmm, Xbyak::Operand, Xbyak::Xmm,
Xbyak::Operand>(unroll_x, unroll_y,
&Xbyak::CodeGenerator::vmovups,
Expand All @@ -400,7 +402,7 @@ void jit_avx2_kernel_sgemm_kern::generate() {
break;
}

if (mayiuse(avx512_core)) {
if (avx512_gemm_available()) {
sub(AA_, -16 * elt_size_);
} else {
if ((unroll_y != unroll_n_) || (unroll_x <= 4)) {
Expand Down
Loading

0 comments on commit 8b178e8

Please sign in to comment.