From 81a0aace7b5cd313e1d3218b5f4c17fa2e8d11c4 Mon Sep 17 00:00:00 2001 From: wangshuai09 <391746016@qq.com> Date: Tue, 25 Jun 2024 08:50:54 +0000 Subject: [PATCH] update softmax --- ggml-cann/aclnn_ops.cpp | 234 +++++++++++++++++++++++++------------ ggml.c | 164 +++++++++++++++++--------- ggml.h | 1 - llama.cpp | 4 +- tests/test-backend-ops.cpp | 53 +++++---- 5 files changed, 301 insertions(+), 155 deletions(-) diff --git a/ggml-cann/aclnn_ops.cpp b/ggml-cann/aclnn_ops.cpp index e8c52a621db4a..7025842e4bae3 100644 --- a/ggml-cann/aclnn_ops.cpp +++ b/ggml-cann/aclnn_ops.cpp @@ -1271,22 +1271,36 @@ void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream())); } -void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, float scale, - ggml_tensor* bind_tensor) { +static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, + float scale, aclTensor* acl_dst, bool inplace, + ggml_tensor* bind_tensor) { aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); uint64_t workspaceSize = 0; aclOpExecutor* executor; void* workspaceAddr = nullptr; - ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale, - &workspaceSize, &executor)); - if (workspaceSize > 0) { - workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize); + if (inplace) { + ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale, + &workspaceSize, &executor)); + if (workspaceSize > 0) { + workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize); + } + + ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor, + ctx.stream())); } + else { + ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst, + &workspaceSize, &executor)); + if (workspaceSize > 0) { + workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize); + } - ACL_CHECK( - aclnnInplaceMuls(workspaceAddr, workspaceSize, executor, ctx.stream())); + ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor, + ctx.stream())); + } + ACL_CHECK(aclDestroyScalar(acl_scale)); } @@ -1384,7 +1398,8 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, // freq float freq_param = -logf(max_period) / half; - aclnn_muls(ctx, tmp_arange_tensor, freq_param, dst); + bool inplace = true; + aclnn_muls(ctx, tmp_arange_tensor, freq_param, nullptr, inplace, dst); aclnn_exp(ctx, tmp_arange_tensor, dst); // permute: src [0,1,2,3]->[0,1,3,2] @@ -1479,8 +1494,9 @@ void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, ACL_CHECK(aclDestroyScalar(acl_scalar)); } -void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst, - aclTensor* acl_exp, ggml_tensor* bind_tensor) { +static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, + aclTensor* acl_dst, aclTensor* acl_exp, + ggml_tensor* bind_tensor) { // acl_dst = acl_dst^acl_exp uint64_t workspaceSize = 0; @@ -1499,17 +1515,17 @@ void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst, void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_position, aclTensor* acl_dst, const int n_head, - const int64_t src_ne0, const int64_t src_ne1, - const int64_t src_ne2, const int64_t src_ne3, - const size_t src_nb0, float max_bias, ggml_tensor* dst) { - const int64_t ne2_ne3 = src_ne2 * src_ne3; + int64_t* src_ne, const size_t src_nb0, float max_bias, + ggml_tensor* dst) { + GGML_UNUSED(src_ne[1]); + const int64_t ne2_ne3 = src_ne[2] * src_ne[3]; GGML_ASSERT(src_nb0 == sizeof(float)); - GGML_ASSERT(n_head == src_ne2); + GGML_ASSERT(n_head == src_ne[2]); - const int n_heads_log2_floor = 1 << (int)floor(log2(n_head)); + const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); // init arange void* tmp_arange_buffer = @@ -1560,7 +1576,7 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, tmp_mk_base1_ne, tmp_mk_base1_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor, dst); - + aclTensor* tmp_mk_base2_tensor = nullptr; if (n_heads_log2_floor < ne2_ne3) { int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; @@ -1584,7 +1600,7 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor, dst); // reshape mk - int64_t tmp_mk_ne[] = {1, 1, src_ne2, src_ne3}; + int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]}; size_t tmp_mk_nb[GGML_MAX_DIMS]; tmp_mk_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -1595,7 +1611,7 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); // acl_position * mk - int64_t tmp_output_ne[] = {src_ne0, 1, src_ne2, src_ne3}; + int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]}; size_t tmp_output_nb[GGML_MAX_DIMS]; tmp_output_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -1660,8 +1676,7 @@ void ggml_cann_alibi(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // call alibi aclTensor* acl_src = create_acl_tensor(src); aclTensor* acl_dst = create_acl_tensor(dst); - aclnn_alibi(ctx, acl_src, tmp_position_tensor, acl_dst, n_head, ne0, ne1, - ne2, ne3, nb0, max_bias, dst); + aclnn_alibi(ctx, acl_src, tmp_position_tensor, acl_dst, n_head, src->ne, nb0, max_bias, dst); ACL_CHECK(aclDestroyTensor(tmp_position_tensor)); ACL_CHECK(aclDestroyTensor(acl_src)); @@ -1696,6 +1711,25 @@ void aclnn_inplace_add(ggml_backend_cann_context& ctx, aclTensor* acl_src, ACL_CHECK(aclDestroyScalar(alpha)); } +static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src, + int64_t dim, aclTensor* acl_dst, + ggml_tensor* bind_tensor) { + + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst, + &workspaceSize, &executor)); + + if (workspaceSize > 0) { + workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize); + } + + aclrtStream stream = ctx.stream(); + ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream)); +} + void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; ggml_tensor* src1 = dst->src[1]; // mask @@ -1710,78 +1744,124 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { memcpy(&scale, (float*)dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float)); + // input mul scale aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); size_t n_bytes = ggml_nbytes(src0); - void* buffer = ctx.alloc_buffer(dst, n_bytes); - aclTensor* temp_tensor = - create_acl_tensor(buffer, ACL_FLOAT, ggml_type_size(src0->type), - src0->ne, src0->nb, GGML_MAX_DIMS); - - // aclnn scale - uint64_t workspaceSize = 0; - aclOpExecutor* executor; - void* workspaceAddr = nullptr; - - aclnnMulsGetWorkspaceSize(acl_src0, acl_scale, temp_tensor, &workspaceSize, - &executor); - if (workspaceSize > 0) { - workspaceAddr = ctx.alloc_buffer(dst, workspaceSize); - } + void* input_mul_scale_buffer = ctx.alloc_buffer(dst, n_bytes); + aclTensor* acl_input_mul_scale_tensor = create_acl_tensor( + input_mul_scale_buffer, + ACL_FLOAT, + ggml_type_size(src0->type), + src0->ne, src0->nb, + GGML_MAX_DIMS); - aclrtStream stream = ctx.stream(); - aclnnMuls(workspaceAddr, workspaceSize, executor, stream); + bool inplace = false; + aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace, + dst); // mask - aclTensor* acl_src1 = nullptr; + aclTensor* acl_src1_fp32_tensor = nullptr; + aclTensor* tmp_permute_tenosr = nullptr; if (src1) { - acl_src1 = create_acl_tensor(src1); - aclnn_inplace_add(ctx, acl_src1, temp_tensor, dst); - } + const bool use_f16 = src1->type == GGML_TYPE_F16; + if (use_f16) { + // cast to fp32 + size_t n_bytes = ggml_nelements(src1) * sizeof(float_t); + size_t src1_fp32_nb[GGML_MAX_DIMS]; + src1_fp32_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1]; + } + void* src1_fp32_buffer = ctx.alloc_buffer(dst, n_bytes); + acl_src1_fp32_tensor = create_acl_tensor(src1_fp32_buffer, + ACL_FLOAT, + sizeof(float), + src1->ne, + src1_fp32_nb, + GGML_MAX_DIMS); + aclTensor* acl_src1 = create_acl_tensor(src1); + aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT, dst); + + ACL_CHECK(aclDestroyTensor(acl_src1)); + + } + else { + acl_src1_fp32_tensor = create_acl_tensor(src1); + } + + // broadcast the mask across rows + if (src1->ne[1] != src0->ne[1]) { + // mask shape: [1,1,a,b] + int64_t tmp_permute_ne[] = {src1->ne[0], src1->ne[2], src1->ne[1], src1->ne[3]}; + size_t tmp_permute_nb[GGML_MAX_DIMS]; + tmp_permute_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1]; + } - aclTensor* temp_output_tensor = nullptr; - aclTensor* acl_src2 = nullptr; - if (max_bias > 0.0f) { - acl_src2 = create_acl_tensor(src2); + void* tmp_permute_buffer = ctx.alloc_buffer(dst, ggml_nbytes(src1)); + tmp_permute_tenosr = create_acl_tensor( + tmp_permute_buffer, ACL_FLOAT, sizeof(float), + tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + int64_t permute_dim[] = {0, 2, 1, 3}; + int64_t num_dims = 4; + aclnn_permute(ctx, acl_src1_fp32_tensor, tmp_permute_tenosr, permute_dim, num_dims, dst); + } + + // alibi const int n_head = src0->ne[2]; - const int64_t src_ne0 = src0->ne[0]; - const int64_t src_ne1 = src0->ne[1]; - const int64_t src_ne2 = src0->ne[2]; - const int64_t src_ne3 = src0->ne[3]; const size_t src_nb0 = src0->nb[0]; - // alibi n_bytes = ggml_nbytes(dst); void* output_buffer = ctx.alloc_buffer(dst, n_bytes); - temp_output_tensor = create_acl_tensor(output_buffer, ACL_FLOAT, - ggml_type_size(dst->type), - dst->ne, dst->nb, GGML_MAX_DIMS); - aclnn_alibi(ctx, temp_tensor, acl_src2, temp_output_tensor, n_head, - src_ne0, src_ne1, src_ne2, src_ne3, src_nb0, max_bias, dst); - - // softmax - ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(temp_output_tensor, 3, acl_dst, - &workspaceSize, &executor)); - - } else { + aclTensor* alibi_output_tensor = create_acl_tensor( + output_buffer, + ACL_FLOAT, + ggml_type_size(dst->type), + dst->ne, dst->nb, + GGML_MAX_DIMS); + if (max_bias <=0.0f) { + // slope = 1.0 + if (tmp_permute_tenosr) { + aclnn_add(ctx, tmp_permute_tenosr, acl_input_mul_scale_tensor, + alibi_output_tensor, dst); + } + else { + aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor, + alibi_output_tensor, dst); + } + + } + else { + // slope != 1.0 + if (tmp_permute_tenosr) { + aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_permute_tenosr, + alibi_output_tensor, n_head, src0->ne, src_nb0, max_bias, + dst); + } + else { + aclnn_alibi(ctx, acl_input_mul_scale_tensor, acl_src1_fp32_tensor, + alibi_output_tensor, n_head, src0->ne, src_nb0, max_bias, + dst); + } + } + // softmax - ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(temp_tensor, 3, acl_dst, - &workspaceSize, &executor)); - } - - if (workspaceSize > 0) { - workspaceAddr = ctx.alloc_buffer(dst, workspaceSize); + aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst, dst); + ACL_CHECK(aclDestroyTensor(alibi_output_tensor)); + } + else { + aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst, dst); } - ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream)); - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(acl_src2)); + ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor)); ACL_CHECK(aclDestroyTensor(acl_dst)); ACL_CHECK(aclDestroyScalar(acl_scale)); - ACL_CHECK(aclDestroyTensor(temp_tensor)); - ACL_CHECK(aclDestroyTensor(temp_output_tensor)); + ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor)); + ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr)); + } void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { diff --git a/ggml.c b/ggml.c index f6dcd06b15777..19bc01661e8f9 100644 --- a/ggml.c +++ b/ggml.c @@ -1891,6 +1891,36 @@ inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const flo } #endif +#if defined(__ARM_NEON) && defined(__aarch64__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static float32x4_t ggml_v_expf(float32x4_t x) { + const float32x4_t r = vdupq_n_f32(0x1.8p23f); + const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); + const float32x4_t n = vsubq_f32(z, r); + const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, + vdupq_n_f32(0x1.7f7d1cp-20f)); + const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); + const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); + const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); + const float32x4_t u = vmulq_f32(b, b); + const float32x4_t j = vfmaq_f32( + vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), + vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), + vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); + if (!vpaddd_u64(vreinterpretq_u64_u32(c))) + return vfmaq_f32(k, j, k); + const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); + const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); + const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); + return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), + vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); +} +#endif + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; @@ -2379,6 +2409,60 @@ bool ggml_is_numa(void) { return g_state.numa.n_nodes > 1; } + +static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { + int i = 0; + ggml_float sum = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(max))); + _mm512_storeu_ps(y + i, val); + sum += (ggml_float)_mm512_reduce_add_ps(val); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(max))); + _mm256_storeu_ps(y + i, val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(max))); + _mm_storeu_ps(y + i, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif + sum += (ggml_float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(max))); + vst1q_f32(y + i, val); + sum += (ggml_float)vaddvq_f32(val); + } +#endif + for (; i < n; ++i) { + float val = expf(x[i] - max); + sum += (ggml_float)val; + y[i] = val; + } + return sum; +} + //////////////////////////////////////////////////////////////////////////////// void ggml_print_object(const struct ggml_object * obj) { @@ -5402,26 +5486,21 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias, bool inplace) { GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); - } - - if (pos) { - GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F32); - GGML_ASSERT(pos->ne[0] == a->ne[0]); + GGML_ASSERT(mask->ne[0] == a->ne[0]); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (max_bias > 0.0f) { - GGML_ASSERT(pos); + GGML_ASSERT(mask); } bool is_node = false; @@ -5439,7 +5518,6 @@ static struct ggml_tensor * ggml_soft_max_impl( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = mask; - result->src[2] = pos; return result; } @@ -5447,23 +5525,22 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * ggml_soft_max( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false); } struct ggml_tensor * ggml_soft_max_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true); } struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias) { - return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false); + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } // ggml_soft_max_back @@ -12245,15 +12322,10 @@ static void ggml_compute_forward_soft_max_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src2 = dst->src[2]; assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - float scale = 1.0f; float max_bias = 0.0f; @@ -12267,12 +12339,12 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 - const uint32_t n_head_kv = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv)); + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -12289,29 +12361,31 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; - // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - float * pos = src2 ? (float *) src2->data : src0->data; + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { + // ALiBi + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); - if (mp) { - ggml_vec_acc_f32(nc, wp, mp); - } - - // ALiBi bias - if (max_bias > 0.0f) { - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - - for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*pos[i]; + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*mp_f32[i]; + } } } @@ -12325,22 +12399,7 @@ static void ggml_compute_forward_soft_max_f32( float max = -INFINITY; ggml_vec_max_f32(nc, &max, wp); - ggml_float sum = 0.0; - - uint16_t scvt; - for (int i = 0; i < nc; i++) { - if (wp[i] == -INFINITY) { - dp[i] = 0.0f; - } else { - // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max); - ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max); - memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]); - sum += (ggml_float)val; - dp[i] = val; - } - } - + ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); assert(sum > 0.0); sum = 1.0/sum; @@ -12373,6 +12432,7 @@ static void ggml_compute_forward_soft_max( } } + // ggml_compute_forward_soft_max_back static void ggml_compute_forward_soft_max_back_f32( diff --git a/ggml.h b/ggml.h index 4507a38d66c24..ead49dcdedc12 100644 --- a/ggml.h +++ b/ggml.h @@ -1427,7 +1427,6 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias); diff --git a/llama.cpp b/llama.cpp index 320f4bdb8e217..9f76007708584 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6433,7 +6433,7 @@ static struct ggml_tensor * llm_build_kqv( } else #endif { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); } @@ -8060,7 +8060,7 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4963a7a845f67..855d58c644f7a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1098,34 +1098,38 @@ struct test_diag_mask_inf : public test_case { // GGML_OP_SOFT_MAX struct test_soft_max : public test_case { - const ggml_type type; + const ggml_type type_src; + const ggml_type type_mask; const std::array ne; const bool mask; const float scale; const float max_bias; std::string vars() override { - return VARS_TO_STR5(type, ne, mask, scale, max_bias); + return VARS_TO_STR6(type_src, type_mask, ne, mask, scale, max_bias); } - test_soft_max(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}, - bool mask = false, - float scale = 1.0f, - float max_bias = 0.0f) - : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {} + // the 1024 test with bias occasionally fails: + // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL + virtual double max_nmse_err() override { + return 1e-6; + } + + test_soft_max(ggml_type type_src = GGML_TYPE_F32, + ggml_type type_mask = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}, + bool mask = false, + float scale = 1.0f, + float max_bias = 0.0f) + : type_src(type_src), type_mask(type_mask), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * a = ggml_new_tensor(ctx, type_src, 4, ne.data()); ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); + mask = ggml_new_tensor_2d(ctx, type_mask, ne[0], ne[1]); } - ggml_tensor * pos = nullptr; - if (max_bias > 0.0f) { - pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); - } - ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias); + ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); return out; } }; @@ -1618,7 +1622,7 @@ struct test_llm : public test_case { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f); // split cached v into n_head heads struct ggml_tensor * v = @@ -2139,7 +2143,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (int n = 0; n < 10; ++n) { int64_t ne0 = dist_ne0(rng); int64_t ne1 = dist_ne1(rng); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f)); } exponent <<= 1; @@ -2147,21 +2151,24 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #endif for (bool mask : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { + if (!mask && max_bias > 0.0f) continue; for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias)); + for (ggml_type mask_type: {GGML_TYPE_F32, GGML_TYPE_F16}) { + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, mask_type, {ne0, ne1, 1, 1}, mask, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, mask_type, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias)); + } } } } } } + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F16, {16, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B