Skip to content

Commit

Permalink
update softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshuai09 committed Jun 25, 2024
1 parent 8cb2abc commit decc078
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 94 deletions.
195 changes: 117 additions & 78 deletions ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,22 +1262,36 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream()));
}

static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, float scale,
ggml_tensor* bind_tensor) {
static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
float scale, aclTensor* acl_dst, bool inplace,
ggml_tensor* bind_tensor) {
aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale,
&workspaceSize, &executor));
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize);
if (inplace) {
ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale,
&workspaceSize, &executor));
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize);
}

ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor,
ctx.stream()));
}
else {
ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst,
&workspaceSize, &executor));
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(bind_tensor, workspaceSize);
}

ACL_CHECK(
aclnnInplaceMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor,
ctx.stream()));
}


ACL_CHECK(aclDestroyScalar(acl_scale));
}
Expand Down Expand Up @@ -1375,7 +1389,8 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,

// freq
float freq_param = -logf(max_period) / half;
aclnn_muls(ctx, tmp_arange_tensor, freq_param, dst);
bool inplace = true;
aclnn_muls(ctx, tmp_arange_tensor, freq_param, nullptr, inplace, dst);
aclnn_exp(ctx, tmp_arange_tensor, dst);

// permute: src [0,1,2,3]->[0,1,3,2]
Expand Down Expand Up @@ -1470,8 +1485,9 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
ACL_CHECK(aclDestroyScalar(acl_scalar));
}

static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst,
aclTensor* acl_exp, ggml_tensor* bind_tensor) {
static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
aclTensor* acl_dst, aclTensor* acl_exp,
ggml_tensor* bind_tensor) {
// acl_dst = acl_dst^acl_exp

uint64_t workspaceSize = 0;
Expand All @@ -1490,18 +1506,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* a

static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_position, aclTensor* acl_dst, const int n_head,
const int64_t src_ne0, const int64_t src_ne1,
const int64_t src_ne2, const int64_t src_ne3,
const size_t src_nb0, float max_bias, ggml_tensor* dst) {
GGML_UNUSED(src_ne1);
const int64_t ne2_ne3 = src_ne2 * src_ne3;
const int64_t* src_ne, const size_t src_nb0, float max_bias,
ggml_tensor* dst) {
GGML_UNUSED(src_ne[1]);
const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
GGML_ASSERT(src_nb0 == sizeof(float));
GGML_ASSERT(n_head == src_ne2);
GGML_ASSERT(n_head == src_ne[2]);

const int n_heads_log2_floor = 1 << (int)floor(log2(n_head));
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));

const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);

// init arange
void* tmp_arange_buffer =
Expand Down Expand Up @@ -1552,7 +1567,7 @@ static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
tmp_mk_base1_ne, tmp_mk_base1_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);

aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor, dst);

aclTensor* tmp_mk_base2_tensor = nullptr;
if (n_heads_log2_floor < ne2_ne3) {
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
Expand All @@ -1576,7 +1591,7 @@ static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor, dst);

// reshape mk
int64_t tmp_mk_ne[] = {1, 1, src_ne2, src_ne3};
int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
size_t tmp_mk_nb[GGML_MAX_DIMS];
tmp_mk_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
Expand All @@ -1587,7 +1602,7 @@ static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);

// acl_position * mk
int64_t tmp_output_ne[] = {src_ne0, 1, src_ne2, src_ne3};
int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
size_t tmp_output_nb[GGML_MAX_DIMS];
tmp_output_nb[0] = ggml_type_size(dst->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
Expand Down Expand Up @@ -1640,6 +1655,21 @@ static void aclnn_inplace_add(ggml_backend_cann_context& ctx, aclTensor* acl_src
ACL_CHECK(aclDestroyScalar(alpha));
}

static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
int64_t dim, aclTensor* acl_dst,
ggml_tensor* bind_tensor) {

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst,
&workspaceSize, &executor));

aclrtStream stream = ctx.stream();
ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream));
}

void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0];
ggml_tensor* src1 = dst->src[1]; // mask
Expand All @@ -1653,78 +1683,87 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));

// input mul scale
aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);

size_t n_bytes = ggml_nbytes(src0);
void* buffer = ctx.alloc_buffer(dst, n_bytes);
aclTensor* temp_tensor =
create_acl_tensor(buffer, ACL_FLOAT, ggml_type_size(src0->type),
src0->ne, src0->nb, GGML_MAX_DIMS);

// aclnn scale
uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;

aclnnMulsGetWorkspaceSize(acl_src0, acl_scale, temp_tensor, &workspaceSize,
&executor);
if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(dst, workspaceSize);
}

aclrtStream stream = ctx.stream();
aclnnMuls(workspaceAddr, workspaceSize, executor, stream);
void* input_mul_scale_buffer = ctx.alloc_buffer(dst, n_bytes);
aclTensor* acl_input_mul_scale_tensor = create_acl_tensor(
input_mul_scale_buffer,
ACL_FLOAT,
ggml_type_size(src0->type),
src0->ne, src0->nb,
GGML_MAX_DIMS);

bool inplace = false;
aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace,
dst);

// mask
aclTensor* acl_src1 = nullptr;
aclTensor* acl_src1_fp32_tensor = nullptr;
if (src1) {
acl_src1 = create_acl_tensor(src1);
aclnn_inplace_add(ctx, acl_src1, temp_tensor, dst);
}

aclTensor* temp_output_tensor = nullptr;
aclTensor* acl_src2 = nullptr;
if (max_bias > 0.0f) {
acl_src2 = create_acl_tensor(src2);
const bool use_f16 = src1->type == GGML_TYPE_F16;
if (use_f16) {
// cast to fp32
size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
size_t src1_fp32_nb[GGML_MAX_DIMS];
src1_fp32_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
}
void* src1_fp32_buffer = ctx.alloc_buffer(dst, n_bytes);
acl_src1_fp32_tensor = create_acl_tensor(src1_fp32_buffer,
ACL_FLOAT,
sizeof(float),
src1->ne,
src1_fp32_nb,
GGML_MAX_DIMS);
aclTensor* acl_src1 = create_acl_tensor(src1);
aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT, dst);
ACL_CHECK(aclDestroyTensor(acl_src1));
}
else {
acl_src1_fp32_tensor = create_acl_tensor(src1);
}

// alibi
const int n_head = src0->ne[2];
const int64_t src_ne0 = src0->ne[0];
const int64_t src_ne1 = src0->ne[1];
const int64_t src_ne2 = src0->ne[2];
const int64_t src_ne3 = src0->ne[3];
const size_t src_nb0 = src0->nb[0];

// alibi
n_bytes = ggml_nbytes(dst);
void* output_buffer = ctx.alloc_buffer(dst, n_bytes);
temp_output_tensor = create_acl_tensor(output_buffer, ACL_FLOAT,
ggml_type_size(dst->type),
dst->ne, dst->nb, GGML_MAX_DIMS);
aclnn_alibi(ctx, temp_tensor, acl_src2, temp_output_tensor, n_head,
src_ne0, src_ne1, src_ne2, src_ne3, src_nb0, max_bias, dst);

// softmax
ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(temp_output_tensor, 3, acl_dst,
&workspaceSize, &executor));

} else {
aclTensor* alibi_output_tensor = create_acl_tensor(
output_buffer,
ACL_FLOAT,
ggml_type_size(dst->type),
dst->ne, dst->nb,
GGML_MAX_DIMS);
if (max_bias <=0.0f) {
// slope = 1.0
aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
alibi_output_tensor, dst);
}
else {
// slope != 0
aclnn_alibi(ctx, acl_input_mul_scale_tensor, acl_src1_fp32_tensor,
alibi_output_tensor, n_head, src0->ne, src_nb0, max_bias,
dst);
}

// softmax
ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(temp_tensor, 3, acl_dst,
&workspaceSize, &executor));
aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst, dst);
ACL_CHECK(aclDestroyTensor(alibi_output_tensor));
}
else {
aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst, dst);
}

if (workspaceSize > 0) {
workspaceAddr = ctx.alloc_buffer(dst, workspaceSize);
}

ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream));

ACL_CHECK(aclDestroyTensor(acl_src0));
ACL_CHECK(aclDestroyTensor(acl_src1));
ACL_CHECK(aclDestroyTensor(acl_src2));
ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor));
ACL_CHECK(aclDestroyTensor(acl_dst));
ACL_CHECK(aclDestroyScalar(acl_scale));
ACL_CHECK(aclDestroyTensor(temp_tensor));
ACL_CHECK(aclDestroyTensor(temp_output_tensor));
ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));

}

void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
Expand Down
36 changes: 20 additions & 16 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1155,14 +1155,15 @@ struct test_diag_mask_inf : public test_case {

// GGML_OP_SOFT_MAX
struct test_soft_max : public test_case {
const ggml_type type;
const ggml_type type_src;
const ggml_type type_mask;
const std::array<int64_t, 4> ne;
const bool mask;
const float scale;
const float max_bias;

std::string vars() override {
return VARS_TO_STR5(type, ne, mask, scale, max_bias);
return VARS_TO_STR6(type_src, type_mask, ne, mask, scale, max_bias);
}

// the 1024 test with bias occasionally fails:
Expand All @@ -1171,18 +1172,19 @@ struct test_soft_max : public test_case {
return 1e-6;
}

test_soft_max(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10},
bool mask = false,
float scale = 1.0f,
float max_bias = 0.0f)
: type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
test_soft_max(ggml_type type_src = GGML_TYPE_F32,
ggml_type type_mask = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10},
bool mask = false,
float scale = 1.0f,
float max_bias = 0.0f)
: type_src(type_src), type_mask(type_mask), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * a = ggml_new_tensor(ctx, type_src, 4, ne.data());
ggml_tensor * mask = nullptr;
if (this->mask) {
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
mask = ggml_new_tensor_2d(ctx, type_mask, ne[0], ne[1]);
}
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
return out;
Expand Down Expand Up @@ -2272,7 +2274,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n = 0; n < 10; ++n) {
int64_t ne0 = dist_ne0(rng);
int64_t ne1 = dist_ne1(rng);
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
}

exponent <<= 1;
Expand All @@ -2284,17 +2286,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
for (int64_t ne1 : {16, 1024}) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
for (ggml_type mask_type: {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, mask_type, {ne0, ne1, 1, 1}, mask, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, mask_type, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
}
}
}
}
}
}

test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));

{
bool all = true;
Expand Down

0 comments on commit decc078

Please sign in to comment.