From 6740231207ab5caf25cc7e398c343c1beb3858f3 Mon Sep 17 00:00:00 2001 From: wangshuai09 <391746016@qq.com> Date: Mon, 1 Jul 2024 07:26:34 +0000 Subject: [PATCH] fix softmax --- ggml/src/ggml-cann/aclnn_ops.cpp | 35 ++++++++++++++------------------ 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index d377bc723982f..8798dbcf375de 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1753,7 +1753,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // mask aclTensor* acl_src1_fp32_tensor = nullptr; - aclTensor* tmp_permute_tenosr = nullptr; + aclTensor* tmp_mask_tensor = nullptr; if (src1) { const bool use_f16 = src1->type == GGML_TYPE_F16; if (use_f16) { @@ -1781,23 +1781,18 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { acl_src1_fp32_tensor = create_acl_tensor(src1); } - // broadcast the mask across rows + // broadcast the mask across rows, only use ne11 of ne01 in mask if (src1->ne[1] != src0->ne[1]) { - // mask shape: [1,1,a,b] - int64_t tmp_permute_ne[] = {src1->ne[0], src1->ne[2], src1->ne[1], src1->ne[3]}; - size_t tmp_permute_nb[GGML_MAX_DIMS]; - tmp_permute_nb[0] = sizeof(float_t); + // mask shape: [1,1,ne11,ne10] + int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1}; + size_t tmp_mask_nb[GGML_MAX_DIMS]; + tmp_mask_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { - tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1]; + tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1]; } - - void* tmp_permute_buffer = ctx.alloc_buffer(dst, ggml_nbytes(src1)); - tmp_permute_tenosr = create_acl_tensor( - tmp_permute_buffer, ACL_FLOAT, sizeof(float), - tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - int64_t permute_dim[] = {0, 2, 1, 3}; - int64_t num_dims = 4; - aclnn_permute(ctx, acl_src1_fp32_tensor, tmp_permute_tenosr, permute_dim, num_dims, dst); + tmp_mask_tensor = create_acl_tensor( + src1->data, ACL_FLOAT, sizeof(float), + tmp_mask_ne, tmp_mask_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); } // alibi @@ -1814,8 +1809,8 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { GGML_MAX_DIMS); if (max_bias <=0.0f) { // slope = 1.0 - if (tmp_permute_tenosr) { - aclnn_add(ctx, tmp_permute_tenosr, acl_input_mul_scale_tensor, + if (tmp_mask_tensor) { + aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor, alibi_output_tensor, dst); } else { @@ -1826,8 +1821,8 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { } else { // slope != 1.0 - if (tmp_permute_tenosr) { - aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_permute_tenosr, + if (tmp_mask_tensor) { + aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor, alibi_output_tensor, n_head, src0->ne, src_nb0, max_bias, dst); } @@ -1851,7 +1846,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_dst)); ACL_CHECK(aclDestroyScalar(acl_scale)); ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr)); + ACL_CHECK(aclDestroyTensor(tmp_mask_tensor)); }