From 2717a0754cd5f3adecda091e4ef6b022e6051473 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Sun, 25 Aug 2024 20:49:20 +0800 Subject: [PATCH 1/4] Add amx option in CMakeLists; add amx detection and register header file Signed-off-by: Molly Sophia --- CMakeLists.txt | 4 +++ src/cpu.cpp | 15 +++++++++++ src/cpu.h | 2 ++ src/layer/arm/amx_usability.h | 50 +++++++++++++++++++++++++++++++++++ src/platform.h.in | 1 + 5 files changed, 72 insertions(+) create mode 100644 src/layer/arm/amx_usability.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f32a80c86ee..1db6803f937e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,6 +150,10 @@ if(NOT (CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES " endif() endif() +if(MACOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm") + option(NCNN_APPLE_AMX "optimize apple silicon platforms with apple amx" ON) +endif() + if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm") OR (APPLE AND CMAKE_OSX_ARCHITECTURES MATCHES "arm64") OR (CMAKE_SYSTEM_PROCESSOR MATCHES "^(arm|aarch64)") diff --git a/src/cpu.cpp b/src/cpu.cpp index f9e64a1cc75b..180c5dc02b9d 100644 --- a/src/cpu.cpp +++ b/src/cpu.cpp @@ -2378,6 +2378,21 @@ int cpu_support_arm_svef32mm() #endif } +int cpu_support_arm_amx() +{ + try_initialize_global_cpu_info(); +#if __aarch64__ && __APPLE__ + return g_hw_cpufamily == CPUFAMILY_ARM_FIRESTORM_ICESTORM + || g_hw_cpufamily == CPUFAMILY_ARM_AVALANCHE_BLIZZARD + || g_hw_cpufamily == CPUFAMILY_ARM_IBIZA + || g_hw_cpufamily == CPUFAMILY_ARM_LOBOS + || g_hw_cpufamily == CPUFAMILY_ARM_PALMA; + +#else + return 0; +#endif +} + int cpu_support_x86_avx() { try_initialize_global_cpu_info(); diff --git a/src/cpu.h b/src/cpu.h index 2ae6b8c3ffe9..3677715cfea1 100644 --- a/src/cpu.h +++ b/src/cpu.h @@ -80,6 +80,8 @@ NCNN_EXPORT int cpu_support_arm_svebf16(); NCNN_EXPORT int cpu_support_arm_svei8mm(); // svef32mm = aarch64 svef32mm NCNN_EXPORT int cpu_support_arm_svef32mm(); +// amx = aarch64 apple amx +NCNN_EXPORT int cpu_support_arm_amx(); // avx = x86 avx NCNN_EXPORT int cpu_support_x86_avx(); diff --git a/src/layer/arm/amx_usability.h b/src/layer/arm/amx_usability.h new file mode 100644 index 000000000000..6b4fc93adf01 --- /dev/null +++ b/src/layer/arm/amx_usability.h @@ -0,0 +1,50 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef AMX_USABILITY_H +#define AMX_USABILITY_H + +// From https://github.com/corsix/amx/blob/main/aarch64.h +#define AMX_NOP_OP_IMM5(op, imm5) \ + __asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" : : "i"(op), "i"(imm5) : "memory") + +#define AMX_OP_GPR(op, gpr) \ + __asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" : : "i"(op), "r"((uint64_t)(gpr)) : "memory") + +#define AMX_LDX(gpr) AMX_OP_GPR( 0, gpr) +#define AMX_LDY(gpr) AMX_OP_GPR( 1, gpr) +#define AMX_STX(gpr) AMX_OP_GPR( 2, gpr) +#define AMX_STY(gpr) AMX_OP_GPR( 3, gpr) +#define AMX_LDZ(gpr) AMX_OP_GPR( 4, gpr) +#define AMX_STZ(gpr) AMX_OP_GPR( 5, gpr) +#define AMX_LDZI(gpr) AMX_OP_GPR( 6, gpr) +#define AMX_STZI(gpr) AMX_OP_GPR( 7, gpr) +#define AMX_EXTRX(gpr) AMX_OP_GPR( 8, gpr) +#define AMX_EXTRY(gpr) AMX_OP_GPR( 9, gpr) +#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr) +#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr) +#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr) +#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr) +#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr) +#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr) +#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr) +#define AMX_SET() AMX_NOP_OP_IMM5(17, 0) +#define AMX_CLR() AMX_NOP_OP_IMM5(17, 1) +#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr) +#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr) +#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr) +#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr) +#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr) + +#endif // AMX_USABILITY_H \ No newline at end of file diff --git a/src/platform.h.in b/src/platform.h.in index 50a9454b7da0..b45d2dbd31ea 100644 --- a/src/platform.h.in +++ b/src/platform.h.in @@ -45,6 +45,7 @@ #cmakedefine01 NCNN_AVX512BF16 #cmakedefine01 NCNN_AVX512FP16 #cmakedefine01 NCNN_VFPV4 +#cmakedefine01 NCNN_APPLE_AMX #cmakedefine01 NCNN_ARM82 #cmakedefine01 NCNN_ARM82DOT #cmakedefine01 NCNN_ARM82FP16FML From dfe89d7dc68a9b412bb13c4b8d9591881640cbe5 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Wed, 28 Aug 2024 10:50:18 +0800 Subject: [PATCH 2/4] amx innerproduct simple attempt Signed-off-by: Molly Sophia --- src/layer/arm/amx_usability.h | 72 +++++++++++++++++++++++++++++- src/layer/arm/innerproduct_arm.cpp | 72 ++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 2 deletions(-) diff --git a/src/layer/arm/amx_usability.h b/src/layer/arm/amx_usability.h index 6b4fc93adf01..6955081c495c 100644 --- a/src/layer/arm/amx_usability.h +++ b/src/layer/arm/amx_usability.h @@ -39,12 +39,80 @@ #define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr) #define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr) #define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr) -#define AMX_SET() AMX_NOP_OP_IMM5(17, 0) -#define AMX_CLR() AMX_NOP_OP_IMM5(17, 1) #define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr) #define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr) #define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr) #define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr) #define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr) +#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t)&*(ptr)) + (((uint64_t)((row) + (flags) * 64)) << 56)) +void amx_set() +{ + AMX_NOP_OP_IMM5(17, 0); +} + +void amx_clr() +{ + AMX_NOP_OP_IMM5(17, 1); +} + +void amx_ldx(bool pair, unsigned int x_row, const void * ptr) +{ + if (x_row >= 8) + return; + + uint64_t oprand = (uint64_t)ptr + ((uint64_t)x_row << 56); + if (pair) + oprand |= 1ULL << 62; + + AMX_LDX(oprand); +} + +void amx_ldy(bool pair, unsigned int y_row, const void * ptr) +{ + if (y_row >= 8) + return; + + uint64_t oprand = (uint64_t)ptr + ((uint64_t)y_row << 56); + if (pair) + oprand |= 1ULL << 62; + + AMX_LDY(oprand); +} + +void amx_ldz(bool pair, unsigned int z_row, const void * ptr) +{ + if (z_row >= 64) + return; + + uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56); + if (pair) + oprand |= 1ULL << 62; + + AMX_LDZ(oprand); +} + +void amx_stz(bool pair, unsigned int z_row, const void * ptr) +{ + if (z_row >= 64) + return; + + uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56); + if (pair) + oprand |= 1ULL << 62; + + AMX_STZ(oprand); +} + +void amx_fma32(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row) +{ + uint64_t oprand = 0; + if (vector) + oprand |= 1ULL << 63; + + oprand |= (uint64_t)y_offset & 0x1FF; + oprand |= ((uint64_t)x_offset & 0x1FF) << 10; + + AMX_FMA32(oprand); +} #endif // AMX_USABILITY_H \ No newline at end of file diff --git a/src/layer/arm/innerproduct_arm.cpp b/src/layer/arm/innerproduct_arm.cpp index 99e93d6fbe09..054cafefd91d 100644 --- a/src/layer/arm/innerproduct_arm.cpp +++ b/src/layer/arm/innerproduct_arm.cpp @@ -23,6 +23,10 @@ #include "arm_activation.h" #include "arm_usability.h" +#if __aarch64__ && NCNN_APPLE_AMX +#include "amx_usability.h" +#endif + #include "cpu.h" namespace ncnn { @@ -142,6 +146,12 @@ int InnerProduct_arm::destroy_pipeline(const Option& opt) int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +// #if __aarch64__ && NCNN_APPLE_AMX +// AMX_SET(); +// uint8_t arr[256]; +// AMX_LDX(arr); +// AMX_CLR(); +// #endif #if NCNN_INT8 if (opt.use_int8_inference && int8_scale_term) { @@ -369,6 +379,68 @@ int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Optio } int i = 0; +#if __aarch64__ && NCNN_APPLE_AMX + amx_set(); + + float _sums[32] = {0}; + amx_ldz(true, 0, _sums); + for (; i + 31 < num_input; i += 32) + { + amx_ldx(true, 0, m); + amx_ldy(true, 0, kptr); + amx_fma32(true, 0, 0, 0); + amx_fma32(true, 64, 64, 1); + + m += 32; + kptr += 32; + } + amx_stz(true, 0, _sums); +#if __ARM_NEON + { + float32x4_t _sum = vdupq_n_f32(0.f); + for (int q = 0; q < 8; q++) + { + _sum = vaddq_f32(_sum, vld1q_f32(_sums + q * 4)); + } + sum += vaddvq_f32(_sum); + memset(_sums, 0, 16 * sizeof(float)); + } +#else + for (int q = 0; q < 32; q++) + { + sum += _sums[q]; + _sums[q] = 0; + } +#endif + amx_ldz(false, 0, _sums); + for (; i + 15 < num_input; i += 16) + { + amx_ldx(false, 0, m); + amx_ldy(false, 0, kptr); + amx_fma32(true, 0, 0, 0); + + m += 16; + kptr += 16; + } + amx_stz(false, 0, _sums); +#if __ARM_NEON + { + float32x4_t _sum = vdupq_n_f32(0.f); + for (int q = 0; q < 4; q++) + { + _sum = vaddq_f32(_sum, vld1q_f32(_sums + q * 4)); + } + sum += vaddvq_f32(_sum); + } +#else + for (int q = 0; q < 16; q++) + { + sum += _sums[q]; + } +#endif + amx_clr(); +#endif + #if __ARM_NEON float32x4_t _sum = vdupq_n_f32(0.f); for (; i + 3 < num_input; i += 4) From f445e7156e307a4183416f48316b14874cb11945 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Wed, 28 Aug 2024 10:50:45 +0800 Subject: [PATCH 3/4] [WIP] amx conv im2col fp16s Signed-off-by: Molly Sophia --- src/layer/arm/amx_usability.h | 34 +- src/layer/arm/convolution_im2col_gemm_fp16s.h | 1706 ++++++++++++++++- 2 files changed, 1738 insertions(+), 2 deletions(-) diff --git a/src/layer/arm/amx_usability.h b/src/layer/arm/amx_usability.h index 6955081c495c..d2bdc267530f 100644 --- a/src/layer/arm/amx_usability.h +++ b/src/layer/arm/amx_usability.h @@ -103,7 +103,29 @@ void amx_stz(bool pair, unsigned int z_row, const void * ptr) AMX_STZ(oprand); } -void amx_fma32(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row) +void amx_fma16_masked(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row, uint8_t x_mode, uint8_t x_mask, uint8_t y_mode, uint8_t y_mask) +{ + uint64_t oprand = 0; + if (vector) + oprand |= 1ULL << 63; + + oprand |= (uint64_t)y_offset & 0x1FF; + oprand |= ((uint64_t)x_offset & 0x1FF) << 10; + oprand |= ((uint64_t)z_row & 0x3F) << 20; + oprand |= ((uint64_t)y_mask & 0x1F) << 32; + oprand |= ((uint64_t)y_mode & 0x3) << 37; + oprand |= ((uint64_t)x_mask & 0x1F) << 41; + oprand |= ((uint64_t)x_mode & 0x3) << 46; + + AMX_FMA16(oprand); +} + +void amx_fma16(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row) +{ + amx_fma16_masked(vector, x_offset, y_offset, z_row, 0, 0, 0, 0); +} + +void amx_fma32_masked(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row, uint8_t x_mode, uint8_t x_mask, uint8_t y_mode, uint8_t y_mask) { uint64_t oprand = 0; if (vector) @@ -111,8 +133,18 @@ void amx_fma32(bool vector, unsigned int x_offset, unsigned int y_offset, int z_ oprand |= (uint64_t)y_offset & 0x1FF; oprand |= ((uint64_t)x_offset & 0x1FF) << 10; + oprand |= ((uint64_t)z_row & 0x3F) << 20; + oprand |= ((uint64_t)y_mask & 0x1F) << 32; + oprand |= ((uint64_t)y_mode & 0x3) << 37; + oprand |= ((uint64_t)x_mask & 0x1F) << 41; + oprand |= ((uint64_t)x_mode & 0x3) << 46; AMX_FMA32(oprand); } +void amx_fma32(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row) +{ + amx_fma32_masked(vector, x_offset, y_offset, z_row, 0, 0, 0, 0); +} + #endif // AMX_USABILITY_H \ No newline at end of file diff --git a/src/layer/arm/convolution_im2col_gemm_fp16s.h b/src/layer/arm/convolution_im2col_gemm_fp16s.h index 8e499b3e73a8..7ab071db9016 100644 --- a/src/layer/arm/convolution_im2col_gemm_fp16s.h +++ b/src/layer/arm/convolution_im2col_gemm_fp16s.h @@ -3017,6 +3017,1696 @@ static void convolution_gemm_transB_packed_tile_fp16sa(const Mat& AT_tile, const } } +#if __aarch64__ && NCNN_APPLE_AMX +#include + +static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end) +{ + // NCNN_LOGE("convolution_gemm_transB_packed_tile_fp16sa_amx %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + amx_set(); + const int out_elempack = top_blob.elempack; + const int out_hstep = (int)top_blob.cstep; + + const __fp16* pAT = AT_tile; + const __fp16* pBT = BT_tile; + const __fp16* pC = CT_tile; + + __fp16* outptr = topT_tile; + + int ii = 0; + for (; ii + 7 < max_ii; ii += 8) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const __fp16* pB = pBT; + + if (pC) + { + pC = (const __fp16*)CT_tile + i + ii; + } + + int jj = 0; + // NCNN_LOGE("max_jj %d", max_jj); + for (; jj + 11 < max_jj; jj += 12) + { + const __fp16* pA = pAT; + + if (k == 0) + { + if (pC) + { + for (int r = 0; r < 12; r++) + amx_ldz(false, 2*r, pC); + } + else + { + __fp16 sums[32]; + memset(sums, 0, 32 * sizeof(__fp16)); + for (int r = 0; r < 12; r++) + amx_ldz(false, 2*r, sums); + } + } + else + { + for (int r = 0; r < 12; r++) + amx_ldz(false, 2*r, outptr + 8 * r); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + amx_ldx(false, 0, pA); + amx_ldy(false, 0, pB); + amx_fma16_masked(false, 0, 0, 0, 0x2, 8, 0x2, 12); + + pA += 8; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 8) + { + __fp16 tmp[96 + 24]; + for (int r = 0; r < 12; r++) { + amx_stz(false, 2*r, tmp + r * 8); + } + memcpy(outptr0, tmp, 96 * sizeof(__fp16)); + outptr0 += 96; + } + if (out_elempack == 4) + { + __fp16 tmp[32]; + for (int r = 0; r < 12; r++) { + amx_stz(false, 2*r, tmp); + float16x8_t _tmp = vld1q_f16(tmp); + vst1_f16(outptr0 + 4 * r, vget_low_f16(_tmp)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * r, vget_high_f16(_tmp)); + } + + outptr0 += 48; + } + if (out_elempack == 1) + { + float16x8_t _sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb; + __fp16 tmp[32]; + amx_stz(false, 0, tmp); + _sum0 = vld1q_f16(tmp); + amx_stz(false, 2, tmp); + _sum1 = vld1q_f16(tmp); + amx_stz(false, 4, tmp); + _sum2 = vld1q_f16(tmp); + amx_stz(false, 6, tmp); + _sum3 = vld1q_f16(tmp); + amx_stz(false, 8, tmp); + _sum4 = vld1q_f16(tmp); + amx_stz(false, 10, tmp); + _sum5 = vld1q_f16(tmp); + amx_stz(false, 12, tmp); + _sum6 = vld1q_f16(tmp); + amx_stz(false, 14, tmp); + _sum7 = vld1q_f16(tmp); + amx_stz(false, 16, tmp); + _sum8 = vld1q_f16(tmp); + amx_stz(false, 18, tmp); + _sum9 = vld1q_f16(tmp); + amx_stz(false, 20, tmp); + _suma = vld1q_f16(tmp); + amx_stz(false, 22, tmp); + _sumb = vld1q_f16(tmp); + transpose8x12_ph(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb); + + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + 4, vget_high_f16(_sum0)); + vst1_f16(outptr0 + 8, vget_low_f16(_sum1)); + vst1_f16(outptr0 + out_hstep, vget_high_f16(_sum1)); + vst1_f16(outptr0 + out_hstep + 4, vget_low_f16(_sum2)); + vst1_f16(outptr0 + out_hstep + 8, vget_high_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 2, vget_low_f16(_sum3)); + vst1_f16(outptr0 + out_hstep * 2 + 4, vget_high_f16(_sum3)); + vst1_f16(outptr0 + out_hstep * 2 + 8, vget_low_f16(_sum4)); + vst1_f16(outptr0 + out_hstep * 3, vget_high_f16(_sum4)); + vst1_f16(outptr0 + out_hstep * 3 + 4, vget_low_f16(_sum5)); + vst1_f16(outptr0 + out_hstep * 3 + 8, vget_high_f16(_sum5)); + vst1_f16(outptr0 + out_hstep * 4, vget_low_f16(_sum6)); + vst1_f16(outptr0 + out_hstep * 4 + 4, vget_high_f16(_sum6)); + vst1_f16(outptr0 + out_hstep * 4 + 8, vget_low_f16(_sum7)); + vst1_f16(outptr0 + out_hstep * 5, vget_high_f16(_sum7)); + vst1_f16(outptr0 + out_hstep * 5 + 4, vget_low_f16(_sum8)); + vst1_f16(outptr0 + out_hstep * 5 + 8, vget_high_f16(_sum8)); + vst1_f16(outptr0 + out_hstep * 6, vget_low_f16(_sum9)); + vst1_f16(outptr0 + out_hstep * 6 + 4, vget_high_f16(_sum9)); + vst1_f16(outptr0 + out_hstep * 6 + 8, vget_low_f16(_suma)); + vst1_f16(outptr0 + out_hstep * 7, vget_high_f16(_suma)); + vst1_f16(outptr0 + out_hstep * 7 + 4, vget_low_f16(_sumb)); + vst1_f16(outptr0 + out_hstep * 7 + 8, vget_high_f16(_sumb)); + + outptr0 += 12; + } + } + else + { + __fp16 tmp[32]; + for (int r = 0; r < 12; r++) { + amx_stz(false, 2*r, tmp); + memcpy(outptr0 + 8 * r, tmp, 8 * sizeof(__fp16)); + } + } + + outptr += 96; + } + for (; jj + 7 < max_jj; jj += 8) + { + const __fp16* pA = pAT; + + float16x8_t _sum0; + float16x8_t _sum1; + float16x8_t _sum2; + float16x8_t _sum3; + float16x8_t _sum4; + float16x8_t _sum5; + float16x8_t _sum6; + float16x8_t _sum7; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1q_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + else + { + _sum0 = vdupq_n_f16(0.f); + _sum1 = vdupq_n_f16(0.f); + _sum2 = vdupq_n_f16(0.f); + _sum3 = vdupq_n_f16(0.f); + _sum4 = vdupq_n_f16(0.f); + _sum5 = vdupq_n_f16(0.f); + _sum6 = vdupq_n_f16(0.f); + _sum7 = vdupq_n_f16(0.f); + } + } + else + { + _sum0 = vld1q_f16(outptr); + _sum1 = vld1q_f16(outptr + 8); + _sum2 = vld1q_f16(outptr + 8 * 2); + _sum3 = vld1q_f16(outptr + 8 * 3); + _sum4 = vld1q_f16(outptr + 8 * 4); + _sum5 = vld1q_f16(outptr + 8 * 5); + _sum6 = vld1q_f16(outptr + 8 * 6); + _sum7 = vld1q_f16(outptr + 8 * 7); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x8_t _pA = vld1q_f16(pA); + + float16x8_t _pB = vld1q_f16(pB); + + _sum0 = vfmaq_laneq_f16(_sum0, _pA, _pB, 0); + _sum1 = vfmaq_laneq_f16(_sum1, _pA, _pB, 1); + _sum2 = vfmaq_laneq_f16(_sum2, _pA, _pB, 2); + _sum3 = vfmaq_laneq_f16(_sum3, _pA, _pB, 3); + _sum4 = vfmaq_laneq_f16(_sum4, _pA, _pB, 4); + _sum5 = vfmaq_laneq_f16(_sum5, _pA, _pB, 5); + _sum6 = vfmaq_laneq_f16(_sum6, _pA, _pB, 6); + _sum7 = vfmaq_laneq_f16(_sum7, _pA, _pB, 7); + + pA += 8; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 8) + { + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr0 + 8, _sum1); + vst1q_f16(outptr0 + 8 * 2, _sum2); + vst1q_f16(outptr0 + 8 * 3, _sum3); + vst1q_f16(outptr0 + 8 * 4, _sum4); + vst1q_f16(outptr0 + 8 * 5, _sum5); + vst1q_f16(outptr0 + 8 * 6, _sum6); + vst1q_f16(outptr0 + 8 * 7, _sum7); + outptr0 += 64; + } + if (out_elempack == 4) + { + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + 4, vget_low_f16(_sum1)); + vst1_f16(outptr0 + 4 * 2, vget_low_f16(_sum2)); + vst1_f16(outptr0 + 4 * 3, vget_low_f16(_sum3)); + vst1_f16(outptr0 + 4 * 4, vget_low_f16(_sum4)); + vst1_f16(outptr0 + 4 * 5, vget_low_f16(_sum5)); + vst1_f16(outptr0 + 4 * 6, vget_low_f16(_sum6)); + vst1_f16(outptr0 + 4 * 7, vget_low_f16(_sum7)); + + vst1_f16(outptr0 + out_hstep * 4, vget_high_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 4 + 4, vget_high_f16(_sum1)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 2, vget_high_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 3, vget_high_f16(_sum3)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 4, vget_high_f16(_sum4)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 5, vget_high_f16(_sum5)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 6, vget_high_f16(_sum6)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 7, vget_high_f16(_sum7)); + + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose8x8_ph(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr0 + out_hstep, _sum1); + vst1q_f16(outptr0 + out_hstep * 2, _sum2); + vst1q_f16(outptr0 + out_hstep * 3, _sum3); + vst1q_f16(outptr0 + out_hstep * 4, _sum4); + vst1q_f16(outptr0 + out_hstep * 5, _sum5); + vst1q_f16(outptr0 + out_hstep * 6, _sum6); + vst1q_f16(outptr0 + out_hstep * 7, _sum7); + + outptr0 += 8; + } + } + else + { + vst1q_f16(outptr, _sum0); + vst1q_f16(outptr + 8, _sum1); + vst1q_f16(outptr + 8 * 2, _sum2); + vst1q_f16(outptr + 8 * 3, _sum3); + vst1q_f16(outptr + 8 * 4, _sum4); + vst1q_f16(outptr + 8 * 5, _sum5); + vst1q_f16(outptr + 8 * 6, _sum6); + vst1q_f16(outptr + 8 * 7, _sum7); + } + + outptr += 64; + } + for (; jj + 3 < max_jj; jj += 4) + { + const __fp16* pA = pAT; + + float16x8_t _sum0; + float16x8_t _sum1; + float16x8_t _sum2; + float16x8_t _sum3; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1q_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + else + { + _sum0 = vdupq_n_f16(0.f); + _sum1 = vdupq_n_f16(0.f); + _sum2 = vdupq_n_f16(0.f); + _sum3 = vdupq_n_f16(0.f); + } + } + else + { + _sum0 = vld1q_f16(outptr); + _sum1 = vld1q_f16(outptr + 8); + _sum2 = vld1q_f16(outptr + 8 * 2); + _sum3 = vld1q_f16(outptr + 8 * 3); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x8_t _pA = vld1q_f16(pA); + + float16x4_t _pB = vld1_f16(pB); + + _sum0 = vfmaq_lane_f16(_sum0, _pA, _pB, 0); + _sum1 = vfmaq_lane_f16(_sum1, _pA, _pB, 1); + _sum2 = vfmaq_lane_f16(_sum2, _pA, _pB, 2); + _sum3 = vfmaq_lane_f16(_sum3, _pA, _pB, 3); + + pA += 8; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 8) + { + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr0 + 8, _sum1); + vst1q_f16(outptr0 + 8 * 2, _sum2); + vst1q_f16(outptr0 + 8 * 3, _sum3); + outptr0 += 32; + } + if (out_elempack == 4) + { + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + 4, vget_low_f16(_sum1)); + vst1_f16(outptr0 + 4 * 2, vget_low_f16(_sum2)); + vst1_f16(outptr0 + 4 * 3, vget_low_f16(_sum3)); + + vst1_f16(outptr0 + out_hstep * 4, vget_high_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 4 + 4, vget_high_f16(_sum1)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 2, vget_high_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 3, vget_high_f16(_sum3)); + + outptr0 += 16; + } + if (out_elempack == 1) + { + transpose8x4_ph(_sum0, _sum1, _sum2, _sum3); + + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 1, vget_high_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 2, vget_low_f16(_sum1)); + vst1_f16(outptr0 + out_hstep * 3, vget_high_f16(_sum1)); + vst1_f16(outptr0 + out_hstep * 4, vget_low_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 5, vget_high_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 6, vget_low_f16(_sum3)); + vst1_f16(outptr0 + out_hstep * 7, vget_high_f16(_sum3)); + + outptr0 += 4; + } + } + else + { + vst1q_f16(outptr, _sum0); + vst1q_f16(outptr + 8, _sum1); + vst1q_f16(outptr + 8 * 2, _sum2); + vst1q_f16(outptr + 8 * 3, _sum3); + } + + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + const __fp16* pA = pAT; + + float16x8_t _sum0; + float16x8_t _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1q_f16(pC); + _sum1 = _sum0; + } + else + { + _sum0 = vdupq_n_f16(0.f); + _sum1 = vdupq_n_f16(0.f); + } + } + else + { + _sum0 = vld1q_f16(outptr); + _sum1 = vld1q_f16(outptr + 8); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x8_t _pA = vld1q_f16(pA); + + float16x4_t _pB = vld1_f16(pB); + + _sum0 = vfmaq_lane_f16(_sum0, _pA, _pB, 0); + _sum1 = vfmaq_lane_f16(_sum1, _pA, _pB, 1); + + pA += 8; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 8) + { + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr0 + 8, _sum1); + outptr0 += 16; + } + if (out_elempack == 4) + { + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + 4, vget_low_f16(_sum1)); + + vst1_f16(outptr0 + out_hstep * 4, vget_high_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 4 + 4, vget_high_f16(_sum1)); + outptr0 += 8; + } + if (out_elempack == 1) + { + __fp16 sum0[8]; + __fp16 sum1[8]; + vst1q_f16(sum0, _sum0); + vst1q_f16(sum1, _sum1); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0[out_hstep * 4 + 1] = sum1[4]; + outptr0[out_hstep * 5 + 1] = sum1[5]; + outptr0[out_hstep * 6 + 1] = sum1[6]; + outptr0[out_hstep * 7 + 1] = sum1[7]; + outptr0 += 2; + } + } + else + { + vst1q_f16(outptr, _sum0); + vst1q_f16(outptr + 8, _sum1); + } + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + const __fp16* pA = pAT; + + float16x8_t _sum0; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1q_f16(pC); + } + else + { + _sum0 = vdupq_n_f16(0.f); + } + } + else + { + _sum0 = vld1q_f16(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x8_t _pA = vld1q_f16(pA); + + float16x8_t _pB = vld1q_dup_f16(pB); + + _sum0 = vfmaq_f16(_sum0, _pA, _pB); + + pA += 8; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 8) + { + vst1q_f16(outptr0, _sum0); + outptr0 += 8; + } + if (out_elempack == 4) + { + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 4, vget_high_f16(_sum0)); + outptr0 += 4; + } + if (out_elempack == 1) + { + __fp16 sum0[8]; + vst1q_f16(sum0, _sum0); + + outptr0[0] = sum0[0]; + outptr0[out_hstep * 1] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + outptr0++; + } + } + else + { + vst1q_f16(outptr, _sum0); + } + + outptr += 8; + } + + pAT += max_kk * 8; + } + for (; ii + 3 < max_ii; ii += 4) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const __fp16* pB = pBT; + + if (pC) + { + pC = (const __fp16*)CT_tile + i + ii; + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + float16x4_t _sum0; + float16x4_t _sum1; + float16x4_t _sum2; + float16x4_t _sum3; + float16x4_t _sum4; + float16x4_t _sum5; + float16x4_t _sum6; + float16x4_t _sum7; + float16x4_t _sum8; + float16x4_t _sum9; + float16x4_t _suma; + float16x4_t _sumb; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + _sum2 = vdup_n_f16(0.f); + _sum3 = vdup_n_f16(0.f); + _sum4 = vdup_n_f16(0.f); + _sum5 = vdup_n_f16(0.f); + _sum6 = vdup_n_f16(0.f); + _sum7 = vdup_n_f16(0.f); + _sum8 = vdup_n_f16(0.f); + _sum9 = vdup_n_f16(0.f); + _suma = vdup_n_f16(0.f); + _sumb = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4 * 1); + _sum2 = vld1_f16(outptr + 4 * 2); + _sum3 = vld1_f16(outptr + 4 * 3); + _sum4 = vld1_f16(outptr + 4 * 4); + _sum5 = vld1_f16(outptr + 4 * 5); + _sum6 = vld1_f16(outptr + 4 * 6); + _sum7 = vld1_f16(outptr + 4 * 7); + _sum8 = vld1_f16(outptr + 4 * 8); + _sum9 = vld1_f16(outptr + 4 * 9); + _suma = vld1_f16(outptr + 4 * 10); + _sumb = vld1_f16(outptr + 4 * 11); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + float16x4_t _pB2 = vld1_f16(pB + 8); + + _sum0 = vfma_lane_f16(_sum0, _pA, _pB0, 0); + _sum1 = vfma_lane_f16(_sum1, _pA, _pB0, 1); + _sum2 = vfma_lane_f16(_sum2, _pA, _pB0, 2); + _sum3 = vfma_lane_f16(_sum3, _pA, _pB0, 3); + _sum4 = vfma_lane_f16(_sum4, _pA, _pB1, 0); + _sum5 = vfma_lane_f16(_sum5, _pA, _pB1, 1); + _sum6 = vfma_lane_f16(_sum6, _pA, _pB1, 2); + _sum7 = vfma_lane_f16(_sum7, _pA, _pB1, 3); + _sum8 = vfma_lane_f16(_sum8, _pA, _pB2, 0); + _sum9 = vfma_lane_f16(_sum9, _pA, _pB2, 1); + _suma = vfma_lane_f16(_suma, _pA, _pB2, 2); + _sumb = vfma_lane_f16(_sumb, _pA, _pB2, 3); + + pA += 4; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 4 * 2, _sum2); + vst1_f16(outptr0 + 4 * 3, _sum3); + vst1_f16(outptr0 + 4 * 4, _sum4); + vst1_f16(outptr0 + 4 * 5, _sum5); + vst1_f16(outptr0 + 4 * 6, _sum6); + vst1_f16(outptr0 + 4 * 7, _sum7); + vst1_f16(outptr0 + 4 * 8, _sum8); + vst1_f16(outptr0 + 4 * 9, _sum9); + vst1_f16(outptr0 + 4 * 10, _suma); + vst1_f16(outptr0 + 4 * 11, _sumb); + outptr0 += 48; + } + if (out_elempack == 1) + { + transpose4x12_ph(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb); + + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 8, _sum2); + vst1_f16(outptr0 + out_hstep, _sum3); + vst1_f16(outptr0 + out_hstep + 4, _sum4); + vst1_f16(outptr0 + out_hstep + 8, _sum5); + vst1_f16(outptr0 + out_hstep * 2, _sum6); + vst1_f16(outptr0 + out_hstep * 2 + 4, _sum7); + vst1_f16(outptr0 + out_hstep * 2 + 8, _sum8); + vst1_f16(outptr0 + out_hstep * 3, _sum9); + vst1_f16(outptr0 + out_hstep * 3 + 4, _suma); + vst1_f16(outptr0 + out_hstep * 3 + 8, _sumb); + outptr0 += 12; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + vst1_f16(outptr + 4 * 2, _sum2); + vst1_f16(outptr + 4 * 3, _sum3); + vst1_f16(outptr + 4 * 4, _sum4); + vst1_f16(outptr + 4 * 5, _sum5); + vst1_f16(outptr + 4 * 6, _sum6); + vst1_f16(outptr + 4 * 7, _sum7); + vst1_f16(outptr + 4 * 8, _sum8); + vst1_f16(outptr + 4 * 9, _sum9); + vst1_f16(outptr + 4 * 10, _suma); + vst1_f16(outptr + 4 * 11, _sumb); + } + + outptr += 48; + } + for (; jj + 7 < max_jj; jj += 8) + { + float16x4_t _sum0; + float16x4_t _sum1; + float16x4_t _sum2; + float16x4_t _sum3; + float16x4_t _sum4; + float16x4_t _sum5; + float16x4_t _sum6; + float16x4_t _sum7; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + _sum2 = vdup_n_f16(0.f); + _sum3 = vdup_n_f16(0.f); + _sum4 = vdup_n_f16(0.f); + _sum5 = vdup_n_f16(0.f); + _sum6 = vdup_n_f16(0.f); + _sum7 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4 * 1); + _sum2 = vld1_f16(outptr + 4 * 2); + _sum3 = vld1_f16(outptr + 4 * 3); + _sum4 = vld1_f16(outptr + 4 * 4); + _sum5 = vld1_f16(outptr + 4 * 5); + _sum6 = vld1_f16(outptr + 4 * 6); + _sum7 = vld1_f16(outptr + 4 * 7); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + + _sum0 = vfma_lane_f16(_sum0, _pA, _pB0, 0); + _sum1 = vfma_lane_f16(_sum1, _pA, _pB0, 1); + _sum2 = vfma_lane_f16(_sum2, _pA, _pB0, 2); + _sum3 = vfma_lane_f16(_sum3, _pA, _pB0, 3); + _sum4 = vfma_lane_f16(_sum4, _pA, _pB1, 0); + _sum5 = vfma_lane_f16(_sum5, _pA, _pB1, 1); + _sum6 = vfma_lane_f16(_sum6, _pA, _pB1, 2); + _sum7 = vfma_lane_f16(_sum7, _pA, _pB1, 3); + + pA += 4; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 4 * 2, _sum2); + vst1_f16(outptr0 + 4 * 3, _sum3); + vst1_f16(outptr0 + 4 * 4, _sum4); + vst1_f16(outptr0 + 4 * 5, _sum5); + vst1_f16(outptr0 + 4 * 6, _sum6); + vst1_f16(outptr0 + 4 * 7, _sum7); + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose4x8_ph(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + out_hstep, _sum2); + vst1_f16(outptr0 + out_hstep + 4, _sum3); + vst1_f16(outptr0 + out_hstep * 2, _sum4); + vst1_f16(outptr0 + out_hstep * 2 + 4, _sum5); + vst1_f16(outptr0 + out_hstep * 3, _sum6); + vst1_f16(outptr0 + out_hstep * 3 + 4, _sum7); + outptr0 += 8; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + vst1_f16(outptr + 4 * 2, _sum2); + vst1_f16(outptr + 4 * 3, _sum3); + vst1_f16(outptr + 4 * 4, _sum4); + vst1_f16(outptr + 4 * 5, _sum5); + vst1_f16(outptr + 4 * 6, _sum6); + vst1_f16(outptr + 4 * 7, _sum7); + } + + outptr += 32; + } + for (; jj + 3 < max_jj; jj += 4) + { + float16x4_t _sum0; + float16x4_t _sum1; + float16x4_t _sum2; + float16x4_t _sum3; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + _sum2 = vdup_n_f16(0.f); + _sum3 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4); + _sum2 = vld1_f16(outptr + 4 * 2); + _sum3 = vld1_f16(outptr + 4 * 3); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + float16x4_t _pB = vld1_f16(pB); + + _sum0 = vfma_lane_f16(_sum0, _pA, _pB, 0); + _sum1 = vfma_lane_f16(_sum1, _pA, _pB, 1); + _sum2 = vfma_lane_f16(_sum2, _pA, _pB, 2); + _sum3 = vfma_lane_f16(_sum3, _pA, _pB, 3); + + pA += 4; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 4 * 2, _sum2); + vst1_f16(outptr0 + 4 * 3, _sum3); + outptr0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_ph(_sum0, _sum1, _sum2, _sum3); + + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + out_hstep, _sum1); + vst1_f16(outptr0 + out_hstep * 2, _sum2); + vst1_f16(outptr0 + out_hstep * 3, _sum3); + outptr0 += 4; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + vst1_f16(outptr + 4 * 2, _sum2); + vst1_f16(outptr + 4 * 3, _sum3); + } + + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + float16x4_t _sum0; + float16x4_t _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + _sum1 = _sum0; + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + + _sum0 = vfma_n_f16(_sum0, _pA, pB[0]); + _sum1 = vfma_n_f16(_sum1, _pA, pB[1]); + + pA += 4; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + outptr0 += 8; + } + if (out_elempack == 1) + { + __fp16 sum0[4]; + __fp16 sum1[4]; + vst1_f16(sum0, _sum0); + vst1_f16(sum1, _sum1); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0 += 2; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + } + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + float16x4_t _sum0; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + } + else + { + _sum0 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + float16x4_t _pB = vdup_n_f16(pB[0]); + + _sum0 = vfma_f16(_sum0, _pA, _pB); + + pA += 4; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + outptr0 += 4; + } + if (out_elempack == 1) + { + __fp16 sum0[4]; + vst1_f16(sum0, _sum0); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0++; + } + } + else + { + vst1_f16(outptr, _sum0); + } + + outptr += 4; + } + + pAT += max_kk * 4; + } + for (; ii + 1 < max_ii; ii += 2) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j; + + const __fp16* pB = pBT; + + if (pC) + { + pC = (const __fp16*)CT_tile + i + ii; + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + float16x4_t _sum00; + float16x4_t _sum01; + float16x4_t _sum02; + float16x4_t _sum10; + float16x4_t _sum11; + float16x4_t _sum12; + + if (k == 0) + { + if (pC) + { + _sum00 = vdup_n_f16(pC[0]); + _sum01 = vdup_n_f16(pC[0]); + _sum02 = vdup_n_f16(pC[0]); + _sum10 = vdup_n_f16(pC[1]); + _sum11 = vdup_n_f16(pC[1]); + _sum12 = vdup_n_f16(pC[1]); + } + else + { + _sum00 = vdup_n_f16(0.f); + _sum01 = vdup_n_f16(0.f); + _sum02 = vdup_n_f16(0.f); + _sum10 = vdup_n_f16(0.f); + _sum11 = vdup_n_f16(0.f); + _sum12 = vdup_n_f16(0.f); + } + } + else + { + float16x4x2_t _tmp01 = vld2_f16(outptr); + float16x4x2_t _tmp23 = vld2_f16(outptr + 8); + float16x4x2_t _tmp45 = vld2_f16(outptr + 16); + _sum00 = _tmp01.val[0]; + _sum01 = _tmp23.val[0]; + _sum02 = _tmp45.val[0]; + _sum10 = _tmp01.val[1]; + _sum11 = _tmp23.val[1]; + _sum12 = _tmp45.val[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + float16x4_t _pB2 = vld1_f16(pB + 8); + + float16x4_t _pA0 = vld1_dup_f16(pA); + float16x4_t _pA1 = vld1_dup_f16(pA + 1); + + _sum00 = vfma_f16(_sum00, _pB0, _pA0); + _sum01 = vfma_f16(_sum01, _pB1, _pA0); + _sum02 = vfma_f16(_sum02, _pB2, _pA0); + _sum10 = vfma_f16(_sum10, _pB0, _pA1); + _sum11 = vfma_f16(_sum11, _pB1, _pA1); + _sum12 = vfma_f16(_sum12, _pB2, _pA1); + + pA += 2; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum00); + vst1_f16(outptr0 + 4, _sum01); + vst1_f16(outptr0 + 8, _sum02); + vst1_f16(outptr0 + out_hstep, _sum10); + vst1_f16(outptr0 + out_hstep + 4, _sum11); + vst1_f16(outptr0 + out_hstep + 8, _sum12); + outptr0 += 12; + } + } + else + { + float16x4x2_t _tmp01; + _tmp01.val[0] = _sum00; + _tmp01.val[1] = _sum10; + float16x4x2_t _tmp23; + _tmp23.val[0] = _sum01; + _tmp23.val[1] = _sum11; + float16x4x2_t _tmp45; + _tmp45.val[0] = _sum02; + _tmp45.val[1] = _sum12; + vst2_f16(outptr, _tmp01); + vst2_f16(outptr + 8, _tmp23); + vst2_f16(outptr + 16, _tmp45); + } + + outptr += 24; + } + for (; jj + 7 < max_jj; jj += 8) + { + float16x4_t _sum00; + float16x4_t _sum01; + float16x4_t _sum10; + float16x4_t _sum11; + + if (k == 0) + { + if (pC) + { + _sum00 = vdup_n_f16(pC[0]); + _sum01 = vdup_n_f16(pC[0]); + _sum10 = vdup_n_f16(pC[1]); + _sum11 = vdup_n_f16(pC[1]); + } + else + { + _sum00 = vdup_n_f16(0.f); + _sum01 = vdup_n_f16(0.f); + _sum10 = vdup_n_f16(0.f); + _sum11 = vdup_n_f16(0.f); + } + } + else + { + float16x4x2_t _tmp01 = vld2_f16(outptr); + float16x4x2_t _tmp23 = vld2_f16(outptr + 8); + _sum00 = _tmp01.val[0]; + _sum01 = _tmp23.val[0]; + _sum10 = _tmp01.val[1]; + _sum11 = _tmp23.val[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + + float16x4_t _pA0 = vld1_dup_f16(pA); + float16x4_t _pA1 = vld1_dup_f16(pA + 1); + + _sum00 = vfma_f16(_sum00, _pB0, _pA0); + _sum01 = vfma_f16(_sum01, _pB1, _pA0); + _sum10 = vfma_f16(_sum10, _pB0, _pA1); + _sum11 = vfma_f16(_sum11, _pB1, _pA1); + + pA += 2; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum00); + vst1_f16(outptr0 + 4, _sum01); + vst1_f16(outptr0 + out_hstep, _sum10); + vst1_f16(outptr0 + out_hstep + 4, _sum11); + outptr0 += 8; + } + } + else + { + float16x4x2_t _tmp01; + _tmp01.val[0] = _sum00; + _tmp01.val[1] = _sum10; + float16x4x2_t _tmp23; + _tmp23.val[0] = _sum01; + _tmp23.val[1] = _sum11; + vst2_f16(outptr, _tmp01); + vst2_f16(outptr + 8, _tmp23); + } + + outptr += 16; + } + for (; jj + 3 < max_jj; jj += 4) + { + float16x4_t _sum0; + float16x4_t _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = vdup_n_f16(pC[0]); + _sum1 = vdup_n_f16(pC[1]); + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + } + } + else + { + float16x4x2_t _tmp01 = vld2_f16(outptr); + _sum0 = _tmp01.val[0]; + _sum1 = _tmp01.val[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB = vld1_f16(pB); + + _sum0 = vfma_n_f16(_sum0, _pB, pA[0]); + _sum1 = vfma_n_f16(_sum1, _pB, pA[1]); + + pA += 2; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, (_sum0)); + vst1_f16(outptr0 + out_hstep, (_sum1)); + outptr0 += 4; + } + } + else + { + float16x4x2_t _tmp01; + _tmp01.val[0] = _sum0; + _tmp01.val[1] = _sum1; + vst2_f16(outptr, _tmp01); + } + + outptr += 8; + } + for (; jj + 1 < max_jj; jj += 2) + { + __fp16 sum00; + __fp16 sum01; + __fp16 sum10; + __fp16 sum11; + + if (k == 0) + { + if (pC) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[0]; + sum11 = pC[1]; + } + else + { + sum00 = 0.f; + sum01 = 0.f; + sum10 = 0.f; + sum11 = 0.f; + } + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + + pA += 2; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum00; + outptr0[1] = sum10; + outptr0[out_hstep] = sum01; + outptr0[out_hstep + 1] = sum11; + outptr0 += 2; + } + } + else + { + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + } + + outptr += 4; + } + for (; jj < max_jj; jj += 1) + { + __fp16 sum0; + __fp16 sum1; + + if (k == 0) + { + if (pC) + { + sum0 = pC[0]; + sum1 = pC[1]; + } + else + { + sum0 = 0.f; + sum1 = 0.f; + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[out_hstep] = sum1; + outptr0++; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j; + + const __fp16* pB = pBT; + + if (pC) + { + pC = (const __fp16*)CT_tile + i + ii; + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + float16x4_t _sum0; + float16x4_t _sum1; + float16x4_t _sum2; + + if (k == 0) + { + if (pC) + { + _sum0 = vdup_n_f16(pC[0]); + _sum1 = vdup_n_f16(pC[0]); + _sum2 = vdup_n_f16(pC[0]); + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + _sum2 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4); + _sum2 = vld1_f16(outptr + 8); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + float16x4_t _pB2 = vld1_f16(pB + 8); + + float16x4_t _pA0 = vdup_n_f16(pA[0]); + + _sum0 = vfma_f16(_sum0, _pA0, _pB0); + _sum1 = vfma_f16(_sum1, _pA0, _pB1); + _sum2 = vfma_f16(_sum2, _pA0, _pB2); + + pA += 1; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 8, _sum2); + outptr0 += 12; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + vst1_f16(outptr + 8, _sum2); + } + + outptr += 12; + } + for (; jj + 7 < max_jj; jj += 8) + { + float16x4_t _sum0; + float16x4_t _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = vdup_n_f16(pC[0]); + _sum1 = vdup_n_f16(pC[0]); + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + + float16x4_t _pA0 = vdup_n_f16(pA[0]); + + _sum0 = vfma_f16(_sum0, _pA0, _pB0); + _sum1 = vfma_f16(_sum1, _pA0, _pB1); + + pA += 1; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + outptr0 += 8; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + } + + outptr += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float16x4_t _sum; + + if (k == 0) + { + if (pC) + { + _sum = vdup_n_f16(pC[0]); + } + else + { + _sum = vdup_n_f16(0.f); + } + } + else + { + _sum = vld1_f16(outptr); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB = vld1_f16(pB); + float16x4_t _pA = vdup_n_f16(pA[0]); + + _sum = vfma_f16(_sum, _pA, _pB); + + pA += 1; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum); + outptr0 += 4; + } + } + else + { + vst1_f16(outptr, _sum); + } + + outptr += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + __fp16 sum0; + __fp16 sum1; + + if (k == 0) + { + if (pC) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + else + { + sum0 = 0.f; + sum1 = 0.f; + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + + pA += 1; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[1] = sum1; + outptr0 += 2; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + __fp16 sum; + + if (k == 0) + { + if (pC) + { + sum = pC[0]; + } + else + { + sum = 0.f; + } + } + else + { + sum = outptr[0]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum += pA[0] * pB[0]; + + pA += 1; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum; + outptr0++; + } + } + else + { + outptr[0] = sum; + } + + outptr += 1; + } + + pAT += max_kk; + } + amx_clr(); +} +#endif // __aarch64__ && NCNN_APPLE_AMX + static void convolution_im2col_gemm_get_optimal_tile_mnk_fp16sa(int M, int N, int K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) { // resolve optimal tile size from cache size @@ -3195,6 +4885,10 @@ static int convolution_im2col_gemm_fp16sa(const Mat& bottom_blob, Mat& top_blob, return -100; } +#if __aarch64__ && NCNN_APPLE_AMX + bool amx_supported = cpu_support_arm_amx(); +#endif + #pragma omp parallel for num_threads(nT) for (int ppj = 0; ppj < nn_M; ppj++) { @@ -3220,7 +4914,17 @@ static int convolution_im2col_gemm_fp16sa(const Mat& bottom_blob, Mat& top_blob, bool k_end = k + TILE_K >= K; - convolution_gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end, opt.use_a53_a55_optimized_kernel); +#if __aarch64__ && NCNN_APPLE_AMX +// #if 0 + if (amx_supported) + { + convolution_gemm_transB_packed_tile_fp16sa_amx(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end); + } + else +#endif + { + convolution_gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end, opt.use_a53_a55_optimized_kernel); + } } } } From 0a23eea025b3eb43563a707967f87e45d7a9019c Mon Sep 17 00:00:00 2001 From: MollySophia Date: Thu, 10 Oct 2024 03:20:09 +0000 Subject: [PATCH 4/4] apply code-format changes --- src/layer/arm/amx_usability.h | 84 ++++++++++--------- src/layer/arm/convolution_im2col_gemm_fp16s.h | 23 ++--- 2 files changed, 58 insertions(+), 49 deletions(-) diff --git a/src/layer/arm/amx_usability.h b/src/layer/arm/amx_usability.h index d2bdc267530f..d6783351000a 100644 --- a/src/layer/arm/amx_usability.h +++ b/src/layer/arm/amx_usability.h @@ -16,35 +16,41 @@ #define AMX_USABILITY_H // From https://github.com/corsix/amx/blob/main/aarch64.h -#define AMX_NOP_OP_IMM5(op, imm5) \ - __asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" : : "i"(op), "i"(imm5) : "memory") - -#define AMX_OP_GPR(op, gpr) \ - __asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" : : "i"(op), "r"((uint64_t)(gpr)) : "memory") - -#define AMX_LDX(gpr) AMX_OP_GPR( 0, gpr) -#define AMX_LDY(gpr) AMX_OP_GPR( 1, gpr) -#define AMX_STX(gpr) AMX_OP_GPR( 2, gpr) -#define AMX_STY(gpr) AMX_OP_GPR( 3, gpr) -#define AMX_LDZ(gpr) AMX_OP_GPR( 4, gpr) -#define AMX_STZ(gpr) AMX_OP_GPR( 5, gpr) -#define AMX_LDZI(gpr) AMX_OP_GPR( 6, gpr) -#define AMX_STZI(gpr) AMX_OP_GPR( 7, gpr) -#define AMX_EXTRX(gpr) AMX_OP_GPR( 8, gpr) -#define AMX_EXTRY(gpr) AMX_OP_GPR( 9, gpr) -#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr) -#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr) -#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr) -#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr) -#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr) -#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr) -#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr) -#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr) -#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr) -#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr) -#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr) -#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr) -#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t)&*(ptr)) + (((uint64_t)((row) + (flags) * 64)) << 56)) +#define AMX_NOP_OP_IMM5(op, imm5) \ + __asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" \ + : \ + : "i"(op), "i"(imm5) \ + : "memory") + +#define AMX_OP_GPR(op, gpr) \ + __asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" \ + : \ + : "i"(op), "r"((uint64_t)(gpr)) \ + : "memory") + +#define AMX_LDX(gpr) AMX_OP_GPR(0, gpr) +#define AMX_LDY(gpr) AMX_OP_GPR(1, gpr) +#define AMX_STX(gpr) AMX_OP_GPR(2, gpr) +#define AMX_STY(gpr) AMX_OP_GPR(3, gpr) +#define AMX_LDZ(gpr) AMX_OP_GPR(4, gpr) +#define AMX_STZ(gpr) AMX_OP_GPR(5, gpr) +#define AMX_LDZI(gpr) AMX_OP_GPR(6, gpr) +#define AMX_STZI(gpr) AMX_OP_GPR(7, gpr) +#define AMX_EXTRX(gpr) AMX_OP_GPR(8, gpr) +#define AMX_EXTRY(gpr) AMX_OP_GPR(9, gpr) +#define AMX_FMA64(gpr) AMX_OP_GPR(10, gpr) +#define AMX_FMS64(gpr) AMX_OP_GPR(11, gpr) +#define AMX_FMA32(gpr) AMX_OP_GPR(12, gpr) +#define AMX_FMS32(gpr) AMX_OP_GPR(13, gpr) +#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr) +#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr) +#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr) +#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr) +#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr) +#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr) +#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr) +#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr) +#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t) & *(ptr)) + (((uint64_t)((row) + (flags)*64)) << 56)) void amx_set() { AMX_NOP_OP_IMM5(17, 0); @@ -55,7 +61,7 @@ void amx_clr() AMX_NOP_OP_IMM5(17, 1); } -void amx_ldx(bool pair, unsigned int x_row, const void * ptr) +void amx_ldx(bool pair, unsigned int x_row, const void* ptr) { if (x_row >= 8) return; @@ -63,11 +69,11 @@ void amx_ldx(bool pair, unsigned int x_row, const void * ptr) uint64_t oprand = (uint64_t)ptr + ((uint64_t)x_row << 56); if (pair) oprand |= 1ULL << 62; - + AMX_LDX(oprand); } -void amx_ldy(bool pair, unsigned int y_row, const void * ptr) +void amx_ldy(bool pair, unsigned int y_row, const void* ptr) { if (y_row >= 8) return; @@ -75,11 +81,11 @@ void amx_ldy(bool pair, unsigned int y_row, const void * ptr) uint64_t oprand = (uint64_t)ptr + ((uint64_t)y_row << 56); if (pair) oprand |= 1ULL << 62; - + AMX_LDY(oprand); } -void amx_ldz(bool pair, unsigned int z_row, const void * ptr) +void amx_ldz(bool pair, unsigned int z_row, const void* ptr) { if (z_row >= 64) return; @@ -87,11 +93,11 @@ void amx_ldz(bool pair, unsigned int z_row, const void * ptr) uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56); if (pair) oprand |= 1ULL << 62; - + AMX_LDZ(oprand); } -void amx_stz(bool pair, unsigned int z_row, const void * ptr) +void amx_stz(bool pair, unsigned int z_row, const void* ptr) { if (z_row >= 64) return; @@ -99,7 +105,7 @@ void amx_stz(bool pair, unsigned int z_row, const void * ptr) uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56); if (pair) oprand |= 1ULL << 62; - + AMX_STZ(oprand); } @@ -116,7 +122,7 @@ void amx_fma16_masked(bool vector, unsigned int x_offset, unsigned int y_offset, oprand |= ((uint64_t)y_mode & 0x3) << 37; oprand |= ((uint64_t)x_mask & 0x1F) << 41; oprand |= ((uint64_t)x_mode & 0x3) << 46; - + AMX_FMA16(oprand); } @@ -138,7 +144,7 @@ void amx_fma32_masked(bool vector, unsigned int x_offset, unsigned int y_offset, oprand |= ((uint64_t)y_mode & 0x3) << 37; oprand |= ((uint64_t)x_mask & 0x1F) << 41; oprand |= ((uint64_t)x_mode & 0x3) << 46; - + AMX_FMA32(oprand); } diff --git a/src/layer/arm/convolution_im2col_gemm_fp16s.h b/src/layer/arm/convolution_im2col_gemm_fp16s.h index 7ab071db9016..6d78ed65ae93 100644 --- a/src/layer/arm/convolution_im2col_gemm_fp16s.h +++ b/src/layer/arm/convolution_im2col_gemm_fp16s.h @@ -3056,20 +3056,20 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c if (pC) { for (int r = 0; r < 12; r++) - amx_ldz(false, 2*r, pC); + amx_ldz(false, 2 * r, pC); } else { __fp16 sums[32]; memset(sums, 0, 32 * sizeof(__fp16)); for (int r = 0; r < 12; r++) - amx_ldz(false, 2*r, sums); + amx_ldz(false, 2 * r, sums); } } else { for (int r = 0; r < 12; r++) - amx_ldz(false, 2*r, outptr + 8 * r); + amx_ldz(false, 2 * r, outptr + 8 * r); } int kk = 0; @@ -3088,8 +3088,9 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c if (out_elempack == 8) { __fp16 tmp[96 + 24]; - for (int r = 0; r < 12; r++) { - amx_stz(false, 2*r, tmp + r * 8); + for (int r = 0; r < 12; r++) + { + amx_stz(false, 2 * r, tmp + r * 8); } memcpy(outptr0, tmp, 96 * sizeof(__fp16)); outptr0 += 96; @@ -3097,8 +3098,9 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c if (out_elempack == 4) { __fp16 tmp[32]; - for (int r = 0; r < 12; r++) { - amx_stz(false, 2*r, tmp); + for (int r = 0; r < 12; r++) + { + amx_stz(false, 2 * r, tmp); float16x8_t _tmp = vld1q_f16(tmp); vst1_f16(outptr0 + 4 * r, vget_low_f16(_tmp)); vst1_f16(outptr0 + out_hstep * 4 + 4 * r, vget_high_f16(_tmp)); @@ -3167,8 +3169,9 @@ static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, c else { __fp16 tmp[32]; - for (int r = 0; r < 12; r++) { - amx_stz(false, 2*r, tmp); + for (int r = 0; r < 12; r++) + { + amx_stz(false, 2 * r, tmp); memcpy(outptr0 + 8 * r, tmp, 8 * sizeof(__fp16)); } } @@ -4915,7 +4918,7 @@ static int convolution_im2col_gemm_fp16sa(const Mat& bottom_blob, Mat& top_blob, bool k_end = k + TILE_K >= K; #if __aarch64__ && NCNN_APPLE_AMX -// #if 0 + // #if 0 if (amx_supported) { convolution_gemm_transB_packed_tile_fp16sa_amx(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end);