-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fused out correction in CP #1248
base: main
Are you sure you want to change the base?
Changes from 15 commits
49a738e
314c841
c0f1998
c8af3f7
c66d980
dafbd5a
96dfd5e
2df8bd4
4f397da
e33b2a2
35c9836
fd83104
f977de3
97882f1
9d7e3cf
89bbeb7
d346d9c
b2a2fc2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,58 +102,187 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, | |
} | ||
|
||
/*************************************************************************************************** | ||
* Support THD format for Context Parallel: Out correction in forward | ||
* Support THD(including SBHD and BSHD) format for Context Parallel: Out correction in forward | ||
**************************************************************************************************/ | ||
|
||
template <typename dtype, int only_second_half, int tile_size, bool lse_packed> | ||
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, | ||
float *lse_per_step, int *cu_seqlens, int batch, | ||
int num_heads, int dim_per_head, int lse_seqlen) { | ||
// format of out and lse, ignoring d as it’s always the last dimension. | ||
enum QKVFormat { SBH, BSH, BHS, HBS, TH, HT }; | ||
|
||
template <int n> | ||
struct TensorList { | ||
void *addresses_out[n]; | ||
void *addresses_lse[n]; | ||
int start_tensor_this_launch; | ||
}; | ||
|
||
// describe tensor format for simplified computation. | ||
template <QKVFormat format> | ||
struct TensorFormat { | ||
// store the bsht order for simplified computation, where bsht corresponds to 0, 1, 2, 3, and store_format[3] marks whether bs is fused into t | ||
int8_t store_format[4]; | ||
int *cu_seqlens_s; | ||
// size of tensor, b s h t | ||
int size[4]; | ||
__forceinline__ __device__ TensorFormat(int size_kernel[4], int *cu_seqlens = nullptr) { | ||
for (int i = 0; i < 4; i++) { | ||
size[i] = size_kernel[i]; | ||
} | ||
|
||
if constexpr (format == QKVFormat::TH) { | ||
cu_seqlens_s = cu_seqlens; | ||
store_format[0] = 3; | ||
store_format[1] = 2; | ||
store_format[3] = 1; | ||
} else if constexpr (format == QKVFormat::HT) { | ||
cu_seqlens_s = cu_seqlens; | ||
store_format[0] = 2; | ||
store_format[1] = 3; | ||
store_format[3] = 1; | ||
} else if constexpr (format == QKVFormat::SBH) { | ||
store_format[0] = 1; | ||
store_format[1] = 0; | ||
store_format[2] = 2; | ||
store_format[3] = 0; | ||
} else if constexpr (format == QKVFormat::HBS) { | ||
store_format[0] = 2; | ||
store_format[1] = 0; | ||
store_format[2] = 1; | ||
store_format[3] = 0; | ||
} else if constexpr (format == QKVFormat::BSH) { | ||
store_format[0] = 0; | ||
store_format[1] = 1; | ||
store_format[2] = 2; | ||
store_format[3] = 0; | ||
} else if constexpr (format == QKVFormat::BHS) { | ||
store_format[0] = 0; | ||
store_format[1] = 2; | ||
store_format[2] = 1; | ||
store_format[3] = 0; | ||
} | ||
} | ||
|
||
// calculate address according to index | ||
__forceinline__ __device__ int compute_address(int id[4]) { | ||
int address; | ||
if (store_format[3] == 1) { | ||
address = id[store_format[0]] * size[store_format[1]] + id[store_format[1]]; | ||
} else { | ||
address = id[store_format[0]] * size[store_format[1]] + id[store_format[1]]; | ||
address = address * size[store_format[2]] + id[store_format[2]]; | ||
} | ||
return address; | ||
} | ||
|
||
// compute half right index | ||
__forceinline__ __device__ void compute_half_right(int id[4]) { | ||
if constexpr (format == QKVFormat::TH) { | ||
id[1] -= (cu_seqlens_s[id[0] + 1] - cu_seqlens_s[id[0]]) / 2; | ||
id[3] -= cu_seqlens_s[id[0] + 1] / 2; | ||
} else if constexpr (format == QKVFormat::BSH || format == QKVFormat::SBH) { | ||
id[1] -= size[1] / 2; | ||
} | ||
} | ||
}; | ||
|
||
template <typename dtype, int tile_size, bool causal, QKVFormat out_format, QKVFormat lse_format, | ||
int max_tensors> | ||
__global__ void fused_out_correction_kernel(dtype *out, TensorList<max_tensors> tensors, float *lse, | ||
int *cu_seqlens, int batch, int num_heads, | ||
int dim_per_head, int lse_seqlen, int cp_size, int rank, | ||
int start) { | ||
extern __shared__ int cu_seqlens_s[]; | ||
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { | ||
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); | ||
int full_num; | ||
int num_total_tokens; | ||
|
||
if constexpr (out_format == QKVFormat::TH) { | ||
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { | ||
cu_seqlens_s[i] = cu_seqlens[i]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need this to initialize TensorFormat when THD format is applied |
||
} | ||
__syncthreads(); | ||
num_total_tokens = cu_seqlens_s[batch]; | ||
} else if constexpr (out_format == QKVFormat::SBH || out_format == QKVFormat::BSH) { | ||
num_total_tokens = lse_seqlen * batch; | ||
} | ||
|
||
if constexpr (causal) { | ||
full_num = min(start + tensors.start_tensor_this_launch, max(rank + 1, start)); | ||
} else { | ||
full_num = start + tensors.start_tensor_this_launch; | ||
} | ||
__syncthreads(); | ||
|
||
int size[4] = {batch, lse_seqlen, num_heads, lse_seqlen}; | ||
TensorFormat<out_format> out_full(size, cu_seqlens_s); | ||
TensorFormat<lse_format> lse_full(size); | ||
|
||
int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; | ||
int lane_id = threadIdx.x % tile_size; | ||
int num_tiles = (blockDim.x * gridDim.x) / tile_size; | ||
int num_total_tokens = cu_seqlens_s[batch]; | ||
int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); | ||
|
||
for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { | ||
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); | ||
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { | ||
size_t idx, idx_per_step; | ||
size_t idx_out_full, idx_lse_full, idx_out_half, idx_lse_half; | ||
|
||
if constexpr (lse_packed) { | ||
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; | ||
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; | ||
} else { | ||
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id; | ||
int col = token_id - cu_seqlens_s[seq_id]; | ||
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; | ||
idx = row * lse_seqlen + col + seq_len * only_second_half; | ||
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; | ||
for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { | ||
int head_id = blockIdx.y; | ||
int id[4]; | ||
if constexpr (out_format == QKVFormat::TH) { | ||
id[0] = binary_search(token_id, cu_seqlens_s, batch + 1); | ||
id[1] = token_id - cu_seqlens_s[id[0]]; | ||
} else if constexpr (out_format == QKVFormat::BSH) { | ||
id[0] = token_id / lse_seqlen; | ||
id[1] = token_id - id[0] * lse_seqlen; | ||
} else if constexpr (out_format == QKVFormat::SBH) { | ||
id[1] = token_id / batch; | ||
id[0] = token_id - id[1] * batch; | ||
} | ||
id[2] = head_id; | ||
id[3] = token_id; | ||
|
||
idx_out_full = out_full.compute_address(id); | ||
idx_lse_full = lse_full.compute_address(id); | ||
|
||
dtype *cur_out = out + idx_out_full * dim_per_head; | ||
float lse_temp = lse[idx_lse_full]; | ||
|
||
int end = full_num; | ||
|
||
if (start + tensors.start_tensor_this_launch > full_num) { | ||
out_full.compute_half_right(id); | ||
if (id[1] >= 0) { | ||
int size_half[4] = {batch, lse_seqlen / 2, num_heads, lse_seqlen / 2}; | ||
TensorFormat<out_format> out_half(size_half); | ||
TensorFormat<lse_format> lse_half(size_half); | ||
idx_out_half = out_half.compute_address(id); | ||
idx_lse_half = lse_half.compute_address(id); | ||
end = start + tensors.start_tensor_this_launch; | ||
} | ||
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); | ||
} | ||
|
||
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; | ||
idx = (idx * num_heads + head_id) * dim_per_head; | ||
idx_per_step = (static_cast<size_t>(token_id) * num_heads + head_id) * dim_per_head; | ||
dtype *cur_out = out + idx; | ||
dtype *cur_out_per_step = out_per_step + idx_per_step; | ||
for (int j = lane_id; j < num_loops_per_head; j += tile_size) { | ||
size_t idx_out; | ||
size_t idx_lse; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okok |
||
float4 data = reinterpret_cast<float4 *>(cur_out)[j]; | ||
dtype *p = reinterpret_cast<dtype *>(&data); | ||
|
||
for (int i = start; i < end; i++) { | ||
if (id[1] >= 0 && start + tensors.start_tensor_this_launch > full_num && i > rank) { | ||
idx_out = idx_out_half; | ||
idx_lse = idx_lse_half; | ||
} else { | ||
idx_out = idx_out_full; | ||
idx_lse = idx_lse_full; | ||
} | ||
|
||
for (int j = lane_id; j < num_loops_per_head; j += tile_size) { | ||
dtype *cur_out_per_step = | ||
reinterpret_cast<dtype *>(tensors.addresses_out[i]) + idx_out * dim_per_head; | ||
float4 data_per_step = reinterpret_cast<float4 *>(cur_out_per_step)[j]; | ||
float4 data = reinterpret_cast<float4 *>(cur_out)[j]; | ||
float lse_corrected_exp = | ||
exp(reinterpret_cast<float *>(tensors.addresses_lse[i])[idx_lse] - lse_temp); | ||
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step); | ||
dtype *p = reinterpret_cast<dtype *>(&data); | ||
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { | ||
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); | ||
} | ||
reinterpret_cast<float4 *>(cur_out)[j] = data; | ||
} | ||
reinterpret_cast<float4 *>(cur_out)[j] = data; | ||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1473,24 +1473,6 @@ def flash_attn_p2p_communicate( | |
return send_recv_reqs | ||
|
||
|
||
@jit_fuser | ||
def flash_attn_fwd_out_correction( | ||
out: torch.Tensor, | ||
out_per_step: torch.Tensor, | ||
softmax_lse: torch.Tensor, | ||
softmax_lse_per_step: torch.Tensor, | ||
movedim_src: int, | ||
movedim_dst: int, | ||
): | ||
"""Merge partial outputs of each step in Attention with context parallelism""" | ||
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim( | ||
movedim_src, movedim_dst | ||
) | ||
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) | ||
out_corrected = out_per_step * softmax_lse_corrected_exp | ||
out.add_(out_corrected) | ||
|
||
|
||
@jit_fuser | ||
def flash_attn_fwd_softmax_lse_correction( | ||
softmax_lse: torch.Tensor, | ||
|
@@ -2351,57 +2333,46 @@ def forward( | |
torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) | ||
|
||
softmax_lse = softmax_lse.to(torch.float) | ||
for i in range(cp_size): | ||
out_ = None | ||
if qkv_format == "bshd": | ||
out_per_step[i] = out_per_step[i].view( | ||
out.shape[0], -1, *out.shape[-2:] | ||
) # pylint: disable=used-before-assignment | ||
out_ = out[:, 1, ...] | ||
elif qkv_format == "sbhd": | ||
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) | ||
out_ = out[1] | ||
|
||
if i <= rank or not causal: | ||
if qkv_format in ["bshd", "sbhd"]: | ||
flash_attn_fwd_out_correction( | ||
out.view(*out_per_step[i].shape), | ||
out_per_step[i], | ||
softmax_lse, | ||
softmax_lse_per_step[i], | ||
0 if softmax_lse_in_packed_format else 2, | ||
2 if softmax_lse_in_packed_format else seq_dim, | ||
) | ||
elif qkv_format == "thd": | ||
tex.thd_out_correction( | ||
out, | ||
out_per_step[i], | ||
softmax_lse, | ||
softmax_lse_per_step[i], | ||
cu_seqlens_q_padded, | ||
False, | ||
softmax_lse_in_packed_format, | ||
) | ||
else: | ||
if qkv_format in ["bshd", "sbhd"]: | ||
flash_attn_fwd_out_correction( | ||
out_, | ||
out_per_step[i], | ||
softmax_lse_[..., 1, :], | ||
softmax_lse_per_step[i], | ||
0 if softmax_lse_in_packed_format else 2, | ||
2 if softmax_lse_in_packed_format else seq_dim, | ||
) | ||
elif qkv_format == "thd": | ||
tex.thd_out_correction( | ||
out, | ||
out_per_step[i], | ||
softmax_lse, | ||
softmax_lse_per_step[i], | ||
cu_seqlens_q_padded, | ||
True, | ||
softmax_lse_in_packed_format, | ||
) | ||
|
||
if qkv_format == "sbhd": | ||
tex.fused_out_correction( | ||
out.view(-1, *out.shape[-3:]), | ||
out_per_step, | ||
softmax_lse, | ||
softmax_lse_per_step, | ||
cu_seqlens_q_padded, | ||
qkv_format, | ||
cp_size, | ||
rank, | ||
causal, | ||
softmax_lse_in_packed_format, | ||
) | ||
elif qkv_format == "bshd": | ||
tex.fused_out_correction( | ||
out.view(out.shape[-4], -1, *out.shape[-2:]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, i will fix this |
||
out_per_step, | ||
softmax_lse, | ||
softmax_lse_per_step, | ||
cu_seqlens_q_padded, | ||
qkv_format, | ||
cp_size, | ||
rank, | ||
causal, | ||
softmax_lse_in_packed_format, | ||
) | ||
else: | ||
tex.fused_out_correction( | ||
out, | ||
out_per_step, | ||
softmax_lse, | ||
softmax_lse_per_step, | ||
cu_seqlens_q_padded, | ||
qkv_format, | ||
cp_size, | ||
rank, | ||
causal, | ||
softmax_lse_in_packed_format, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we merger this two function into one? I see they are almost same. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok they are merged, and remove softmax_lse_ in cuda code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lots of arguments are same, only first one is different? should remove the duplicated code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, thanks for your suggestion |
||
|
||
if qkv_format != "thd" and softmax_lse_in_packed_format: | ||
# [np, b, sq] -> [np, t] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,6 +136,11 @@ std::vector<at::Tensor> fused_attn_bwd( | |
at::Tensor fa_prepare_fwd(at::Tensor qkvi); | ||
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); | ||
|
||
void fused_out_correction(at::Tensor out, const std::vector<at::Tensor> &out_per_step, | ||
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step, | ||
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size, | ||
int rank, bool causal, bool softmax_lse_in_packed_format); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put these two functions close to other CP helper functions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't you fix this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i will fix this |
||
/*************************************************************************************************** | ||
* GEMM | ||
**************************************************************************************************/ | ||
|
@@ -446,10 +451,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st | |
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, | ||
bool lse_packed); | ||
|
||
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, | ||
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, | ||
bool only_second_half, bool lse_packed); | ||
|
||
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, | ||
const at::Tensor &cu_seqlens, const std::string &first_half, | ||
const std::string &second_half); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change this to "Support BSHD, SBHD, and THD formats for Context Parallel: Out correction in forward"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok