diff --git a/cmake/gen/avx256vnni_microkernels.cmake b/cmake/gen/avx256vnni_microkernels.cmake index 83cce6898e3..f4209ab5364 100644 --- a/cmake/gen/avx256vnni_microkernels.cmake +++ b/cmake/gen/avx256vnni_microkernels.cmake @@ -15,7 +15,8 @@ SET(PROD_AVX256VNNI_MICROKERNEL_SRCS src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-1x8c8-minmax-avx256vnni.c src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni.c src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni.c - src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c) + src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c + src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c) SET(NON_PROD_AVX256VNNI_MICROKERNEL_SRCS src/qd8-f16-qc4w-gemm/gen/qd8-f16-qc4w-gemm-1x8c8-minmax-avx256vnni-prfm.c @@ -110,6 +111,8 @@ SET(NON_PROD_AVX256VNNI_MICROKERNEL_SRCS src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni.c src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c + src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c + src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni-prfm.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x8c8-minmax-fp32-avx256vnni-prfm.c @@ -144,6 +147,7 @@ SET(NON_PROD_AVX256VNNI_MICROKERNEL_SRCS src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-14x8c8-minmax-fp32-avx256vnni.c src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c + src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c src/qs8-rsum/gen/qs8-rsum-avx256vnni-u32.c src/qs8-rsum/gen/qs8-rsum-avx256vnni-u64-acc2.c src/qs8-rsum/gen/qs8-rsum-avx256vnni-u128-acc2.c diff --git a/cmake/gen/avxvnni_microkernels.cmake b/cmake/gen/avxvnni_microkernels.cmake index 14adf2f97dd..e8a6bce98d1 100644 --- a/cmake/gen/avxvnni_microkernels.cmake +++ b/cmake/gen/avxvnni_microkernels.cmake @@ -132,6 +132,8 @@ SET(NON_PROD_AVXVNNI_MICROKERNEL_SRCS src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c + src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c + src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avxvnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-avxvnni-prfm.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-avxvnni.c @@ -161,6 +163,8 @@ SET(NON_PROD_AVXVNNI_MICROKERNEL_SRCS src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x8c8-minmax-fp32-avxvnni-prfm.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x8c8-minmax-fp32-avxvnni.c src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c + src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c + src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c src/qs8-rsum/gen/qs8-rsum-avxvnni-u32.c src/qs8-rsum/gen/qs8-rsum-avxvnni-u64-acc2.c src/qs8-rsum/gen/qs8-rsum-avxvnni-u128-acc4.c) diff --git a/cmake/gen/scalar_microkernels.cmake b/cmake/gen/scalar_microkernels.cmake index 3b180f1e3f5..b4e3b340933 100644 --- a/cmake/gen/scalar_microkernels.cmake +++ b/cmake/gen/scalar_microkernels.cmake @@ -159,7 +159,6 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/qs8-gavgpool/gen/qs8-gavgpool-7p7x-minmax-fp32-scalar-imagic-c4.c src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-imagic-c1.c src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-imagic-c4.c - src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-imagic.c @@ -178,7 +177,6 @@ SET(PROD_SCALAR_MICROKERNEL_SRCS src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x4-minmax-fp32-scalar-lrintf.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x2-minmax-fp32-scalar-imagic.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x4-minmax-fp32-scalar-lrintf.c - src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-scalar.c src/qs8-rdsum/gen/qs8-rdsum-minmax-fp32-scalar-u1-acc1.c src/qs8-rsum/gen/qs8-rsum-scalar-u4.c src/qs8-vadd/gen/qs8-vadd-minmax-scalar-u1.c @@ -637,6 +635,7 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c2.c src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c4.c src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c + src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c @@ -721,6 +720,7 @@ SET(NON_PROD_SCALAR_MICROKERNEL_SRCS src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-imagic.c src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-lrintf.c src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-scalar.c + src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-scalar.c src/qs8-requantization/qs8-requantization-fp32-scalar-fmagic.c src/qs8-requantization/qs8-requantization-fp32-scalar-lrintf.c src/qs8-requantization/qs8-requantization-gemmlowp-scalar.c diff --git a/gen/avx256vnni_microkernels.bzl b/gen/avx256vnni_microkernels.bzl index e762617889c..8d1fc78cd8b 100644 --- a/gen/avx256vnni_microkernels.bzl +++ b/gen/avx256vnni_microkernels.bzl @@ -12,6 +12,7 @@ PROD_AVX256VNNI_MICROKERNEL_SRCS = [ "src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-8x8c8-minmax-avx256vnni.c", "src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-1x8c8-minmax-avx256vnni.c", "src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-8x8c8-minmax-avx256vnni.c", + "src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c", ] NON_PROD_AVX256VNNI_MICROKERNEL_SRCS = [ @@ -107,6 +108,8 @@ NON_PROD_AVX256VNNI_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-14x8c8-minmax-avx256vnni.c", "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c", "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c", + "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c", + "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni-prfm.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avx256vnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-5x8c8-minmax-fp32-avx256vnni-prfm.c", @@ -141,6 +144,7 @@ NON_PROD_AVX256VNNI_MICROKERNEL_SRCS = [ "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-14x8c8-minmax-fp32-avx256vnni.c", "src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c", "src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c", + "src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c", "src/qs8-rsum/gen/qs8-rsum-avx256vnni-u32.c", "src/qs8-rsum/gen/qs8-rsum-avx256vnni-u64-acc2.c", "src/qs8-rsum/gen/qs8-rsum-avx256vnni-u128-acc2.c", diff --git a/gen/avxvnni_microkernels.bzl b/gen/avxvnni_microkernels.bzl index 384f3cb1643..3782abc69c8 100644 --- a/gen/avxvnni_microkernels.bzl +++ b/gen/avxvnni_microkernels.bzl @@ -129,6 +129,8 @@ NON_PROD_AVXVNNI_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c", "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c", + "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c", + "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avxvnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-avxvnni-prfm.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-2x8c8-minmax-fp32-avxvnni.c", @@ -158,6 +160,8 @@ NON_PROD_AVXVNNI_MICROKERNEL_SRCS = [ "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x8c8-minmax-fp32-avxvnni-prfm.c", "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-8x8c8-minmax-fp32-avxvnni.c", "src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c", + "src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c", + "src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c", "src/qs8-rsum/gen/qs8-rsum-avxvnni-u32.c", "src/qs8-rsum/gen/qs8-rsum-avxvnni-u64-acc2.c", "src/qs8-rsum/gen/qs8-rsum-avxvnni-u128-acc4.c", diff --git a/gen/scalar_microkernels.bzl b/gen/scalar_microkernels.bzl index a9309987572..a7c2bba6c7c 100644 --- a/gen/scalar_microkernels.bzl +++ b/gen/scalar_microkernels.bzl @@ -155,7 +155,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-gavgpool/gen/qs8-gavgpool-7p7x-minmax-fp32-scalar-imagic-c4.c", "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-imagic-c1.c", "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-imagic-c4.c", - "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x64c4-gemm-goi-scalar.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p1c-minmax-fp32-scalar-fmagic.c", "src/qs8-qc8w-dwconv/gen/qs8-qc8w-dwconv-3p2c-minmax-fp32-scalar-imagic.c", @@ -174,7 +173,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x4-minmax-fp32-scalar-lrintf.c", "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-2x2-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-3x4-minmax-fp32-scalar-lrintf.c", - "src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-scalar.c", "src/qs8-rdsum/gen/qs8-rdsum-minmax-fp32-scalar-u1-acc1.c", "src/qs8-rsum/gen/qs8-rsum-scalar-u4.c", "src/qs8-vadd/gen/qs8-vadd-minmax-scalar-u1.c", @@ -634,6 +632,7 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c2.c", "src/qs8-gavgpool/gen/qs8-gavgpool-7x-minmax-fp32-scalar-lrintf-c4.c", "src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-scalar.c", + "src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x16c4-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-scalar.c", "src/qs8-packw/gen/qs8-packw-x32c4-gemm-goi-scalar.c", @@ -718,6 +717,7 @@ NON_PROD_SCALAR_MICROKERNEL_SRCS = [ "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-imagic.c", "src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x4-minmax-fp32-scalar-lrintf.c", "src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-scalar.c", + "src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-scalar.c", "src/qs8-requantization/qs8-requantization-fp32-scalar-fmagic.c", "src/qs8-requantization/qs8-requantization-fp32-scalar-lrintf.c", "src/qs8-requantization/qs8-requantization-gemmlowp-scalar.c", diff --git a/scripts/generate-x8-packw.sh b/scripts/generate-x8-packw.sh index 9ed28b05e4e..58bb3d2f991 100755 --- a/scripts/generate-x8-packw.sh +++ b/scripts/generate-x8-packw.sh @@ -42,4 +42,15 @@ tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -D IZP= tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c & tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c & tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c & + +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c & + +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=0 -D AVX=10 -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c & +tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=128 -D AVX=10 -D PREFETCH=1 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c & + wait diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index d26a884b553..3273f12f2f3 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -3238,42 +3238,45 @@ static void init_qs8_qc8w_gemm_config(void) { const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); #if XNN_ENABLE_AVX512AMX - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512amx) { - qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx); - qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(16)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx); - qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x64c4__avx512amx); - qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(16)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx); - qs8_qc8w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params; - qs8_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qs8_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qs8_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar; - qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w; - qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w; - qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_deconv_goki_w; - qs8_qc8w_gemm_config.mr = 16; - qs8_qc8w_gemm_config.nr = 64; - qs8_qc8w_gemm_config.log2_kr = 2; - } else + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512amx) { + qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x64c4__avx512amx); + qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(16)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x64c4__avx512amx); + qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x64c4__avx512amx); + qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(16)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_16x64c4__avx512amx); + qs8_qc8w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params; + qs8_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qs8_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qs8_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; + qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x64c4__scalar; + qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w; + qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w; + qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_deconv_goki_w; + qs8_qc8w_gemm_config.mr = 16; + qs8_qc8w_gemm_config.nr = 64; + qs8_qc8w_gemm_config.log2_kr = 2; + } else #endif - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { - qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm); - qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm); - qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm); - qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm); - qs8_qc8w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params; - qs8_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. - qs8_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. - qs8_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_to_qu8_gemm_gio_w; - qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__scalar; - qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_to_qu8_conv_goki_w; - qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_to_qu8_conv_kgo_w; - qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_to_qu8_deconv_goki_w; - qs8_qc8w_gemm_config.mr = 7; - qs8_qc8w_gemm_config.nr = 16; - qs8_qc8w_gemm_config.log2_kr = 3; - #if XNN_ENABLE_AVXVNNIINT8 - } else if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnniint8) { + #if XNN_ENABLE_AVX256VNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { + qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm); + qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm); + qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512vnni_prfm); + qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_7x16c8__avx512vnni_prfm); + qs8_qc8w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params; + qs8_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. + qs8_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. + qs8_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_to_qu8_gemm_gio_w; + qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni; + qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_to_qu8_conv_goki_w; + qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_to_qu8_conv_kgo_w; + qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_to_qu8_deconv_goki_w; + qs8_qc8w_gemm_config.mr = 7; + qs8_qc8w_gemm_config.nr = 16; + qs8_qc8w_gemm_config.log2_kr = 3; + } else + #endif + #if XNN_ENABLE_AVXVNNIINT8 && XNN_ENABLE_AVXVNNI + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnniint8 && hardware_config->use_x86_avxvnni) { qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avxvnniint8_prfm); qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x8c8__avxvnniint8_prfm); qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avxvnniint8_prfm); @@ -3282,24 +3285,17 @@ static void init_qs8_qc8w_gemm_config(void) { qs8_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function. qs8_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function. qs8_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; - #if XNN_ENABLE_AVXVNNI - if (hardware_config->use_x86_avxvnni) { - qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni; - } else - #else - { - qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__scalar; - } - #endif + qs8_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni; qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w; qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w; qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_deconv_goki_w; qs8_qc8w_gemm_config.mr = 5; qs8_qc8w_gemm_config.nr = 8; qs8_qc8w_gemm_config.log2_kr = 3; + } else #endif #if XNN_ENABLE_AVXVNNI - } else if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) { qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avxvnni_prfm); qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x8c8__avxvnni_prfm); qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avxvnni_prfm); @@ -3315,8 +3311,9 @@ static void init_qs8_qc8w_gemm_config(void) { qs8_qc8w_gemm_config.mr = 5; qs8_qc8w_gemm_config.nr = 8; qs8_qc8w_gemm_config.log2_kr = 3; + } else #endif - } else if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm); qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x16c8__avx512skx_prfm); qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x16c8__avx512skx_prfm); @@ -3328,21 +3325,23 @@ static void init_qs8_qc8w_gemm_config(void) { qs8_qc8w_gemm_config.mr = 7; qs8_qc8w_gemm_config.nr = 16; qs8_qc8w_gemm_config.log2_kr = 3; + } else #if XNN_ENABLE_AVX256SKX - } else if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256skx) { - qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avx256skx); - qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__avx256skx); - qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avx256skx); - qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__avx256skx); - qs8_qc8w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params; - qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w; - qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w; - qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_deconv_goki_w; - qs8_qc8w_gemm_config.mr = 4; - qs8_qc8w_gemm_config.nr = 8; - qs8_qc8w_gemm_config.log2_kr = 3; + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx256skx) { + qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avx256skx); + qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x8c8__avx256skx); + qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avx256skx); + qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_4x8c8__avx256skx); + qs8_qc8w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params; + qs8_qc8w_gemm_config.pack_igemm_goki = (xnn_pack_conv_goki_w_fn) xnn_pack_qs8_conv_goki_w; + qs8_qc8w_gemm_config.pack_igemm_kgo = (xnn_pack_conv_kgo_w_fn) xnn_pack_qs8_conv_kgo_w; + qs8_qc8w_gemm_config.pack_deconv_goki = (xnn_pack_deconv_goki_w_fn) xnn_pack_qs8_deconv_goki_w; + qs8_qc8w_gemm_config.mr = 4; + qs8_qc8w_gemm_config.nr = 8; + qs8_qc8w_gemm_config.log2_kr = 3; + } else #endif - } else if (hardware_config->use_x86_avx2) { + if (hardware_config->use_x86_avx2) { qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x8c8__avx2); qs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(3)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x8c8__avx2); qs8_qc8w_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_qs8_qc8w_igemm_minmax_fp32_ukernel_1x8c8__avx2); diff --git a/src/configs/hardware-config.c b/src/configs/hardware-config.c index 52603ea4081..0f66e884572 100644 --- a/src/configs/hardware-config.c +++ b/src/configs/hardware-config.c @@ -150,19 +150,19 @@ static void init_hardware_config(void) { #else hardware_config.use_x86_avxvnniint8 = 0; #endif -#if XNN_ENABLE_AVX256SKX && XNN_ENABLE_AVX512AMX +#if XNN_ENABLE_AVX256SKX // Using cpuinfo_has_x86_amx_int8 as placeholder for cpuinfo_has_x86_avx10 hardware_config.use_x86_avx256skx = hardware_config.use_x86_avx512skx || cpuinfo_has_x86_amx_int8(); #else hardware_config.use_x86_avx256skx = 0; #endif -#if XNN_ENABLE_AVX256VNNI && XNN_ENABLE_AVX512AMX +#if XNN_ENABLE_AVX256VNNI // Using cpuinfo_has_x86_amx_int8 as placeholder for cpuinfo_has_x86_avx10 hardware_config.use_x86_avx256vnni = (hardware_config.use_x86_avx512skx && cpuinfo_has_x86_avxvnni()) || cpuinfo_has_x86_amx_int8(); #else hardware_config.use_x86_avx256vnni = 0; #endif -#if XNN_ENABLE_AVX256VNNIGFNI && XNN_ENABLE_AVX512AMX +#if XNN_ENABLE_AVX256VNNIGFNI // Using cpuinfo_has_x86_amx_int8 as placeholder for cpuinfo_has_x86_avx10 hardware_config.use_x86_avx256vnnigfni = hardware_config.use_x86_avx256vnni && cpuinfo_has_x86_gfni(); #else diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c new file mode 100644 index 00000000000..cfe2ec2166a --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni-prfm.c @@ -0,0 +1,805 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + + +void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + const __m256i vone = _mm256_set1_epi8(1); + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c new file mode 100644 index 00000000000..67091f212a8 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avx256vnni.c @@ -0,0 +1,708 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + + +void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + const __m256i vone = _mm256_set1_epi8(1); + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c new file mode 100644 index 00000000000..6e2926d419d --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c @@ -0,0 +1,805 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + + +void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + const __m256i vone = _mm256_set1_epi8(1); + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c new file mode 100644 index 00000000000..8ddf82e3b88 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c @@ -0,0 +1,708 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + + +void xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + const __m256i vone = _mm256_set1_epi8(1); + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c index 723562466c2..f3f8d2672b4 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni-prfm.c @@ -53,16 +53,13 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); b += 8; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); } out += 8 * sizeof(uint32_t); @@ -90,18 +87,20 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( xnn_prefetch_to_l1((const int8_t*) w7); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + // KC main loop multiple of 8x8 size_t k = kc; for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); xnn_prefetch_to_l1((const int8_t*) w0 + 128); xnn_prefetch_to_l1((const int8_t*) w1 + 128); xnn_prefetch_to_l1((const int8_t*) w2 + 128); @@ -111,11 +110,11 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 128); xnn_prefetch_to_l1((const int8_t*) w7 + 128); - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); w0 += 8; w1 += 8; @@ -128,22 +127,21 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( out += 64; } - // KC remainder 1..KR-1 + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); w0 += 4; w1 += 4; w2 += 4; @@ -155,25 +153,23 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); } w0 += 2; @@ -187,48 +183,44 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); } w0 += 1; @@ -241,26 +233,29 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( w7 += 1; } - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); out += 64; } - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); out = (int8_t*) ((uintptr_t) out + extra_bytes); w0 = w7; } // NC remainder (1..7) if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; @@ -277,196 +272,215 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } out += (8 - n) * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - if XNN_UNPREDICTABLE(n < 2) { - w1 = w0; - } - const int8_t* w2 = w1 + kc; - if XNN_UNPREDICTABLE(n <= 2) { - w2 = w1; - } - const int8_t* w3 = w2 + kc; - if XNN_UNPREDICTABLE(n < 4) { - w3 = w2; - } - const int8_t* w4 = w3 + kc; - if XNN_UNPREDICTABLE(n <= 4) { - w4 = w3; - } - const int8_t* w5 = w4 + kc; - if XNN_UNPREDICTABLE(n < 6) { - w5 = w4; - } - const int8_t* w6 = w5 + kc; - if XNN_UNPREDICTABLE(n <= 6) { - w6 = w5; - } - const int8_t* w7 = w6 + kc; - if XNN_UNPREDICTABLE(n < 8) { - w7 = w6; - } - - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - // KC main loop multiple of 8x8 - size_t k = kc; - for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - w0 += 8; - w1 += 8; - w2 += 8; - w3 += 8; - w4 += 8; - w5 += 8; - w6 += 8; - w7 += 8; - out += 64; - } - - // KC remainder of 1..7 - if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); - - if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); - w0 += 4; - w1 += 4; - w2 += 4; - w3 += 4; - w4 += 4; - w5 += 4; - w6 += 4; - w7 += 4; - } - if (k & 2) { - if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); - } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); - } - - w0 += 2; - w1 += 2; - w2 += 2; - w3 += 2; - w4 += 2; - w5 += 2; - w6 += 2; - w7 += 2; - } - if (k & 1) { - if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); - } - else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); - } - else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); - } - else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); - } - - w0 += 1; - w1 += 1; - w2 += 1; - w3 += 1; - w4 += 1; - w5 += 1; - w6 += 1; - w7 += 1; - } - - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - out += 64; - } - - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); - out = (int8_t*) ((uintptr_t) out + extra_bytes); + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c index 16b459073d5..7919e5d3561 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avx256vnni.c @@ -52,16 +52,13 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); b += 8; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); } out += 8 * sizeof(uint32_t); @@ -73,24 +70,26 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( const int8_t* w6 = w5 + kc; const int8_t* w7 = w6 + kc; + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + // KC main loop multiple of 8x8 size_t k = kc; for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); w0 += 8; w1 += 8; @@ -103,22 +102,21 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( out += 64; } - // KC remainder 1..KR-1 + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); w0 += 4; w1 += 4; w2 += 4; @@ -130,25 +128,23 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); } w0 += 2; @@ -162,48 +158,44 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); } w0 += 1; @@ -216,26 +208,29 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( w7 += 1; } - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); out += 64; } - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); out = (int8_t*) ((uintptr_t) out + extra_bytes); w0 = w7; } // NC remainder (1..7) if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; @@ -252,196 +247,191 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } out += (8 - n) * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - if XNN_UNPREDICTABLE(n < 2) { - w1 = w0; - } - const int8_t* w2 = w1 + kc; - if XNN_UNPREDICTABLE(n <= 2) { - w2 = w1; - } - const int8_t* w3 = w2 + kc; - if XNN_UNPREDICTABLE(n < 4) { - w3 = w2; - } - const int8_t* w4 = w3 + kc; - if XNN_UNPREDICTABLE(n <= 4) { - w4 = w3; - } - const int8_t* w5 = w4 + kc; - if XNN_UNPREDICTABLE(n < 6) { - w5 = w4; - } - const int8_t* w6 = w5 + kc; - if XNN_UNPREDICTABLE(n <= 6) { - w6 = w5; - } - const int8_t* w7 = w6 + kc; - if XNN_UNPREDICTABLE(n < 8) { - w7 = w6; - } - - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - // KC main loop multiple of 8x8 - size_t k = kc; - for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - w0 += 8; - w1 += 8; - w2 += 8; - w3 += 8; - w4 += 8; - w5 += 8; - w6 += 8; - w7 += 8; - out += 64; - } - - // KC remainder of 1..7 - if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); - - if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); - w0 += 4; - w1 += 4; - w2 += 4; - w3 += 4; - w4 += 4; - w5 += 4; - w6 += 4; - w7 += 4; - } - if (k & 2) { - if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); - } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); - } - - w0 += 2; - w1 += 2; - w2 += 2; - w3 += 2; - w4 += 2; - w5 += 2; - w6 += 2; - w7 += 2; - } - if (k & 1) { - if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); - } - else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); - } - else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); - } - else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); - } - - w0 += 1; - w1 += 1; - w2 += 1; - w3 += 1; - w4 += 1; - w5 += 1; - w6 += 1; - w7 += 1; - } - - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - out += 64; - } - - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); - out = (int8_t*) ((uintptr_t) out + extra_bytes); + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c index 0be860be66b..687273d457f 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c @@ -53,16 +53,13 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); b += 8; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); } out += 8 * sizeof(uint32_t); @@ -90,18 +87,20 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( xnn_prefetch_to_l1((const int8_t*) w7); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + // KC main loop multiple of 8x8 size_t k = kc; for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); xnn_prefetch_to_l1((const int8_t*) w0 + 128); xnn_prefetch_to_l1((const int8_t*) w1 + 128); xnn_prefetch_to_l1((const int8_t*) w2 + 128); @@ -111,11 +110,11 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 128); xnn_prefetch_to_l1((const int8_t*) w7 + 128); - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); w0 += 8; w1 += 8; @@ -128,22 +127,21 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( out += 64; } - // KC remainder 1..KR-1 + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); w0 += 4; w1 += 4; w2 += 4; @@ -155,25 +153,23 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); } w0 += 2; @@ -187,48 +183,44 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); } w0 += 1; @@ -241,26 +233,29 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( w7 += 1; } - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); out += 64; } - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); out = (int8_t*) ((uintptr_t) out + extra_bytes); w0 = w7; } // NC remainder (1..7) if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; @@ -277,196 +272,215 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } out += (8 - n) * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - if XNN_UNPREDICTABLE(n < 2) { - w1 = w0; - } - const int8_t* w2 = w1 + kc; - if XNN_UNPREDICTABLE(n <= 2) { - w2 = w1; - } - const int8_t* w3 = w2 + kc; - if XNN_UNPREDICTABLE(n < 4) { - w3 = w2; - } - const int8_t* w4 = w3 + kc; - if XNN_UNPREDICTABLE(n <= 4) { - w4 = w3; - } - const int8_t* w5 = w4 + kc; - if XNN_UNPREDICTABLE(n < 6) { - w5 = w4; - } - const int8_t* w6 = w5 + kc; - if XNN_UNPREDICTABLE(n <= 6) { - w6 = w5; - } - const int8_t* w7 = w6 + kc; - if XNN_UNPREDICTABLE(n < 8) { - w7 = w6; - } - - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - // KC main loop multiple of 8x8 - size_t k = kc; - for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - w0 += 8; - w1 += 8; - w2 += 8; - w3 += 8; - w4 += 8; - w5 += 8; - w6 += 8; - w7 += 8; - out += 64; - } - - // KC remainder of 1..7 - if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); - - if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); - w0 += 4; - w1 += 4; - w2 += 4; - w3 += 4; - w4 += 4; - w5 += 4; - w6 += 4; - w7 += 4; - } - if (k & 2) { - if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); - } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); - } - - w0 += 2; - w1 += 2; - w2 += 2; - w3 += 2; - w4 += 2; - w5 += 2; - w6 += 2; - w7 += 2; - } - if (k & 1) { - if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); - } - else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); - } - else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); - } - else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); - } - - w0 += 1; - w1 += 1; - w2 += 1; - w3 += 1; - w4 += 1; - w5 += 1; - w6 += 1; - w7 += 1; - } - - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - out += 64; - } - - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); - out = (int8_t*) ((uintptr_t) out + extra_bytes); + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c index 30f6dc2f8f8..d7b555ff552 100644 --- a/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c +++ b/src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c @@ -52,16 +52,13 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); b += 8; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); } out += 8 * sizeof(uint32_t); @@ -73,24 +70,26 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( const int8_t* w6 = w5 + kc; const int8_t* w7 = w6 + kc; + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + // KC main loop multiple of 8x8 size_t k = kc; for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); w0 += 8; w1 += 8; @@ -103,22 +102,21 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( out += 64; } - // KC remainder 1..KR-1 + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); w0 += 4; w1 += 4; w2 += 4; @@ -130,25 +128,23 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); } w0 += 2; @@ -162,48 +158,44 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); } w0 += 1; @@ -216,26 +208,29 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( w7 += 1; } - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); out += 64; } - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); out = (int8_t*) ((uintptr_t) out + extra_bytes); w0 = w7; } // NC remainder (1..7) if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; @@ -252,196 +247,191 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni( } out += (8 - n) * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - if XNN_UNPREDICTABLE(n < 2) { - w1 = w0; - } - const int8_t* w2 = w1 + kc; - if XNN_UNPREDICTABLE(n <= 2) { - w2 = w1; - } - const int8_t* w3 = w2 + kc; - if XNN_UNPREDICTABLE(n < 4) { - w3 = w2; - } - const int8_t* w4 = w3 + kc; - if XNN_UNPREDICTABLE(n <= 4) { - w4 = w3; - } - const int8_t* w5 = w4 + kc; - if XNN_UNPREDICTABLE(n < 6) { - w5 = w4; - } - const int8_t* w6 = w5 + kc; - if XNN_UNPREDICTABLE(n <= 6) { - w6 = w5; - } - const int8_t* w7 = w6 + kc; - if XNN_UNPREDICTABLE(n < 8) { - w7 = w6; - } - - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - // KC main loop multiple of 8x8 - size_t k = kc; - for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - w0 += 8; - w1 += 8; - w2 += 8; - w3 += 8; - w4 += 8; - w5 += 8; - w6 += 8; - w7 += 8; - out += 64; - } - - // KC remainder of 1..7 - if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); - - if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); - w0 += 4; - w1 += 4; - w2 += 4; - w3 += 4; - w4 += 4; - w5 += 4; - w6 += 4; - w7 += 4; - } - if (k & 2) { - if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); - } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); - } - - w0 += 2; - w1 += 2; - w2 += 2; - w3 += 2; - w4 += 2; - w5 += 2; - w6 += 2; - w7 += 2; - } - if (k & 1) { - if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); - } - else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); - } - else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); - } - else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); - } - - w0 += 1; - w1 += 1; - w2 += 1; - w3 += 1; - w4 += 1; - w5 += 1; - w6 += 1; - w7 += 1; - } - - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - out += 64; - } - - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); - out = (int8_t*) ((uintptr_t) out + extra_bytes); + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/qs8-packw/qs8-packw.h b/src/qs8-packw/qs8-packw.h index 2b7bd77950d..21088b9e763 100644 --- a/src/qs8-packw/qs8-packw.h +++ b/src/qs8-packw/qs8-packw.h @@ -20,6 +20,11 @@ XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvn XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm, 8, 8, 1, 8, 1, 0) XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni, 8, 8, 1, 8, 1, 128) XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm, 8, 8, 1, 8, 1, 128) + +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni, 16, 8, 1, 8, 1, 0) +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm, 16, 8, 1, 8, 1, 0) +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni, 16, 8, 1, 8, 1, 128) +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm, 16, 8, 1, 8, 1, 128) #endif #if XNN_ENABLE_AVX256VNNI && (XNN_ARCH_X86_64 || XNN_ARCH_X86) @@ -27,4 +32,9 @@ XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_packw_gemm_goi_ukernel_x8c8__av XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm, 8, 8, 1, 8, 1, 0) XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni, 8, 8, 1, 8, 1, 128) XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm, 8, 8, 1, 8, 1, 128) + +XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni, 16, 8, 1, 8, 1, 0) +XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm, 16, 8, 1, 8, 1, 0) +XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni, 16, 8, 1, 8, 1, 128) +XNN_QS8_UKERNEL(xnn_arch_x86_avx256vnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm, 16, 8, 1, 8, 1, 128) #endif diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c new file mode 100644 index 00000000000..4b40552f071 --- /dev/null +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni-prfm.c @@ -0,0 +1,805 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + + +void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + const __m256i vone = _mm256_set1_epi8(1); + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c new file mode 100644 index 00000000000..357f922f971 --- /dev/null +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avx256vnni.c @@ -0,0 +1,708 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + + +void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avx256vnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + const __m256i vone = _mm256_set1_epi8(1); + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c new file mode 100644 index 00000000000..a63ac278656 --- /dev/null +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni-prfm.c @@ -0,0 +1,805 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + + +void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + const __m256i vone = _mm256_set1_epi8(1); + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w8); + xnn_prefetch_to_l1((const int8_t*) w8 + 64); + xnn_prefetch_to_l1((const int8_t*) w9); + xnn_prefetch_to_l1((const int8_t*) w9 + 64); + xnn_prefetch_to_l1((const int8_t*) w10); + xnn_prefetch_to_l1((const int8_t*) w10 + 64); + xnn_prefetch_to_l1((const int8_t*) w11); + xnn_prefetch_to_l1((const int8_t*) w11 + 64); + xnn_prefetch_to_l1((const int8_t*) w12); + xnn_prefetch_to_l1((const int8_t*) w12 + 64); + xnn_prefetch_to_l1((const int8_t*) w13); + xnn_prefetch_to_l1((const int8_t*) w13 + 64); + xnn_prefetch_to_l1((const int8_t*) w14); + xnn_prefetch_to_l1((const int8_t*) w14 + 64); + xnn_prefetch_to_l1((const int8_t*) w15); + xnn_prefetch_to_l1((const int8_t*) w15 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w8 + 128); + xnn_prefetch_to_l1((const int8_t*) w9 + 128); + xnn_prefetch_to_l1((const int8_t*) w10 + 128); + xnn_prefetch_to_l1((const int8_t*) w11 + 128); + xnn_prefetch_to_l1((const int8_t*) w12 + 128); + xnn_prefetch_to_l1((const int8_t*) w13 + 128); + xnn_prefetch_to_l1((const int8_t*) w14 + 128); + xnn_prefetch_to_l1((const int8_t*) w15 + 128); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c new file mode 100644 index 00000000000..f11674bc7a3 --- /dev/null +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x16c8-gemm-goi-avxvnni.c @@ -0,0 +1,708 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + + +void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__avxvnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 16); + assert(kr == 8); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + const __m256i vone = _mm256_set1_epi8(1); + int8_t* out = (int8_t*) packed_weights; + const uint32_t* b = (const uint32_t*) bias; + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 128): 128); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 16 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (;n >= 16; n -= 16) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + const __m256i vb8 = _mm256_loadu_si256((const __m256i*) (b + 8)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + _mm256_storeu_si256((__m256i*) (out + 32), vb8); + b += 16; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 32), _mm256_setzero_si256()); + } + out += 16 * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + const int8_t* w8 = w7 + kc; + const int8_t* w9 = w8 + kc; + const int8_t* w10 = w9 + kc; + const int8_t* w11 = w10 + kc; + const int8_t* w12 = w11 + kc; + const int8_t* w13 = w12 + kc; + const int8_t* w14 = w13 + kc; + const int8_t* w15 = w14 + kc; + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w15; + } + + // NC remainder (1..15) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 15); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((uint32_t*) out) = *b++; + out += sizeof(uint32_t); + } while (--nb != 0); + } else { + size_t nb = n; + do { + *((uint32_t*) out) = 0; + out += sizeof(uint32_t); + } while (--nb != 0); + } + out += (16 - n) * sizeof(uint32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + const int8_t* w8 = w7 + kc; + if XNN_UNPREDICTABLE(n <= 8) { + w8 = w7; + } + const int8_t* w9 = w8 + kc; + if XNN_UNPREDICTABLE(n < 10) { + w9 = w8; + } + const int8_t* w10 = w9 + kc; + if XNN_UNPREDICTABLE(n <= 10) { + w10 = w9; + } + const int8_t* w11 = w10 + kc; + if XNN_UNPREDICTABLE(n < 12) { + w11 = w10; + } + const int8_t* w12 = w11 + kc; + if XNN_UNPREDICTABLE(n <= 12) { + w12 = w11; + } + const int8_t* w13 = w12 + kc; + if XNN_UNPREDICTABLE(n < 14) { + w13 = w12; + } + const int8_t* w14 = w13 + kc; + if XNN_UNPREDICTABLE(n <= 14) { + w14 = w13; + } + const int8_t* w15 = w14 + kc; + if XNN_UNPREDICTABLE(n < 16) { + w15 = w14; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + __m256i vacc8 = _mm256_setzero_si256(); + __m256i vacc12 = _mm256_setzero_si256(); + + // KC main loop multiple of 16x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w8)); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w9)), 0x0C); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w10)), 0x30); + v8 = _mm256_blend_epi32(v8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w11)), 0xC0); + __m256i v12 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w12)); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w13)), 0x0C); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w14)), 0x30); + v12 = _mm256_blend_epi32(v12, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w15)), 0xC0); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + w8 += 8; + w9 += 8; + w10 += 8; + w11 += 8; + w12 += 8; + w13 += 8; + w14 += 8; + w15 += 8; + out += 128; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + __m256i v8 = _mm256_setzero_si256(); + __m256i v12 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w8, 0); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w9, 2); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w10, 4); + v8 = _mm256_insert_epi32(v8, *(const int32_t *)w11, 6); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w12, 0); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w13, 2); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w14, 4); + v12 = _mm256_insert_epi32(v12, *(const int32_t *)w15, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + w8 += 4; + w9 += 4; + w10 += 4; + w11 += 4; + w12 += 4; + w13 += 4; + w14 += 4; + w15 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 2); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 6); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 10); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 14); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 2); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 6); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 10); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w8, 0); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w9, 4); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w10, 8); + v8 = _mm256_insert_epi16(v8, *(const int16_t *)w11, 12); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w12, 0); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w13, 4); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w14, 8); + v12 = _mm256_insert_epi16(v12, *(const int16_t *)w15, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + w8 += 2; + w9 += 2; + w10 += 2; + w11 += 2; + w12 += 2; + w13 += 2; + w14 += 2; + w15 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 6); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 14); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 22); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 30); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 6); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 14); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 22); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 4); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 12); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 20); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 28); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 4); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 12); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 20); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 2); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 10); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 18); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 26); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 2); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 10); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 18); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w8, 0); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w9, 8); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w10, 16); + v8 = _mm256_insert_epi8(v8, *(const int8_t *)w11, 24); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w12, 0); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w13, 8); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w14, 16); + v12 = _mm256_insert_epi8(v12, *(const int8_t *)w15, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + w8 += 1; + w9 += 1; + w10 += 1; + w11 += 1; + w12 += 1; + w13 += 1; + w14 += 1; + w15 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + vacc8 = _mm256_dpbusd_avx_epi32(vacc8, vone, v8); + vacc12 = _mm256_dpbusd_avx_epi32(vacc12, vone, v12); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + _mm256_storeu_si256((__m256i *)&out[64], v8); + _mm256_storeu_si256((__m256i *)&out[96], v12); + + out += 128; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + __m256i vksum8 = _mm256_hadd_epi32(vacc8, vacc12); + vksum8 = _mm256_permute4x64_epi64(vksum8, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + vksum8 = _mm256_mullo_epi32(vksum8, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + __m256i vpack8 = _mm256_loadu_si256((const __m256i*) (packed_b + 8)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + vpack8 = _mm256_sub_epi32(vpack8, vksum8); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + _mm256_storeu_si256((__m256i *) (packed_b + 8), vpack8); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c index e0df268cc5e..56153c32777 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni-prfm.c @@ -53,16 +53,13 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); b += 8; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); } out += 8 * sizeof(uint32_t); @@ -90,18 +87,20 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( xnn_prefetch_to_l1((const int8_t*) w7); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + // KC main loop multiple of 8x8 size_t k = kc; for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); xnn_prefetch_to_l1((const int8_t*) w0 + 128); xnn_prefetch_to_l1((const int8_t*) w1 + 128); xnn_prefetch_to_l1((const int8_t*) w2 + 128); @@ -111,11 +110,11 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 128); xnn_prefetch_to_l1((const int8_t*) w7 + 128); - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); w0 += 8; w1 += 8; @@ -128,22 +127,21 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( out += 64; } - // KC remainder 1..KR-1 + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); w0 += 4; w1 += 4; w2 += 4; @@ -155,25 +153,23 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); } w0 += 2; @@ -187,48 +183,44 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); } w0 += 1; @@ -241,26 +233,29 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( w7 += 1; } - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); out += 64; } - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); out = (int8_t*) ((uintptr_t) out + extra_bytes); w0 = w7; } // NC remainder (1..7) if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; @@ -277,196 +272,215 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni_prfm( } out += (8 - n) * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - if XNN_UNPREDICTABLE(n < 2) { - w1 = w0; - } - const int8_t* w2 = w1 + kc; - if XNN_UNPREDICTABLE(n <= 2) { - w2 = w1; - } - const int8_t* w3 = w2 + kc; - if XNN_UNPREDICTABLE(n < 4) { - w3 = w2; - } - const int8_t* w4 = w3 + kc; - if XNN_UNPREDICTABLE(n <= 4) { - w4 = w3; - } - const int8_t* w5 = w4 + kc; - if XNN_UNPREDICTABLE(n < 6) { - w5 = w4; - } - const int8_t* w6 = w5 + kc; - if XNN_UNPREDICTABLE(n <= 6) { - w6 = w5; - } - const int8_t* w7 = w6 + kc; - if XNN_UNPREDICTABLE(n < 8) { - w7 = w6; - } - - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - // KC main loop multiple of 8x8 - size_t k = kc; - for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - w0 += 8; - w1 += 8; - w2 += 8; - w3 += 8; - w4 += 8; - w5 += 8; - w6 += 8; - w7 += 8; - out += 64; - } - - // KC remainder of 1..7 - if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); - - if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); - w0 += 4; - w1 += 4; - w2 += 4; - w3 += 4; - w4 += 4; - w5 += 4; - w6 += 4; - w7 += 4; - } - if (k & 2) { - if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); - } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); - } - - w0 += 2; - w1 += 2; - w2 += 2; - w3 += 2; - w4 += 2; - w5 += 2; - w6 += 2; - w7 += 2; - } - if (k & 1) { - if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); - } - else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); - } - else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); - } - else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); - } - - w0 += 1; - w1 += 1; - w2 += 1; - w3 += 1; - w4 += 1; - w5 += 1; - w6 += 1; - w7 += 1; - } - - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - out += 64; - } - - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); - out = (int8_t*) ((uintptr_t) out + extra_bytes); + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c index 8bd40b63205..53ca5f206d1 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avx256vnni.c @@ -52,16 +52,13 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); b += 8; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); } out += 8 * sizeof(uint32_t); @@ -73,24 +70,26 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( const int8_t* w6 = w5 + kc; const int8_t* w7 = w6 + kc; + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + // KC main loop multiple of 8x8 size_t k = kc; for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); w0 += 8; w1 += 8; @@ -103,22 +102,21 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( out += 64; } - // KC remainder 1..KR-1 + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); w0 += 4; w1 += 4; w2 += 4; @@ -130,25 +128,23 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); } w0 += 2; @@ -162,48 +158,44 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); } w0 += 1; @@ -216,26 +208,29 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( w7 += 1; } - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); out += 64; } - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); out = (int8_t*) ((uintptr_t) out + extra_bytes); w0 = w7; } // NC remainder (1..7) if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; @@ -252,196 +247,191 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avx256vnni( } out += (8 - n) * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - if XNN_UNPREDICTABLE(n < 2) { - w1 = w0; - } - const int8_t* w2 = w1 + kc; - if XNN_UNPREDICTABLE(n <= 2) { - w2 = w1; - } - const int8_t* w3 = w2 + kc; - if XNN_UNPREDICTABLE(n < 4) { - w3 = w2; - } - const int8_t* w4 = w3 + kc; - if XNN_UNPREDICTABLE(n <= 4) { - w4 = w3; - } - const int8_t* w5 = w4 + kc; - if XNN_UNPREDICTABLE(n < 6) { - w5 = w4; - } - const int8_t* w6 = w5 + kc; - if XNN_UNPREDICTABLE(n <= 6) { - w6 = w5; - } - const int8_t* w7 = w6 + kc; - if XNN_UNPREDICTABLE(n < 8) { - w7 = w6; - } - - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - // KC main loop multiple of 8x8 - size_t k = kc; - for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - w0 += 8; - w1 += 8; - w2 += 8; - w3 += 8; - w4 += 8; - w5 += 8; - w6 += 8; - w7 += 8; - out += 64; - } - - // KC remainder of 1..7 - if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); - - if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); - w0 += 4; - w1 += 4; - w2 += 4; - w3 += 4; - w4 += 4; - w5 += 4; - w6 += 4; - w7 += 4; - } - if (k & 2) { - if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); - } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); - } - - w0 += 2; - w1 += 2; - w2 += 2; - w3 += 2; - w4 += 2; - w5 += 2; - w6 += 2; - w7 += 2; - } - if (k & 1) { - if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); - } - else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); - } - else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); - } - else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); - } - - w0 += 1; - w1 += 1; - w2 += 1; - w3 += 1; - w4 += 1; - w5 += 1; - w6 += 1; - w7 += 1; - } - - vacc0124x8 = _mm256_dpbusd_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - out += 64; - } - - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); - out = (int8_t*) ((uintptr_t) out + extra_bytes); + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0 = _mm256_dpbusd_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c index 7fce3b2b730..6bf982a99e5 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni-prfm.c @@ -53,16 +53,13 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); b += 8; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); } out += 8 * sizeof(uint32_t); @@ -90,18 +87,20 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( xnn_prefetch_to_l1((const int8_t*) w7); xnn_prefetch_to_l1((const int8_t*) w7 + 64); + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + // KC main loop multiple of 8x8 size_t k = kc; for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); xnn_prefetch_to_l1((const int8_t*) w0 + 128); xnn_prefetch_to_l1((const int8_t*) w1 + 128); xnn_prefetch_to_l1((const int8_t*) w2 + 128); @@ -111,11 +110,11 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( xnn_prefetch_to_l1((const int8_t*) w6 + 128); xnn_prefetch_to_l1((const int8_t*) w7 + 128); - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); w0 += 8; w1 += 8; @@ -128,22 +127,21 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( out += 64; } - // KC remainder 1..KR-1 + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); w0 += 4; w1 += 4; w2 += 4; @@ -155,25 +153,23 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); } w0 += 2; @@ -187,48 +183,44 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); } w0 += 1; @@ -241,26 +233,29 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( w7 += 1; } - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); out += 64; } - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); out = (int8_t*) ((uintptr_t) out + extra_bytes); w0 = w7; } // NC remainder (1..7) if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; @@ -277,196 +272,215 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm( } out += (8 - n) * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - if XNN_UNPREDICTABLE(n < 2) { - w1 = w0; - } - const int8_t* w2 = w1 + kc; - if XNN_UNPREDICTABLE(n <= 2) { - w2 = w1; - } - const int8_t* w3 = w2 + kc; - if XNN_UNPREDICTABLE(n < 4) { - w3 = w2; - } - const int8_t* w4 = w3 + kc; - if XNN_UNPREDICTABLE(n <= 4) { - w4 = w3; - } - const int8_t* w5 = w4 + kc; - if XNN_UNPREDICTABLE(n < 6) { - w5 = w4; - } - const int8_t* w6 = w5 + kc; - if XNN_UNPREDICTABLE(n <= 6) { - w6 = w5; - } - const int8_t* w7 = w6 + kc; - if XNN_UNPREDICTABLE(n < 8) { - w7 = w6; - } - - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - // KC main loop multiple of 8x8 - size_t k = kc; - for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - w0 += 8; - w1 += 8; - w2 += 8; - w3 += 8; - w4 += 8; - w5 += 8; - w6 += 8; - w7 += 8; - out += 64; - } - - // KC remainder of 1..7 - if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); - - if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); - w0 += 4; - w1 += 4; - w2 += 4; - w3 += 4; - w4 += 4; - w5 += 4; - w6 += 4; - w7 += 4; - } - if (k & 2) { - if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); - } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); - } - - w0 += 2; - w1 += 2; - w2 += 2; - w3 += 2; - w4 += 2; - w5 += 2; - w6 += 2; - w7 += 2; - } - if (k & 1) { - if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); - } - else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); - } - else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); - } - else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); - } - - w0 += 1; - w1 += 1; - w2 += 1; - w3 += 1; - w4 += 1; - w5 += 1; - w6 += 1; - w7 += 1; - } - - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - out += 64; - } - - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); - out = (int8_t*) ((uintptr_t) out + extra_bytes); + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c index c3f18f74e1f..9c6c539df97 100644 --- a/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c +++ b/src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c @@ -52,16 +52,13 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( const int8_t* w0 = (const int8_t*) weights; size_t n = nc; for (;n >= 8; n -= 8) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); b += 8; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); } out += 8 * sizeof(uint32_t); @@ -73,24 +70,26 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( const int8_t* w6 = w5 + kc; const int8_t* w7 = w6 + kc; + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + // KC main loop multiple of 8x8 size_t k = kc; for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); w0 += 8; w1 += 8; @@ -103,22 +102,21 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( out += 64; } - // KC remainder 1..KR-1 + // KC remainder of 1..7 if (k != 0) { assert(k >= 1 && k <= 7); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); w0 += 4; w1 += 4; w2 += 4; @@ -130,25 +128,23 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); } w0 += 2; @@ -162,48 +158,44 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); } w0 += 1; @@ -216,26 +208,29 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( w7 += 1; } - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); out += 64; } - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); out = (int8_t*) ((uintptr_t) out + extra_bytes); w0 = w7; } // NC remainder (1..7) if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; @@ -252,196 +247,191 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni( } out += (8 - n) * sizeof(uint32_t); - const int8_t* w1 = w0 + kc; - if XNN_UNPREDICTABLE(n < 2) { - w1 = w0; - } - const int8_t* w2 = w1 + kc; - if XNN_UNPREDICTABLE(n <= 2) { - w2 = w1; - } - const int8_t* w3 = w2 + kc; - if XNN_UNPREDICTABLE(n < 4) { - w3 = w2; - } - const int8_t* w4 = w3 + kc; - if XNN_UNPREDICTABLE(n <= 4) { - w4 = w3; - } - const int8_t* w5 = w4 + kc; - if XNN_UNPREDICTABLE(n < 6) { - w5 = w4; - } - const int8_t* w6 = w5 + kc; - if XNN_UNPREDICTABLE(n <= 6) { - w6 = w5; - } - const int8_t* w7 = w6 + kc; - if XNN_UNPREDICTABLE(n < 8) { - w7 = w6; - } - - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - // KC main loop multiple of 8x8 - size_t k = kc; - for (; k >= 8; k -= 8) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); - - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - w0 += 8; - w1 += 8; - w2 += 8; - w3 += 8; - w4 += 8; - w5 += 8; - w6 += 8; - w7 += 8; - out += 64; - } - - // KC remainder of 1..7 - if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); - - if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); - w0 += 4; - w1 += 4; - w2 += 4; - w3 += 4; - w4 += 4; - w5 += 4; - w6 += 4; - w7 += 4; - } - if (k & 2) { - if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); - } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); - } - - w0 += 2; - w1 += 2; - w2 += 2; - w3 += 2; - w4 += 2; - w5 += 2; - w6 += 2; - w7 += 2; - } - if (k & 1) { - if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); - } - else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); - } - else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); - } - else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); - } - - w0 += 1; - w1 += 1; - w2 += 1; - w3 += 1; - w4 += 1; - w5 += 1; - w6 += 1; - w7 += 1; - } - - vacc0124x8 = _mm256_dpbusd_avx_epi32(vacc0124x8, vone, v0123x8); - vacc4567x8 = _mm256_dpbusd_avx_epi32(vacc4567x8, vone, v4567x8); - - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[32], v4567x8); - - out += 64; - } - - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); - out = (int8_t*) ((uintptr_t) out + extra_bytes); + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0 = _mm256_setzero_si256(); + __m256i vacc4 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x8 + size_t k = kc; + for (; k >= 8; k -= 8) { + __m256i v0 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); + __m256i v4 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); + v4 = _mm256_blend_epi32(v4, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + w0 += 8; + w1 += 8; + w2 += 8; + w3 += 8; + w4 += 8; + w5 += 8; + w6 += 8; + w7 += 8; + out += 64; + } + + // KC remainder of 1..7 + if (k != 0) { + assert(k >= 1 && k <= 7); + __m256i v0 = _mm256_setzero_si256(); + __m256i v4 = _mm256_setzero_si256(); + + if (k & 4) { + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w0, 0); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w1, 2); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w2, 4); + v0 = _mm256_insert_epi32(v0, *(const int32_t *)w3, 6); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w4, 0); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w5, 2); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w6, 4); + v4 = _mm256_insert_epi32(v4, *(const int32_t *)w7, 6); + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + } + if (k & 2) { + if (k & 4) { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 2); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 6); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 10); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 14); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 2); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 6); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 10); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 14); + } else { + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w0, 0); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w1, 4); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w2, 8); + v0 = _mm256_insert_epi16(v0, *(const int16_t *)w3, 12); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w4, 0); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w5, 4); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w6, 8); + v4 = _mm256_insert_epi16(v4, *(const int16_t *)w7, 12); + } + + w0 += 2; + w1 += 2; + w2 += 2; + w3 += 2; + w4 += 2; + w5 += 2; + w6 += 2; + w7 += 2; + } + if (k & 1) { + if ((k & 4) && (k & 2)) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 6); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 14); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 22); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 30); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 6); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 14); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 22); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 30); + } + else if (k & 4) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 4); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 12); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 20); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 28); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 4); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 12); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 20); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 28); + } + else if (k & 2) { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 2); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 10); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 18); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 26); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 2); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 10); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 18); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 26); + } + else { + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w0, 0); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w1, 8); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w2, 16); + v0 = _mm256_insert_epi8(v0, *(const int8_t *)w3, 24); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w4, 0); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w5, 8); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w6, 16); + v4 = _mm256_insert_epi8(v4, *(const int8_t *)w7, 24); + } + + w0 += 1; + w1 += 1; + w2 += 1; + w3 += 1; + w4 += 1; + w5 += 1; + w6 += 1; + w7 += 1; + } + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + vacc4 = _mm256_dpbusd_avx_epi32(vacc4, vone, v4); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + _mm256_storeu_si256((__m256i *)&out[32], v4); + + out += 64; + } + + __m256i vksum0 = _mm256_hadd_epi32(vacc0, vacc4); + vksum0 = _mm256_permute4x64_epi64(vksum0, _MM_SHUFFLE(3, 1, 2, 0)); + vksum0 = _mm256_mullo_epi32(vksum0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/x8-packw/kr-avxvnni.c.in b/src/x8-packw/kr-avxvnni.c.in index 56ac77e0891..7300efd09e8 100644 --- a/src/x8-packw/kr-avxvnni.c.in +++ b/src/x8-packw/kr-avxvnni.c.in @@ -3,7 +3,7 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -$assert NR == 8 +$assert NR in [8, 16] $assert KR == 8 $assert TYPE in ["int8_t"] $assert IZP in [0, 128] @@ -20,12 +20,11 @@ $if PREFETCH: #include "xnnpack/prefetch.h" -$BITS = {"int8_t": 8}[TYPE] $BTYPE = {"int8_t": "uint32_t"}[TYPE] $WTYPE = {"int8_t": "int8_t"}[TYPE] $_MM256_DPBUSD_EPI32 = "_mm256_dpbusd_avx_epi32" if AVX == 2 else "_mm256_dpbusd_epi32" $ISA = "avxvnni" if AVX == 2 else "avx256vnni" -void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_prfm" if PREFETCH else ""}( +void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_prfm" if PREFETCH else ""}( size_t g, size_t nc, size_t kc, @@ -33,10 +32,7 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N size_t kr, size_t sr, const ${WTYPE}* weights, - $if BITS == 8: - const int32_t* bias, - $else: - const ${WTYPE}* bias, + const int32_t* bias, const void* scale, ${WTYPE}* packed_weights, size_t extra_bytes, @@ -54,31 +50,26 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N const __m256i vone = _mm256_set1_epi8(1); ${TYPE}* out = (${TYPE}*) packed_weights; const ${BTYPE}* b = (const ${BTYPE}*) bias; - $if BITS == 8: - const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP}); - __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP}); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); do { // NC main loop multiple of ${NR} const ${TYPE}* w0 = (const ${TYPE}*) weights; size_t n = nc; for (;n >= ${NR}; n -= ${NR}) { - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); - - $if BITS == 8: - int32_t* packed_b = (int32_t*) out; + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { - const __m256i vb = _mm256_loadu_si256((const __m256i*) b); - _mm256_storeu_si256((__m256i*) out, vb); + $for N in range(0, NR, 8): + const __m256i vb${N} = _mm256_loadu_si256((const __m256i*) (b + ${N})); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), vb${N}); b += ${NR}; } else { - _mm256_storeu_si256((__m256i*) out, _mm256_setzero_si256()); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), _mm256_setzero_si256()); } - $if BTYPE == TYPE: - out += ${NR}; - $else: - out += ${NR} * sizeof(${BTYPE}); + out += ${NR} * sizeof(${BTYPE}); $for N in range(1, NR): const ${TYPE}* w${N} = w${N-1} + kc; @@ -87,74 +78,60 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N xnn_prefetch_to_l1((const int8_t*) w${N}); xnn_prefetch_to_l1((const int8_t*) w${N} + 64); + $for N in range(0, NR, 4): + __m256i vacc${N} = _mm256_setzero_si256(); + // KC main loop multiple of ${NR}x${KR} size_t k = kc; for (; k >= ${KR}; k -= ${KR}) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + $for N in range(0, NR, 4): + __m256i v${N} = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N})); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+1})), 0x0C); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+2})), 0x30); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+3})), 0xC0); $if PREFETCH: $for N in range(0, NR): xnn_prefetch_to_l1((const int8_t*) w${N} + 128); - $if BITS == 8: - vacc0124x8 = ${_MM256_DPBUSD_EPI32}(vacc0124x8, vone, v0123x8); - vacc4567x8 = ${_MM256_DPBUSD_EPI32}(vacc4567x8, vone, v4567x8); + $for N in range(0, NR, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); $for N in range(NR): w${N} += ${KR}; out += ${NR*KR}; } - // KC remainder 1..KR-1 + // KC remainder of 1..${KR-1} if (k != 0) { assert(k >= 1 && k <= ${KR-1}); - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + $for N in range(0, NR, 4): + __m256i v${N} = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi32(v${N}, *(const int32_t *)w${N}, 0); + v${N} = _mm256_insert_epi32(v${N}, *(const int32_t *)w${N+1}, 2); + v${N} = _mm256_insert_epi32(v${N}, *(const int32_t *)w${N+2}, 4); + v${N} = _mm256_insert_epi32(v${N}, *(const int32_t *)w${N+3}, 6); $for N in range(NR): w${N} += 4; } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N}, 2); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+1}, 6); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+2}, 10); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+3}, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N}, 0); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+1}, 4); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+2}, 8); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+3}, 12); } $for N in range(NR): @@ -162,103 +139,82 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N}, 6); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+1}, 14); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+2}, 22); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+3}, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N}, 4); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+1}, 12); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+2}, 20); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+3}, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N}, 2); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+1}, 10); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+2}, 18); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+3}, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N}, 0); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+1}, 8); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+2}, 16); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+3}, 24); } $for N in range(NR): w${N} += 1; } - $if BITS == 8: - vacc0124x8 = ${_MM256_DPBUSD_EPI32}(vacc0124x8, vone, v0123x8); - vacc4567x8 = ${_MM256_DPBUSD_EPI32}(vacc4567x8, vone, v4567x8); + $for N in range(0, NR, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); out += ${NR*KR}; } - $if BITS == 8: - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + $for N in range(0, NR, 8): + __m256i vksum${N} = _mm256_hadd_epi32(vacc${N}, vacc${N+4}); + vksum${N} = _mm256_permute4x64_epi64(vksum${N}, _MM_SHUFFLE(3, 1, 2, 0)); + $for N in range(0, NR, 8): + vksum${N} = _mm256_mullo_epi32(vksum${N}, vzeropoint); + $for N in range(0, NR, 8): + __m256i vpack${N} = _mm256_loadu_si256((const __m256i*) (packed_b + ${N})); + $for N in range(0, NR, 8): + vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); out = (${TYPE}*) ((uintptr_t) out + extra_bytes); w0 = w${NR-1}; } // NC remainder (1..${NR-1}) if XNN_UNLIKELY(n != 0) { - $if BITS == 8: - int32_t* packed_b = (int32_t*) out; + assert(n >= 1 && n <= ${NR-1}); + + int32_t* packed_b = (int32_t*) out; if XNN_LIKELY(b != NULL) { size_t nb = n; do { - $if BTYPE == TYPE: - *out++ = *b++; - $else: - *((${BTYPE}*) out) = *b++; - out += sizeof(${BTYPE}); + *((${BTYPE}*) out) = *b++; + out += sizeof(${BTYPE}); } while (--nb != 0); } else { size_t nb = n; do { - $if BTYPE == TYPE: - *out++ = 0; - $else: - *((${BTYPE}*) out) = 0; - out += sizeof(${BTYPE}); + *((${BTYPE}*) out) = 0; + out += sizeof(${BTYPE}); } while (--nb != 0); } - $if BTYPE == TYPE: - out += (${NR} - n); - $else: - out += (${NR} - n) * sizeof(${BTYPE}); + out += (${NR} - n) * sizeof(${BTYPE}); - $if NR > 2: $for N in range(1, NR): const ${TYPE}* w${N} = w${N-1} + kc; $if N % 2 == 0: @@ -269,30 +225,31 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N if XNN_UNPREDICTABLE(n < ${N+1}) { w${N} = w${N-1}; } + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N}); + xnn_prefetch_to_l1((const int8_t*) w${N} + 64); - $if BITS == 8: - __m256i vacc0124x8 = _mm256_setzero_si256(); - __m256i vacc4567x8 = _mm256_setzero_si256(); + $for N in range(0, NR, 4): + __m256i vacc${N} = _mm256_setzero_si256(); // KC main loop multiple of ${NR}x${KR} size_t k = kc; for (; k >= ${KR}; k -= ${KR}) { - __m256i v0123x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w0)); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w1)), 0x0C); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w2)), 0x30); - v0123x8 = _mm256_blend_epi32(v0123x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w3)), 0xC0); - - __m256i v4567x8 = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w4)); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w5)), 0x0C); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w6)), 0x30); - v4567x8 = _mm256_blend_epi32(v4567x8, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w7)), 0xC0); + $for N in range(0, NR, 4): + __m256i v${N} = _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N})); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+1})), 0x0C); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+2})), 0x30); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi64x((int64_t) unaligned_load_u64(w${N+3})), 0xC0); + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 128); - $if BITS == 8: - vacc0124x8 = ${_MM256_DPBUSD_EPI32}(vacc0124x8, vone, v0123x8); - vacc4567x8 = ${_MM256_DPBUSD_EPI32}(vacc4567x8, vone, v4567x8); + $for N in range(0, NR, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); $for N in range(NR): w${N} += ${KR}; @@ -301,43 +258,32 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N // KC remainder of 1..${KR-1} if (k != 0) { - __m256i v0123x8 = _mm256_setzero_si256(); - __m256i v4567x8 = _mm256_setzero_si256(); + assert(k >= 1 && k <= ${KR-1}); + $for N in range(0, NR, 4): + __m256i v${N} = _mm256_setzero_si256(); if (k & 4) { - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w0, 0); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w1, 2); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w2, 4); - v0123x8 = _mm256_insert_epi32(v0123x8, *(const int32_t *)w3, 6); - - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w4, 0); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w5, 2); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w6, 4); - v4567x8 = _mm256_insert_epi32(v4567x8, *(const int32_t *)w7, 6); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi32(v${N}, *(const int32_t *)w${N}, 0); + v${N} = _mm256_insert_epi32(v${N}, *(const int32_t *)w${N+1}, 2); + v${N} = _mm256_insert_epi32(v${N}, *(const int32_t *)w${N+2}, 4); + v${N} = _mm256_insert_epi32(v${N}, *(const int32_t *)w${N+3}, 6); $for N in range(NR): w${N} += 4; } if (k & 2) { if (k & 4) { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 2); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 6); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 10); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 14); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 2); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 6); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 10); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 14); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N}, 2); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+1}, 6); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+2}, 10); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+3}, 14); } else { - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w0, 0); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w1, 4); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w2, 8); - v0123x8 = _mm256_insert_epi16(v0123x8, *(const int16_t *)w3, 12); - - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w4, 0); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w5, 4); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w6, 8); - v4567x8 = _mm256_insert_epi16(v4567x8, *(const int16_t *)w7, 12); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N}, 0); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+1}, 4); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+2}, 8); + v${N} = _mm256_insert_epi16(v${N}, *(const int16_t *)w${N+3}, 12); } $for N in range(NR): @@ -345,72 +291,61 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N } if (k & 1) { if ((k & 4) && (k & 2)) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 6); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 14); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 22); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 30); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 6); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 14); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 22); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 30); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N}, 6); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+1}, 14); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+2}, 22); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+3}, 30); } else if (k & 4) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 4); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 12); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 20); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 28); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 4); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 12); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 20); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 28); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N}, 4); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+1}, 12); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+2}, 20); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+3}, 28); } else if (k & 2) { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 2); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 10); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 18); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 26); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 2); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 10); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 18); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 26); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N}, 2); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+1}, 10); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+2}, 18); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+3}, 26); } else { - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w0, 0); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w1, 8); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w2, 16); - v0123x8 = _mm256_insert_epi8(v0123x8, *(const int8_t *)w3, 24); - - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w4, 0); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w5, 8); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w6, 16); - v4567x8 = _mm256_insert_epi8(v4567x8, *(const int8_t *)w7, 24); + $for N in range(0, NR, 4): + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N}, 0); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+1}, 8); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+2}, 16); + v${N} = _mm256_insert_epi8(v${N}, *(const int8_t *)w${N+3}, 24); } $for N in range(NR): w${N} += 1; } - $if BITS == 8: - vacc0124x8 = ${_MM256_DPBUSD_EPI32}(vacc0124x8, vone, v0123x8); - vacc4567x8 = ${_MM256_DPBUSD_EPI32}(vacc4567x8, vone, v4567x8); + $for N in range(0, NR, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); - _mm256_storeu_si256((__m256i *)&out[0], v0123x8); - _mm256_storeu_si256((__m256i *)&out[${4 * KR}], v4567x8); + $for N in range(0, NR, 4): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); out += ${NR*KR}; } - $if BITS == 8: - __m256i vksum = _mm256_hadd_epi32(vacc0124x8, vacc4567x8); - vksum = _mm256_permute4x64_epi64(vksum, _MM_SHUFFLE(3, 1, 2, 0)); - vksum = _mm256_mullo_epi32(vksum, vzeropoint); - __m256i vpack = _mm256_loadu_si256((const __m256i*) packed_b); - _mm256_storeu_si256((__m256i *)packed_b, _mm256_sub_epi32(vpack, vksum)); + $for N in range(0, NR, 8): + __m256i vksum${N} = _mm256_hadd_epi32(vacc${N}, vacc${N+4}); + vksum${N} = _mm256_permute4x64_epi64(vksum${N}, _MM_SHUFFLE(3, 1, 2, 0)); + $for N in range(0, NR, 8): + vksum${N} = _mm256_mullo_epi32(vksum${N}, vzeropoint); + $for N in range(0, NR, 8): + __m256i vpack${N} = _mm256_loadu_si256((const __m256i*) (packed_b + ${N})); + $for N in range(0, NR, 8): + vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); out = (${TYPE}*) ((uintptr_t) out + extra_bytes); } + weights += nc * kc; } while (--g != 0); } diff --git a/src/x8-packw/kr-scalar.c.in b/src/x8-packw/kr-scalar.c.in index 5a8ead608ba..d0a433fae64 100644 --- a/src/x8-packw/kr-scalar.c.in +++ b/src/x8-packw/kr-scalar.c.in @@ -100,12 +100,14 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N assert(k >= 1 && k <= ${KR-1}); $for N in range(NR): const ${TYPE} v${N}x0 = w${N}[0]; - ksum${N} += (uint32_t) v${N}x0; + $if BITS == 8: + ksum${N} += (uint32_t) v${N}x0; out[${N*KR}] = v${N}x0; $for K in range(1, KR): if (${K} < k) { const ${TYPE} v${N}x${K} = w${N}[${K}]; - ksum${N} += (uint32_t) v${N}x${K}; + $if BITS == 8: + ksum${N} += (uint32_t) v${N}x${K}; out[${N*KR+K}] = v${N}x${K}; } w${N} += k; @@ -184,12 +186,14 @@ void xnn_qs${BITS}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${N assert(k >= 1 && k <= ${KR-1}); $for N in range(NR-1): const ${TYPE} v${N}x0 = w${N}[0]; - ksum${N} += (uint32_t) v${N}x0; + $if BITS == 8: + ksum${N} += (uint32_t) v${N}x0; out[${N*KR}] = v${N}x0; $for K in range(1, KR): if (${K} < k) { const ${TYPE} v${N}x${K} = w${N}[${K}]; - ksum${N} += (uint32_t) v${N}x${K}; + $if BITS == 8: + ksum${N} += (uint32_t) v${N}x${K}; out[${N*KR+K}] = v${N}x${K}; } w${N} += k; diff --git a/test/packw-microkernel-tester.h b/test/packw-microkernel-tester.h index 56b2abfc440..f4be8cacbc7 100644 --- a/test/packw-microkernel-tester.h +++ b/test/packw-microkernel-tester.h @@ -118,21 +118,13 @@ class PackWMicrokernelTester { const xnn_qs8_packing_params packing_params = { 0 }; // Compute reference results. - if (izp() == 128) { - xnn_pack_qs8_to_qu8_gemm_goi_w(/*g=*/1, n(), k(), nr(), kr(), sr(), - reinterpret_cast(weights.data()), - bias_data, - /*scale=*/nullptr, - reinterpret_cast(packed_w_ref.data()), - /*extra_bytes=*/0, &packing_params); - } else { - xnn_pack_qs8_gemm_goi_w(/*g=*/1, n(), k(), nr(), kr(), sr(), - reinterpret_cast(weights.data()), - bias_data, - /*scale=*/nullptr, - reinterpret_cast(packed_w_ref.data()), - /*extra_bytes=*/0, &packing_params); - } + auto* pack_function = izp() == 128 ? xnn_pack_qs8_to_qu8_gemm_goi_w : xnn_pack_qs8_gemm_goi_w; + pack_function(/*g=*/1, n(), k(), nr(), kr(), sr(), + reinterpret_cast(weights.data()), + bias_data, + /*scale=*/nullptr, + reinterpret_cast(packed_w_ref.data()), + /*extra_bytes=*/0, &packing_params); // Call optimized micro-kernel. packw(/*g=*/1, n(), k(), nr(), kr(), sr(), @@ -140,7 +132,9 @@ class PackWMicrokernelTester { // Verify bias results. for (size_t i = 0; i < packed_n() * sizeof(int32_t); i++) { - EXPECT_EQ((int32_t) packed_w[i], (int32_t) packed_w_ref[i]); + if (packed_w_ref[i] != INT8_C(0x7B)) { // Allow pad to differ + EXPECT_EQ((int32_t) packed_w[i], (int32_t) packed_w_ref[i]); + } } // Verify weights results. diff --git a/test/qs8-packw.cc b/test/qs8-packw.cc index 324cec3348f..0f69c429a15 100644 --- a/test/qs8-packw.cc +++ b/test/qs8-packw.cc @@ -10,6 +10,7 @@ #include "xnnpack/common.h" #include "xnnpack/isa-checks.h" #include "xnnpack/packw.h" +#include "next_prime.h" #include "packw-microkernel-tester.h" namespace { @@ -34,13 +35,15 @@ std::string GetTestQS8Name(const testing::TestParamInfo& const XnnTestQS8Param xnn_test_qs8_params[] = { #include "src/qs8-packw/qs8-packw.h" }; + #undef XNN_QS8_UKERNEL } // namespace -TEST_P(XnnTestQS8, k_eq_kblock) { +TEST_P(XnnTestQS8, null_bias) { TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); PackWMicrokernelTester() + .nullbias(true) .n(GetParam().nr * GetParam().nr_scale) .k(GetParam().kblock) .nr(GetParam().nr * GetParam().nr_scale) @@ -50,14 +53,11 @@ TEST_P(XnnTestQS8, k_eq_kblock) { .Test(GetParam().ukernel); } -TEST_P(XnnTestQS8, k_div_kblock) { - if (GetParam().kblock <= 1) { - GTEST_SKIP(); - } +TEST_P(XnnTestQS8, k_eq_kblock) { TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); PackWMicrokernelTester() .n(GetParam().nr * GetParam().nr_scale) - .k(GetParam().kblock * 5) + .k(GetParam().kblock) .nr(GetParam().nr * GetParam().nr_scale) .kr(GetParam().kr) .sr(GetParam().sr) @@ -65,6 +65,20 @@ TEST_P(XnnTestQS8, k_div_kblock) { .Test(GetParam().ukernel); } +TEST_P(XnnTestQS8, k_div_kblock) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + for (size_t k = GetParam().kblock; k < GetParam().kblock * 5; k += GetParam().kblock) { + PackWMicrokernelTester() + .n(GetParam().nr * GetParam().nr_scale) + .k(k) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .izp(GetParam().izp) + .Test(GetParam().ukernel); + } +} + TEST_P(XnnTestQS8, k_lt_kblock) { if (GetParam().kblock <= 1) { GTEST_SKIP(); @@ -84,7 +98,7 @@ TEST_P(XnnTestQS8, k_lt_kblock) { TEST_P(XnnTestQS8, k_gt_kblock) { TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t k = GetParam().kblock + 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { + for (size_t k = GetParam().kblock + 1; k < GetParam().kblock * 5; k = xnnpack::NextPrime(k + 1)) { PackWMicrokernelTester() .n(GetParam().nr * GetParam().nr_scale) .k(k) @@ -96,12 +110,74 @@ TEST_P(XnnTestQS8, k_gt_kblock) { } } -TEST_P(XnnTestQS8, n_eq_nr) { +TEST_P(XnnTestQS8, n_eq_1) { + if (GetParam().nr <= 1 || GetParam().nr_scale != 1) { + GTEST_SKIP(); + } TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t k = 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { + PackWMicrokernelTester() + .n(1 * GetParam().nr_scale) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .izp(GetParam().izp) + .Test(GetParam().ukernel); +} + + +TEST_P(XnnTestQS8, n_div_nr_null_bias) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + for (size_t n = GetParam().nr; n < GetParam().nr * 5; n += GetParam().nr) { PackWMicrokernelTester() - .n(GetParam().nr * GetParam().nr_scale) - .k(k) + .nullbias(true) + .n(n * GetParam().nr_scale) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .izp(GetParam().izp) + .Test(GetParam().ukernel); + } +} + +TEST_P(XnnTestQS8, n_div_nr) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + for (size_t n = GetParam().nr; n < GetParam().nr * 5; n += GetParam().nr) { + PackWMicrokernelTester() + .n(n * GetParam().nr_scale) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .izp(GetParam().izp) + .Test(GetParam().ukernel); + } +} + +TEST_P(XnnTestQS8, n_lt_nr) { + if (GetParam().nr <= 1 || GetParam().nr_scale != 1) { + GTEST_SKIP(); + } + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + for (size_t n = 1; n < GetParam().nr * GetParam().nr_scale; n++) { + PackWMicrokernelTester() + .n(n) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .izp(GetParam().izp) + .Test(GetParam().ukernel); + } +} + +TEST_P(XnnTestQS8, n_gt_nr) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + for (size_t n = GetParam().nr * GetParam().nr_scale; n < GetParam().nr * GetParam().nr_scale * 5; n = xnnpack::NextPrime(n + 1)) { + PackWMicrokernelTester() + .n(n) + .k(xnnpack::NextPrime(GetParam().kblock + 1)) .nr(GetParam().nr * GetParam().nr_scale) .kr(GetParam().kr) .sr(GetParam().sr) @@ -114,4 +190,3 @@ INSTANTIATE_TEST_SUITE_P(qs8_packw, XnnTestQS8, testing::ValuesIn(xnn_test_qs8_params), GetTestQS8Name); - diff --git a/test/x8-packw.cc b/test/x8-packw.cc index 4d64ba3fe4d..c6bdcc55ab7 100644 --- a/test/x8-packw.cc +++ b/test/x8-packw.cc @@ -10,6 +10,7 @@ #include "xnnpack/common.h" #include "xnnpack/isa-checks.h" #include "xnnpack/packw.h" +#include "next_prime.h" #include "packw-microkernel-tester.h" namespace { @@ -39,9 +40,10 @@ const XnnTestParam xnn_test_params[] = { } // namespace -TEST_P(XnnTest, k_eq_kblock) { +TEST_P(XnnTest, null_bias) { TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); PackWMicrokernelTester() + .nullbias(true) .n(GetParam().nr * GetParam().nr_scale) .k(GetParam().kblock) .nr(GetParam().nr * GetParam().nr_scale) @@ -50,26 +52,20 @@ TEST_P(XnnTest, k_eq_kblock) { .Test(GetParam().ukernel); } -TEST_P(XnnTest, k_div_kblock) { - if (GetParam().kblock <= 1) { - GTEST_SKIP(); - } +TEST_P(XnnTest, k_eq_kblock) { TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); PackWMicrokernelTester() .n(GetParam().nr * GetParam().nr_scale) - .k(GetParam().kblock * 5) + .k(GetParam().kblock) .nr(GetParam().nr * GetParam().nr_scale) .kr(GetParam().kr) .sr(GetParam().sr) .Test(GetParam().ukernel); } -TEST_P(XnnTest, k_lt_kblock) { - if (GetParam().kblock <= 1) { - GTEST_SKIP(); - } +TEST_P(XnnTest, k_div_kblock) { TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t k = 1; k < GetParam().kblock; k++) { + for (size_t k = GetParam().kblock; k < GetParam().kblock * 5; k += GetParam().kblock) { PackWMicrokernelTester() .n(GetParam().nr * GetParam().nr_scale) .k(k) @@ -80,9 +76,12 @@ TEST_P(XnnTest, k_lt_kblock) { } } -TEST_P(XnnTest, k_gt_kblock) { +TEST_P(XnnTest, k_lt_kblock) { + if (GetParam().kblock <= 1) { + GTEST_SKIP(); + } TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t k = GetParam().kblock + 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { + for (size_t k = 1; k < GetParam().kblock; k++) { PackWMicrokernelTester() .n(GetParam().nr * GetParam().nr_scale) .k(k) @@ -93,9 +92,9 @@ TEST_P(XnnTest, k_gt_kblock) { } } -TEST_P(XnnTest, n_eq_nr) { +TEST_P(XnnTest, k_gt_kblock) { TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t k = 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { + for (size_t k = GetParam().kblock + 1; k < GetParam().kblock * 5; k = xnnpack::NextPrime(k + 1)) { PackWMicrokernelTester() .n(GetParam().nr * GetParam().nr_scale) .k(k) @@ -106,15 +105,26 @@ TEST_P(XnnTest, n_eq_nr) { } } -TEST_P(XnnTest, n_div_nr) { +TEST_P(XnnTest, n_eq_1) { if (GetParam().nr <= 1 || GetParam().nr_scale != 1) { GTEST_SKIP(); } TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t k = 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { + PackWMicrokernelTester() + .n(1 * GetParam().nr_scale) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .Test(GetParam().ukernel); +} + +TEST_P(XnnTest, n_div_nr) { + TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); + for (size_t n = GetParam().nr; n < GetParam().nr * 5; n += GetParam().nr) { PackWMicrokernelTester() - .n(GetParam().nr * 2 * GetParam().nr_scale) - .k(k) + .n(n * GetParam().nr_scale) + .k(GetParam().kblock) .nr(GetParam().nr * GetParam().nr_scale) .kr(GetParam().kr) .sr(GetParam().sr) @@ -127,118 +137,30 @@ TEST_P(XnnTest, n_lt_nr) { GTEST_SKIP(); } TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t k = 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { - for (size_t n = 1; n < GetParam().nr * GetParam().nr_scale; n++) { - PackWMicrokernelTester() - .n(n) - .k(k) - .nr(GetParam().nr * GetParam().nr_scale) - .kr(GetParam().kr) - .sr(GetParam().sr) - .Test(GetParam().ukernel); - } + for (size_t n = 1; n < GetParam().nr * GetParam().nr_scale; n++) { + PackWMicrokernelTester() + .n(n) + .k(GetParam().kblock) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .Test(GetParam().ukernel); } } TEST_P(XnnTest, n_gt_nr) { TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t k = 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { - if (GetParam().nr_scale == 1) { - for (size_t n = GetParam().nr + 1; n < (GetParam().nr == 1 ? 4 : GetParam().nr * 2); n++) { - PackWMicrokernelTester() - .n(n) - .k(k) - .nr(GetParam().nr) - .kr(GetParam().kr) - .sr(GetParam().sr) - .Test(GetParam().ukernel); - } - } else { - for (size_t n = (GetParam().nr + 1) * GetParam().nr_scale; - n < (GetParam().nr == 1 ? 4 : GetParam().nr * 2) * GetParam().nr_scale; - n += 1 * GetParam().nr_scale) { - PackWMicrokernelTester() - .n(n) - .k(k) - .nr(GetParam().nr * GetParam().nr_scale) - .kr(GetParam().kr) - .sr(GetParam().sr) - .Test(GetParam().ukernel); - } - } - } -} - -TEST_P(XnnTest, g_gt_1) { - TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t g = 2; g <= 3; g++) { - for (size_t k = 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { - if (GetParam().nr_scale == 1) { - for (size_t n = GetParam().nr + 1; n < (GetParam().nr == 1 ? 4 : GetParam().nr * 2); n++) { - PackWMicrokernelTester() - .g(g) - .n(n) - .k(k) - .nr(GetParam().nr) - .kr(GetParam().kr) - .sr(GetParam().sr) - .Test(GetParam().ukernel); - } - } else { - for (size_t n = (GetParam().nr + 1) * GetParam().nr_scale; - n < (GetParam().nr == 1 ? 4 : GetParam().nr * 2) * GetParam().nr_scale; - n += 1 * GetParam().nr_scale) { - PackWMicrokernelTester() - .g(g) - .n(n) - .k(k) - .nr(GetParam().nr * GetParam().nr_scale) - .kr(GetParam().kr) - .sr(GetParam().sr) - .Test(GetParam().ukernel); - } - } - } - } -} - -TEST_P(XnnTest, null_bias) { - TEST_REQUIRES_ARCH_FLAGS(GetParam().arch_flags); - for (size_t g = 2; g <= 3; g++) { - for (size_t k = 1; k < (GetParam().kblock == 1 ? 4 : GetParam().kblock * 2); k++) { - if (GetParam().nr_scale == 1) { - for (size_t n = GetParam().nr + 1; n < (GetParam().nr == 1 ? 4 : GetParam().nr * 2); n++) { - PackWMicrokernelTester() - .nullbias(true) - .g(g) - .n(n) - .k(k) - .nr(GetParam().nr) - .kr(GetParam().kr) - .sr(GetParam().sr) - .Test(GetParam().ukernel); - } - } else { - for (size_t n = (GetParam().nr + 1) * GetParam().nr_scale; - n < (GetParam().nr == 1 ? 4 : GetParam().nr * 2) * GetParam().nr_scale; - n += 1 * GetParam().nr_scale) { - PackWMicrokernelTester() - .nullbias(true) - .g(g) - .n(n) - .k(k) - .nr(GetParam().nr * GetParam().nr_scale) - .kr(GetParam().kr) - .sr(GetParam().sr) - .Test(GetParam().ukernel); - } - } - } + for (size_t n = GetParam().nr * GetParam().nr_scale; n < GetParam().nr * GetParam().nr_scale * 5; n = xnnpack::NextPrime(n + 1)) { + PackWMicrokernelTester() + .n(n) + .k(xnnpack::NextPrime(GetParam().kblock + 1)) + .nr(GetParam().nr * GetParam().nr_scale) + .kr(GetParam().kr) + .sr(GetParam().sr) + .Test(GetParam().ukernel); } } - INSTANTIATE_TEST_SUITE_P(x8_packw, XnnTest, testing::ValuesIn(xnn_test_params), GetTestName); -