Skip to content

Commit

Permalink
updata rope
Browse files Browse the repository at this point in the history
  • Loading branch information
wangshuai09 committed Jun 29, 2024
1 parent 8e62bbc commit f7a6bf8
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 52 deletions.
43 changes: 29 additions & 14 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <exception>
#include "kernels/ascendc_kernels.h"

#define MAX(a, b) ((a) > (b) ? (a) : (b))

static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst, int64_t* repeat_array,
ggml_tensor* bind_tensor) {
Expand Down Expand Up @@ -2218,6 +2220,7 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src,
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position
ggml_tensor* src2 = dst->src[2]; // freq_factors

// param init
rope_param param;
Expand All @@ -2228,9 +2231,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {

const int mode = ((int32_t *) dst->op_params)[2];
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
param.is_neox = is_neox;
param.is_glm = is_glm;

memcpy(&param.freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&param.freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
Expand All @@ -2248,9 +2249,18 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
float corr_dims[2];
ggml_rope_yarn_corr_dims(param.n_dims, param.n_orig_ctx, param.freq_base,
param.beta_fast, param.beta_slow, corr_dims);
param.corr_dims[0] = corr_dims[0];
param.corr_dims[1] = corr_dims[1];

param.corr_dims0 = corr_dims[0];
param.corr_dims1 = corr_dims[1];
param.corr_dims0_neg = -1.0f * corr_dims[0];
param.rope_yarn_ramp_max_inv = MAX(0.001f, corr_dims[1] - corr_dims[0]);

param.use_freq_factors = false;
if (src2 != NULL) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src2->ne[0] >= param.n_dims / 2);
param.use_freq_factors = true;
}
// printf("\nrepo param: mode: %d, n_dims: %d, fs: %f, ef: %f, af: %f, ff: %d\n", mode, param.n_dims, param.freq_scale, param.ext_factor, param.attn_factor, param.use_freq_factors);
// param copy
void *param_buffer;
ACL_CHECK(aclrtMalloc(&param_buffer, sizeof(rope_param),
Expand Down Expand Up @@ -2284,15 +2294,24 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
* sizeof(float_t));
void* cos_buffer = ctx.alloc_buffer(dst, src0->ne[0] * src0->ne[2]
* sizeof(float_t));


aclrtlaunch_ascendc_rope_init_cache(param.position_ne[0], ctx.stream(),
position_cast_buffer,
if (src2 != NULL) {
aclrtlaunch_ascendc_rope_init_cache(param.position_ne[0], ctx.stream(),
position_cast_buffer, src2->data,
sin_buffer, cos_buffer,
param_buffer,
((ggml_tensor*)src0->extra)->ne);
}
else {
aclrtlaunch_ascendc_rope_init_cache(param.position_ne[0], ctx.stream(),
position_cast_buffer, nullptr,
sin_buffer, cos_buffer,
param_buffer,
((ggml_tensor*)src0->extra)->ne);
}
ACL_CHECK(aclrtFree(param_buffer));

GGML_ASSERT(param.ext_factor==0);

// reshape sin&cos
// TODO: ne[3] != 0
int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
Expand All @@ -2316,11 +2335,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// roll input
void* input_roll_buffer;
aclTensor* acl_minus_one_tensor;
if (is_glm) {
// TODO
GGML_ASSERT(false);
}
else if (!is_neox) {
if (!is_neox) {
// roll input: [q0,q1,q2,...] -> [q1,q0,q3,q2...]
input_roll_buffer = ctx.alloc_buffer(dst, ggml_nbytes(src0));
int64_t input_roll_ne[4] = {2, src0->ne[1]*(src0->ne[0]/2), src0->ne[2],
Expand Down
7 changes: 5 additions & 2 deletions ggml/src/ggml-cann/kernels/rope.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ typedef struct {
int n_dims;
int n_orig_ctx;
float theta_scale;
float corr_dims[2];
float corr_dims0;
float corr_dims1;
float corr_dims0_neg;
float rope_yarn_ramp_max_inv;
bool is_neox;
bool is_glm;
bool use_freq_factors;

} rope_param;
#pragma pack(pop)
Expand Down
Loading

0 comments on commit f7a6bf8

Please sign in to comment.