Skip to content

Commit

Permalink
amx innerproduct simple attempt
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <[email protected]>
  • Loading branch information
MollySophia committed Aug 28, 2024
1 parent 725846b commit bb58f2e
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 2 deletions.
72 changes: 70 additions & 2 deletions src/layer/arm/amx_usability.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,80 @@
#define AMX_MAC16(gpr) AMX_OP_GPR(14, gpr)
#define AMX_FMA16(gpr) AMX_OP_GPR(15, gpr)
#define AMX_FMS16(gpr) AMX_OP_GPR(16, gpr)
#define AMX_SET() AMX_NOP_OP_IMM5(17, 0)
#define AMX_CLR() AMX_NOP_OP_IMM5(17, 1)
#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr)
#define AMX_VECFP(gpr) AMX_OP_GPR(19, gpr)
#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr)
#define AMX_MATFP(gpr) AMX_OP_GPR(21, gpr)
#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr)
#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t)&*(ptr)) + (((uint64_t)((row) + (flags) * 64)) << 56))
void amx_set()
{
AMX_NOP_OP_IMM5(17, 0);
}

void amx_clr()
{
AMX_NOP_OP_IMM5(17, 1);
}

void amx_ldx(bool pair, unsigned int x_row, const void * ptr)
{
if (x_row >= 8)
return;

uint64_t oprand = (uint64_t)ptr + ((uint64_t)x_row << 56);
if (pair)
oprand |= 1ULL << 62;

AMX_LDX(oprand);
}

void amx_ldy(bool pair, unsigned int y_row, const void * ptr)
{
if (y_row >= 8)
return;

uint64_t oprand = (uint64_t)ptr + ((uint64_t)y_row << 56);
if (pair)
oprand |= 1ULL << 62;

AMX_LDY(oprand);
}

void amx_ldz(bool pair, unsigned int z_row, const void * ptr)
{
if (z_row >= 64)
return;

uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56);
if (pair)
oprand |= 1ULL << 62;

AMX_LDZ(oprand);
}

void amx_stz(bool pair, unsigned int z_row, const void * ptr)
{
if (z_row >= 64)
return;

uint64_t oprand = (uint64_t)ptr + ((uint64_t)z_row << 56);
if (pair)
oprand |= 1ULL << 62;

AMX_STZ(oprand);
}

void amx_fma32(bool vector, unsigned int x_offset, unsigned int y_offset, int z_row)
{
uint64_t oprand = 0;
if (vector)
oprand |= 1ULL << 63;

oprand |= (uint64_t)y_offset & 0x1FF;
oprand |= ((uint64_t)x_offset & 0x1FF) << 10;

AMX_FMA32(oprand);
}

#endif // AMX_USABILITY_H
72 changes: 72 additions & 0 deletions src/layer/arm/innerproduct_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "arm_activation.h"
#include "arm_usability.h"

#if __aarch64__ && NCNN_APPLE_AMX
#include "amx_usability.h"
#endif

#include "cpu.h"

namespace ncnn {
Expand Down Expand Up @@ -142,6 +146,12 @@ int InnerProduct_arm::destroy_pipeline(const Option& opt)

int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
// #if __aarch64__ && NCNN_APPLE_AMX
// AMX_SET();
// uint8_t arr[256];
// AMX_LDX(arr);
// AMX_CLR();
// #endif
#if NCNN_INT8
if (opt.use_int8_inference && int8_scale_term)
{
Expand Down Expand Up @@ -369,6 +379,68 @@ int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Optio
}

int i = 0;
#if __aarch64__ && NCNN_APPLE_AMX
amx_set();

float _sums[32] = {0};
amx_ldz(true, 0, _sums);
for (; i + 31 < num_input; i += 32)
{
amx_ldx(true, 0, m);
amx_ldy(true, 0, kptr);
amx_fma32(true, 0, 0, 0);
amx_fma32(true, 64, 64, 1);

m += 32;
kptr += 32;
}
amx_stz(true, 0, _sums);
#if __ARM_NEON
{
float32x4_t _sum = vdupq_n_f32(0.f);
for (int q = 0; q < 8; q++)
{
_sum = vaddq_f32(_sum, vld1q_f32(_sums + q * 4));
}
sum += vaddvq_f32(_sum);
memset(_sums, 0, 16 * sizeof(float));
}
#else
for (int q = 0; q < 32; q++)
{
sum += _sums[q];
_sums[q] = 0;
}
#endif
amx_ldz(false, 0, _sums);
for (; i + 15 < num_input; i += 16)
{
amx_ldx(false, 0, m);
amx_ldy(false, 0, kptr);
amx_fma32(true, 0, 0, 0);

m += 16;
kptr += 16;
}
amx_stz(false, 0, _sums);
#if __ARM_NEON
{
float32x4_t _sum = vdupq_n_f32(0.f);
for (int q = 0; q < 4; q++)
{
_sum = vaddq_f32(_sum, vld1q_f32(_sums + q * 4));
}
sum += vaddvq_f32(_sum);
}
#else
for (int q = 0; q < 16; q++)
{
sum += _sums[q];
}
#endif
amx_clr();
#endif

#if __ARM_NEON
float32x4_t _sum = vdupq_n_f32(0.f);
for (; i + 3 < num_input; i += 4)
Expand Down

0 comments on commit bb58f2e

Please sign in to comment.