From 685ad1b5afc0dfd321b673516359608342784361 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Fri, 15 Nov 2024 22:06:06 +0000 Subject: [PATCH] Moved framework agnostic THD kernels to common. Signed-off-by: Michael Goldfarb --- .../common/fused_attn/thd_utils.cu | 82 ++++++ .../common/fused_attn/thd_utils.h | 222 +++++++++++++++ .../pytorch/csrc/extensions/attention.cu | 265 ------------------ 3 files changed, 304 insertions(+), 265 deletions(-) create mode 100644 transformer_engine/common/fused_attn/thd_utils.cu create mode 100644 transformer_engine/common/fused_attn/thd_utils.h diff --git a/transformer_engine/common/fused_attn/thd_utils.cu b/transformer_engine/common/fused_attn/thd_utils.cu new file mode 100644 index 0000000000..cdb3f82b5b --- /dev/null +++ b/transformer_engine/common/fused_attn/thd_utils.cu @@ -0,0 +1,82 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "thd_utils.h" + +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size * 2) == 0); + cu_seqlens_s[i] = seqlen / world_size; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + + for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; + index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; + output[token_id] = index; + } +} + +__forceinline__ __device__ int binary_search(int target, int *array, int len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_total_tokens = cu_seqlens_s[batch]; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; + half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); + + for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { + int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); + + size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; + float4 *cur_half_token = + reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); + + offset_in_bytes = + (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; + float4 *cur_token = + reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); + + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { + cur_half_token[idx] = cur_token[idx]; + } + } +} diff --git a/transformer_engine/common/fused_attn/thd_utils.h b/transformer_engine/common/fused_attn/thd_utils.h new file mode 100644 index 0000000000..362c66ff21 --- /dev/null +++ b/transformer_engine/common/fused_attn/thd_utils.h @@ -0,0 +1,222 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_THS_UTILS_H_ + +#include +#include + +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank); + +/*************************************************************************************************** + * Support THD format for Context Parallel: Read the half of a THD tensor + **************************************************************************************************/ + +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token); + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ + +struct LseCorrectionFunctor { + __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, + size_t half_idx) { + double val = lse[idx]; + float val_per_step = half_lse[half_idx]; + double max_scale = max(val, val_per_step); + double min_scale = min(val, val_per_step); + lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); + } +}; + +struct ReadLseFunctor { + __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, + size_t half_idx) { + half_lse[half_idx] = lse[idx]; + } +}; + +template +__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, + int num_heads, int total_tokens) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + int num_total_tokens = cu_seqlens_s[batch]; + + for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, half_idx; + if constexpr (lse_packed) { + idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; + half_idx = head_id * total_tokens / 2 + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + + idx = row * total_tokens + col + seq_len; + half_idx = row * total_tokens / 2 + col; + } + + Functor::run(lse, half_lse, idx, half_idx); + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Out correction in forward + **************************************************************************************************/ + +template +__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, + float *lse_per_step, int *cu_seqlens, int batch, + int num_heads, int dim_per_head, int lse_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); + } + __syncthreads(); + + int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; + int lane_id = threadIdx.x % tile_size; + int num_tiles = (blockDim.x * gridDim.x) / tile_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); + + for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, idx_per_step; + + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * lse_seqlen + col + seq_len * only_second_half; + idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; + } + float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); + + idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx = (idx * num_heads + head_id) * dim_per_head; + idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; + dtype *cur_out = out + idx; + dtype *cur_out_per_step = out_per_step + idx_per_step; + + for (int j = lane_id; j < num_loops_per_head; j += tile_size) { + float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; + float4 data = reinterpret_cast(cur_out)[j]; + dtype *p_per_step = reinterpret_cast(&data_per_step); + dtype *p = reinterpret_cast(&data); + for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { + p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); + } + reinterpret_cast(cur_out)[j] = data; + } + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Gradients correction in backward + **************************************************************************************************/ + +struct EmptyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} +}; + +struct CopyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + } +}; + +template +struct AddFunctor { + __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); + + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); + +#pragma unroll + for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { + p_[i] += p[i]; + } + + reinterpret_cast(token)[idx] = d_; + } +}; + +template +__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, + int batch, int hidden_size, int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + if constexpr (functor_idx < 2) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } else { + cu_seqlens_s[i] = cu_seqlens[i]; + } + } + __syncthreads(); + + int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; + int lane_id = threadIdx.x % group_size; + int num_groups = (blockDim.x * gridDim.x) / group_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size; + if constexpr (functor_idx < 2) { + grad_per_step = grad_per_step + offset / 2 * blockIdx.y; + } else { + grad_per_step = grad_per_step + offset * blockIdx.y; + } + grad = grad + offset * blockIdx.y; + + for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + + int token_offset; + bool is_first_half; + if constexpr (functor_idx < 2) { + token_offset = cu_seqlens_s[seq_id + functor_idx]; + is_first_half = (functor_idx == 0); + } else { + token_offset = 0; + int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); + } + + dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; + dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; + for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { + if (is_first_half) { + Functor_0::run(token, token_per_step, idx); + } else { + Functor_1::run(token, token_per_step, idx); + } + } + } +} + +#endif diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 8088a2b8f1..9d657f2704 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1359,64 +1359,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } -/*************************************************************************************************** - * Support THD format for Context Parallel: Binary search - **************************************************************************************************/ - -__forceinline__ __device__ int binary_search(int target, int *array, int len) { - int left = 1, right = len - 1; - while (left < right) { - int mid = (left + right) / 2; - if (array[mid] <= target) { - left = mid + 1; - } else { - right = mid; - } - } - return left - 1; -} - /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ -__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, - int hidden_size_in_bytes, int half_idx, - int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; - int laneid = threadIdx.x % 32; - int num_warps = (blockDim.x * gridDim.x) / 32; - int num_total_tokens = cu_seqlens_s[batch]; - int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; - half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); - tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); - - for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { - int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); - - size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; - float4 *cur_half_token = - reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); - - offset_in_bytes = - (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; - float4 *cur_token = - reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); - - for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { - cur_half_token[idx] = cur_token[idx]; - } - } -} - at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, int half_idx) { NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); @@ -1464,51 +1410,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s * Support THD format for Context Parallel: softmax_lse related operations **************************************************************************************************/ -template -__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, - int num_heads, int total_tokens) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - int num_total_tokens = cu_seqlens_s[batch]; - - for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t idx, half_idx; - if constexpr (lse_packed) { - idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; - half_idx = head_id * total_tokens / 2 + token_id; - } else { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - - idx = row * total_tokens + col + seq_len; - half_idx = row * total_tokens / 2 + col; - } - - Functor::run(lse, half_lse, idx, half_idx); - } - } -} - -struct LseCorrectionFunctor { - __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, - size_t half_idx) { - double val = lse[idx]; - float val_per_step = half_lse[half_idx]; - double max_scale = max(val, val_per_step); - double min_scale = min(val, val_per_step); - lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); - } -}; - void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); @@ -1559,13 +1460,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st } } -struct ReadLseFunctor { - __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, - size_t half_idx) { - half_lse[half_idx] = lse[idx]; - } -}; - at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); @@ -1620,59 +1514,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ * Support THD format for Context Parallel: Out correction in forward **************************************************************************************************/ -template -__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, - float *lse_per_step, int *cu_seqlens, int batch, - int num_heads, int dim_per_head, int lse_seqlen) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); - } - __syncthreads(); - - int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; - int lane_id = threadIdx.x % tile_size; - int num_tiles = (blockDim.x * gridDim.x) / tile_size; - int num_total_tokens = cu_seqlens_s[batch]; - int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); - - for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t idx, idx_per_step; - - if constexpr (lse_packed) { - idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; - } else { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - idx = row * lse_seqlen + col + seq_len * only_second_half; - idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; - } - float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); - - idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx = (idx * num_heads + head_id) * dim_per_head; - idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; - dtype *cur_out = out + idx; - dtype *cur_out_per_step = out_per_step + idx_per_step; - - for (int j = lane_id; j < num_loops_per_head; j += tile_size) { - float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; - float4 data = reinterpret_cast(cur_out)[j]; - dtype *p_per_step = reinterpret_cast(&data_per_step); - dtype *p = reinterpret_cast(&data); - for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { - p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); - } - reinterpret_cast(cur_out)[j] = data; - } - } - } -} - template static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, @@ -1773,87 +1614,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at * Support THD format for Context Parallel: Gradients correction in backward **************************************************************************************************/ -template -__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, - int batch, int hidden_size, int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - if constexpr (functor_idx < 2) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } else { - cu_seqlens_s[i] = cu_seqlens[i]; - } - } - __syncthreads(); - - int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; - int lane_id = threadIdx.x % group_size; - int num_groups = (blockDim.x * gridDim.x) / group_size; - int num_total_tokens = cu_seqlens_s[batch]; - int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size; - if constexpr (functor_idx < 2) { - grad_per_step = grad_per_step + offset / 2 * blockIdx.y; - } else { - grad_per_step = grad_per_step + offset * blockIdx.y; - } - grad = grad + offset * blockIdx.y; - - for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - - int token_offset; - bool is_first_half; - if constexpr (functor_idx < 2) { - token_offset = cu_seqlens_s[seq_id + functor_idx]; - is_first_half = (functor_idx == 0); - } else { - token_offset = 0; - int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); - } - - dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; - dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; - for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { - if (is_first_half) { - Functor_0::run(token, token_per_step, idx); - } else { - Functor_1::run(token, token_per_step, idx); - } - } - } -} - -struct EmptyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} -}; - -struct CopyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { - reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; - } -}; - -template -struct AddFunctor { - __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { - float4 d_ = reinterpret_cast(token)[idx]; - dtype *p_ = reinterpret_cast(&d_); - - float4 d = reinterpret_cast(token_per_step)[idx]; - dtype *p = reinterpret_cast(&d); - -#pragma unroll - for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { - p_[i] += p[i]; - } - - reinterpret_cast(token)[idx] = d_; - } -}; - template static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step, const at::Tensor &cu_seqlens) { @@ -1945,31 +1705,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, * Support THD format for Context Parallel: Generate partitioned indices for input tokens **************************************************************************************************/ -__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, - int total_tokens, int world_size, int rank) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - int seqlen = cu_seqlens[i]; - // Currently we assume that each sequence length is divisible by (world_size*2) since we have - // to distribute each sequence evenly to different GPUs. - assert(seqlen % (world_size * 2) == 0); - cu_seqlens_s[i] = seqlen / world_size; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - - for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - int index = token_id - cu_seqlens_s[seq_id]; - int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; - index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; - output[token_id] = index; - } -} - at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank) { NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);