diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 51594633d5da4..d377bc723982f 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -31,6 +31,8 @@ #include #include "kernels/ascendc_kernels.h" +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* repeat_array, ggml_tensor* bind_tensor) { @@ -2218,6 +2220,7 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src, void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src1 = dst->src[1]; // position + ggml_tensor* src2 = dst->src[2]; // freq_factors // param init rope_param param; @@ -2228,9 +2231,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const int mode = ((int32_t *) dst->op_params)[2]; const bool is_neox = mode & 2; - const bool is_glm = mode & 4; param.is_neox = is_neox; - param.is_glm = is_glm; memcpy(¶m.freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(¶m.freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -2248,9 +2249,18 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { float corr_dims[2]; ggml_rope_yarn_corr_dims(param.n_dims, param.n_orig_ctx, param.freq_base, param.beta_fast, param.beta_slow, corr_dims); - param.corr_dims[0] = corr_dims[0]; - param.corr_dims[1] = corr_dims[1]; - + param.corr_dims0 = corr_dims[0]; + param.corr_dims1 = corr_dims[1]; + param.corr_dims0_neg = -1.0f * corr_dims[0]; + param.rope_yarn_ramp_max_inv = MAX(0.001f, corr_dims[1] - corr_dims[0]); + + param.use_freq_factors = false; + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= param.n_dims / 2); + param.use_freq_factors = true; + } + // printf("\nrepo param: mode: %d, n_dims: %d, fs: %f, ef: %f, af: %f, ff: %d\n", mode, param.n_dims, param.freq_scale, param.ext_factor, param.attn_factor, param.use_freq_factors); // param copy void *param_buffer; ACL_CHECK(aclrtMalloc(¶m_buffer, sizeof(rope_param), @@ -2284,15 +2294,24 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * sizeof(float_t)); void* cos_buffer = ctx.alloc_buffer(dst, src0->ne[0] * src0->ne[2] * sizeof(float_t)); - - - aclrtlaunch_ascendc_rope_init_cache(param.position_ne[0], ctx.stream(), - position_cast_buffer, + if (src2 != NULL) { + aclrtlaunch_ascendc_rope_init_cache(param.position_ne[0], ctx.stream(), + position_cast_buffer, src2->data, sin_buffer, cos_buffer, param_buffer, ((ggml_tensor*)src0->extra)->ne); + } + else { + aclrtlaunch_ascendc_rope_init_cache(param.position_ne[0], ctx.stream(), + position_cast_buffer, nullptr, + sin_buffer, cos_buffer, + param_buffer, + ((ggml_tensor*)src0->extra)->ne); + } ACL_CHECK(aclrtFree(param_buffer)); + GGML_ASSERT(param.ext_factor==0); + // reshape sin&cos // TODO: ne[3] != 0 int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; @@ -2316,11 +2335,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // roll input void* input_roll_buffer; aclTensor* acl_minus_one_tensor; - if (is_glm) { - // TODO - GGML_ASSERT(false); - } - else if (!is_neox) { + if (!is_neox) { // roll input: [q0,q1,q2,...] -> [q1,q0,q3,q2...] input_roll_buffer = ctx.alloc_buffer(dst, ggml_nbytes(src0)); int64_t input_roll_ne[4] = {2, src0->ne[1]*(src0->ne[0]/2), src0->ne[2], diff --git a/ggml/src/ggml-cann/kernels/rope.h b/ggml/src/ggml-cann/kernels/rope.h index 223da0595b0a7..6d2114cd0fc06 100644 --- a/ggml/src/ggml-cann/kernels/rope.h +++ b/ggml/src/ggml-cann/kernels/rope.h @@ -9,9 +9,12 @@ typedef struct { int n_dims; int n_orig_ctx; float theta_scale; - float corr_dims[2]; + float corr_dims0; + float corr_dims1; + float corr_dims0_neg; + float rope_yarn_ramp_max_inv; bool is_neox; - bool is_glm; + bool use_freq_factors; } rope_param; #pragma pack(pop) diff --git a/ggml/src/ggml-cann/kernels/rope_init_cache.cpp b/ggml/src/ggml-cann/kernels/rope_init_cache.cpp index 1d55f8969a483..ed2fb46baacb4 100644 --- a/ggml/src/ggml-cann/kernels/rope_init_cache.cpp +++ b/ggml/src/ggml-cann/kernels/rope_init_cache.cpp @@ -11,6 +11,7 @@ class InitCache { public: __aicore__ inline InitCache() {} __aicore__ inline void init(GM_ADDR position, + GM_ADDR freq_factors, GM_ADDR sin_output, GM_ADDR cos_output, rope_param& param, @@ -60,13 +61,26 @@ class InitCache { // other param attn_factor = param.attn_factor; freq_scale = param.freq_scale; + ext_factor = param.ext_factor; is_neox = param.is_neox; - is_glm = param.is_glm; + use_freq_factors = param.use_freq_factors; + corr_dims0 = param.corr_dims0; + corr_dims1 = param.corr_dims1; + corr_dims0_neg = param.corr_dims0_neg; + rope_yarn_ramp_max_inv = param.rope_yarn_ramp_max_inv; // stride position_stride = op_block_idx; output_stride = op_block_idx * broadcast_size; + if (use_freq_factors) { + freq_factors_gm.SetGlobalBuffer((__gm__ float_t*)freq_factors, count); + pipe.InitBuffer(freq_factors_queue, BUFFER_NUM, + (sizeof(float_t)*count+32-1)/32*32); + pipe.InitBuffer(freq_factors_brcast_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + } + position_gm.SetGlobalBuffer((__gm__ float_t*)position + position_stride, 1); output_sin_gm.SetGlobalBuffer((__gm__ float_t*)sin_output + @@ -88,6 +102,39 @@ class InitCache { (sizeof(float_t)*broadcast_size+32-1)/32*32); pipe.InitBuffer(broadcast_power_buffer, (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(theta_base_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(theta_div_ff_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(theta_interp_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(arange_brcast_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(arange_brcast_div_buff, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(arange_brcast_div_add_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(rope_yarn_ramp_y_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(rope_yarn_ramp_y_max_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(rope_yarn_ramp_y_maxmin_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(rope_yarn_ramp_y_maxmin_neg_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(ramp_mix_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(theta0_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(theta1_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(arange_brcast_add_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + pipe.InitBuffer(mscale_buffer, + (sizeof(float_t)*broadcast_size+32-1)/32*32); + + + pipe.InitBuffer(theta_buffer, (sizeof(float_t)*broadcast_size+32-1)/32*32); pipe.InitBuffer(sin_buffer, @@ -96,7 +143,7 @@ class InitCache { (sizeof(float_t)*broadcast_size+32-1)/32*32); } - __aicore__ inline void copy_in() { + __aicore__ inline void copy_position_in() { LocalTensor input_local = position_queue.AllocTensor(); @@ -109,6 +156,24 @@ class InitCache { position_queue.EnQue(input_local); } + __aicore__ inline void copy_freq_factors_in() { + LocalTensor freq_factors_local = + freq_factors_queue.AllocTensor(); + + DataCopyExtParams dataCopyParams; + dataCopyParams.blockCount = 1; + dataCopyParams.blockLen = count * sizeof(float_t); + DataCopyPadExtParams padParams; + DataCopyPad(freq_factors_local, freq_factors_gm, dataCopyParams, + padParams); + PRINTF("\nfreq_scale: \n"); + for (int i =0 ; i<64; i++) { + PRINTF("%f,", freq_factors_local.GetValue(i)); + } + + freq_factors_queue.EnQue(freq_factors_local); + } + __aicore__ inline void copy_out() { LocalTensor sin_local = sin_mul_mscale_queue.DeQue(); int32_t BLOCK_NUM = 32 / sizeof(float_t); @@ -129,16 +194,14 @@ class InitCache { LocalTensor arange_local = arange_queue.AllocTensor(); ArithProgression(arange_local, first_value, diff_value, count); - // pow + // theta stride LocalTensor power_local = power_queue.AllocTensor(); Power(power_local, static_cast(theta_scale), arange_local); LocalTensor power_brcast_local = broadcast_power_buffer.Get(); - - //TODO: is_glm==true. - if (!is_glm && !is_neox) { + if (!is_neox) { // for :dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; // dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; // the value of 0,1 or 2,3, ..., should be same. @@ -146,12 +209,6 @@ class InitCache { // broadcast: e.g. arange [64, 1] -> [64, 2] BroadCast(power_brcast_local, power_local, broadcast_shape0, arange_shape); - // position: [1] - copy_in(); - LocalTensor position_local = - position_queue.DeQue(); - position_value = position_local.GetValue(0); - position_queue.FreeTensor(position_local); } else { // for: dst_data[0] = x0*cos_theta - x1*sin_theta; @@ -161,34 +218,151 @@ class InitCache { // broadcast: e.g. arange [1, 64] -> [2, 64] BroadCast(power_brcast_local, power_local, broadcast_shape2, arange_shape1); - - // position * freq_scale - copy_in(); - LocalTensor position_local = - position_queue.DeQue(); - position_value = position_local.GetValue(0); - position_value = position_value * freq_scale; - position_queue.FreeTensor(position_local); } - - // theta - LocalTensor theta_local = theta_buffer.Get(); - Muls(theta_local, power_brcast_local, position_value, + + // position + copy_position_in(); + LocalTensor position_local = + position_queue.DeQue(); + position_value = position_local.GetValue(0); + position_queue.FreeTensor(position_local); + + // theta_base + LocalTensor theta_base_local = theta_base_buffer.Get(); + Muls(theta_base_local, power_brcast_local, position_value, broadcast_size); - // sin & cos - // TODO: if ext_factor != 0 - LocalTensor sin_local = sin_buffer.Get(); - Sin(sin_local, theta_local); + // theta_extrap = theta_base/ff + LocalTensor theta_div_ff_local; + if (use_freq_factors) { + copy_freq_factors_in(); + LocalTensor freq_factors_local = + freq_factors_queue.DeQue(); + LocalTensor freq_factors_brcast_local = freq_factors_brcast_buffer.Get(); + if (!is_neox) { + BroadCast(freq_factors_brcast_local, + freq_factors_local, broadcast_shape0, + arange_shape); + } + else { + BroadCast(freq_factors_brcast_local, + freq_factors_local, broadcast_shape2, arange_shape1); + } + + theta_div_ff_local = theta_div_ff_buffer.Get(); + Div(theta_div_ff_local, theta_base_local, freq_factors_brcast_local, + broadcast_size); + freq_factors_queue.FreeTensor(freq_factors_local); + } + else { + theta_div_ff_local = theta_base_local; + } + + // theta_interp + LocalTensor theta_interp_local = theta_interp_buffer.Get(); + PRINTF("freq_scale: %f \n", freq_scale); + PRINTF("mscale: %f \n", attn_factor); + Muls(theta_interp_local, theta_div_ff_local, freq_scale, broadcast_size); + PRINTF("\n"); + for (int i =0; i<128; i++) { + PRINTF("%d: %f, ", i, theta_interp_local.GetValue(i)); + } + + LocalTensor theta_local = theta_buffer.Get(); LocalTensor sin_mul_mscale_local = sin_mul_mscale_queue.AllocTensor(); - Muls(sin_mul_mscale_local, sin_local, attn_factor, broadcast_size); - - LocalTensor cos_local = cos_buffer.Get(); - Cos(cos_local, theta_local); LocalTensor cos_mul_mscale_local = cos_mul_mscale_queue.AllocTensor(); - Muls(cos_mul_mscale_local, cos_local, attn_factor, broadcast_size); + float_t mscale = 0; + if (ext_factor != 0.0f && freq_scale != 1.0f) { + // Need to check + // rope_yarn_ramp + LocalTensor arange_brcast_local = + arange_brcast_buffer.Get(); + BroadCast(arange_brcast_local, arange_local, + broadcast_shape0, arange_shape); + + LocalTensor arange_brcast_div_local = + arange_brcast_div_buff.Get(); + float_t div_param = 0.5; + Muls(arange_brcast_div_local, arange_brcast_local, div_param, broadcast_size); + + LocalTensor arange_brcast_div_add_local = + arange_brcast_div_add_buffer.Get(); + Adds(arange_brcast_div_add_local, arange_brcast_div_local, corr_dims0_neg, broadcast_size); + + LocalTensor rope_yarn_ramp_y_local = + rope_yarn_ramp_y_buffer.Get(); + Muls(rope_yarn_ramp_y_local, arange_brcast_div_add_local, rope_yarn_ramp_max_inv, broadcast_size); + + LocalTensor rope_yarn_ramp_y_max_local = + rope_yarn_ramp_y_max_buffer.Get(); + float_t max_param = 0.0f; + Maxs(rope_yarn_ramp_y_max_local, rope_yarn_ramp_y_local, max_param, broadcast_size); + + LocalTensor rope_yarn_ramp_y_maxmin_local = + rope_yarn_ramp_y_maxmin_buffer.Get(); + float_t mins_param = 1.0f; + Mins(rope_yarn_ramp_y_maxmin_local, rope_yarn_ramp_y_max_local, mins_param, broadcast_size); + + LocalTensor rope_yarn_ramp_y_maxmin_neg_local = + rope_yarn_ramp_y_maxmin_neg_buffer.Get(); + float_t neg_param = -1.0f; + Muls(rope_yarn_ramp_y_maxmin_neg_local, rope_yarn_ramp_y_maxmin_local, neg_param, broadcast_size); + + LocalTensor ramp_mix_local = ramp_mix_buffer.Get(); + float_t add_param = 1.0f; + Adds(ramp_mix_local, rope_yarn_ramp_y_maxmin_neg_local, add_param, broadcast_size); + + // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + LocalTensor theta0_local = theta0_buffer.Get(); + Mul(theta0_local, theta_interp_local, rope_yarn_ramp_y_maxmin_local, broadcast_size); + + LocalTensor theta1_local = theta1_buffer.Get(); + Mul(theta1_local, theta_div_ff_local, ramp_mix_local, broadcast_size); + + Add(theta_local, theta0_local, theta1_local, broadcast_size); + + // mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale) + LocalTensor arange_brcast_add_local = + arange_brcast_add_buffer.Get(); + add_param = 1.0f; + Adds(arange_brcast_add_local, arange_brcast_local, add_param, broadcast_size); + + LocalTensor mscale_local = mscale_buffer.Get(); + Power(mscale_local, static_cast(attn_factor), + arange_brcast_add_local); + + // + LocalTensor sin_local = sin_buffer.Get(); + LocalTensor cos_local = cos_buffer.Get(); + Sin(sin_local, theta_div_ff_local); + Cos(cos_local, theta_div_ff_local); + + Mul(sin_mul_mscale_local, sin_local, mscale_local, broadcast_size); + Mul(cos_mul_mscale_local, cos_local, mscale_local, broadcast_size); + + } + else { + theta_local = theta_interp_local; + mscale = attn_factor; + + LocalTensor sin_local = sin_buffer.Get(); + LocalTensor cos_local = cos_buffer.Get(); + Sin(sin_local, theta_local); + Cos(cos_local, theta_local); + + Muls(sin_mul_mscale_local, sin_local, mscale, broadcast_size); + Muls(cos_mul_mscale_local, cos_local, mscale, broadcast_size); + // PRINTF("\n sin: \n"); + // for (int i =0; i<128; i++) { + // PRINTF("%d: %f, ", i, sin_local.GetValue(i)); + // } + // PRINTF("\n sin * mscale: \n"); + // for (int i =0; i<128; i++) { + // PRINTF("%d: %f, ", i, sin_mul_mscale_local.GetValue(i)); + // } + } // release, VECCALC not need. arange_queue.FreeTensor(arange_local); @@ -209,8 +383,14 @@ class InitCache { float_t theta_scale; float_t attn_factor; float_t freq_scale; + float_t ext_factor; bool is_neox; - bool is_glm; + bool use_freq_factors; + float_t corr_dims0; + float_t corr_dims1; + float_t corr_dims0_neg; + float_t rope_yarn_ramp_max_inv; + uint32_t broadcast_shape0[2]; uint32_t broadcast_shape1[2]; @@ -226,18 +406,35 @@ class InitCache { TPipe pipe; GlobalTensor position_gm; + GlobalTensor freq_factors_gm; GlobalTensor output_sin_gm; GlobalTensor output_cos_gm; TQue arange_queue; TQue power_queue; TQue position_queue; + TQue freq_factors_queue; TQue sin_mul_mscale_queue; TQue cos_mul_mscale_queue; TBuf broadcast_power_buffer; + TBuf theta_base_buffer; TBuf theta_buffer; TBuf sin_buffer; TBuf cos_buffer; - + TBuf freq_factors_brcast_buffer; + TBuf theta_div_ff_buffer; + TBuf theta_interp_buffer; + TBuf arange_brcast_buffer; + TBuf arange_brcast_div_buff; + TBuf arange_brcast_div_add_buffer; + TBuf rope_yarn_ramp_y_buffer; + TBuf rope_yarn_ramp_y_max_buffer; + TBuf rope_yarn_ramp_y_maxmin_buffer; + TBuf rope_yarn_ramp_y_maxmin_neg_buffer; + TBuf ramp_mix_buffer; + TBuf theta0_buffer; + TBuf theta1_buffer; + TBuf arange_brcast_add_buffer; + TBuf mscale_buffer; }; template @@ -251,6 +448,7 @@ __aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, int32_t size) { extern "C" __global__ __aicore__ void ascendc_rope_init_cache( GM_ADDR position_gm, + GM_ADDR freq_factors_gm, GM_ADDR output_sin_gm, GM_ADDR output_cos_gm, GM_ADDR param, @@ -271,6 +469,6 @@ extern "C" __global__ __aicore__ void ascendc_rope_init_cache( copy_to_ub(input_ne_gm, input_ne_ub, 32); InitCache op; - op.init(position_gm, output_sin_gm, output_cos_gm, param_ub, input_ne_ub); + op.init(position_gm, freq_factors_gm, output_sin_gm, output_cos_gm, param_ub, input_ne_ub); op.calculate(); } \ No newline at end of file