Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Common] Moved framework agnostic THD kernels to common. #1339

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
run: pip install . -v
env:
NVTE_FRAMEWORK: jax
MAX_JOBS: 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- name: 'Sanity check'
run: python tests/jax/test_sanity_import.py
paddle:
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ list(APPEND transformer_engine_SOURCES
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/thd_utils.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp
Expand Down
76 changes: 76 additions & 0 deletions transformer_engine/common/fused_attn/thd_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../cudnn_utils.h"
#include "thd_utils.h"

namespace transformer_engine {
namespace fused_attn {

__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch,
int total_tokens, int world_size, int rank) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
int seqlen = cu_seqlens[i];
// Currently we assume that each sequence length is divisible by (world_size*2) since we have
// to distribute each sequence evenly to different GPUs.
assert(seqlen % (world_size * 2) == 0);
cu_seqlens_s[i] = seqlen / world_size;
}
__syncthreads();

int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;

for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank;
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
output[token_id] = index;
}
}

__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();

int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int num_total_tokens = cu_seqlens_s[batch];
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);

size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size_in_bytes;
half = reinterpret_cast<void *>(reinterpret_cast<char *>(half) + offset / 2 * blockIdx.y);
tensor = reinterpret_cast<void *>(reinterpret_cast<char *>(tensor) + offset * blockIdx.y);

for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) {
int seqid = binary_search(token_id, cu_seqlens_s, batch + 1);

size_t offset_in_bytes = static_cast<size_t>(token_id) * hidden_size_in_bytes;
float4 *cur_half_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(half) + offset_in_bytes);

offset_in_bytes =
(static_cast<size_t>(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes;
float4 *cur_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(tensor) + offset_in_bytes);

for (int idx = laneid; idx < num_float4s_per_token; idx += 32) {
cur_half_token[idx] = cur_token[idx];
}
}
}

} // namespace fused_attn
} // namespace transformer_engine
249 changes: 249 additions & 0 deletions transformer_engine/common/fused_attn/thd_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_

#include <cuda.h>
#include <cuda_bf16.h>

namespace transformer_engine {
namespace fused_attn {

/***************************************************************************************************
* Support THD format for Context Parallel: Binary search an array for a target value
**************************************************************************************************/

__forceinline__ __device__ int binary_search(int target, int *array, int len) {
int left = 1, right = len - 1;
while (left < right) {
int mid = (left + right) / 2;
if (array[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left - 1;
}

/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/

__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch,
int total_tokens, int world_size, int rank);

/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/

__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token);

/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/

struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx,
size_t half_idx) {
double val = lse[idx];
float val_per_step = half_lse[half_idx];
double max_scale = max(val, val_per_step);
double min_scale = min(val, val_per_step);
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale));
}
};

struct ReadLseFunctor {
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) {
half_lse[half_idx] = lse[idx];
}
};

template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int total_tokens) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();

int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
int num_total_tokens = cu_seqlens_s[batch];

for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, half_idx;
if constexpr (lse_packed) {
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * total_tokens / 2 + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];

idx = row * total_tokens + col + seq_len;
half_idx = row * total_tokens / 2 + col;
}

Functor::run(lse, half_lse, idx, half_idx);
}
}
}

/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/

template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
}
__syncthreads();

int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size;
int lane_id = threadIdx.x % tile_size;
int num_tiles = (blockDim.x * gridDim.x) / tile_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4);

for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, idx_per_step;

if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);

idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx = (idx * num_heads + head_id) * dim_per_head;
idx_per_step = (static_cast<size_t>(token_id) * num_heads + head_id) * dim_per_head;
dtype *cur_out = out + idx;
dtype *cur_out_per_step = out_per_step + idx_per_step;

for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
float4 data_per_step = reinterpret_cast<float4 *>(cur_out_per_step)[j];
float4 data = reinterpret_cast<float4 *>(cur_out)[j];
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp);
}
reinterpret_cast<float4 *>(cur_out)[j] = data;
}
}
}
}

/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/

struct EmptyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {}
};

struct CopyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {
reinterpret_cast<float4 *>(token)[idx] = reinterpret_cast<float4 *>(token_per_step)[idx];
}
};

template <typename dtype>
struct AddFunctor {
__forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) {
float4 d_ = reinterpret_cast<float4 *>(token)[idx];
dtype *p_ = reinterpret_cast<dtype *>(&d_);

float4 d = reinterpret_cast<float4 *>(token_per_step)[idx];
dtype *p = reinterpret_cast<dtype *>(&d);

#pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i];
}

reinterpret_cast<float4 *>(token)[idx] = d_;
}
};

template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx, int group_size>
__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens,
int batch, int hidden_size, int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
if constexpr (functor_idx < 2) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
} else {
cu_seqlens_s[i] = cu_seqlens[i];
}
}
__syncthreads();

int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size;
int lane_id = threadIdx.x % group_size;
int num_groups = (blockDim.x * gridDim.x) / group_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4);

size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size;
if constexpr (functor_idx < 2) {
grad_per_step = grad_per_step + offset / 2 * blockIdx.y;
} else {
grad_per_step = grad_per_step + offset * blockIdx.y;
}
grad = grad + offset * blockIdx.y;

for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);

int token_offset;
bool is_first_half;
if constexpr (functor_idx < 2) {
token_offset = cu_seqlens_s[seq_id + functor_idx];
is_first_half = (functor_idx == 0);
} else {
token_offset = 0;
int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2);
}

dtype *token = &grad[(token_id + token_offset) * static_cast<size_t>(hidden_size)];
dtype *token_per_step = &grad_per_step[token_id * static_cast<size_t>(hidden_size)];
for (int idx = lane_id; idx < num_inner_loops; idx += group_size) {
if (is_first_half) {
Functor_0::run(token, token_per_step, idx);
} else {
Functor_1::run(token, token_per_step, idx);
}
}
}
}

} // namespace fused_attn
} // namespace transformer_engine

#endif
Loading
Loading