Skip to content

Commit

Permalink
optimize twoshot
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 committed Dec 31, 2024
1 parent 2f421b5 commit e87e627
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post10"
version = "0.0.2.post11"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
34 changes: 28 additions & 6 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) {
return flag;
}

static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) {
uint32_t flag;
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}

namespace trt_llm {
////////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -117,7 +127,11 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const
}

__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx, int const grid_size) {
size_t const world_size, int const tidx, int const bidx, int const grid_size,
bool start = true, bool need_fence = false) {
if (!start) {
__syncthreads();
}
// After this function, the block of id == bidx of each GPU has reached the barrier
if (tidx < world_size) {
// we can think of signals having the shape [world_size, 2, num_blocks, world_size]
Expand All @@ -131,12 +145,20 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
flag_block_offset += (grid_size + 1) * world_size;
}

st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);

if (need_fence) {
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
} else {
st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank);
}
// Blocks check that corresponding blocks on other GPUs have also set the flag
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;

while (ld_flag_acquire(peer_barrier_d) != flag) {
if (need_fence) {
while (ld_flag_acquire(peer_barrier_d) != flag) {
}
} else {
while (ld_flag_volatile(peer_barrier_d) != flag) {
}
}
}

Expand Down Expand Up @@ -217,7 +239,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) {
}

template <typename T, int RANKS_PER_NODE>
static __global__ void twoShotAllReduceKernel(AllReduceParams params) {
static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) {
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
// message
Expand Down Expand Up @@ -313,7 +335,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params) {
}

block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx,
grid_size);
grid_size, false, true);

// Gather all needed elts from other intra-node ranks
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) {
Expand Down
4 changes: 2 additions & 2 deletions sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@

namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
constexpr size_t DEFAULT_BLOCK_SIZE = 512;

enum class AllReduceStrategyType : int8_t {
RING = 0,
Expand Down

0 comments on commit e87e627

Please sign in to comment.