Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support twoshot kernel #2688

Merged
merged 3 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post10"
version = "0.0.2.post11"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
206 changes: 204 additions & 2 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
return flag;
}

static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) {
uint32_t flag;
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}

namespace trt_llm {
////////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -116,6 +126,45 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
__syncthreads();
}

__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx, int const grid_size,
bool start = true, bool need_fence = false) {
if (!start) {
__syncthreads();
}
// After this function, the block of id == bidx of each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, 2, num_blocks, world_size]
// (+ an offset on dim 2 to account for flags used in multi_gpu_barrier)
// Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension

// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t flag_block_offset = world_size + bidx * world_size;

if (flag % 2 == 1) {
flag_block_offset += (grid_size + 1) * world_size;
}

if (need_fence) {
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
} else {
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
}
// Blocks check that corresponding blocks on other GPUs have also set the flag
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;

if (need_fence) {
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
} else {
while (ld_flag_volatile(peer_barrier_d) != flag) {
}
}
}

__syncthreads();
}

template <typename T, int RANKS_PER_NODE> /* COPY_INPUT = false, PUSH_MODE = false */
static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
Expand Down Expand Up @@ -189,6 +238,124 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
}
}

template <typename T, int RANKS_PER_NODE>
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
// GPU 0 | B0 | B1 | B0 | B1 |
// GPU 1 | B0 | B1 | B0 | B1 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
// part (the first half of the message, see GPU responsibility row above)
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
// where GPU 1 is responsible: the second half of the message.
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
// to be read.
//
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
// However, it's only responsible for the summation of a single chunk.
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
// 2. block sync so the blocks have been shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
// 4. block barrier (corresponding blocks have finished reduction)
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
// written at index 0 of 2nd dim)

int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
int const grid_size = gridDim.x;

// The number of elements packed into one for comms
static constexpr int PACKED_ELTS = 16 / sizeof(T);
using PackedType = typename PackedOn16Bytes<T>::Type;

T* local_shared_buffer = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[params.local_rank]);
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);

size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);

T* buffers[RANKS_PER_NODE];
int ranks[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// A mapping of the ranks to scatter reads as much as possible
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
ranks[ii] = rank;
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif

block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);

// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
size_t const responsible_block_offset = local_offset + params.rank_offset;

// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][responsible_block_offset]);
}

// Sum the values from the different ranks.
PackedType sums;
sums.packed = {0, 0, 0, 0};
#pragma unroll
for (int rank = 0; rank < RANKS_PER_NODE; ++rank) {
// Always reduce from rank 0 to ensure stable reduce order.
int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE;
sums.packed = add128b(sums, vals[ii]);
}

// Store to the local buffer.
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
}

block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size, false, true);

// Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// use round-robin gathering from other ranks
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total) {
continue;
}

*reinterpret_cast<int4*>(&local_output_buffer[offset_rank]) = *reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
}
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

////////////////////////////////////////////////////////////////////////////////////////////////////

inline int divUp(int a, int b) {
Expand All @@ -211,6 +378,33 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
params.elts_per_rank = params.elts_total;
break;
}
case AllReduceStrategyType::TWOSHOT: {
assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0);
size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE);

/*
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
*/
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
blocks_per_grid += 1;
}

threads_per_block = total_threads / blocks_per_grid;

// NOTE: need to adjust here
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
size_t iter_factor = 1;
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
iter_factor += 1;
}
blocks_per_grid /= iter_factor;
}
params.elts_per_rank = params.elts_total / params.ranks_per_node;
params.rank_offset = params.local_rank * params.elts_per_rank;
params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread);
break;
}
default:
assert(false && "Algorithm not supported here.");
}
Expand All @@ -223,7 +417,16 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
template <typename T, int RANKS_PER_NODE>
void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block,
cudaStream_t stream) {
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
switch (algo) {
case AllReduceStrategyType::ONESHOT: {
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
case AllReduceStrategyType::TWOSHOT: {
twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
break;
}
}
}

template <typename T>
Expand All @@ -233,7 +436,6 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy
CHECK_CUDA_SUCCESS(
cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream));

assert(strat == AllReduceStrategyType::ONESHOT && "Custom allreduce only support oneshot");
CHECK_CUDA_SUCCESS(cudaGetLastError());

size_t elts_per_thread = 16 / sizeof(T);
Expand Down
14 changes: 6 additions & 8 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@

namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
constexpr size_t DEFAULT_BLOCK_SIZE = 512;

enum class AllReduceStrategyType : int8_t {
RING = 0,
Expand All @@ -53,9 +53,9 @@ struct AllReduceParams {

inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
if (world_size <= 2) {
return 16 * 1000 * 1000;
return 16 * 1024 * 1024;
}
return 8 * 1000 * 1000;
return 8 * 1024 * 1024;
}

inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) {
Expand All @@ -71,17 +71,15 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world
}

if (world_size <= 4) {
if (message_size < 1 * 1000 * 1000) {
if (message_size < 1 * 1024 * 1024) {
return AllReduceStrategyType::ONESHOT;
}
assert(false && "Custom allreduce do not twoshot currently");
return AllReduceStrategyType::TWOSHOT;
}

if (message_size < 500 * 1000) {
if (message_size < 512 * 1024) {
return AllReduceStrategyType::ONESHOT;
}
assert(false && "Custom allreduce do not twoshot currently");
return AllReduceStrategyType::TWOSHOT;
}

Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size);

// should be gurantee in python code
assert(strategy == AllReduceStrategyType::ONESHOT);
assert(strategy == AllReduceStrategyType::ONESHOT || strategy == AllReduceStrategyType::TWOSHOT);
assert(CanApplyCustomAllReduce(num_elements, dtype));

// Initialize the all-reduce kernel arguments.
Expand Down
13 changes: 4 additions & 9 deletions sgl-kernel/tests/test_trt_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,8 @@ class TestCustomAllReduce(unittest.TestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
cls.test_sizes = {
2: [512, 4096, 32768, 262144, 2097152],
4: [512, 4096, 32768, 131072],
6: [512, 4096, 32768, 65536],
8: [512, 4096, 32768, 65536],
}
cls.world_sizes = [2, 4, 6, 8]
cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
cls.world_sizes = [2, 4, 8]

@staticmethod
def create_shared_buffer(
Expand Down Expand Up @@ -194,7 +189,7 @@ def correctness(self, world_size, rank, distributed_init_port):
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)

test_loop = 10
for sz in self.test_sizes[world_size]:
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(
Expand All @@ -216,7 +211,7 @@ def performance(self, world_size, rank, distributed_init_port):
self.init_vllm_allreduce(rank, group)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)

for sz in self.test_sizes[world_size]:
for sz in self.test_sizes:
inp1 = torch.randint(
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
)
Expand Down
Loading