Skip to content

Commit

Permalink
Add rotate_half implementation for fused_rope (#56401)
Browse files Browse the repository at this point in the history
* add rotate_half in fused_rope

* add position_ids in fused_rope

* modified examples about fused_rope

* add set_device in examples
  • Loading branch information
tianhaodongbd authored Sep 4, 2023
1 parent be9cb94 commit c089a2a
Show file tree
Hide file tree
Showing 11 changed files with 459 additions and 114 deletions.
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/fused_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
support_dygraph_mode : true

- backward_op : fused_rotary_position_embedding_grad
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
args : (Tensor sin, Tensor cos, Tensor position_ids, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad
optional : sin, cos, position_ids, out_k_grad, out_v_grad, k_grad, v_grad
infer_meta :
func : FusedRopeGradInferMeta
kernel :
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index

- op : fused_rotary_position_embedding
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos)
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true)
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
infer_meta :
func : FusedRopeInferMeta
optional : k,v,sin,cos, out_k, out_v
optional : k, v, sin, cos, position_ids, out_k, out_v
kernel :
func : fused_rotary_position_embedding
data_type : q
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1219,9 +1219,11 @@ void IndexPutGradInferMeta(const MetaTensor& x,

void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
bool use_neox_rotary_style,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,11 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,

void FusedRopeGradInferMeta(const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
const MetaTensor& dout_q,
const MetaTensor& dout_k,
const MetaTensor& dout_v,
bool use_neox_rotary_style,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4041,6 +4041,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
bool use_neox_rotary_style,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& v,
const MetaTensor& sin,
const MetaTensor& cos,
const MetaTensor& position_ids,
bool use_neox_rotary_style,
MetaTensor* out_q,
MetaTensor* out_k,
MetaTensor* out_v);
Expand Down
48 changes: 36 additions & 12 deletions paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ template <typename T, typename Context>
void FusedRopeGradKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
const paddle::optional<DenseTensor>& position_ids,
const DenseTensor& dout_q,
const paddle::optional<DenseTensor>& dout_k,
const paddle::optional<DenseTensor>& dout_v,
bool use_neox_rotary_style,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
Expand Down Expand Up @@ -58,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;
const int64_t* position_ids_data = NULL;

ins_data[0] = dout_q.data<T>();
outs_data[0] = dq->data<T>();
Expand Down Expand Up @@ -86,21 +89,42 @@ void FusedRopeGradKernel(const Context& dev_ctx,
sin_cos_data[1] = cos->data<T>();

flag_sin_cos = true;

if (position_ids.get_ptr()) {
position_ids_data = position_ids->data<int64_t>();
}
}

int sign = -1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
if (use_neox_rotary_style) {
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
} else {
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
}

} // namespace fusion
Expand Down
99 changes: 78 additions & 21 deletions paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx,
const paddle::optional<DenseTensor>& v,
const paddle::optional<DenseTensor>& sin,
const paddle::optional<DenseTensor>& cos,
const paddle::optional<DenseTensor>& position_ids,
bool use_neox_rotary_style,
DenseTensor* out_q,
DenseTensor* out_k,
DenseTensor* out_v) {
Expand Down Expand Up @@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx,
phi::Array<T*, 3> outs_data;
phi::Array<const T*, 3> ins_data;
phi::Array<const T*, 2> sin_cos_data;
const int64_t* position_ids_data = NULL;

ins_data[0] = q.data<T>();
outs_data[0] = out_q->data<T>();
Expand Down Expand Up @@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx,
"The batch_size and num_heads of sin and cos must be 1."));
}
int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0;
PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] == seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len and head_dim of sin and cos "
"must be the same as those of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));

if (position_ids.get_ptr()) {
PADDLE_ENFORCE_EQ(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] >= seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len of sin and cos must be greater than or equal to "
"this of q. The head_dim of sin and cos must be the same as this "
"of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));

auto position_ids_dims = position_ids.get_ptr()->dims();
PADDLE_ENFORCE_EQ(position_ids_dims.size(),
2,
phi::errors::InvalidArgument(
"The dims of position_ids is expected to "
"be 2, but recieved %d.",
position_ids_dims.size()));

PADDLE_ENFORCE_EQ(
(position_ids_dims[0] == batch_size &&
position_ids_dims[1] == seq_len),
true,
phi::errors::InvalidArgument(
"The batch_size and seq_len of position_ids must be the same as "
"those of q. But recieved position_ids's "
"shape is {%s}, q's shape is {%s}.",
position_ids_dims,
q.dims()));

position_ids_data = position_ids->data<int64_t>();
} else {
PADDLE_ENFORCE_EQ(
(sin_dims[dims_size - 1] == head_dim &&
sin_dims[sin_seq_len_dim] == seq_len),
true,
phi::errors::InvalidArgument(
"The seq_len and head_dim of sin and cos "
"must be the same as those of q. But recieved sin's "
"shape is {%s}, q's shape is {%s}.",
sin_dims,
q.dims()));
}

sin_cos_data[0] = sin->data<T>();
sin_cos_data[1] = cos->data<T>();
Expand All @@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx,
}

int sign = 1;
VectorizedFusedRopeKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
if (use_neox_rotary_style) {
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
} else {
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
<<<grid, block, 0, stream>>>(ins_data,
sin_cos_data,
position_ids_data,
flag_sin_cos,
sign,
batch_size,
seq_len,
num_heads,
head_dim,
outs_data,
num_inputs,
div_c);
}
}
} // namespace fusion
} // namespace phi
Expand Down
Loading

0 comments on commit c089a2a

Please sign in to comment.