diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 0ddf3b0ef1..54582a7877 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post9" +version = "0.0.2.post10" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index 15c6bf4710..795f9157d2 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -118,31 +118,19 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int } void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t mem_tokens_cnts = ((num_experts + 1) * num_experts) * sizeof(int32_t); - const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); - - // allocate global memory - int32_t* tokens_cnts; - int32_t* cumsum; - cudaMalloc(&tokens_cnts, mem_tokens_cnts); - cudaMalloc(&cumsum, mem_cumsum); - - // set dynamic shared mem auto kernel = moe_align_block_size_kernel; kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), - num_experts, block_size, topk_ids.numel(), tokens_cnts, cumsum); - - cudaFree(tokens_cnts); - cudaFree(cumsum); + num_experts, block_size, topk_ids.numel(), + token_cnts_buffer.data_ptr(), cumsum_buffer.data_ptr()); }); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index a620f58a58..55318879ae 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -8,6 +8,8 @@ def moe_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, ): _moe_align_block_size( topk_ids, @@ -16,4 +18,6 @@ def moe_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, ) diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 5503cea0f3..92596a47e5 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -18,8 +18,22 @@ def test_moe_align_block_size(): ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, )