diff --git a/include/mori/ops/dispatch_combine/dispatch_combine.hpp b/include/mori/ops/dispatch_combine/dispatch_combine.hpp index 897d0568..f97e8ce7 100644 --- a/include/mori/ops/dispatch_combine/dispatch_combine.hpp +++ b/include/mori/ops/dispatch_combine/dispatch_combine.hpp @@ -79,6 +79,7 @@ struct EpDispatchCombineConfig { int maxNumInpTokenPerRank{128}; int numExpertPerRank{1}; int numExpertPerToken{2}; + int numWorstToken{0}; int warpNumPerBlock{1}; int blockNum{1}; // If true, use external buffer which incurs extra copy overhead; otherwise, the kernel assumes @@ -98,6 +99,9 @@ struct EpDispatchCombineConfig { } inline __host__ __device__ int MaxNumTokensToRecv() const { + if (numWorstToken != 0) { + return numWorstToken; + } return worldSize * MaxNumTokensToRecvPerRank(); } }; diff --git a/python/mori/ops/dispatch_combine.py b/python/mori/ops/dispatch_combine.py index 7167bdc0..3fe66b3f 100644 --- a/python/mori/ops/dispatch_combine.py +++ b/python/mori/ops/dispatch_combine.py @@ -45,6 +45,7 @@ class EpDispatchCombineConfig: num_experts_per_token: int warp_num_per_block: int = 8 block_num: int = 80 + num_worst_token: int = 0 use_external_inp_buf: bool = True kernel_type: EpDispatchCombineKernelType = EpDispatchCombineKernelType.IntraNode @@ -71,6 +72,7 @@ def __init__(self, config): num_experts_per_token=config.num_experts_per_token, warp_num_per_block=config.warp_num_per_block, block_num=config.block_num, + num_worst_token=config.num_worst_token, use_external_inp_buf=config.use_external_inp_buf, ) ) diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index 32400f9a..ece6c579 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -66,10 +66,12 @@ mori::application::SymmMemObjPtr ShmemMallocAndReturnMemObjPtr(size_t size, unsi void EpDispatchCombineHandle::InitializeShmemBuf() { size_t maxTokenSize = static_cast(config.MaxNumTokensToRecv()) * config.hiddenDim * config.maxTokenTypeSize; + size_t maxStagingTokSize = static_cast(config.MaxNumTokensToRecv()) * (config.hiddenDim * config.maxTokenTypeSize + (sizeof(float) + sizeof(index_t)) * config.numExpertPerToken + config.scaleDim * config.scaleTypeSize); + shmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); shmemOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached); shmemStagingTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); diff --git a/src/pybind/mori.cpp b/src/pybind/mori.cpp index d5d0cf23..6e63b457 100644 --- a/src/pybind/mori.cpp +++ b/src/pybind/mori.cpp @@ -234,13 +234,13 @@ void RegisterMoriOps(py::module_& m) { .export_values(); pybind11::class_(m, "EpDispatchCombineConfig") - .def(pybind11::init(), + .def(pybind11::init(), py::arg("rank") = 0, py::arg("world_size") = 0, py::arg("hidden_dim") = 0, py::arg("scale_dim") = 0, py::arg("scale_type_size") = 0, py::arg("max_token_type_size") = 0, py::arg("max_num_inp_token_per_rank") = 0, py::arg("num_experts_per_rank") = 0, py::arg("num_experts_per_token") = 0, - py::arg("warp_num_per_block") = 0, py::arg("block_num") = 0, - py::arg("use_external_inp_buf") = true) + py::arg("num_worst_token") = 0, py::arg("warp_num_per_block") = 0, + py::arg("block_num") = 0, py::arg("use_external_inp_buf") = true) .def_readwrite("rank", &mori::moe::EpDispatchCombineConfig::rank) .def_readwrite("world_size", &mori::moe::EpDispatchCombineConfig::worldSize) .def_readwrite("hidden_dim", &mori::moe::EpDispatchCombineConfig::hiddenDim) @@ -252,6 +252,7 @@ void RegisterMoriOps(py::module_& m) { .def_readwrite("num_experts_per_rank", &mori::moe::EpDispatchCombineConfig::numExpertPerRank) .def_readwrite("num_experts_per_token", &mori::moe::EpDispatchCombineConfig::numExpertPerToken) + .def_readwrite("num_worst_token", &mori::moe::EpDispatchCombineConfig::numWorstToken) .def_readwrite("warp_num_per_block", &mori::moe::EpDispatchCombineConfig::warpNumPerBlock) .def_readwrite("block_num", &mori::moe::EpDispatchCombineConfig::blockNum); diff --git a/tests/python/ops/bench_dispatch_combine.py b/tests/python/ops/bench_dispatch_combine.py index 6f73a664..b1118f64 100644 --- a/tests/python/ops/bench_dispatch_combine.py +++ b/tests/python/ops/bench_dispatch_combine.py @@ -189,6 +189,7 @@ def _bench_dispatch_combine( num_experts_per_rank=16, num_experts_per_token=8, ): + num_worst_token = max_num_inp_token_per_rank * world_size config = mori.ops.EpDispatchCombineConfig( data_type=data_type, rank=rank, @@ -200,6 +201,7 @@ def _bench_dispatch_combine( max_num_inp_token_per_rank=max_num_inp_token_per_rank, num_experts_per_rank=num_experts_per_rank, num_experts_per_token=num_experts_per_token, + num_worst_token=num_worst_token, warp_num_per_block=16, block_num=80, use_external_inp_buf=False,