Skip to content

Commit

Permalink
update softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshuai09 committed Jun 26, 2024
1 parent 3feb9b2 commit 81a0aac
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 155 deletions.
234 changes: 157 additions & 77 deletions ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,22 +1271,36 @@ void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream()));
}

void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, float scale,
ggml_tensor* bind_tensor) {
static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
float scale, aclTensor* acl_dst, bool inplace,
ggml_tensor* bind_tensor) {
aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale,
&workspaceSize, &executor));
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize);
if (inplace) {
ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale,
&workspaceSize, &executor));
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize);
}

ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor,
ctx.stream()));
}
else {
ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst,
&workspaceSize, &executor));
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize);
}

ACL_CHECK(
aclnnInplaceMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor,
ctx.stream()));
}


ACL_CHECK(aclDestroyScalar(acl_scale));
}
Expand Down Expand Up @@ -1384,7 +1398,8 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,

// freq
float freq_param = -logf(max_period) / half;
aclnn_muls(ctx, tmp_arange_tensor, freq_param, dst);
bool inplace = true;
aclnn_muls(ctx, tmp_arange_tensor, freq_param, nullptr, inplace, dst);
aclnn_exp(ctx, tmp_arange_tensor, dst);

// permute: src [0,1,2,3]->[0,1,3,2]
Expand Down Expand Up @@ -1479,8 +1494,9 @@ void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
ACL_CHECK(aclDestroyScalar(acl_scalar));
}

void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst,
aclTensor* acl_exp, ggml_tensor* bind_tensor) {
static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
aclTensor* acl_dst, aclTensor* acl_exp,
ggml_tensor* bind_tensor) {
// acl_dst = acl_dst^acl_exp

uint64_t workspaceSize = 0;
Expand All @@ -1499,17 +1515,17 @@ void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst,

void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_position, aclTensor* acl_dst, const int n_head,
const int64_t src_ne0, const int64_t src_ne1,
const int64_t src_ne2, const int64_t src_ne3,
const size_t src_nb0, float max_bias, ggml_tensor* dst) {
const int64_t ne2_ne3 = src_ne2 * src_ne3;
int64_t* src_ne, const size_t src_nb0, float max_bias,
ggml_tensor* dst) {
GGML_UNUSED(src_ne[1]);
const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
GGML_ASSERT(src_nb0 == sizeof(float));
GGML_ASSERT(n_head == src_ne2);
GGML_ASSERT(n_head == src_ne[2]);

const int n_heads_log2_floor = 1 << (int)floor(log2(n_head));
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));

const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);

// init arange
void* tmp_arange_buffer =
Expand Down Expand Up @@ -1560,7 +1576,7 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
tmp_mk_base1_ne, tmp_mk_base1_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);

aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor, dst);

aclTensor* tmp_mk_base2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
Expand All @@ -1584,7 +1600,7 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor, dst);

// reshape mk
int64_t tmp_mk_ne[] = {1, 1, src_ne2, src_ne3};
int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
size_t tmp_mk_nb[GGML_MAX_DIMS];
tmp_mk_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
Expand All @@ -1595,7 +1611,7 @@ void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);

// acl_position * mk
int64_t tmp_output_ne[] = {src_ne0, 1, src_ne2, src_ne3};
int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
size_t tmp_output_nb[GGML_MAX_DIMS];
tmp_output_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
Expand Down Expand Up @@ -1660,8 +1676,7 @@ void ggml_cann_alibi(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// call alibi
aclTensor* acl_src = create_acl_tensor(src);
aclTensor* acl_dst = create_acl_tensor(dst);
aclnn_alibi(ctx, acl_src, tmp_position_tensor, acl_dst, n_head, ne0, ne1,
ne2, ne3, nb0, max_bias, dst);
aclnn_alibi(ctx, acl_src, tmp_position_tensor, acl_dst, n_head, src->ne, nb0, max_bias, dst);

ACL_CHECK(aclDestroyTensor(tmp_position_tensor));
ACL_CHECK(aclDestroyTensor(acl_src));
Expand Down Expand Up @@ -1696,6 +1711,25 @@ void aclnn_inplace_add(ggml_backend_cann_context& ctx, aclTensor* acl_src,
ACL_CHECK(aclDestroyScalar(alpha));
}

static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
int64_t dim, aclTensor* acl_dst,
ggml_tensor* bind_tensor) {

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst,
&workspaceSize, &executor));

if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize);
}

aclrtStream stream = ctx.stream();
ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream));
}

void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0];
ggml_tensor* src1 = dst->src[1]; // mask
Expand All @@ -1710,78 +1744,124 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));

// input mul scale
aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);

size_t n_bytes = ggml_nbytes(src0);
void* buffer = ctx.alloc_buffer(dst, n_bytes);
aclTensor* temp_tensor =
create_acl_tensor(buffer, ACL_FLOAT, ggml_type_size(src0->type),
src0->ne, src0->nb, GGML_MAX_DIMS);

// aclnn scale
uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

aclnnMulsGetWorkspaceSize(acl_src0, acl_scale, temp_tensor, &workspaceSize,
&executor);
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(dst, workspaceSize);
}
void* input_mul_scale_buffer = ctx.alloc_buffer(dst, n_bytes);
aclTensor* acl_input_mul_scale_tensor = create_acl_tensor(
input_mul_scale_buffer,
ACL_FLOAT,
ggml_type_size(src0->type),
src0->ne, src0->nb,
GGML_MAX_DIMS);

aclrtStream stream = ctx.stream();
aclnnMuls(workspaceAddr, workspaceSize, executor, stream);
bool inplace = false;
aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace,
dst);

// mask
aclTensor* acl_src1 = nullptr;
aclTensor* acl_src1_fp32_tensor = nullptr;
aclTensor* tmp_permute_tenosr = nullptr;
if (src1) {
acl_src1 = create_acl_tensor(src1);
aclnn_inplace_add(ctx, acl_src1, temp_tensor, dst);
}
const bool use_f16 = src1->type == GGML_TYPE_F16;
if (use_f16) {
// cast to fp32
size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
size_t src1_fp32_nb[GGML_MAX_DIMS];
src1_fp32_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
}
void* src1_fp32_buffer = ctx.alloc_buffer(dst, n_bytes);
acl_src1_fp32_tensor = create_acl_tensor(src1_fp32_buffer,
ACL_FLOAT,
sizeof(float),
src1->ne,
src1_fp32_nb,
GGML_MAX_DIMS);
aclTensor* acl_src1 = create_acl_tensor(src1);
aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT, dst);

ACL_CHECK(aclDestroyTensor(acl_src1));

}
else {
acl_src1_fp32_tensor = create_acl_tensor(src1);
}

// broadcast the mask across rows
if (src1->ne[1] != src0->ne[1]) {
// mask shape: [1,1,a,b]
int64_t tmp_permute_ne[] = {src1->ne[0], src1->ne[2], src1->ne[1], src1->ne[3]};
size_t tmp_permute_nb[GGML_MAX_DIMS];
tmp_permute_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
}

aclTensor* temp_output_tensor = nullptr;
aclTensor* acl_src2 = nullptr;
if (max_bias > 0.0f) {
acl_src2 = create_acl_tensor(src2);
void* tmp_permute_buffer = ctx.alloc_buffer(dst, ggml_nbytes(src1));
tmp_permute_tenosr = create_acl_tensor(
tmp_permute_buffer, ACL_FLOAT, sizeof(float),
tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
int64_t permute_dim[] = {0, 2, 1, 3};
int64_t num_dims = 4;
aclnn_permute(ctx, acl_src1_fp32_tensor, tmp_permute_tenosr, permute_dim, num_dims, dst);
}

// alibi
const int n_head = src0->ne[2];
const int64_t src_ne0 = src0->ne[0];
const int64_t src_ne1 = src0->ne[1];
const int64_t src_ne2 = src0->ne[2];
const int64_t src_ne3 = src0->ne[3];
const size_t src_nb0 = src0->nb[0];

// alibi
n_bytes = ggml_nbytes(dst);
void* output_buffer = ctx.alloc_buffer(dst, n_bytes);
temp_output_tensor = create_acl_tensor(output_buffer, ACL_FLOAT,
ggml_type_size(dst->type),
dst->ne, dst->nb, GGML_MAX_DIMS);
aclnn_alibi(ctx, temp_tensor, acl_src2, temp_output_tensor, n_head,
src_ne0, src_ne1, src_ne2, src_ne3, src_nb0, max_bias, dst);

// softmax
ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(temp_output_tensor, 3, acl_dst,
&workspaceSize, &executor));

} else {
aclTensor* alibi_output_tensor = create_acl_tensor(
output_buffer,
ACL_FLOAT,
ggml_type_size(dst->type),
dst->ne, dst->nb,
GGML_MAX_DIMS);
if (max_bias <=0.0f) {
// slope = 1.0
if (tmp_permute_tenosr) {
aclnn_add(ctx, tmp_permute_tenosr, acl_input_mul_scale_tensor,
alibi_output_tensor, dst);
}
else {
aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
alibi_output_tensor, dst);
}

}
else {
// slope != 1.0
if (tmp_permute_tenosr) {
aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_permute_tenosr,
alibi_output_tensor, n_head, src0->ne, src_nb0, max_bias,
dst);
}
else {
aclnn_alibi(ctx, acl_input_mul_scale_tensor, acl_src1_fp32_tensor,
alibi_output_tensor, n_head, src0->ne, src_nb0, max_bias,
dst);
}
}

// softmax
ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(temp_tensor, 3, acl_dst,
&workspaceSize, &executor));
}

if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(dst, workspaceSize);
aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst, dst);
ACL_CHECK(aclDestroyTensor(alibi_output_tensor));
}
else {
aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst, dst);
}

ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream));

ACL_CHECK(aclDestroyTensor(acl_src0));
ACL_CHECK(aclDestroyTensor(acl_src1));
ACL_CHECK(aclDestroyTensor(acl_src2));
ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor));
ACL_CHECK(aclDestroyTensor(acl_dst));
ACL_CHECK(aclDestroyScalar(acl_scale));
ACL_CHECK(aclDestroyTensor(temp_tensor));
ACL_CHECK(aclDestroyTensor(temp_output_tensor));
ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));
ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr));

}

void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
Expand Down
Loading

0 comments on commit 81a0aac

Please sign in to comment.