Skip to content

Commit

Permalink
fix softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshuai09 committed Jul 1, 2024
1 parent f7a6bf8 commit 6740231
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1753,7 +1753,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {

// mask
aclTensor* acl_src1_fp32_tensor = nullptr;
aclTensor* tmp_permute_tenosr = nullptr;
aclTensor* tmp_mask_tensor = nullptr;
if (src1) {
const bool use_f16 = src1->type == GGML_TYPE_F16;
if (use_f16) {
Expand Down Expand Up @@ -1781,23 +1781,18 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
acl_src1_fp32_tensor = create_acl_tensor(src1);
}

// broadcast the mask across rows
// broadcast the mask across rows, only use ne11 of ne01 in mask
if (src1->ne[1] != src0->ne[1]) {
// mask shape: [1,1,a,b]
int64_t tmp_permute_ne[] = {src1->ne[0], src1->ne[2], src1->ne[1], src1->ne[3]};
size_t tmp_permute_nb[GGML_MAX_DIMS];
tmp_permute_nb[0] = sizeof(float_t);
// mask shape: [1,1,ne11,ne10]
int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
size_t tmp_mask_nb[GGML_MAX_DIMS];
tmp_mask_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
}

void* tmp_permute_buffer = ctx.alloc_buffer(dst, ggml_nbytes(src1));
tmp_permute_tenosr = create_acl_tensor(
tmp_permute_buffer, ACL_FLOAT, sizeof(float),
tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
int64_t permute_dim[] = {0, 2, 1, 3};
int64_t num_dims = 4;
aclnn_permute(ctx, acl_src1_fp32_tensor, tmp_permute_tenosr, permute_dim, num_dims, dst);
tmp_mask_tensor = create_acl_tensor(
src1->data, ACL_FLOAT, sizeof(float),
tmp_mask_ne, tmp_mask_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
}

// alibi
Expand All @@ -1814,8 +1809,8 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
GGML_MAX_DIMS);
if (max_bias <=0.0f) {
// slope = 1.0
if (tmp_permute_tenosr) {
aclnn_add(ctx, tmp_permute_tenosr, acl_input_mul_scale_tensor,
if (tmp_mask_tensor) {
aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
alibi_output_tensor, dst);
}
else {
Expand All @@ -1826,8 +1821,8 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
}
else {
// slope != 1.0
if (tmp_permute_tenosr) {
aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_permute_tenosr,
if (tmp_mask_tensor) {
aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
alibi_output_tensor, n_head, src0->ne, src_nb0, max_bias,
dst);
}
Expand All @@ -1851,7 +1846,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ACL_CHECK(aclDestroyTensor(acl_dst));
ACL_CHECK(aclDestroyScalar(acl_scale));
ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));
ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr));
ACL_CHECK(aclDestroyTensor(tmp_mask_tensor));

}

Expand Down

0 comments on commit 6740231

Please sign in to comment.