diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 54582a7877..359ffafd70 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post10" +version = "0.0.2.post11" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index 0a7cfdb537..b4d17ded19 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -41,6 +41,16 @@ static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) { return flag; } +static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) { + uint32_t flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + return flag; +} + namespace trt_llm { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -117,7 +127,11 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const } __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, - size_t const world_size, int const tidx, int const bidx, int const grid_size) { + size_t const world_size, int const tidx, int const bidx, int const grid_size, + bool start = true, bool need_fence = false) { + if (!start) { + __syncthreads(); + } // After this function, the block of id == bidx of each GPU has reached the barrier if (tidx < world_size) { // we can think of signals having the shape [world_size, 2, num_blocks, world_size] @@ -131,12 +145,20 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag flag_block_offset += (grid_size + 1) * world_size; } - st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); - + if (need_fence) { + st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); + } else { + st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank); + } // Blocks check that corresponding blocks on other GPUs have also set the flag uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx; - while (ld_flag_acquire(peer_barrier_d) != flag) { + if (need_fence) { + while (ld_flag_acquire(peer_barrier_d) != flag) { + } + } else { + while (ld_flag_volatile(peer_barrier_d) != flag) { + } } } @@ -217,7 +239,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { } template -static __global__ void twoShotAllReduceKernel(AllReduceParams params) { +static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) { // Suppose that two GPUs participate in the AR exchange, and we start two blocks. // The message is partitioned into chunks as detailed below: // message @@ -313,7 +335,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params) { } block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size); + grid_size, false, true); // Gather all needed elts from other intra-node ranks for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh index e707592613..1c7c714dc4 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh @@ -25,9 +25,9 @@ namespace trt_llm { constexpr size_t WARP_SIZE = 32; -constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24; +constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36; constexpr size_t MAX_RANKS_PER_NODE = 8; -constexpr size_t DEFAULT_BLOCK_SIZE = 1024; +constexpr size_t DEFAULT_BLOCK_SIZE = 512; enum class AllReduceStrategyType : int8_t { RING = 0,