-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Common] Moved framework agnostic THD kernels to common. #1339
Merged
mgoldfarb-nvidia
merged 4 commits into
NVIDIA:main
from
mgoldfarb-nvidia:mgoldfarb/move_thd_fused_attn_utils_to_common
Nov 25, 2024
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
89fcca9
Moved framework agnostic THD kernels to common.
mgoldfarb-nvidia ff1729d
Force max jobs to 1
mgoldfarb-nvidia 136604a
Fix namespace issues.
mgoldfarb-nvidia f114214
Merge branch 'main' into mgoldfarb/move_thd_fused_attn_utils_to_common
mgoldfarb-nvidia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
#include "../cudnn_utils.h" | ||
#include "thd_utils.h" | ||
|
||
namespace transformer_engine { | ||
namespace fused_attn { | ||
|
||
__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, | ||
int total_tokens, int world_size, int rank) { | ||
extern __shared__ int cu_seqlens_s[]; | ||
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { | ||
int seqlen = cu_seqlens[i]; | ||
// Currently we assume that each sequence length is divisible by (world_size*2) since we have | ||
// to distribute each sequence evenly to different GPUs. | ||
assert(seqlen % (world_size * 2) == 0); | ||
cu_seqlens_s[i] = seqlen / world_size; | ||
} | ||
__syncthreads(); | ||
|
||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
int num_threads = blockDim.x * gridDim.x; | ||
|
||
for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { | ||
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); | ||
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; | ||
int index = token_id - cu_seqlens_s[seq_id]; | ||
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; | ||
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; | ||
output[token_id] = index; | ||
} | ||
} | ||
|
||
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, | ||
int hidden_size_in_bytes, int half_idx, | ||
int dim_size_of_token) { | ||
extern __shared__ int cu_seqlens_s[]; | ||
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { | ||
cu_seqlens_s[i] = cu_seqlens[i] / 2; | ||
} | ||
__syncthreads(); | ||
|
||
int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; | ||
int laneid = threadIdx.x % 32; | ||
int num_warps = (blockDim.x * gridDim.x) / 32; | ||
int num_total_tokens = cu_seqlens_s[batch]; | ||
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); | ||
|
||
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size_in_bytes; | ||
half = reinterpret_cast<void *>(reinterpret_cast<char *>(half) + offset / 2 * blockIdx.y); | ||
tensor = reinterpret_cast<void *>(reinterpret_cast<char *>(tensor) + offset * blockIdx.y); | ||
|
||
for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { | ||
int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); | ||
|
||
size_t offset_in_bytes = static_cast<size_t>(token_id) * hidden_size_in_bytes; | ||
float4 *cur_half_token = | ||
reinterpret_cast<float4 *>(reinterpret_cast<char *>(half) + offset_in_bytes); | ||
|
||
offset_in_bytes = | ||
(static_cast<size_t>(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; | ||
float4 *cur_token = | ||
reinterpret_cast<float4 *>(reinterpret_cast<char *>(tensor) + offset_in_bytes); | ||
|
||
for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { | ||
cur_half_token[idx] = cur_token[idx]; | ||
} | ||
} | ||
} | ||
|
||
} // namespace fused_attn | ||
} // namespace transformer_engine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
/************************************************************************* | ||
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* | ||
* See LICENSE for license information. | ||
************************************************************************/ | ||
|
||
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ | ||
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_ | ||
|
||
#include <cuda.h> | ||
#include <cuda_bf16.h> | ||
|
||
namespace transformer_engine { | ||
namespace fused_attn { | ||
|
||
/*************************************************************************************************** | ||
* Support THD format for Context Parallel: Binary search an array for a target value | ||
**************************************************************************************************/ | ||
|
||
__forceinline__ __device__ int binary_search(int target, int *array, int len) { | ||
int left = 1, right = len - 1; | ||
while (left < right) { | ||
int mid = (left + right) / 2; | ||
if (array[mid] <= target) { | ||
left = mid + 1; | ||
} else { | ||
right = mid; | ||
} | ||
} | ||
return left - 1; | ||
} | ||
|
||
/*************************************************************************************************** | ||
* Support THD format for Context Parallel: Generate partitioned indices for input tokens | ||
**************************************************************************************************/ | ||
|
||
__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, | ||
int total_tokens, int world_size, int rank); | ||
|
||
/*************************************************************************************************** | ||
* Support THD format for Context Parallel: Read the half of a THD tensor | ||
**************************************************************************************************/ | ||
|
||
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, | ||
int hidden_size_in_bytes, int half_idx, | ||
int dim_size_of_token); | ||
|
||
/*************************************************************************************************** | ||
* Support THD format for Context Parallel: softmax_lse related operations | ||
**************************************************************************************************/ | ||
|
||
struct LseCorrectionFunctor { | ||
__forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, | ||
size_t half_idx) { | ||
double val = lse[idx]; | ||
float val_per_step = half_lse[half_idx]; | ||
double max_scale = max(val, val_per_step); | ||
double min_scale = min(val, val_per_step); | ||
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); | ||
} | ||
}; | ||
|
||
struct ReadLseFunctor { | ||
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, | ||
size_t half_idx) { | ||
half_lse[half_idx] = lse[idx]; | ||
} | ||
}; | ||
|
||
template <typename lse_dtype, bool lse_packed, typename Functor> | ||
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, | ||
int num_heads, int total_tokens) { | ||
extern __shared__ int cu_seqlens_s[]; | ||
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { | ||
cu_seqlens_s[i] = cu_seqlens[i] / 2; | ||
} | ||
__syncthreads(); | ||
|
||
int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
int num_threads = blockDim.x * gridDim.x; | ||
int num_total_tokens = cu_seqlens_s[batch]; | ||
|
||
for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { | ||
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); | ||
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { | ||
size_t idx, half_idx; | ||
if constexpr (lse_packed) { | ||
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; | ||
half_idx = head_id * total_tokens / 2 + token_id; | ||
} else { | ||
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id; | ||
int col = token_id - cu_seqlens_s[seq_id]; | ||
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; | ||
|
||
idx = row * total_tokens + col + seq_len; | ||
half_idx = row * total_tokens / 2 + col; | ||
} | ||
|
||
Functor::run(lse, half_lse, idx, half_idx); | ||
} | ||
} | ||
} | ||
|
||
/*************************************************************************************************** | ||
* Support THD format for Context Parallel: Out correction in forward | ||
**************************************************************************************************/ | ||
|
||
template <typename dtype, int only_second_half, int tile_size, bool lse_packed> | ||
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, | ||
float *lse_per_step, int *cu_seqlens, int batch, | ||
int num_heads, int dim_per_head, int lse_seqlen) { | ||
extern __shared__ int cu_seqlens_s[]; | ||
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { | ||
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); | ||
} | ||
__syncthreads(); | ||
|
||
int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; | ||
int lane_id = threadIdx.x % tile_size; | ||
int num_tiles = (blockDim.x * gridDim.x) / tile_size; | ||
int num_total_tokens = cu_seqlens_s[batch]; | ||
int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); | ||
|
||
for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { | ||
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); | ||
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { | ||
size_t idx, idx_per_step; | ||
|
||
if constexpr (lse_packed) { | ||
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; | ||
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; | ||
} else { | ||
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id; | ||
int col = token_id - cu_seqlens_s[seq_id]; | ||
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; | ||
idx = row * lse_seqlen + col + seq_len * only_second_half; | ||
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; | ||
} | ||
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); | ||
|
||
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; | ||
idx = (idx * num_heads + head_id) * dim_per_head; | ||
idx_per_step = (static_cast<size_t>(token_id) * num_heads + head_id) * dim_per_head; | ||
dtype *cur_out = out + idx; | ||
dtype *cur_out_per_step = out_per_step + idx_per_step; | ||
|
||
for (int j = lane_id; j < num_loops_per_head; j += tile_size) { | ||
float4 data_per_step = reinterpret_cast<float4 *>(cur_out_per_step)[j]; | ||
float4 data = reinterpret_cast<float4 *>(cur_out)[j]; | ||
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step); | ||
dtype *p = reinterpret_cast<dtype *>(&data); | ||
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { | ||
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); | ||
} | ||
reinterpret_cast<float4 *>(cur_out)[j] = data; | ||
} | ||
} | ||
} | ||
} | ||
|
||
/*************************************************************************************************** | ||
* Support THD format for Context Parallel: Gradients correction in backward | ||
**************************************************************************************************/ | ||
|
||
struct EmptyFunctor { | ||
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} | ||
}; | ||
|
||
struct CopyFunctor { | ||
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { | ||
reinterpret_cast<float4 *>(token)[idx] = reinterpret_cast<float4 *>(token_per_step)[idx]; | ||
} | ||
}; | ||
|
||
template <typename dtype> | ||
struct AddFunctor { | ||
__forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { | ||
float4 d_ = reinterpret_cast<float4 *>(token)[idx]; | ||
dtype *p_ = reinterpret_cast<dtype *>(&d_); | ||
|
||
float4 d = reinterpret_cast<float4 *>(token_per_step)[idx]; | ||
dtype *p = reinterpret_cast<dtype *>(&d); | ||
|
||
#pragma unroll | ||
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { | ||
p_[i] += p[i]; | ||
} | ||
|
||
reinterpret_cast<float4 *>(token)[idx] = d_; | ||
} | ||
}; | ||
|
||
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx, int group_size> | ||
__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, | ||
int batch, int hidden_size, int dim_size_of_token) { | ||
extern __shared__ int cu_seqlens_s[]; | ||
for (int i = threadIdx.x; i <= batch; i += blockDim.x) { | ||
if constexpr (functor_idx < 2) { | ||
cu_seqlens_s[i] = cu_seqlens[i] / 2; | ||
} else { | ||
cu_seqlens_s[i] = cu_seqlens[i]; | ||
} | ||
} | ||
__syncthreads(); | ||
|
||
int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; | ||
int lane_id = threadIdx.x % group_size; | ||
int num_groups = (blockDim.x * gridDim.x) / group_size; | ||
int num_total_tokens = cu_seqlens_s[batch]; | ||
int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); | ||
|
||
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size; | ||
if constexpr (functor_idx < 2) { | ||
grad_per_step = grad_per_step + offset / 2 * blockIdx.y; | ||
} else { | ||
grad_per_step = grad_per_step + offset * blockIdx.y; | ||
} | ||
grad = grad + offset * blockIdx.y; | ||
|
||
for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { | ||
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); | ||
|
||
int token_offset; | ||
bool is_first_half; | ||
if constexpr (functor_idx < 2) { | ||
token_offset = cu_seqlens_s[seq_id + functor_idx]; | ||
is_first_half = (functor_idx == 0); | ||
} else { | ||
token_offset = 0; | ||
int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; | ||
is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); | ||
} | ||
|
||
dtype *token = &grad[(token_id + token_offset) * static_cast<size_t>(hidden_size)]; | ||
dtype *token_per_step = &grad_per_step[token_id * static_cast<size_t>(hidden_size)]; | ||
for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { | ||
if (is_first_half) { | ||
Functor_0::run(token, token_per_step, idx); | ||
} else { | ||
Functor_1::run(token, token_per_step, idx); | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace fused_attn | ||
} // namespace transformer_engine | ||
|
||
#endif |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @timmoon10