Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused out correction in CP #1248

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 161 additions & 33 deletions transformer_engine/common/fused_attn/thd_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,58 +102,186 @@ __global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens,
}

/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
* Support BSHD, SBHD, and THD formats for Context Parallel: Out correction in forward
**************************************************************************************************/

template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int lse_seqlen) {
// format of out and lse, ignoring d as it’s always the last dimension.
enum QKVFormat { SBH, BSH, BHS, HBS, TH, HT };

template <int n>
struct TensorList {
void *addresses_out[n];
void *addresses_lse[n];
int start_tensor_this_launch;
};

// describe tensor format for simplified computation.
template <QKVFormat format>
struct TensorFormat {
// store the bsht order for simplified computation, where bsht corresponds to 0, 1, 2, 3, and store_format[3] marks whether bs is fused into t
int8_t store_format[4];
int *cu_seqlens_s;
// size of tensor, b s h t
int size[4];
__forceinline__ __device__ TensorFormat(int size_kernel[4], int *cu_seqlens = nullptr) {
for (int i = 0; i < 4; i++) {
size[i] = size_kernel[i];
}

if constexpr (format == QKVFormat::TH) {
cu_seqlens_s = cu_seqlens;
store_format[0] = 3;
store_format[1] = 2;
store_format[3] = 1;
} else if constexpr (format == QKVFormat::HT) {
cu_seqlens_s = cu_seqlens;
store_format[0] = 2;
store_format[1] = 3;
store_format[3] = 1;
} else if constexpr (format == QKVFormat::SBH) {
store_format[0] = 1;
store_format[1] = 0;
store_format[2] = 2;
store_format[3] = 0;
} else if constexpr (format == QKVFormat::HBS) {
store_format[0] = 2;
store_format[1] = 0;
store_format[2] = 1;
store_format[3] = 0;
} else if constexpr (format == QKVFormat::BSH) {
store_format[0] = 0;
store_format[1] = 1;
store_format[2] = 2;
store_format[3] = 0;
} else if constexpr (format == QKVFormat::BHS) {
store_format[0] = 0;
store_format[1] = 2;
store_format[2] = 1;
store_format[3] = 0;
}
}

// calculate address according to index
__forceinline__ __device__ int compute_address(int id[4]) {
int address;
if (store_format[3] == 1) {
address = id[store_format[0]] * size[store_format[1]] + id[store_format[1]];
} else {
address = id[store_format[0]] * size[store_format[1]] + id[store_format[1]];
address = address * size[store_format[2]] + id[store_format[2]];
}
return address;
}

// compute half right index
__forceinline__ __device__ void compute_half_right(int id[4]) {
if constexpr (format == QKVFormat::TH) {
id[1] -= (cu_seqlens_s[id[0] + 1] - cu_seqlens_s[id[0]]) / 2;
id[3] -= cu_seqlens_s[id[0] + 1] / 2;
} else if constexpr (format == QKVFormat::BSH || format == QKVFormat::SBH) {
id[1] -= size[1] / 2;
}
}
};

template <typename dtype, int tile_size, bool causal, QKVFormat out_format, QKVFormat lse_format,
int max_tensors>
__global__ void fused_out_correction_kernel(dtype *out, TensorList<max_tensors> tensors, float *lse,
int *cu_seqlens, int batch, int num_heads,
int dim_per_head, int lse_seqlen, int cp_size, int rank,
int start) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
int full_num;
int num_total_tokens;

if constexpr (out_format == QKVFormat::TH) {
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this to initialize TensorFormat when THD format is applied

}
__syncthreads();
num_total_tokens = cu_seqlens_s[batch];
} else if constexpr (out_format == QKVFormat::SBH || out_format == QKVFormat::BSH) {
num_total_tokens = lse_seqlen * batch;
}
__syncthreads();

if constexpr (causal) {
full_num = min(start + tensors.start_tensor_this_launch, max(rank + 1, start));
} else {
full_num = start + tensors.start_tensor_this_launch;
}

int size[4] = {batch, lse_seqlen, num_heads, lse_seqlen};
TensorFormat<out_format> out_full(size, cu_seqlens_s);
TensorFormat<lse_format> lse_full(size);

int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size;
int lane_id = threadIdx.x % tile_size;
int num_tiles = (blockDim.x * gridDim.x) / tile_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4);

for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, idx_per_step;
size_t idx_out_full, idx_lse_full, idx_out_half, idx_lse_half;

if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) {
int head_id = blockIdx.y;
int id[4];
if constexpr (out_format == QKVFormat::TH) {
id[0] = binary_search(token_id, cu_seqlens_s, batch + 1);
id[1] = token_id - cu_seqlens_s[id[0]];
} else if constexpr (out_format == QKVFormat::BSH) {
id[0] = token_id / lse_seqlen;
id[1] = token_id - id[0] * lse_seqlen;
} else if constexpr (out_format == QKVFormat::SBH) {
id[1] = token_id / batch;
id[0] = token_id - id[1] * batch;
}
id[2] = head_id;
id[3] = token_id;

idx_out_full = out_full.compute_address(id);
idx_lse_full = lse_full.compute_address(id);

dtype *cur_out = out + idx_out_full * dim_per_head;
float lse_temp = lse[idx_lse_full];

int end = full_num;

if (start + tensors.start_tensor_this_launch > full_num) {
out_full.compute_half_right(id);
if (id[1] >= 0) {
int size_half[4] = {batch, lse_seqlen / 2, num_heads, lse_seqlen / 2};
TensorFormat<out_format> out_half(size_half);
TensorFormat<lse_format> lse_half(size_half);
idx_out_half = out_half.compute_address(id);
idx_lse_half = lse_half.compute_address(id);
end = start + tensors.start_tensor_this_launch;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);

idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx = (idx * num_heads + head_id) * dim_per_head;
idx_per_step = (static_cast<size_t>(token_id) * num_heads + head_id) * dim_per_head;
dtype *cur_out = out + idx;
dtype *cur_out_per_step = out_per_step + idx_per_step;
}

for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
float4 data = reinterpret_cast<float4 *>(cur_out)[j];
dtype *p = reinterpret_cast<dtype *>(&data);

for (int i = start; i < end; i++) {
size_t idx_out;
size_t idx_lse;
if (causal && id[1] >= 0 && i > rank) {
idx_out = idx_out_half;
idx_lse = idx_lse_half;
} else {
idx_out = idx_out_full;
idx_lse = idx_lse_full;
}
dtype *cur_out_per_step =
reinterpret_cast<dtype *>(tensors.addresses_out[i]) + idx_out * dim_per_head;
float4 data_per_step = reinterpret_cast<float4 *>(cur_out_per_step)[j];
float4 data = reinterpret_cast<float4 *>(cur_out)[j];
float lse_corrected_exp =
exp(reinterpret_cast<float *>(tensors.addresses_lse[i])[idx_lse] - lse_temp);
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp);
}
reinterpret_cast<float4 *>(cur_out)[j] = data;
}
reinterpret_cast<float4 *>(cur_out)[j] = data;
}
}
}
Expand Down
82 changes: 13 additions & 69 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,24 +1560,6 @@ def flash_attn_p2p_communicate(
return send_recv_reqs


@jit_fuser
def flash_attn_fwd_out_correction(
out: torch.Tensor,
out_per_step: torch.Tensor,
softmax_lse: torch.Tensor,
softmax_lse_per_step: torch.Tensor,
movedim_src: int,
movedim_dst: int,
):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(
movedim_src, movedim_dst
)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected)


@jit_fuser
def flash_attn_fwd_softmax_lse_correction(
softmax_lse: torch.Tensor,
Expand Down Expand Up @@ -2438,57 +2420,6 @@ def forward(
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])

softmax_lse = softmax_lse.to(torch.float)
for i in range(cp_size):
out_ = None
if qkv_format == "bshd":
out_per_step[i] = out_per_step[i].view(
out.shape[0], -1, *out.shape[-2:]
) # pylint: disable=used-before-assignment
out_ = out[:, 1, ...]
elif qkv_format == "sbhd":
out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:])
out_ = out[1]

if i <= rank or not causal:
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(
out.view(*out_per_step[i].shape),
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
)
elif qkv_format == "thd":
tex.thd_out_correction(
out,
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q_padded,
False,
softmax_lse_in_packed_format,
)
else:
if qkv_format in ["bshd", "sbhd"]:
flash_attn_fwd_out_correction(
out_,
out_per_step[i],
softmax_lse_[..., 1, :],
softmax_lse_per_step[i],
0 if softmax_lse_in_packed_format else 2,
2 if softmax_lse_in_packed_format else seq_dim,
)
elif qkv_format == "thd":
tex.thd_out_correction(
out,
out_per_step[i],
softmax_lse,
softmax_lse_per_step[i],
cu_seqlens_q_padded,
True,
softmax_lse_in_packed_format,
)

if qkv_format != "thd" and softmax_lse_in_packed_format:
# [np, b, sq] -> [np, t]
Expand All @@ -2501,6 +2432,19 @@ def forward(
out = out.view(-1, *out.shape[-3:])
ctx.batch_size = out.shape[1]

tex.fused_out_correction(
out,
out_per_step,
softmax_lse,
softmax_lse_per_step,
cu_seqlens_q_padded,
qkv_format,
cp_size,
rank,
causal,
softmax_lse_in_packed_format,
)

if cp_size_a2a > 1:
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False)
out = flash_attn_a2a_communicate(
Expand Down
9 changes: 5 additions & 4 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,17 +446,18 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed);

void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half, bool lse_packed);

void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half,
const std::string &second_half);

at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank);

void fused_out_correction(at::Tensor out, const std::vector<at::Tensor> &out_per_step,
const at::Tensor &lse, const std::vector<at::Tensor> &lse_per_step,
const at::Tensor &cu_seqlens, std::string qkv_format, int cp_size,
int rank, bool causal, bool softmax_lse_in_packed_format);

/***************************************************************************************************
* multi_tensor_* kernels
**************************************************************************************************/
Expand Down
Loading