Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondWang0 committed Feb 25, 2024
1 parent cbde34b commit d0fed69
Showing 1 changed file with 130 additions and 0 deletions.
130 changes: 130 additions & 0 deletions kernels/neon/matmul_neon_int8_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,104 @@ static void* matmul_int8_int4_no_offset_over_column(void* args) {
return NULL;
}

inline static void* gemv_int8_int4_no_offset_over_column_unroll128(void* args) {
struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args;
const struct matmul_params* params = mat_args->params;
int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size;
const int num_block = k / block_size;
assert(m == 1);

for (int j = mat_args->start_j; j < mat_args->end_j; j++) {
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
float32x4_t sumv2 = vdupq_n_f32(0.0f);
float32x4_t sumv3 = vdupq_n_f32(0.0f);
const unsigned char* w_start = &params->B.int4_data_ptr[j * k / 2];
const signed char* a_start = &params->A.int8_data_ptr[0];
float* s_a = &params->A_scales[0];
float* s_w = &params->scales[j * k / 32];

const uint8x16_t mask_low4bit = vdupq_n_u8(0xf);
const int8x16_t offsets = vdupq_n_s8(8);
for (int q = 0; q < num_block; q += 4) {
int32x4_t int_sum0 = vdupq_n_s32(0);
int32x4_t int_sum1 = vdupq_n_s32(0);
int32x4_t int_sum2 = vdupq_n_s32(0);
int32x4_t int_sum3 = vdupq_n_s32(0);
float s_0 = *s_a++ * *s_w++;
float s_1 = *s_a++ * *s_w++;
float s_2 = *s_a++ * *s_w++;
float s_3 = *s_a++ * *s_w++;

const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight
const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight
const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight
const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight
w_start += 64;

// Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit
// sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit
// expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)...
// low; (0, 0), (1, 0), (2, 0), (3, 0) ...
// high: (16, 0), (17, 0), (18, 0), (19, 0) ...
int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit));
int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4));
int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit));
int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4));
int8x16_t w2_low = vreinterpretq_s8_u8(vandq_u8(w2, mask_low4bit));
int8x16_t w2_high = vreinterpretq_s8_u8(vshrq_n_u8(w2, 4));
int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit));
int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4));

// apply offset
w0_low = vsubq_s8(w0_low, offsets);
w0_high = vsubq_s8(w0_high, offsets);
w1_low = vsubq_s8(w1_low, offsets);
w1_high = vsubq_s8(w1_high, offsets);
w2_low = vsubq_s8(w2_low, offsets);
w2_high = vsubq_s8(w2_high, offsets);
w3_low = vsubq_s8(w3_low, offsets);
w3_high = vsubq_s8(w3_high, offsets);

// load 64 8-bit activation
const int8x16_t a0 = vld1q_s8(a_start);
const int8x16_t a1 = vld1q_s8(a_start + 16);
const int8x16_t a2 = vld1q_s8(a_start + 32);
const int8x16_t a3 = vld1q_s8(a_start + 48);
const int8x16_t a4 = vld1q_s8(a_start + 64);
const int8x16_t a5 = vld1q_s8(a_start + 80);
const int8x16_t a6 = vld1q_s8(a_start + 96);
const int8x16_t a7 = vld1q_s8(a_start + 112);
a_start += 128;

// dot product into int32x4_t
int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0);
int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1);
int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2);
int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3);
int_sum2 = my_vdotq_s32(int_sum2, w2_low, a4);
int_sum2 = my_vdotq_s32(int_sum2, w2_high, a5);
int_sum3 = my_vdotq_s32(int_sum3, w3_low, a6);
int_sum3 = my_vdotq_s32(int_sum3, w3_high, a7);

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1);
sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2);
sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3);
}
if (params->bias.data_ptr) {
params->C.data_ptr[j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) +
vaddvq_f32(sumv2) + vaddvq_f32(sumv3);
}
else {
params->C.data_ptr[j] =
vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3);
}
}

return NULL;
}

inline static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) {
struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args;
const struct matmul_params* params = mat_args->params;
Expand Down Expand Up @@ -594,4 +692,36 @@ void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_
// for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL);
pool_wait(pool);
};

void MatmulOperator::gemv_accelerator_int8_int4_fast_no_offset(struct matmul_params* params) {
int i, j, k;
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
const int block_size = params->block_size;
float *scale = params->scales, *offset = params->offset;
assert(params->block_size % 32 == 0); // support block size to be multiply of 32
assert(A->row == C->row); // support block size to be multiply of 32
assert(A->row == 1);

quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, block_size);

const int num_thread = params->opt_params.num_thread;
struct a8w4_thread_args threads_args[num_thread];
assert(params->block_size == 32); // support block size 32 for now

static void *pool = pool_start(gemv_int8_int4_no_offset_over_column_unroll128, num_thread);

// Thread creation
for (j = 0; j < num_thread; j++) {
threads_args[j].start_j = j * (params->C.column / num_thread);
if (j == num_thread - 1) {
threads_args[j].end_j = params->C.column;
} else {
threads_args[j].end_j = (j + 1) * (params->C.column / num_thread);
}
threads_args[j].params = params;
pool_enqueue(pool, &threads_args[j], '\0');
}
// Join threads
pool_wait(pool);
};
} // namespace matmul

0 comments on commit d0fed69

Please sign in to comment.