diff --git a/src/layer/arm/amx_usability.h b/src/layer/arm/amx_usability.h index 6955081c495c..d2bdc267530f 100644 --- a/src/layer/arm/amx_usability.h +++ b/src/layer/arm/amx_usability.h @@ -103,7 +103,29 @@ void amx_stz(bool pair, unsigned int z_row, const void * ptr) AMX_STZ(oprand); } -void amx_fma32(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row) +void amx_fma16_masked(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row, uint8_t x_mode, uint8_t x_mask, uint8_t y_mode, uint8_t y_mask) +{ + uint64_t oprand = 0; + if (vector) + oprand |= 1ULL << 63; + + oprand |= (uint64_t)y_offset & 0x1FF; + oprand |= ((uint64_t)x_offset & 0x1FF) << 10; + oprand |= ((uint64_t)z_row & 0x3F) << 20; + oprand |= ((uint64_t)y_mask & 0x1F) << 32; + oprand |= ((uint64_t)y_mode & 0x3) << 37; + oprand |= ((uint64_t)x_mask & 0x1F) << 41; + oprand |= ((uint64_t)x_mode & 0x3) << 46; + + AMX_FMA16(oprand); +} + +void amx_fma16(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row) +{ + amx_fma16_masked(vector, x_offset, y_offset, z_row, 0, 0, 0, 0); +} + +void amx_fma32_masked(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row, uint8_t x_mode, uint8_t x_mask, uint8_t y_mode, uint8_t y_mask) { uint64_t oprand = 0; if (vector) @@ -111,8 +133,18 @@ void amx_fma32(bool vector, unsigned int x_offset, unsigned int y_offset, int z_ oprand |= (uint64_t)y_offset & 0x1FF; oprand |= ((uint64_t)x_offset & 0x1FF) << 10; + oprand |= ((uint64_t)z_row & 0x3F) << 20; + oprand |= ((uint64_t)y_mask & 0x1F) << 32; + oprand |= ((uint64_t)y_mode & 0x3) << 37; + oprand |= ((uint64_t)x_mask & 0x1F) << 41; + oprand |= ((uint64_t)x_mode & 0x3) << 46; AMX_FMA32(oprand); } +void amx_fma32(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row) +{ + amx_fma32_masked(vector, x_offset, y_offset, z_row, 0, 0, 0, 0); +} + #endif // AMX_USABILITY_H \ No newline at end of file diff --git a/src/layer/arm/convolution_im2col_gemm_fp16s.h b/src/layer/arm/convolution_im2col_gemm_fp16s.h index 8e499b3e73a8..7ab071db9016 100644 --- a/src/layer/arm/convolution_im2col_gemm_fp16s.h +++ b/src/layer/arm/convolution_im2col_gemm_fp16s.h @@ -3017,6 +3017,1696 @@ static void convolution_gemm_transB_packed_tile_fp16sa(const Mat& AT_tile, const } } +#if __aarch64__ && NCNN_APPLE_AMX +#include + +static void convolution_gemm_transB_packed_tile_fp16sa_amx(const Mat& AT_tile, const Mat& BT_tile, const Mat& CT_tile, Mat& topT_tile, Mat& top_blob, int i, int max_ii, int j, int max_jj, int k, int max_kk, bool k_end) +{ + // NCNN_LOGE("convolution_gemm_transB_packed_tile_fp16sa_amx %d %d %d %d %d %d", i, max_ii, j, max_jj, k, max_kk); + amx_set(); + const int out_elempack = top_blob.elempack; + const int out_hstep = (int)top_blob.cstep; + + const __fp16* pAT = AT_tile; + const __fp16* pBT = BT_tile; + const __fp16* pC = CT_tile; + + __fp16* outptr = topT_tile; + + int ii = 0; + for (; ii + 7 < max_ii; ii += 8) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const __fp16* pB = pBT; + + if (pC) + { + pC = (const __fp16*)CT_tile + i + ii; + } + + int jj = 0; + // NCNN_LOGE("max_jj %d", max_jj); + for (; jj + 11 < max_jj; jj += 12) + { + const __fp16* pA = pAT; + + if (k == 0) + { + if (pC) + { + for (int r = 0; r < 12; r++) + amx_ldz(false, 2*r, pC); + } + else + { + __fp16 sums[32]; + memset(sums, 0, 32 * sizeof(__fp16)); + for (int r = 0; r < 12; r++) + amx_ldz(false, 2*r, sums); + } + } + else + { + for (int r = 0; r < 12; r++) + amx_ldz(false, 2*r, outptr + 8 * r); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + amx_ldx(false, 0, pA); + amx_ldy(false, 0, pB); + amx_fma16_masked(false, 0, 0, 0, 0x2, 8, 0x2, 12); + + pA += 8; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 8) + { + __fp16 tmp[96 + 24]; + for (int r = 0; r < 12; r++) { + amx_stz(false, 2*r, tmp + r * 8); + } + memcpy(outptr0, tmp, 96 * sizeof(__fp16)); + outptr0 += 96; + } + if (out_elempack == 4) + { + __fp16 tmp[32]; + for (int r = 0; r < 12; r++) { + amx_stz(false, 2*r, tmp); + float16x8_t _tmp = vld1q_f16(tmp); + vst1_f16(outptr0 + 4 * r, vget_low_f16(_tmp)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * r, vget_high_f16(_tmp)); + } + + outptr0 += 48; + } + if (out_elempack == 1) + { + float16x8_t _sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb; + __fp16 tmp[32]; + amx_stz(false, 0, tmp); + _sum0 = vld1q_f16(tmp); + amx_stz(false, 2, tmp); + _sum1 = vld1q_f16(tmp); + amx_stz(false, 4, tmp); + _sum2 = vld1q_f16(tmp); + amx_stz(false, 6, tmp); + _sum3 = vld1q_f16(tmp); + amx_stz(false, 8, tmp); + _sum4 = vld1q_f16(tmp); + amx_stz(false, 10, tmp); + _sum5 = vld1q_f16(tmp); + amx_stz(false, 12, tmp); + _sum6 = vld1q_f16(tmp); + amx_stz(false, 14, tmp); + _sum7 = vld1q_f16(tmp); + amx_stz(false, 16, tmp); + _sum8 = vld1q_f16(tmp); + amx_stz(false, 18, tmp); + _sum9 = vld1q_f16(tmp); + amx_stz(false, 20, tmp); + _suma = vld1q_f16(tmp); + amx_stz(false, 22, tmp); + _sumb = vld1q_f16(tmp); + transpose8x12_ph(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb); + + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + 4, vget_high_f16(_sum0)); + vst1_f16(outptr0 + 8, vget_low_f16(_sum1)); + vst1_f16(outptr0 + out_hstep, vget_high_f16(_sum1)); + vst1_f16(outptr0 + out_hstep + 4, vget_low_f16(_sum2)); + vst1_f16(outptr0 + out_hstep + 8, vget_high_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 2, vget_low_f16(_sum3)); + vst1_f16(outptr0 + out_hstep * 2 + 4, vget_high_f16(_sum3)); + vst1_f16(outptr0 + out_hstep * 2 + 8, vget_low_f16(_sum4)); + vst1_f16(outptr0 + out_hstep * 3, vget_high_f16(_sum4)); + vst1_f16(outptr0 + out_hstep * 3 + 4, vget_low_f16(_sum5)); + vst1_f16(outptr0 + out_hstep * 3 + 8, vget_high_f16(_sum5)); + vst1_f16(outptr0 + out_hstep * 4, vget_low_f16(_sum6)); + vst1_f16(outptr0 + out_hstep * 4 + 4, vget_high_f16(_sum6)); + vst1_f16(outptr0 + out_hstep * 4 + 8, vget_low_f16(_sum7)); + vst1_f16(outptr0 + out_hstep * 5, vget_high_f16(_sum7)); + vst1_f16(outptr0 + out_hstep * 5 + 4, vget_low_f16(_sum8)); + vst1_f16(outptr0 + out_hstep * 5 + 8, vget_high_f16(_sum8)); + vst1_f16(outptr0 + out_hstep * 6, vget_low_f16(_sum9)); + vst1_f16(outptr0 + out_hstep * 6 + 4, vget_high_f16(_sum9)); + vst1_f16(outptr0 + out_hstep * 6 + 8, vget_low_f16(_suma)); + vst1_f16(outptr0 + out_hstep * 7, vget_high_f16(_suma)); + vst1_f16(outptr0 + out_hstep * 7 + 4, vget_low_f16(_sumb)); + vst1_f16(outptr0 + out_hstep * 7 + 8, vget_high_f16(_sumb)); + + outptr0 += 12; + } + } + else + { + __fp16 tmp[32]; + for (int r = 0; r < 12; r++) { + amx_stz(false, 2*r, tmp); + memcpy(outptr0 + 8 * r, tmp, 8 * sizeof(__fp16)); + } + } + + outptr += 96; + } + for (; jj + 7 < max_jj; jj += 8) + { + const __fp16* pA = pAT; + + float16x8_t _sum0; + float16x8_t _sum1; + float16x8_t _sum2; + float16x8_t _sum3; + float16x8_t _sum4; + float16x8_t _sum5; + float16x8_t _sum6; + float16x8_t _sum7; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1q_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + else + { + _sum0 = vdupq_n_f16(0.f); + _sum1 = vdupq_n_f16(0.f); + _sum2 = vdupq_n_f16(0.f); + _sum3 = vdupq_n_f16(0.f); + _sum4 = vdupq_n_f16(0.f); + _sum5 = vdupq_n_f16(0.f); + _sum6 = vdupq_n_f16(0.f); + _sum7 = vdupq_n_f16(0.f); + } + } + else + { + _sum0 = vld1q_f16(outptr); + _sum1 = vld1q_f16(outptr + 8); + _sum2 = vld1q_f16(outptr + 8 * 2); + _sum3 = vld1q_f16(outptr + 8 * 3); + _sum4 = vld1q_f16(outptr + 8 * 4); + _sum5 = vld1q_f16(outptr + 8 * 5); + _sum6 = vld1q_f16(outptr + 8 * 6); + _sum7 = vld1q_f16(outptr + 8 * 7); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x8_t _pA = vld1q_f16(pA); + + float16x8_t _pB = vld1q_f16(pB); + + _sum0 = vfmaq_laneq_f16(_sum0, _pA, _pB, 0); + _sum1 = vfmaq_laneq_f16(_sum1, _pA, _pB, 1); + _sum2 = vfmaq_laneq_f16(_sum2, _pA, _pB, 2); + _sum3 = vfmaq_laneq_f16(_sum3, _pA, _pB, 3); + _sum4 = vfmaq_laneq_f16(_sum4, _pA, _pB, 4); + _sum5 = vfmaq_laneq_f16(_sum5, _pA, _pB, 5); + _sum6 = vfmaq_laneq_f16(_sum6, _pA, _pB, 6); + _sum7 = vfmaq_laneq_f16(_sum7, _pA, _pB, 7); + + pA += 8; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 8) + { + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr0 + 8, _sum1); + vst1q_f16(outptr0 + 8 * 2, _sum2); + vst1q_f16(outptr0 + 8 * 3, _sum3); + vst1q_f16(outptr0 + 8 * 4, _sum4); + vst1q_f16(outptr0 + 8 * 5, _sum5); + vst1q_f16(outptr0 + 8 * 6, _sum6); + vst1q_f16(outptr0 + 8 * 7, _sum7); + outptr0 += 64; + } + if (out_elempack == 4) + { + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + 4, vget_low_f16(_sum1)); + vst1_f16(outptr0 + 4 * 2, vget_low_f16(_sum2)); + vst1_f16(outptr0 + 4 * 3, vget_low_f16(_sum3)); + vst1_f16(outptr0 + 4 * 4, vget_low_f16(_sum4)); + vst1_f16(outptr0 + 4 * 5, vget_low_f16(_sum5)); + vst1_f16(outptr0 + 4 * 6, vget_low_f16(_sum6)); + vst1_f16(outptr0 + 4 * 7, vget_low_f16(_sum7)); + + vst1_f16(outptr0 + out_hstep * 4, vget_high_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 4 + 4, vget_high_f16(_sum1)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 2, vget_high_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 3, vget_high_f16(_sum3)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 4, vget_high_f16(_sum4)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 5, vget_high_f16(_sum5)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 6, vget_high_f16(_sum6)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 7, vget_high_f16(_sum7)); + + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose8x8_ph(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr0 + out_hstep, _sum1); + vst1q_f16(outptr0 + out_hstep * 2, _sum2); + vst1q_f16(outptr0 + out_hstep * 3, _sum3); + vst1q_f16(outptr0 + out_hstep * 4, _sum4); + vst1q_f16(outptr0 + out_hstep * 5, _sum5); + vst1q_f16(outptr0 + out_hstep * 6, _sum6); + vst1q_f16(outptr0 + out_hstep * 7, _sum7); + + outptr0 += 8; + } + } + else + { + vst1q_f16(outptr, _sum0); + vst1q_f16(outptr + 8, _sum1); + vst1q_f16(outptr + 8 * 2, _sum2); + vst1q_f16(outptr + 8 * 3, _sum3); + vst1q_f16(outptr + 8 * 4, _sum4); + vst1q_f16(outptr + 8 * 5, _sum5); + vst1q_f16(outptr + 8 * 6, _sum6); + vst1q_f16(outptr + 8 * 7, _sum7); + } + + outptr += 64; + } + for (; jj + 3 < max_jj; jj += 4) + { + const __fp16* pA = pAT; + + float16x8_t _sum0; + float16x8_t _sum1; + float16x8_t _sum2; + float16x8_t _sum3; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1q_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + else + { + _sum0 = vdupq_n_f16(0.f); + _sum1 = vdupq_n_f16(0.f); + _sum2 = vdupq_n_f16(0.f); + _sum3 = vdupq_n_f16(0.f); + } + } + else + { + _sum0 = vld1q_f16(outptr); + _sum1 = vld1q_f16(outptr + 8); + _sum2 = vld1q_f16(outptr + 8 * 2); + _sum3 = vld1q_f16(outptr + 8 * 3); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x8_t _pA = vld1q_f16(pA); + + float16x4_t _pB = vld1_f16(pB); + + _sum0 = vfmaq_lane_f16(_sum0, _pA, _pB, 0); + _sum1 = vfmaq_lane_f16(_sum1, _pA, _pB, 1); + _sum2 = vfmaq_lane_f16(_sum2, _pA, _pB, 2); + _sum3 = vfmaq_lane_f16(_sum3, _pA, _pB, 3); + + pA += 8; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 8) + { + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr0 + 8, _sum1); + vst1q_f16(outptr0 + 8 * 2, _sum2); + vst1q_f16(outptr0 + 8 * 3, _sum3); + outptr0 += 32; + } + if (out_elempack == 4) + { + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + 4, vget_low_f16(_sum1)); + vst1_f16(outptr0 + 4 * 2, vget_low_f16(_sum2)); + vst1_f16(outptr0 + 4 * 3, vget_low_f16(_sum3)); + + vst1_f16(outptr0 + out_hstep * 4, vget_high_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 4 + 4, vget_high_f16(_sum1)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 2, vget_high_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 4 + 4 * 3, vget_high_f16(_sum3)); + + outptr0 += 16; + } + if (out_elempack == 1) + { + transpose8x4_ph(_sum0, _sum1, _sum2, _sum3); + + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 1, vget_high_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 2, vget_low_f16(_sum1)); + vst1_f16(outptr0 + out_hstep * 3, vget_high_f16(_sum1)); + vst1_f16(outptr0 + out_hstep * 4, vget_low_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 5, vget_high_f16(_sum2)); + vst1_f16(outptr0 + out_hstep * 6, vget_low_f16(_sum3)); + vst1_f16(outptr0 + out_hstep * 7, vget_high_f16(_sum3)); + + outptr0 += 4; + } + } + else + { + vst1q_f16(outptr, _sum0); + vst1q_f16(outptr + 8, _sum1); + vst1q_f16(outptr + 8 * 2, _sum2); + vst1q_f16(outptr + 8 * 3, _sum3); + } + + outptr += 32; + } + for (; jj + 1 < max_jj; jj += 2) + { + const __fp16* pA = pAT; + + float16x8_t _sum0; + float16x8_t _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1q_f16(pC); + _sum1 = _sum0; + } + else + { + _sum0 = vdupq_n_f16(0.f); + _sum1 = vdupq_n_f16(0.f); + } + } + else + { + _sum0 = vld1q_f16(outptr); + _sum1 = vld1q_f16(outptr + 8); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x8_t _pA = vld1q_f16(pA); + + float16x4_t _pB = vld1_f16(pB); + + _sum0 = vfmaq_lane_f16(_sum0, _pA, _pB, 0); + _sum1 = vfmaq_lane_f16(_sum1, _pA, _pB, 1); + + pA += 8; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 8) + { + vst1q_f16(outptr0, _sum0); + vst1q_f16(outptr0 + 8, _sum1); + outptr0 += 16; + } + if (out_elempack == 4) + { + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + 4, vget_low_f16(_sum1)); + + vst1_f16(outptr0 + out_hstep * 4, vget_high_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 4 + 4, vget_high_f16(_sum1)); + outptr0 += 8; + } + if (out_elempack == 1) + { + __fp16 sum0[8]; + __fp16 sum1[8]; + vst1q_f16(sum0, _sum0); + vst1q_f16(sum1, _sum1); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0[out_hstep * 4 + 1] = sum1[4]; + outptr0[out_hstep * 5 + 1] = sum1[5]; + outptr0[out_hstep * 6 + 1] = sum1[6]; + outptr0[out_hstep * 7 + 1] = sum1[7]; + outptr0 += 2; + } + } + else + { + vst1q_f16(outptr, _sum0); + vst1q_f16(outptr + 8, _sum1); + } + + outptr += 16; + } + for (; jj < max_jj; jj += 1) + { + const __fp16* pA = pAT; + + float16x8_t _sum0; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1q_f16(pC); + } + else + { + _sum0 = vdupq_n_f16(0.f); + } + } + else + { + _sum0 = vld1q_f16(outptr); + } + + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x8_t _pA = vld1q_f16(pA); + + float16x8_t _pB = vld1q_dup_f16(pB); + + _sum0 = vfmaq_f16(_sum0, _pA, _pB); + + pA += 8; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 8) + { + vst1q_f16(outptr0, _sum0); + outptr0 += 8; + } + if (out_elempack == 4) + { + vst1_f16(outptr0, vget_low_f16(_sum0)); + vst1_f16(outptr0 + out_hstep * 4, vget_high_f16(_sum0)); + outptr0 += 4; + } + if (out_elempack == 1) + { + __fp16 sum0[8]; + vst1q_f16(sum0, _sum0); + + outptr0[0] = sum0[0]; + outptr0[out_hstep * 1] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[out_hstep * 4] = sum0[4]; + outptr0[out_hstep * 5] = sum0[5]; + outptr0[out_hstep * 6] = sum0[6]; + outptr0[out_hstep * 7] = sum0[7]; + outptr0++; + } + } + else + { + vst1q_f16(outptr, _sum0); + } + + outptr += 8; + } + + pAT += max_kk * 8; + } + for (; ii + 3 < max_ii; ii += 4) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j * out_elempack; + + const __fp16* pB = pBT; + + if (pC) + { + pC = (const __fp16*)CT_tile + i + ii; + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + float16x4_t _sum0; + float16x4_t _sum1; + float16x4_t _sum2; + float16x4_t _sum3; + float16x4_t _sum4; + float16x4_t _sum5; + float16x4_t _sum6; + float16x4_t _sum7; + float16x4_t _sum8; + float16x4_t _sum9; + float16x4_t _suma; + float16x4_t _sumb; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + _sum8 = _sum0; + _sum9 = _sum0; + _suma = _sum0; + _sumb = _sum0; + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + _sum2 = vdup_n_f16(0.f); + _sum3 = vdup_n_f16(0.f); + _sum4 = vdup_n_f16(0.f); + _sum5 = vdup_n_f16(0.f); + _sum6 = vdup_n_f16(0.f); + _sum7 = vdup_n_f16(0.f); + _sum8 = vdup_n_f16(0.f); + _sum9 = vdup_n_f16(0.f); + _suma = vdup_n_f16(0.f); + _sumb = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4 * 1); + _sum2 = vld1_f16(outptr + 4 * 2); + _sum3 = vld1_f16(outptr + 4 * 3); + _sum4 = vld1_f16(outptr + 4 * 4); + _sum5 = vld1_f16(outptr + 4 * 5); + _sum6 = vld1_f16(outptr + 4 * 6); + _sum7 = vld1_f16(outptr + 4 * 7); + _sum8 = vld1_f16(outptr + 4 * 8); + _sum9 = vld1_f16(outptr + 4 * 9); + _suma = vld1_f16(outptr + 4 * 10); + _sumb = vld1_f16(outptr + 4 * 11); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + float16x4_t _pB2 = vld1_f16(pB + 8); + + _sum0 = vfma_lane_f16(_sum0, _pA, _pB0, 0); + _sum1 = vfma_lane_f16(_sum1, _pA, _pB0, 1); + _sum2 = vfma_lane_f16(_sum2, _pA, _pB0, 2); + _sum3 = vfma_lane_f16(_sum3, _pA, _pB0, 3); + _sum4 = vfma_lane_f16(_sum4, _pA, _pB1, 0); + _sum5 = vfma_lane_f16(_sum5, _pA, _pB1, 1); + _sum6 = vfma_lane_f16(_sum6, _pA, _pB1, 2); + _sum7 = vfma_lane_f16(_sum7, _pA, _pB1, 3); + _sum8 = vfma_lane_f16(_sum8, _pA, _pB2, 0); + _sum9 = vfma_lane_f16(_sum9, _pA, _pB2, 1); + _suma = vfma_lane_f16(_suma, _pA, _pB2, 2); + _sumb = vfma_lane_f16(_sumb, _pA, _pB2, 3); + + pA += 4; + pB += 12; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 4 * 2, _sum2); + vst1_f16(outptr0 + 4 * 3, _sum3); + vst1_f16(outptr0 + 4 * 4, _sum4); + vst1_f16(outptr0 + 4 * 5, _sum5); + vst1_f16(outptr0 + 4 * 6, _sum6); + vst1_f16(outptr0 + 4 * 7, _sum7); + vst1_f16(outptr0 + 4 * 8, _sum8); + vst1_f16(outptr0 + 4 * 9, _sum9); + vst1_f16(outptr0 + 4 * 10, _suma); + vst1_f16(outptr0 + 4 * 11, _sumb); + outptr0 += 48; + } + if (out_elempack == 1) + { + transpose4x12_ph(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7, _sum8, _sum9, _suma, _sumb); + + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 8, _sum2); + vst1_f16(outptr0 + out_hstep, _sum3); + vst1_f16(outptr0 + out_hstep + 4, _sum4); + vst1_f16(outptr0 + out_hstep + 8, _sum5); + vst1_f16(outptr0 + out_hstep * 2, _sum6); + vst1_f16(outptr0 + out_hstep * 2 + 4, _sum7); + vst1_f16(outptr0 + out_hstep * 2 + 8, _sum8); + vst1_f16(outptr0 + out_hstep * 3, _sum9); + vst1_f16(outptr0 + out_hstep * 3 + 4, _suma); + vst1_f16(outptr0 + out_hstep * 3 + 8, _sumb); + outptr0 += 12; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + vst1_f16(outptr + 4 * 2, _sum2); + vst1_f16(outptr + 4 * 3, _sum3); + vst1_f16(outptr + 4 * 4, _sum4); + vst1_f16(outptr + 4 * 5, _sum5); + vst1_f16(outptr + 4 * 6, _sum6); + vst1_f16(outptr + 4 * 7, _sum7); + vst1_f16(outptr + 4 * 8, _sum8); + vst1_f16(outptr + 4 * 9, _sum9); + vst1_f16(outptr + 4 * 10, _suma); + vst1_f16(outptr + 4 * 11, _sumb); + } + + outptr += 48; + } + for (; jj + 7 < max_jj; jj += 8) + { + float16x4_t _sum0; + float16x4_t _sum1; + float16x4_t _sum2; + float16x4_t _sum3; + float16x4_t _sum4; + float16x4_t _sum5; + float16x4_t _sum6; + float16x4_t _sum7; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + _sum4 = _sum0; + _sum5 = _sum0; + _sum6 = _sum0; + _sum7 = _sum0; + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + _sum2 = vdup_n_f16(0.f); + _sum3 = vdup_n_f16(0.f); + _sum4 = vdup_n_f16(0.f); + _sum5 = vdup_n_f16(0.f); + _sum6 = vdup_n_f16(0.f); + _sum7 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4 * 1); + _sum2 = vld1_f16(outptr + 4 * 2); + _sum3 = vld1_f16(outptr + 4 * 3); + _sum4 = vld1_f16(outptr + 4 * 4); + _sum5 = vld1_f16(outptr + 4 * 5); + _sum6 = vld1_f16(outptr + 4 * 6); + _sum7 = vld1_f16(outptr + 4 * 7); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + + _sum0 = vfma_lane_f16(_sum0, _pA, _pB0, 0); + _sum1 = vfma_lane_f16(_sum1, _pA, _pB0, 1); + _sum2 = vfma_lane_f16(_sum2, _pA, _pB0, 2); + _sum3 = vfma_lane_f16(_sum3, _pA, _pB0, 3); + _sum4 = vfma_lane_f16(_sum4, _pA, _pB1, 0); + _sum5 = vfma_lane_f16(_sum5, _pA, _pB1, 1); + _sum6 = vfma_lane_f16(_sum6, _pA, _pB1, 2); + _sum7 = vfma_lane_f16(_sum7, _pA, _pB1, 3); + + pA += 4; + pB += 8; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 4 * 2, _sum2); + vst1_f16(outptr0 + 4 * 3, _sum3); + vst1_f16(outptr0 + 4 * 4, _sum4); + vst1_f16(outptr0 + 4 * 5, _sum5); + vst1_f16(outptr0 + 4 * 6, _sum6); + vst1_f16(outptr0 + 4 * 7, _sum7); + outptr0 += 32; + } + if (out_elempack == 1) + { + transpose4x8_ph(_sum0, _sum1, _sum2, _sum3, _sum4, _sum5, _sum6, _sum7); + + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + out_hstep, _sum2); + vst1_f16(outptr0 + out_hstep + 4, _sum3); + vst1_f16(outptr0 + out_hstep * 2, _sum4); + vst1_f16(outptr0 + out_hstep * 2 + 4, _sum5); + vst1_f16(outptr0 + out_hstep * 3, _sum6); + vst1_f16(outptr0 + out_hstep * 3 + 4, _sum7); + outptr0 += 8; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + vst1_f16(outptr + 4 * 2, _sum2); + vst1_f16(outptr + 4 * 3, _sum3); + vst1_f16(outptr + 4 * 4, _sum4); + vst1_f16(outptr + 4 * 5, _sum5); + vst1_f16(outptr + 4 * 6, _sum6); + vst1_f16(outptr + 4 * 7, _sum7); + } + + outptr += 32; + } + for (; jj + 3 < max_jj; jj += 4) + { + float16x4_t _sum0; + float16x4_t _sum1; + float16x4_t _sum2; + float16x4_t _sum3; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + _sum1 = _sum0; + _sum2 = _sum0; + _sum3 = _sum0; + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + _sum2 = vdup_n_f16(0.f); + _sum3 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4); + _sum2 = vld1_f16(outptr + 4 * 2); + _sum3 = vld1_f16(outptr + 4 * 3); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + float16x4_t _pB = vld1_f16(pB); + + _sum0 = vfma_lane_f16(_sum0, _pA, _pB, 0); + _sum1 = vfma_lane_f16(_sum1, _pA, _pB, 1); + _sum2 = vfma_lane_f16(_sum2, _pA, _pB, 2); + _sum3 = vfma_lane_f16(_sum3, _pA, _pB, 3); + + pA += 4; + pB += 4; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 4 * 2, _sum2); + vst1_f16(outptr0 + 4 * 3, _sum3); + outptr0 += 16; + } + if (out_elempack == 1) + { + transpose4x4_ph(_sum0, _sum1, _sum2, _sum3); + + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + out_hstep, _sum1); + vst1_f16(outptr0 + out_hstep * 2, _sum2); + vst1_f16(outptr0 + out_hstep * 3, _sum3); + outptr0 += 4; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + vst1_f16(outptr + 4 * 2, _sum2); + vst1_f16(outptr + 4 * 3, _sum3); + } + + outptr += 16; + } + for (; jj + 1 < max_jj; jj += 2) + { + float16x4_t _sum0; + float16x4_t _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + _sum1 = _sum0; + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + + _sum0 = vfma_n_f16(_sum0, _pA, pB[0]); + _sum1 = vfma_n_f16(_sum1, _pA, pB[1]); + + pA += 4; + pB += 2; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + outptr0 += 8; + } + if (out_elempack == 1) + { + __fp16 sum0[4]; + __fp16 sum1[4]; + vst1_f16(sum0, _sum0); + vst1_f16(sum1, _sum1); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0[1] = sum1[0]; + outptr0[out_hstep + 1] = sum1[1]; + outptr0[out_hstep * 2 + 1] = sum1[2]; + outptr0[out_hstep * 3 + 1] = sum1[3]; + outptr0 += 2; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + } + + outptr += 8; + } + for (; jj < max_jj; jj += 1) + { + float16x4_t _sum0; + + if (k == 0) + { + if (pC) + { + _sum0 = vld1_f16(pC); + } + else + { + _sum0 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pA = vld1_f16(pA); + float16x4_t _pB = vdup_n_f16(pB[0]); + + _sum0 = vfma_f16(_sum0, _pA, _pB); + + pA += 4; + pB += 1; + } + + if (k_end) + { + if (out_elempack == 4) + { + vst1_f16(outptr0, _sum0); + outptr0 += 4; + } + if (out_elempack == 1) + { + __fp16 sum0[4]; + vst1_f16(sum0, _sum0); + + outptr0[0] = sum0[0]; + outptr0[out_hstep] = sum0[1]; + outptr0[out_hstep * 2] = sum0[2]; + outptr0[out_hstep * 3] = sum0[3]; + outptr0++; + } + } + else + { + vst1_f16(outptr, _sum0); + } + + outptr += 4; + } + + pAT += max_kk * 4; + } + for (; ii + 1 < max_ii; ii += 2) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j; + + const __fp16* pB = pBT; + + if (pC) + { + pC = (const __fp16*)CT_tile + i + ii; + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + float16x4_t _sum00; + float16x4_t _sum01; + float16x4_t _sum02; + float16x4_t _sum10; + float16x4_t _sum11; + float16x4_t _sum12; + + if (k == 0) + { + if (pC) + { + _sum00 = vdup_n_f16(pC[0]); + _sum01 = vdup_n_f16(pC[0]); + _sum02 = vdup_n_f16(pC[0]); + _sum10 = vdup_n_f16(pC[1]); + _sum11 = vdup_n_f16(pC[1]); + _sum12 = vdup_n_f16(pC[1]); + } + else + { + _sum00 = vdup_n_f16(0.f); + _sum01 = vdup_n_f16(0.f); + _sum02 = vdup_n_f16(0.f); + _sum10 = vdup_n_f16(0.f); + _sum11 = vdup_n_f16(0.f); + _sum12 = vdup_n_f16(0.f); + } + } + else + { + float16x4x2_t _tmp01 = vld2_f16(outptr); + float16x4x2_t _tmp23 = vld2_f16(outptr + 8); + float16x4x2_t _tmp45 = vld2_f16(outptr + 16); + _sum00 = _tmp01.val[0]; + _sum01 = _tmp23.val[0]; + _sum02 = _tmp45.val[0]; + _sum10 = _tmp01.val[1]; + _sum11 = _tmp23.val[1]; + _sum12 = _tmp45.val[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + float16x4_t _pB2 = vld1_f16(pB + 8); + + float16x4_t _pA0 = vld1_dup_f16(pA); + float16x4_t _pA1 = vld1_dup_f16(pA + 1); + + _sum00 = vfma_f16(_sum00, _pB0, _pA0); + _sum01 = vfma_f16(_sum01, _pB1, _pA0); + _sum02 = vfma_f16(_sum02, _pB2, _pA0); + _sum10 = vfma_f16(_sum10, _pB0, _pA1); + _sum11 = vfma_f16(_sum11, _pB1, _pA1); + _sum12 = vfma_f16(_sum12, _pB2, _pA1); + + pA += 2; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum00); + vst1_f16(outptr0 + 4, _sum01); + vst1_f16(outptr0 + 8, _sum02); + vst1_f16(outptr0 + out_hstep, _sum10); + vst1_f16(outptr0 + out_hstep + 4, _sum11); + vst1_f16(outptr0 + out_hstep + 8, _sum12); + outptr0 += 12; + } + } + else + { + float16x4x2_t _tmp01; + _tmp01.val[0] = _sum00; + _tmp01.val[1] = _sum10; + float16x4x2_t _tmp23; + _tmp23.val[0] = _sum01; + _tmp23.val[1] = _sum11; + float16x4x2_t _tmp45; + _tmp45.val[0] = _sum02; + _tmp45.val[1] = _sum12; + vst2_f16(outptr, _tmp01); + vst2_f16(outptr + 8, _tmp23); + vst2_f16(outptr + 16, _tmp45); + } + + outptr += 24; + } + for (; jj + 7 < max_jj; jj += 8) + { + float16x4_t _sum00; + float16x4_t _sum01; + float16x4_t _sum10; + float16x4_t _sum11; + + if (k == 0) + { + if (pC) + { + _sum00 = vdup_n_f16(pC[0]); + _sum01 = vdup_n_f16(pC[0]); + _sum10 = vdup_n_f16(pC[1]); + _sum11 = vdup_n_f16(pC[1]); + } + else + { + _sum00 = vdup_n_f16(0.f); + _sum01 = vdup_n_f16(0.f); + _sum10 = vdup_n_f16(0.f); + _sum11 = vdup_n_f16(0.f); + } + } + else + { + float16x4x2_t _tmp01 = vld2_f16(outptr); + float16x4x2_t _tmp23 = vld2_f16(outptr + 8); + _sum00 = _tmp01.val[0]; + _sum01 = _tmp23.val[0]; + _sum10 = _tmp01.val[1]; + _sum11 = _tmp23.val[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + + float16x4_t _pA0 = vld1_dup_f16(pA); + float16x4_t _pA1 = vld1_dup_f16(pA + 1); + + _sum00 = vfma_f16(_sum00, _pB0, _pA0); + _sum01 = vfma_f16(_sum01, _pB1, _pA0); + _sum10 = vfma_f16(_sum10, _pB0, _pA1); + _sum11 = vfma_f16(_sum11, _pB1, _pA1); + + pA += 2; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum00); + vst1_f16(outptr0 + 4, _sum01); + vst1_f16(outptr0 + out_hstep, _sum10); + vst1_f16(outptr0 + out_hstep + 4, _sum11); + outptr0 += 8; + } + } + else + { + float16x4x2_t _tmp01; + _tmp01.val[0] = _sum00; + _tmp01.val[1] = _sum10; + float16x4x2_t _tmp23; + _tmp23.val[0] = _sum01; + _tmp23.val[1] = _sum11; + vst2_f16(outptr, _tmp01); + vst2_f16(outptr + 8, _tmp23); + } + + outptr += 16; + } + for (; jj + 3 < max_jj; jj += 4) + { + float16x4_t _sum0; + float16x4_t _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = vdup_n_f16(pC[0]); + _sum1 = vdup_n_f16(pC[1]); + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + } + } + else + { + float16x4x2_t _tmp01 = vld2_f16(outptr); + _sum0 = _tmp01.val[0]; + _sum1 = _tmp01.val[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB = vld1_f16(pB); + + _sum0 = vfma_n_f16(_sum0, _pB, pA[0]); + _sum1 = vfma_n_f16(_sum1, _pB, pA[1]); + + pA += 2; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, (_sum0)); + vst1_f16(outptr0 + out_hstep, (_sum1)); + outptr0 += 4; + } + } + else + { + float16x4x2_t _tmp01; + _tmp01.val[0] = _sum0; + _tmp01.val[1] = _sum1; + vst2_f16(outptr, _tmp01); + } + + outptr += 8; + } + for (; jj + 1 < max_jj; jj += 2) + { + __fp16 sum00; + __fp16 sum01; + __fp16 sum10; + __fp16 sum11; + + if (k == 0) + { + if (pC) + { + sum00 = pC[0]; + sum01 = pC[1]; + sum10 = pC[0]; + sum11 = pC[1]; + } + else + { + sum00 = 0.f; + sum01 = 0.f; + sum10 = 0.f; + sum11 = 0.f; + } + } + else + { + sum00 = outptr[0]; + sum01 = outptr[1]; + sum10 = outptr[2]; + sum11 = outptr[3]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum00 += pA[0] * pB[0]; + sum01 += pA[1] * pB[0]; + sum10 += pA[0] * pB[1]; + sum11 += pA[1] * pB[1]; + + pA += 2; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum00; + outptr0[1] = sum10; + outptr0[out_hstep] = sum01; + outptr0[out_hstep + 1] = sum11; + outptr0 += 2; + } + } + else + { + outptr[0] = sum00; + outptr[1] = sum01; + outptr[2] = sum10; + outptr[3] = sum11; + } + + outptr += 4; + } + for (; jj < max_jj; jj += 1) + { + __fp16 sum0; + __fp16 sum1; + + if (k == 0) + { + if (pC) + { + sum0 = pC[0]; + sum1 = pC[1]; + } + else + { + sum0 = 0.f; + sum1 = 0.f; + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[1] * pB[0]; + pA += 2; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[out_hstep] = sum1; + outptr0++; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + + pAT += max_kk * 2; + } + for (; ii < max_ii; ii += 1) + { + __fp16* outptr0 = (__fp16*)top_blob + (i + ii) * out_hstep + j; + + const __fp16* pB = pBT; + + if (pC) + { + pC = (const __fp16*)CT_tile + i + ii; + } + + int jj = 0; + for (; jj + 11 < max_jj; jj += 12) + { + float16x4_t _sum0; + float16x4_t _sum1; + float16x4_t _sum2; + + if (k == 0) + { + if (pC) + { + _sum0 = vdup_n_f16(pC[0]); + _sum1 = vdup_n_f16(pC[0]); + _sum2 = vdup_n_f16(pC[0]); + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + _sum2 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4); + _sum2 = vld1_f16(outptr + 8); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + float16x4_t _pB2 = vld1_f16(pB + 8); + + float16x4_t _pA0 = vdup_n_f16(pA[0]); + + _sum0 = vfma_f16(_sum0, _pA0, _pB0); + _sum1 = vfma_f16(_sum1, _pA0, _pB1); + _sum2 = vfma_f16(_sum2, _pA0, _pB2); + + pA += 1; + pB += 12; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + vst1_f16(outptr0 + 8, _sum2); + outptr0 += 12; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + vst1_f16(outptr + 8, _sum2); + } + + outptr += 12; + } + for (; jj + 7 < max_jj; jj += 8) + { + float16x4_t _sum0; + float16x4_t _sum1; + + if (k == 0) + { + if (pC) + { + _sum0 = vdup_n_f16(pC[0]); + _sum1 = vdup_n_f16(pC[0]); + } + else + { + _sum0 = vdup_n_f16(0.f); + _sum1 = vdup_n_f16(0.f); + } + } + else + { + _sum0 = vld1_f16(outptr); + _sum1 = vld1_f16(outptr + 4); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB0 = vld1_f16(pB); + float16x4_t _pB1 = vld1_f16(pB + 4); + + float16x4_t _pA0 = vdup_n_f16(pA[0]); + + _sum0 = vfma_f16(_sum0, _pA0, _pB0); + _sum1 = vfma_f16(_sum1, _pA0, _pB1); + + pA += 1; + pB += 8; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum0); + vst1_f16(outptr0 + 4, _sum1); + outptr0 += 8; + } + } + else + { + vst1_f16(outptr, _sum0); + vst1_f16(outptr + 4, _sum1); + } + + outptr += 8; + } + for (; jj + 3 < max_jj; jj += 4) + { + float16x4_t _sum; + + if (k == 0) + { + if (pC) + { + _sum = vdup_n_f16(pC[0]); + } + else + { + _sum = vdup_n_f16(0.f); + } + } + else + { + _sum = vld1_f16(outptr); + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + float16x4_t _pB = vld1_f16(pB); + float16x4_t _pA = vdup_n_f16(pA[0]); + + _sum = vfma_f16(_sum, _pA, _pB); + + pA += 1; + pB += 4; + } + + if (k_end) + { + // if (out_elempack == 1) + { + vst1_f16(outptr0, _sum); + outptr0 += 4; + } + } + else + { + vst1_f16(outptr, _sum); + } + + outptr += 4; + } + for (; jj + 1 < max_jj; jj += 2) + { + __fp16 sum0; + __fp16 sum1; + + if (k == 0) + { + if (pC) + { + sum0 = pC[0]; + sum1 = pC[0]; + } + else + { + sum0 = 0.f; + sum1 = 0.f; + } + } + else + { + sum0 = outptr[0]; + sum1 = outptr[1]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum0 += pA[0] * pB[0]; + sum1 += pA[0] * pB[1]; + + pA += 1; + pB += 2; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum0; + outptr0[1] = sum1; + outptr0 += 2; + } + } + else + { + outptr[0] = sum0; + outptr[1] = sum1; + } + + outptr += 2; + } + for (; jj < max_jj; jj += 1) + { + __fp16 sum; + + if (k == 0) + { + if (pC) + { + sum = pC[0]; + } + else + { + sum = 0.f; + } + } + else + { + sum = outptr[0]; + } + + const __fp16* pA = pAT; + int kk = 0; + for (; kk < max_kk; kk += 1) + { + sum += pA[0] * pB[0]; + + pA += 1; + pB += 1; + } + + if (k_end) + { + // if (out_elempack == 1) + { + outptr0[0] = sum; + outptr0++; + } + } + else + { + outptr[0] = sum; + } + + outptr += 1; + } + + pAT += max_kk; + } + amx_clr(); +} +#endif // __aarch64__ && NCNN_APPLE_AMX + static void convolution_im2col_gemm_get_optimal_tile_mnk_fp16sa(int M, int N, int K, int& TILE_M, int& TILE_N, int& TILE_K, int nT) { // resolve optimal tile size from cache size @@ -3195,6 +4885,10 @@ static int convolution_im2col_gemm_fp16sa(const Mat& bottom_blob, Mat& top_blob, return -100; } +#if __aarch64__ && NCNN_APPLE_AMX + bool amx_supported = cpu_support_arm_amx(); +#endif + #pragma omp parallel for num_threads(nT) for (int ppj = 0; ppj < nn_M; ppj++) { @@ -3220,7 +4914,17 @@ static int convolution_im2col_gemm_fp16sa(const Mat& bottom_blob, Mat& top_blob, bool k_end = k + TILE_K >= K; - convolution_gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end, opt.use_a53_a55_optimized_kernel); +#if __aarch64__ && NCNN_APPLE_AMX +// #if 0 + if (amx_supported) + { + convolution_gemm_transB_packed_tile_fp16sa_amx(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end); + } + else +#endif + { + convolution_gemm_transB_packed_tile_fp16sa(AT_tile, BT_tile, bias, topT_tile, top_blob, i, max_ii, j, max_jj, k, max_kk, k_end, opt.use_a53_a55_optimized_kernel); + } } } }