diff --git a/examples/ops/dispatch_combine/test_dispatch_combine_internode.py b/examples/ops/dispatch_combine/test_dispatch_combine_internode.py index 03dcb702..0d37a239 100644 --- a/examples/ops/dispatch_combine/test_dispatch_combine_internode.py +++ b/examples/ops/dispatch_combine/test_dispatch_combine_internode.py @@ -46,9 +46,11 @@ def __init__( # num_experts_per_rank=256 // world_size, num_experts_per_token=8, warp_num_per_block=16, - block_num=64, + block_num=32, max_token_type_size=2, - kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode, + kernel_type=mori.ops.EpDispatchCombineKernelType.InterNodeV1, + gpu_per_node=self.gpu_per_node, + rdma_block_num=16, ) def setup(self): @@ -75,7 +77,7 @@ def setup(self): self.rng = torch.Generator(device=self.device) # self.rng.manual_seed(int(time.time()) + self.rank) - self.rng.manual_seed(123) + self.rng.manual_seed(3210) def cleanup(self): mori.shmem.shmem_finalize() @@ -145,6 +147,56 @@ def gen_test_data(self, use_max_token_num=False): indices[i] = perm[: self.config.num_experts_per_token] all_rank_indices.append(indices.to(torch.int32).to(self.device)) + num_total_experts = self.config.num_experts_per_rank * self.config.world_size + num_nodes = self.config.world_size // self.config.gpu_per_node + + # Per-rank counts + rank_counts = torch.zeros( + self.config.world_size, dtype=torch.int32, device=self.device + ) + rank_counts_remote_recv = torch.zeros( + self.config.world_size, dtype=torch.int32, device=self.device + ) + rank_counts_remote_send = torch.zeros( + self.config.world_size, dtype=torch.int32, device=self.device + ) + + for src_rank, indices in enumerate(all_rank_indices): + src_node = src_rank // self.config.gpu_per_node + + # Map expert IDs to rank IDs + token_ranks = ( + indices // self.config.num_experts_per_rank + ) # [num_tokens, num_experts_per_token] + + # Deduplicate rank IDs per token + unique_ranks_per_token = [torch.unique(row) for row in token_ranks] + + # For each token, update counts + for ur in unique_ranks_per_token: + rank_counts[ur] += 1 # All ranks that receive this token + + dst_nodes = { + dst_rank // self.config.gpu_per_node for dst_rank in ur.tolist() + } + + for dst_rank in ur.tolist(): + dst_node = dst_rank // self.config.gpu_per_node + if dst_node != src_node: + # Receiving side + rank_counts_remote_recv[dst_rank] += 1 + + # Sending side (dedup by node: count once if token goes to a remote node) + for dst_node in dst_nodes: + if dst_node != src_node: + rank_counts_remote_send[src_rank] += 1 + + if self.config.rank == 0: + print("Rank counts (deduplicated):", rank_counts) + print("Rank counts local nodes:", rank_counts - rank_counts_remote_recv) + print("Rank counts from other nodes:", rank_counts_remote_recv) + # print("Rank counts to other nodes:", rank_counts_remote_send) + # even_indices = ( # torch.arange( # self.config.max_num_inp_token_per_rank @@ -226,6 +278,8 @@ def run_test_once(self, op, test_data, error_round, round): # None, all_rank_scales[self.rank], all_rank_indices[self.rank], + block_num=self.config.block_num, + warp_per_block=16, ) torch.cuda.synchronize() dist.barrier() @@ -243,8 +297,8 @@ def run_test_once(self, op, test_data, error_round, round): print( f"rank {self.rank} token {i} assert {is_pass} expected { all_rank_input[src_pe][src_tok_id]} got {dispatch_output[i]}" ) - # assert False - error_round.add(round) + assert False + # error_round.add(round) if dispatch_weights is not None: assert torch.equal( dispatch_weights[i], all_rank_weights[src_pe][src_tok_id] @@ -257,14 +311,18 @@ def run_test_once(self, op, test_data, error_round, round): if self.rank % self.gpu_per_node == 0: print(f"Node {self.rank // self.gpu_per_node} Dispatch Pass") + torch.cuda.synchronize() dist.barrier() combine_output, combine_output_weight = op.combine( dispatch_output, dispatch_weights, all_rank_indices[self.rank], + block_num=self.config.block_num, + warp_per_block=16, ) torch.cuda.synchronize() + dist.barrier() for i in range(all_rank_num_token[self.rank]): pes = [ @@ -272,55 +330,79 @@ def run_test_once(self, op, test_data, error_round, round): for idx in all_rank_indices[self.rank][i].cpu().tolist() ] unique_pes = len(set(pes)) + unique_innode_pes = len( + [ + pe + for pe in set(pes) + if (pe // self.gpu_per_node == self.rank // self.gpu_per_node) + ] + ) + final_unique_pes = unique_pes + # print( + # self.rank, + # f"token {i} pes {pes} unique pes {unique_pes} unique innode pes {unique_innode_pes}", + # ) + if final_unique_pes == 0: + continue got, expected = combine_output[i], ( - all_rank_input[self.rank][i].to(torch.float32) * unique_pes + all_rank_input[self.rank][i].to(torch.float32) * final_unique_pes ).to(self.config.data_type) ok = torch.allclose(got.float(), expected.float(), atol=1e-2, rtol=1e-2) if not ok: - print(self.rank, "got: ", got) - print(self.rank, "expected: ", expected) - print(self.rank, "delta:", got - expected) - assert False + print( + self.rank, + f"token {i} pes {pes} unique pes {unique_pes} unique innode pes {unique_innode_pes}", + ) + # print(self.rank, "got: ", got) + # print(self.rank, "expected: ", expected, all_rank_input[self.rank][i]) + delta = got - expected + # print(self.rank, i, "delta:", delta.nonzero()) + # assert False error_round.add(round) - if dispatch_weights is not None: - got_weight, expected_weight = ( - combine_output_weight[i], - all_rank_weights[self.rank][i] * unique_pes, - ) - weight_match = torch.allclose( - got_weight, expected_weight, atol=1e-5, rtol=1e-5 - ) - if not weight_match and self.config.rank == 0: - print(f"Weight mismatch for token {i}:") - print( - f" indices[{i}]: {all_rank_indices[self.rank][i].cpu().tolist()}" - ) - print(f" pes: {pes}") - print(f" unique_pes: {unique_pes}") - print(f" got_weight: {got_weight}") - print( - f" expected_weight (weights[{i}] * {unique_pes}): {expected_weight}" - ) - print(f" original weights[{i}]: {all_rank_weights[self.rank][i]}") - print(f" diff: {torch.abs(got_weight - expected_weight)}") - print( - f" max_diff: {torch.abs(got_weight - expected_weight).max()}" - ) - assert weight_match, f"Weight assertion failed for token {i}" + if len(error_round) > 0: + assert False + + # if dispatch_weights is not None: + # got_weight, expected_weight = ( + # combine_output_weight[i], + # all_rank_weights[self.rank][i] * unique_pes, + # ) + # weight_match = torch.allclose( + # got_weight, expected_weight, atol=1e-5, rtol=1e-5 + # ) + # if not weight_match and self.config.rank == 0: + # print(f"Weight mismatch for token {i}:") + # print( + # f" indices[{i}]: {all_rank_indices[self.rank][i].cpu().tolist()}" + # ) + # print(f" pes: {pes}") + # print(f" unique_pes: {unique_pes}") + # print(f" got_weight: {got_weight}") + # print( + # f" expected_weight (weights[{i}] * {unique_pes}): {expected_weight}" + # ) + # print(f" original weights[{i}]: {all_rank_weights[self.rank][i]}") + # print(f" diff: {torch.abs(got_weight - expected_weight)}") + # print( + # f" max_diff: {torch.abs(got_weight - expected_weight).max()}" + # ) + # assert weight_match, f"Weight assertion failed for token {i}" if self.rank % self.gpu_per_node == 0: print(f"Node {self.rank // self.gpu_per_node} Combine Pass") def test_dispatch_combine(self): - op = mori.ops.EpDispatchCombineOp(self.config) error_round = set() - for i in range(500): + for i in range(5000): + if i < 1: + continue + op = mori.ops.EpDispatchCombineOp(self.config) if self.rank == 0: print(f"Round {i} begin") - test_data = self.gen_test_data() + test_data = self.gen_test_data(use_max_token_num=False) if self.rank == 0: print(f"Round {i} gen test_data done") self.run_test_once(op, test_data, error_round, i) @@ -347,6 +429,8 @@ def run_bench_once(self, op, test_data): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) + if self.rank == 0: + print("dispatch") torch.cuda.synchronize() dist.barrier() start_event.record() @@ -361,6 +445,8 @@ def run_bench_once(self, op, test_data): all_rank_weights[self.rank], all_rank_scales[self.rank], all_rank_indices[self.rank], + block_num=self.config.block_num, + warp_per_block=16, ) end_event.record() torch.cuda.synchronize() @@ -378,14 +464,24 @@ def run_bench_once(self, op, test_data): src_node = src_pe // self.gpu_per_node if src_node != my_node: total_rdma_recv_num_token += 1 + if self.config.kernel_type is mori.ops.EpDispatchCombineKernelType.InterNodeV1: + total_rdma_recv_num_token = ( + self.config.max_num_inp_token_per_rank * self.config.world_size // 8 + ) print( f"rank {self.rank} recv {total_recv_num_token} tokens {total_rdma_recv_num_token} rdma tokens" ) element_size = all_rank_input[self.rank].element_size() total_bytes = total_recv_num_token * self.config.hidden_dim * element_size + total_rdma_bytes = ( + total_rdma_recv_num_token * self.config.hidden_dim * element_size + ) + disp_rdma_bandwidth = total_rdma_bytes / (1000**3) / (disp_duration / (10**3)) disp_bandwidth = total_bytes / (1000**3) / (disp_duration / (10**3)) + if self.rank == 0: + print("combine") torch.cuda.synchronize() dist.barrier() start_event.record() @@ -393,34 +489,39 @@ def run_bench_once(self, op, test_data): dispatch_output, None, all_rank_indices[self.rank], + block_num=self.config.block_num, + warp_per_block=16, ) end_event.record() torch.cuda.synchronize() comb_duration = start_event.elapsed_time(end_event) + comb_rdma_bandwidth = total_rdma_bytes / (1000**3) / (comb_duration / (10**3)) comb_bandwidth = total_bytes / (1000**3) / (comb_duration / (10**3)) op.reset() torch.cuda.synchronize() - return disp_duration, disp_bandwidth, comb_duration, comb_bandwidth + return ( + disp_duration, + disp_rdma_bandwidth, + disp_bandwidth, + comb_duration, + comb_rdma_bandwidth, + comb_bandwidth, + ) def bench_dispatch_combine(self): op = mori.ops.EpDispatchCombineOp(self.config) test_data = self.gen_test_data(use_max_token_num=True) disp_duration_us_list = [] + disp_rdma_bandwidth_GB_list = [] disp_bandwidth_GB_list = [] comb_duration_us_list = [] + comb_rdma_bandwidth_GB_list = [] comb_bandwidth_GB_list = [] - # for i in range(10): - # if self.rank == 0: - # print(f"WarmUp Round {i} begin") - # _, _, _, _ = ( - # self.run_bench_once(op, test_data) - # ) - error_round = set() - for i in range(10): + for i in range(1): if self.rank == 0: print(f"WarmUp Round {i} begin") self.run_test_once(op, test_data, error_round, i) @@ -428,60 +529,103 @@ def bench_dispatch_combine(self): len(error_round) == 0 ), f"Warmup failed with errors in rounds: {error_round}" - for i in range(50): + for i in range(20): if self.rank == 0: print(f"Round {i} begin") - disp_duration, disp_bandwidth, comb_duration, comb_bandwidth = ( - self.run_bench_once(op, test_data) - ) + ( + disp_duration, + disp_rdma_bandwidth, + disp_bandwidth, + comb_duration, + comb_rdma_bandwidth, + comb_bandwidth, + ) = self.run_bench_once(op, test_data) disp_duration_output = [torch.zeros(1) for _ in range(self.world_size)] + disp_rdma_bandwidth_output = [ + torch.zeros(1) for _ in range(self.world_size) + ] disp_bandwidth_output = [torch.zeros(1) for _ in range(self.world_size)] comb_duration_output = [torch.zeros(1) for _ in range(self.world_size)] + comb_rdma_bandwidth_output = [ + torch.zeros(1) for _ in range(self.world_size) + ] comb_bandwidth_output = [torch.zeros(1) for _ in range(self.world_size)] dist.all_gather(disp_duration_output, torch.tensor([disp_duration * 1000])) + dist.all_gather( + disp_rdma_bandwidth_output, torch.tensor([disp_rdma_bandwidth]) + ) dist.all_gather(disp_bandwidth_output, torch.tensor([disp_bandwidth])) dist.all_gather(comb_duration_output, torch.tensor([comb_duration * 1000])) + dist.all_gather( + comb_rdma_bandwidth_output, torch.tensor([comb_rdma_bandwidth]) + ) dist.all_gather(comb_bandwidth_output, torch.tensor([comb_bandwidth])) disp_duration_us_list.append([int(t.item()) for t in disp_duration_output]) + disp_rdma_bandwidth_GB_list.append( + [int(t.item()) for t in disp_rdma_bandwidth_output] + ) disp_bandwidth_GB_list.append( [int(t.item()) for t in disp_bandwidth_output] ) comb_duration_us_list.append([int(t.item()) for t in comb_duration_output]) + comb_rdma_bandwidth_GB_list.append( + [int(t.item()) for t in comb_rdma_bandwidth_output] + ) comb_bandwidth_GB_list.append( [int(t.item()) for t in comb_bandwidth_output] ) if self.rank == 0: for i in range(len(disp_duration_us_list)): + print(f"Round {i}") + print( + f" dispatch duration {disp_duration_us_list[i]} avg {sum(disp_duration_us_list[i]) / self.config.world_size:.2f} µs" + ) + print( + f" rdma bandwidth {disp_rdma_bandwidth_GB_list[i]} avg {sum(disp_rdma_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s" + ) print( - f"Round {i} dispatch duration {disp_duration_us_list[i]} " - f"bandwidth {disp_bandwidth_GB_list[i]} " - f"avg {sum(disp_duration_us_list[i]) / self.config.world_size:.2f} µs " - f"avg {sum(disp_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s" + f" bandwidth {disp_bandwidth_GB_list[i]} avg {sum(disp_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s" ) for i in range(len(comb_duration_us_list)): + print(f"Round {i}") print( - f"Round {i} combine duration {comb_duration_us_list[i]} " - f"bandwidth {comb_bandwidth_GB_list[i]} " - f"avg {sum(comb_duration_us_list[i]) / self.config.world_size:.2f} µs " - f"avg {sum(comb_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s" + f" combine duration {comb_duration_us_list[i]} avg {sum(comb_duration_us_list[i]) / self.config.world_size:.2f} µs" + ) + print( + f" rdma bandwidth {comb_rdma_bandwidth_GB_list[i]} avg {sum(comb_rdma_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s" + ) + print( + f" bandwidth {comb_bandwidth_GB_list[i]} avg {sum(comb_bandwidth_GB_list[i]) / self.config.world_size:.2f} GB/s" ) disp_bandwidth_GB_list = disp_bandwidth_GB_list[0:] avg_disp_bw_per_round = [ (sum(round_bw) / len(round_bw)) for round_bw in disp_bandwidth_GB_list ] + avg_disp_rdma_bw_per_round = [ + (sum(round_bw) / len(round_bw)) for round_bw in disp_rdma_bandwidth_GB_list + ] avg_disp_bw = sum(avg_disp_bw_per_round) / len(avg_disp_bw_per_round) + avg_disp_rdma_bw = sum(avg_disp_rdma_bw_per_round) / len( + avg_disp_rdma_bw_per_round + ) comb_bandwidth_GB_list = comb_bandwidth_GB_list[0:] avg_comb_bw_per_round = [ (sum(round_bw) / len(round_bw)) for round_bw in comb_bandwidth_GB_list ] + avg_comb_rdma_bw_per_round = [ + (sum(round_bw) / len(round_bw)) for round_bw in comb_rdma_bandwidth_GB_list + ] avg_comb_bw = sum(avg_comb_bw_per_round) / len(avg_comb_bw_per_round) + avg_comb_rdma_bw = sum(avg_comb_rdma_bw_per_round) / len( + avg_comb_rdma_bw_per_round + ) disp_duration_us_list = disp_duration_us_list[0:] avg_disp_lat_per_round = [ @@ -498,16 +642,18 @@ def bench_dispatch_combine(self): avg_comb_lat = sum(avg_comb_lat_per_round) / len(avg_comb_lat_per_round) best_disp_bw = max(avg_disp_bw_per_round) + best_disp_rdma_bw = max(avg_disp_rdma_bw_per_round) best_comb_bw = max(avg_comb_bw_per_round) + best_comb_rdma_bw = max(avg_comb_rdma_bw_per_round) best_disp_lat = min(avg_disp_lat_per_round) best_comb_lat = min(avg_comb_lat_per_round) if self.rank == 0: print( - f"dispatch: best/avg bandwidth {best_disp_bw:.2f} / {avg_disp_bw:.2f} GB/s | " + f"dispatch: best/avg RDMA bandwidth {best_disp_rdma_bw:.2f} / {avg_disp_rdma_bw:.2f} XGMI bandwidth {best_disp_bw:.2f} / {avg_disp_bw:.2f} GB/s | " f"best/avg latency {best_disp_lat:.2f} / {avg_disp_lat:.2f} µs\n" - f"combine : best/avg bandwidth {best_comb_bw:.2f} / {avg_comb_bw:.2f} GB/s | " + f"combine: best/avg RDMA bandwidth {best_comb_rdma_bw:.2f} / {avg_comb_rdma_bw:.2f} XGMI bandwidth {best_comb_bw:.2f} / {avg_comb_bw:.2f} GB/s | " f"best/avg latency {best_comb_lat:.2f} / {avg_comb_lat:.2f} µs" ) del op @@ -526,6 +672,7 @@ def test_dispatch_combine( world_size, max_tokens, torch.bfloat16, # torch.float8_e4m3fnuz + # torch.float8_e4m3fnuz, ) test_case.setup() if is_bench: diff --git a/include/mori/application/bootstrap/torch_bootstrap.hpp b/include/mori/application/bootstrap/torch_bootstrap.hpp index 1cff551b..3474e7d6 100644 --- a/include/mori/application/bootstrap/torch_bootstrap.hpp +++ b/include/mori/application/bootstrap/torch_bootstrap.hpp @@ -21,8 +21,6 @@ // SOFTWARE. #pragma once -#include - #include "mori/application/bootstrap/base_bootstrap.hpp" namespace mori { @@ -41,7 +39,7 @@ class TorchBootstrapNetwork : public BootstrapNetwork { void Barrier(); private: - c10::intrusive_ptr group; + std::string groupName; }; } // namespace application diff --git a/include/mori/core/core.hpp b/include/mori/core/core.hpp index 4eb367aa..90fff535 100644 --- a/include/mori/core/core.hpp +++ b/include/mori/core/core.hpp @@ -21,7 +21,6 @@ // SOFTWARE. #pragma once -#include "mori/core/lock.hpp" #include "mori/core/transport/p2p/p2p.hpp" #include "mori/core/transport/rdma/rdma.hpp" #include "mori/core/utils.hpp" diff --git a/include/mori/core/transport/p2p/device_primitives.hpp b/include/mori/core/transport/p2p/device_primitives.hpp index 00fc79d0..e058e151 100644 --- a/include/mori/core/transport/p2p/device_primitives.hpp +++ b/include/mori/core/transport/p2p/device_primitives.hpp @@ -205,7 +205,8 @@ inline __device__ void ThreadCopy(T* dst, T* src, size_t nelems) { } template -inline __device__ void WarpCopyImpl(T* dst, const T* src, size_t& offset, size_t nelems) { +inline __device__ void WarpCopyImpl(T* __restrict__ dst, const T* __restrict__ src, size_t& offset, + size_t nelems) { constexpr int VecBytes = 16; constexpr int vecSize = VecBytes / sizeof(T); int laneId = threadIdx.x & (warpSize - 1); @@ -230,7 +231,7 @@ inline __device__ void WarpCopyImpl(T* dst, const T* src, size_t& offset, size_t } template -inline __device__ void WarpCopy(T* dst, const T* src, size_t nelems) { +inline __device__ void WarpCopy(T* __restrict__ dst, const T* __restrict__ src, size_t nelems) { int laneId = threadIdx.x & (warpSize - 1); size_t offset = 0; diff --git a/include/mori/core/transport/rdma/providers/mlx5/mlx5_device_primitives.hpp b/include/mori/core/transport/rdma/providers/mlx5/mlx5_device_primitives.hpp index 2e79a1b8..94292da6 100644 --- a/include/mori/core/transport/rdma/providers/mlx5/mlx5_device_primitives.hpp +++ b/include/mori/core/transport/rdma/providers/mlx5/mlx5_device_primitives.hpp @@ -230,6 +230,9 @@ inline __device__ uint64_t Mlx5PostWriteInline(WorkQueueHandle& wq, uint32_t cur if (bytes == 4) { AtomicStoreRelaxed(reinterpret_cast(wqeDataPtr), reinterpret_cast(val)[0]); + } else if (bytes == 8) { + AtomicStoreRelaxed(reinterpret_cast(wqeDataPtr), + reinterpret_cast(val)[0]); } else { for (int i = 0; i < bytes; i++) { AtomicStoreRelaxed(reinterpret_cast(wqeDataPtr) + i, diff --git a/include/mori/core/utils.hpp b/include/mori/core/utils.hpp index dbf20771..2f48764f 100644 --- a/include/mori/core/utils.hpp +++ b/include/mori/core/utils.hpp @@ -56,30 +56,28 @@ inline __device__ int FlatBlockWarpId() { return FlatBlockThreadId() / DeviceWar inline __device__ int WarpLaneId() { return FlatBlockThreadId() & (DeviceWarpSize() - 1); } -inline __device__ bool IsThreadZeroInBlock() { - return (FlatBlockThreadId() % DeviceWarpSize()) == 0; -} - -inline __device__ uint64_t GetActiveLaneMask() { - return __ballot(true); -} - -inline __device__ unsigned int GetActiveLaneCount(uint64_t activeLaneMask) { - return __popcll(activeLaneMask); -} - -inline __device__ unsigned int GetActiveLaneCount() { - return GetActiveLaneCount(GetActiveLaneMask()); -} - -inline __device__ unsigned int GetActiveLaneNum(uint64_t activeLaneMask) { - return __popcll(activeLaneMask & __lanemask_lt()); -} - -inline __device__ unsigned int GetActiveLaneNum() { - return GetActiveLaneNum(GetActiveLaneMask()); +inline __device__ int WarpLaneId1D() { return threadIdx.x & (warpSize - 1); } + +inline __device__ bool IsThreadZeroInBlock() { + return (FlatBlockThreadId() % DeviceWarpSize()) == 0; +} + +inline __device__ uint64_t GetActiveLaneMask() { return __ballot(true); } + +inline __device__ unsigned int GetActiveLaneCount(uint64_t activeLaneMask) { + return __popcll(activeLaneMask); +} + +inline __device__ unsigned int GetActiveLaneCount() { + return GetActiveLaneCount(GetActiveLaneMask()); +} + +inline __device__ unsigned int GetActiveLaneNum(uint64_t activeLaneMask) { + return __popcll(activeLaneMask & __lanemask_lt()); } +inline __device__ unsigned int GetActiveLaneNum() { return GetActiveLaneNum(GetActiveLaneMask()); } + inline __device__ int GetFirstActiveLaneID(uint64_t activeLaneMask) { return activeLaneMask ? __ffsll((unsigned long long int)activeLaneMask) - 1 : -1; } @@ -92,21 +90,17 @@ inline __device__ int GetLastActiveLaneID(uint64_t activeLaneMask) { inline __device__ int GetLastActiveLaneID() { return GetLastActiveLaneID(GetActiveLaneMask()); } -inline __device__ bool IsFirstActiveLane(uint64_t activeLaneMask) { - return GetActiveLaneNum(activeLaneMask) == 0; -} - -inline __device__ bool IsFirstActiveLane() { - return IsFirstActiveLane(GetActiveLaneMask()); -} - -inline __device__ bool IsLastActiveLane(uint64_t activeLaneMask) { - return GetActiveLaneNum(activeLaneMask) == GetActiveLaneCount(activeLaneMask) - 1; -} - -inline __device__ bool IsLastActiveLane() { - return IsLastActiveLane(GetActiveLaneMask()); -} +inline __device__ bool IsFirstActiveLane(uint64_t activeLaneMask) { + return GetActiveLaneNum(activeLaneMask) == 0; +} + +inline __device__ bool IsFirstActiveLane() { return IsFirstActiveLane(GetActiveLaneMask()); } + +inline __device__ bool IsLastActiveLane(uint64_t activeLaneMask) { + return GetActiveLaneNum(activeLaneMask) == GetActiveLaneCount(activeLaneMask) - 1; +} + +inline __device__ bool IsLastActiveLane() { return IsLastActiveLane(GetActiveLaneMask()); } /* ---------------------------------------------------------------------------------------------- */ /* Atomic Operations */ diff --git a/include/mori/ops/dispatch_combine/dispatch_combine.hpp b/include/mori/ops/dispatch_combine/dispatch_combine.hpp index 0dd55622..6499c6ce 100644 --- a/include/mori/ops/dispatch_combine/dispatch_combine.hpp +++ b/include/mori/ops/dispatch_combine/dispatch_combine.hpp @@ -36,6 +36,7 @@ namespace moe { enum KernelType { IntraNode = 0, InterNode = 1, + InterNodeV1 = 2, }; inline const char* HipDataTypeToString(hipDataType dtype) { @@ -84,6 +85,8 @@ struct EpDispatchCombineConfig { // If true, use external buffer which incurs extra copy overhead; otherwise, the kernel assumes // the provided buffer is shmemInpTokMemObj bool useExternalInpBuffer{true}; + int gpuPerNode{8}; + int rdmaBlockNum{1}; inline __host__ __device__ int MaxNumTokensToSendPerRank() const { return maxNumInpTokenPerRank; } @@ -217,6 +220,22 @@ class EpDispatchCombineHandle { index_t* totalRecvTokenNum{nullptr}; mori::application::SymmMemObjPtr crossDeviceBarrierMemObj; uint32_t* crossDeviceBarrierFlag{nullptr}; + + // Inter-node v1 kernel parameters + // Signal the completion of inter-node token transfer + mori::application::SymmMemObjPtr interNodeChunkFlagMemObj; + // Signal the number of tokens transfered from other nodes + mori::application::SymmMemObjPtr nodeRecvTokenNumMemObj; + // Count the number of tokens sent to other nodes + index_t* destNodeTokenCounter{nullptr}; + // Counter that is used to sort the ordering of inter-node token chunk transfer + index_t* blockFlagCounter{nullptr}; + // + uint32_t* interNodeBlocksBarrier{nullptr}; + + index_t* interNodeDispDestTokIdMap{nullptr}; + index_t* interNodeChunkFlagCombine{nullptr}; + index_t* interNodeDispSendMap{nullptr}; }; template @@ -256,6 +275,14 @@ struct EpDispatchCombineArgs { index_t* totalRecvTokenNum{nullptr}; mori::application::SymmMemObjPtr crossDeviceBarrierMemObj; uint32_t* crossDeviceBarrierFlag{nullptr}; + mori::application::SymmMemObjPtr interNodeChunkFlagMemObj; + index_t* destNodeTokenCounter{nullptr}; + mori::application::SymmMemObjPtr nodeRecvTokenNumMemObj; + index_t* blockFlagCounter{nullptr}; + uint32_t* interNodeBlocksBarrier{nullptr}; + index_t* interNodeDispDestTokIdMap{nullptr}; + index_t* interNodeChunkFlagCombine{nullptr}; + index_t* interNodeDispSendMap{nullptr}; }; using EpDispatchCombineArgsVariant = @@ -299,6 +326,14 @@ EpDispatchCombineArgs GetEpDispatchCombineArgs(const EpDispatchCombineHandle& args.totalRecvTokenNum = handle.totalRecvTokenNum; args.crossDeviceBarrierMemObj = handle.crossDeviceBarrierMemObj; args.crossDeviceBarrierFlag = handle.crossDeviceBarrierFlag; + args.interNodeChunkFlagMemObj = handle.interNodeChunkFlagMemObj; + args.destNodeTokenCounter = handle.destNodeTokenCounter; + args.nodeRecvTokenNumMemObj = handle.nodeRecvTokenNumMemObj; + args.blockFlagCounter = handle.blockFlagCounter; + args.interNodeBlocksBarrier = handle.interNodeBlocksBarrier; + args.interNodeDispDestTokIdMap = handle.interNodeDispDestTokIdMap; + args.interNodeChunkFlagCombine = handle.interNodeChunkFlagCombine; + args.interNodeDispSendMap = handle.interNodeDispSendMap; return args; } diff --git a/include/mori/shmem/shmem_ibgda_kernels.hpp b/include/mori/shmem/shmem_ibgda_kernels.hpp index 118e6c9d..50c4cd0b 100644 --- a/include/mori/shmem/shmem_ibgda_kernels.hpp +++ b/include/mori/shmem/shmem_ibgda_kernels.hpp @@ -369,6 +369,139 @@ inline __device__ void ShmemPutMemNbiThreadKernelImpl(const application::SymmMem // __threadfence_system(); } +template +inline __device__ void ShmemPutMemNbiThreadKernelImpl( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signal, size_t signalOffset, void* signalVal, + size_t signalBytes, const application::SymmMemObjPtr atomic, int pe) { + if (bytes == 0) return; + uintptr_t laddr = source.addr + sourceOffset; + uintptr_t raddr = dest->peerPtrs[pe] + destOffset; + uintptr_t rkey = dest->peerRkeys[pe]; + + GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); + application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; + core::WorkQueueHandle* wq = &ep[pe].wqHandle; + core::CompletionQueueHandle* cq = &ep[pe].cqHandle; + + uint64_t activemask = core::GetActiveLaneMask(); + uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); + uint8_t my_logical_lane_id = core::GetActiveLaneNum(activemask); + bool is_leader{my_logical_lane_id == num_active_lanes - 1}; + const uint64_t leader_phys_lane_id = core::GetLastActiveLaneID(activemask); + uint8_t num_wqes{num_active_lanes}; + num_wqes += 1; + uint32_t warp_sq_counter{0}; + uint32_t warp_msntbl_counter{0}, warp_psn_counter{0}; + uint32_t my_sq_counter{0}, my_msntbl_counter{0}, my_psn_counter{0}; + + if (is_leader) { + if constexpr (PrvdType == core::ProviderType::MLX5) { + warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_wqes, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + uint32_t psnCnt = (bytes + wq->mtuSize - 1) / wq->mtuSize; + atomic_add_packed_msn_and_psn(&wq->msnPack, num_wqes, psnCnt * num_wqes, &warp_msntbl_counter, + &warp_psn_counter); + // TODO: if warp_msntbl_counter overflow 32bit, sq_slot's calculation will be wrong + warp_sq_counter = warp_msntbl_counter * BNXT_RE_NUM_SLOT_PER_WQE; + __hip_atomic_fetch_max(&wq->postIdx, + (warp_msntbl_counter + num_wqes) * BNXT_RE_NUM_SLOT_PER_WQE, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } + } + warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); + if constexpr (PrvdType == core::ProviderType::MLX5) { + my_sq_counter = warp_sq_counter + my_logical_lane_id; + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + my_sq_counter = warp_sq_counter + my_logical_lane_id * BNXT_RE_NUM_SLOT_PER_WQE; + warp_msntbl_counter = __shfl(warp_msntbl_counter, leader_phys_lane_id); + warp_psn_counter = __shfl(warp_psn_counter, leader_phys_lane_id); + my_msntbl_counter = warp_msntbl_counter + my_logical_lane_id; + my_psn_counter = warp_psn_counter + my_logical_lane_id; + } else { + assert(false); + } + + while (true) { + uint64_t db_touched = + __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + uint64_t db_done = __hip_atomic_load(&wq->doneIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + uint64_t num_active_sq_entries = db_touched - db_done; + uint64_t num_free_entries = wq->sqWqeNum - num_active_sq_entries; + uint64_t num_entries_until_warp_last_entry; + if constexpr (PrvdType == core::ProviderType::MLX5) { + num_entries_until_warp_last_entry = warp_sq_counter + num_wqes - db_touched; + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + num_entries_until_warp_last_entry = + warp_sq_counter + num_active_lanes * BNXT_RE_NUM_SLOT_PER_WQE - db_touched; + } else { + assert(false); + } + if (num_free_entries > num_entries_until_warp_last_entry) { + break; + } + ShmemQuietThreadKernelImpl(pe); + } + uint64_t dbr_val; + if constexpr (PrvdType == core::ProviderType::MLX5) { + wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; + dbr_val = core::PostWrite(*wq, my_sq_counter, my_sq_counter, my_sq_counter, false, + ep[pe].handle.qpn, laddr, source.lkey, raddr, rkey, bytes); + if (is_leader) { + // int rank = GetGlobalGpuStatesPtr()->rank; + // uint64_t atomic_sq_counter = my_sq_counter + 1; + // uintptr_t atomic_raddr = atomic->peerPtrs[pe]; + // uintptr_t atomic_rkey = atomic->peerRkeys[pe]; + // wq->outstandingWqe[atomic_sq_counter % OUTSTANDING_TABLE_SIZE] = atomic_sq_counter; + // uint64_t val = 0; + // dbr_val = core::PostAtomic(*wq, atomic_sq_counter, atomic_sq_counter, + // atomic_sq_counter, false, ep[pe].handle.qpn, + // atomic->GetRdmaMemoryRegion(rank).addr, + // atomic->GetRdmaMemoryRegion(rank).lkey, atomic_raddr, + // atomic_rkey, &val, &val, 8, core::AMO_ADD); + // uint64_t signal_sq_counter = warp_sq_counter + num_wqes - 1; + uint64_t signal_sq_counter = my_sq_counter + 1; + uintptr_t signal_raddr = signal->peerPtrs[pe] + signalOffset; + uintptr_t signal_rkey = signal->peerRkeys[pe]; + wq->outstandingWqe[signal_sq_counter % OUTSTANDING_TABLE_SIZE] = signal_sq_counter; + dbr_val = core::PostWriteInline(*wq, signal_sq_counter, signal_sq_counter, + signal_sq_counter, is_leader, ep[pe].handle.qpn, + signalVal, signal_raddr, signal_rkey, signalBytes); + } + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + wq->outstandingWqe[my_sq_counter % wq->sqWqeNum] = my_sq_counter; + dbr_val = + core::PostWrite(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, is_leader, + ep[pe].handle.qpn, laddr, source.lkey, raddr, rkey, bytes); + } else { + assert(false); + } + __threadfence_system(); + if (is_leader) { + uint64_t db_touched{0}; + do { + db_touched = __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } while (db_touched != warp_sq_counter); + + core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_wqes); + // __threadfence_system(); + core::RingDoorbell(wq->dbrAddr, dbr_val); + __threadfence_system(); + + __hip_atomic_fetch_add(&cq->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + if constexpr (PrvdType == core::ProviderType::MLX5) { + __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes * BNXT_RE_NUM_SLOT_PER_WQE, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } + } + // __threadfence_system(); +} + template <> inline __device__ void ShmemPutMemNbiThreadKernel( const application::SymmMemObjPtr dest, size_t destOffset, diff --git a/python/mori/ops/dispatch_combine.py b/python/mori/ops/dispatch_combine.py index c5bad46e..4d9b52b2 100644 --- a/python/mori/ops/dispatch_combine.py +++ b/python/mori/ops/dispatch_combine.py @@ -47,6 +47,8 @@ class EpDispatchCombineConfig: block_num: int = 80 use_external_inp_buf: bool = True kernel_type: EpDispatchCombineKernelType = EpDispatchCombineKernelType.IntraNode + gpu_per_node: int = 8 + rdma_block_num: int = 0 def _cpp_dispatch_combine_factory(entity_name): @@ -72,6 +74,8 @@ def __init__(self, config): warp_num_per_block=config.warp_num_per_block, block_num=config.block_num, use_external_inp_buf=config.use_external_inp_buf, + gpu_per_node=config.gpu_per_node, + rdma_block_num=config.rdma_block_num, ) ) @@ -176,7 +180,10 @@ def _allgather_with_token_num_padding(self, input, max_token_num): def get_dispatch_src_token_pos(self): torch.cuda.synchronize() - if self.config.kernel_type.value == EpDispatchCombineKernelType.IntraNode.value: + if self.config.kernel_type.value in ( + EpDispatchCombineKernelType.IntraNode.value, + EpDispatchCombineKernelType.InterNodeV1.value, + ): return self._get_dispatch_src_token_pos_func(self._handle) dispatch_sender_token_id_map = self._get_dispatch_sender_token_idx_map_func( diff --git a/src/application/bootstrap/torch_bootstrap.cpp b/src/application/bootstrap/torch_bootstrap.cpp index e125ec8f..0b6e8ec9 100644 --- a/src/application/bootstrap/torch_bootstrap.cpp +++ b/src/application/bootstrap/torch_bootstrap.cpp @@ -30,13 +30,12 @@ namespace mori { namespace application { -TorchBootstrapNetwork::TorchBootstrapNetwork(const std::string& groupName) { - this->group = c10d::resolve_process_group(groupName); -} +TorchBootstrapNetwork::TorchBootstrapNetwork(const std::string& name) : groupName(name) {} TorchBootstrapNetwork::~TorchBootstrapNetwork() { Finalize(); } void TorchBootstrapNetwork::Initialize() { + c10::intrusive_ptr group = c10d::resolve_process_group(groupName); this->worldSize = group->getSize(); this->localRank = group->getRank(); } @@ -44,6 +43,8 @@ void TorchBootstrapNetwork::Initialize() { void TorchBootstrapNetwork::Finalize() {} void TorchBootstrapNetwork::Allgather(void* sendbuf, void* recvbuf, size_t sendcount) { + c10::intrusive_ptr group = c10d::resolve_process_group(groupName); + std::vector inputTensors = { at::from_blob(sendbuf, {1, (int)sendcount}, at::TensorOptions().dtype(at::kByte))}; @@ -56,6 +57,8 @@ void TorchBootstrapNetwork::Allgather(void* sendbuf, void* recvbuf, size_t sendc } void TorchBootstrapNetwork::AllToAll(void* sendbuf, void* recvbuf, size_t sendcount) { + c10::intrusive_ptr group = c10d::resolve_process_group(groupName); + at::Tensor inputTensor = at::from_blob(sendbuf, {worldSize, (int)sendcount}, at::TensorOptions().dtype(at::kByte)); @@ -70,6 +73,8 @@ void TorchBootstrapNetwork::AllToAll(void* sendbuf, void* recvbuf, size_t sendco } void TorchBootstrapNetwork::Barrier() { + c10::intrusive_ptr group = c10d::resolve_process_group(groupName); + auto work = group->barrier(); work->wait(); } diff --git a/src/ops/CMakeLists.txt b/src/ops/CMakeLists.txt index 489fa397..74a1dc42 100644 --- a/src/ops/CMakeLists.txt +++ b/src/ops/CMakeLists.txt @@ -1,4 +1,6 @@ -add_library(mori_ops dispatch_combine/dispatch_combine.cpp) +add_library(mori_ops dispatch_combine/dispatch_combine.cpp + dispatch_combine/internode_v1.cpp) +target_include_directories(mori_ops PUBLIC ${CMAKE_SOURCE_DIR}/) target_link_libraries(mori_ops mori_application mori_shmem mori_logging hip::host hip::device) target_include_directories(mori_ops PUBLIC ${CMAKE_SOURCE_DIR}/include) diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index f7b2572b..42479167 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -29,6 +29,7 @@ #include "mori/core/core.hpp" #include "mori/shmem/shmem.hpp" #include "src/ops/dispatch_combine/internode.hpp" +#include "src/ops/dispatch_combine/internode_v1.hpp" #include "src/ops/dispatch_combine/intranode.hpp" namespace mori { @@ -42,6 +43,7 @@ using namespace mori::shmem; /* EpDispatchCombineHandle */ /* ---------------------------------------------------------------------------------------------- */ EpDispatchCombineHandle::EpDispatchCombineHandle(EpDispatchCombineConfig config) : config(config) { + assert(IsPowerOf2(config.gpuPerNode) && (config.worldSize % config.gpuPerNode == 0)); InitializeShmemBuf(); InitializeTokenNumSignalBuf(); InitializeOrderMapBuf(); @@ -113,14 +115,20 @@ void EpDispatchCombineHandle::InitializeTokenNumSignalBuf() { sendAtomicSignalMemObj = ShmemMallocAndReturnMemObjPtr( (config.worldSize * 2) * sizeof(int64_t) * 2, hipDeviceMallocUncached); - HIP_RUNTIME_CHECK(hipMalloc(&totalRecvTokenNum, sizeof(index_t))); + HIP_RUNTIME_CHECK( + hipExtMallocWithFlags((void**)&totalRecvTokenNum, sizeof(index_t), hipDeviceMallocUncached)); HIP_RUNTIME_CHECK(hipMemset(totalRecvTokenNum, 0, sizeof(index_t))); + + size_t nodeTokenNumSignalSize = config.worldSize / config.gpuPerNode * sizeof(index_t); + nodeRecvTokenNumMemObj = + ShmemMallocAndReturnMemObjPtr(nodeTokenNumSignalSize, hipDeviceMallocUncached); } void EpDispatchCombineHandle::FinalizeTokenNumSignalBuf() { ShmemFree(recvTokenNumMemObj->localPtr); ShmemFree(sendTokenNumMemObj->localPtr); ShmemFree(sendAtomicSignalMemObj->localPtr); + ShmemFree(nodeRecvTokenNumMemObj->localPtr); HIP_RUNTIME_CHECK(hipFree(totalRecvTokenNum)); } @@ -138,11 +146,27 @@ void EpDispatchCombineHandle::InitializeOrderMapBuf() { HIP_RUNTIME_CHECK(hipMalloc(&srcPeTokenIdxMap, maxNumOutToken * sizeof(index_t))); HIP_RUNTIME_CHECK(hipMemset(srcPeTokenIdxMap, -1, maxNumOutToken * sizeof(index_t))); - HIP_RUNTIME_CHECK(hipMalloc(&destPeTokenCounter, config.worldSize * sizeof(index_t))); + HIP_RUNTIME_CHECK(hipExtMallocWithFlags((void**)&destPeTokenCounter, + config.worldSize * sizeof(index_t) * 1024, + hipDeviceMallocUncached)); + // HIP_RUNTIME_CHECK(hipMalloc(&destPeTokenCounter, config.worldSize * sizeof(index_t))); HIP_RUNTIME_CHECK(hipMemset(destPeTokenCounter, 0, config.worldSize * sizeof(index_t))); - HIP_RUNTIME_CHECK(hipMalloc(&localPeTokenCounter, config.numExpertPerRank * sizeof(index_t))); - HIP_RUNTIME_CHECK(hipMemset(localPeTokenCounter, 0, config.numExpertPerRank * sizeof(index_t))); + HIP_RUNTIME_CHECK(hipExtMallocWithFlags( + (void**)&destNodeTokenCounter, config.worldSize * sizeof(index_t), hipDeviceMallocUncached)); + HIP_RUNTIME_CHECK(hipMemset(destNodeTokenCounter, 0, config.worldSize * sizeof(index_t))); + + // HIP_RUNTIME_CHECK( + // hipMalloc(&destNodeTokenCounter, config.worldSize / config.gpuPerNode * + // sizeof(index_t))); + HIP_RUNTIME_CHECK( + hipMemset(destNodeTokenCounter, 0, config.worldSize / config.gpuPerNode * sizeof(index_t))); + + HIP_RUNTIME_CHECK(hipExtMallocWithFlags( + (void**)&localPeTokenCounter, config.worldSize * sizeof(index_t), hipDeviceMallocUncached)); + + // HIP_RUNTIME_CHECK(hipMalloc(&localPeTokenCounter, config.worldSize * sizeof(index_t))); + HIP_RUNTIME_CHECK(hipMemset(localPeTokenCounter, 0, config.worldSize * sizeof(index_t))); dispTokOffsetMemObj = ShmemMallocAndReturnMemObjPtr(sizeof(index_t), hipDeviceMallocUncached); dispTokIdToSrcTokIdMemObj = @@ -150,6 +174,20 @@ void EpDispatchCombineHandle::InitializeOrderMapBuf() { HIP_RUNTIME_CHECK(hipMalloc(&dispDestTokIdMap, maxNumOutToken * sizeof(index_t))); HIP_RUNTIME_CHECK(hipMemset(dispDestTokIdMap, 0, maxNumOutToken * sizeof(index_t))); + + size_t maxNumInterNodeToken = config.worldSize / config.gpuPerNode * + config.maxNumInpTokenPerRank * config.numExpertPerToken; + HIP_RUNTIME_CHECK(hipMalloc(&interNodeDispDestTokIdMap, maxNumInterNodeToken * sizeof(index_t))); + HIP_RUNTIME_CHECK( + hipMemset(interNodeDispDestTokIdMap, 0, maxNumInterNodeToken * sizeof(index_t))); + + HIP_RUNTIME_CHECK(hipMalloc(&blockFlagCounter, sizeof(index_t))); + HIP_RUNTIME_CHECK(hipMemset(blockFlagCounter, 0, sizeof(index_t))); + + size_t interNodeDispSendMapSize = + config.worldSize / config.gpuPerNode * config.maxNumInpTokenPerRank * sizeof(index_t); + HIP_RUNTIME_CHECK(hipMalloc(&interNodeDispSendMap, interNodeDispSendMapSize)); + HIP_RUNTIME_CHECK(hipMemset(interNodeDispSendMap, 0, interNodeDispSendMapSize)); } void EpDispatchCombineHandle::FinalizeOrderMapBuf() { @@ -158,29 +196,52 @@ void EpDispatchCombineHandle::FinalizeOrderMapBuf() { HIP_RUNTIME_CHECK(hipFree(destPeTokenIdxMap)); HIP_RUNTIME_CHECK(hipFree(srcPeTokenIdxMap)); HIP_RUNTIME_CHECK(hipFree(destPeTokenCounter)); + HIP_RUNTIME_CHECK(hipFree(destNodeTokenCounter)); HIP_RUNTIME_CHECK(hipFree(localPeTokenCounter)); ShmemFree(dispTokOffsetMemObj->localPtr); ShmemFree(dispTokIdToSrcTokIdMemObj->localPtr); HIP_RUNTIME_CHECK(hipFree(dispDestTokIdMap)); + HIP_RUNTIME_CHECK(hipFree(interNodeDispDestTokIdMap)); + HIP_RUNTIME_CHECK(hipFree(blockFlagCounter)); + HIP_RUNTIME_CHECK(hipFree(interNodeDispSendMap)); } void EpDispatchCombineHandle::InitializeBarrier() { size_t barrierSize = config.worldSize * sizeof(uint32_t); - HIP_RUNTIME_CHECK(hipMalloc(&dispatchGridBarrier, barrierSize)); + HIP_RUNTIME_CHECK( + hipExtMallocWithFlags((void**)&dispatchGridBarrier, barrierSize, hipDeviceMallocUncached)); HIP_RUNTIME_CHECK(hipMemset(dispatchGridBarrier, 0, barrierSize)); - HIP_RUNTIME_CHECK(hipMalloc(&combineGridBarrier, barrierSize)); + HIP_RUNTIME_CHECK( + hipExtMallocWithFlags((void**)&combineGridBarrier, barrierSize, hipDeviceMallocUncached)); HIP_RUNTIME_CHECK(hipMemset(combineGridBarrier, 0, barrierSize)); HIP_RUNTIME_CHECK(hipMalloc(&crossDeviceBarrierFlag, sizeof(uint32_t))); HIP_RUNTIME_CHECK(hipMemsetD32(crossDeviceBarrierFlag, 1, 1)); crossDeviceBarrierMemObj = ShmemMallocAndReturnMemObjPtr( barrierSize * 2 * sizeof(uint64_t) / sizeof(uint32_t), hipDeviceMallocUncached); + + // We allocate one flag for each token, this ensure that we can use all chunk size(>=1) + size_t interNodeChunkFlagSize = + config.worldSize / config.gpuPerNode * config.MaxNumTokensToRecvPerRank() * sizeof(index_t); + interNodeChunkFlagMemObj = + ShmemMallocAndReturnMemObjPtr(interNodeChunkFlagSize, hipDeviceMallocUncached); + + HIP_RUNTIME_CHECK(hipMalloc(&interNodeChunkFlagCombine, interNodeChunkFlagSize)); + HIP_RUNTIME_CHECK(hipMemset(interNodeChunkFlagCombine, 0, interNodeChunkFlagSize)); + + HIP_RUNTIME_CHECK(hipExtMallocWithFlags((void**)&interNodeBlocksBarrier, sizeof(index_t), + hipDeviceMallocUncached)); + // HIP_RUNTIME_CHECK(hipMalloc(&interNodeBlocksBarrier, sizeof(index_t))); + HIP_RUNTIME_CHECK(hipMemset(interNodeBlocksBarrier, 0, sizeof(index_t))); } void EpDispatchCombineHandle::FinalizeBarrier() { HIP_RUNTIME_CHECK(hipFree(dispatchGridBarrier)); HIP_RUNTIME_CHECK(hipFree(combineGridBarrier)); HIP_RUNTIME_CHECK(hipFree(crossDeviceBarrierFlag)); + HIP_RUNTIME_CHECK(hipFree(interNodeChunkFlagCombine)); + HIP_RUNTIME_CHECK(hipFree(interNodeBlocksBarrier)); ShmemFree(crossDeviceBarrierMemObj->localPtr); + ShmemFree(interNodeChunkFlagMemObj->localPtr); } void EpDispatchCombineHandle::LaunchIntraNodeDispatch(int blockNum, int warpPerBlock, @@ -222,6 +283,10 @@ void EpDispatchCombineHandle::LaunchDispatch(KernelType kernelType, int blockNum if (kernelType == KernelType::InterNode) { assert(config.useExternalInpBuffer); EpDispatchInterNodeKernel<<>>(args); + } else if (kernelType == KernelType::InterNodeV1) { + EpDispatchInterNodeV1Kernel<<>>(args); + // hipDeviceSynchronize(); + // DispatchSyncKernel<<>>(args); } else if (kernelType == KernelType::IntraNode) { EpDispatchIntraNodeKernel<<>>(args); } else { @@ -248,6 +313,9 @@ void EpDispatchCombineHandle::LaunchCombine(KernelType kernelType, int blockNum, if (kernelType == KernelType::InterNode) { assert(config.useExternalInpBuffer); EpCombineInterNodeKernel<<>>(args); + } else if (kernelType == KernelType::InterNodeV1) { + assert(config.useExternalInpBuffer); + EpCombineInterNodeV1Kernel<<>>(args); } else if (kernelType == KernelType::IntraNode) { EpCombineIntraNodeKernel<<>>(args); } else { diff --git a/src/ops/dispatch_combine/internode_v1.cpp b/src/ops/dispatch_combine/internode_v1.cpp new file mode 100644 index 00000000..3a1f6863 --- /dev/null +++ b/src/ops/dispatch_combine/internode_v1.cpp @@ -0,0 +1,659 @@ +// Copyright © Advanced Micro Devices, Inc. All rights reserved. +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "src/ops/dispatch_combine/internode_v1.hpp" + +#include "mori/core/core.hpp" +#include "mori/ops/dispatch_combine/dispatch_combine.hpp" +#include "mori/shmem/shmem.hpp" + +namespace mori { +namespace moe { + +/* ---------------------------------------------------------------------------------------------- */ +/* EpDispatchInterNodeV1Kernel */ +/* ---------------------------------------------------------------------------------------------- */ +#define DEF_COMMON_VARS \ + const EpDispatchCombineConfig& config = args.config; \ + int thdId = threadIdx.x; \ + int thdNum = blockDim.x; \ + int laneId = threadIdx.x & (warpSize - 1); \ + int warpId = thdId / warpSize; \ + int warpNum = blockDim.x / warpSize; \ + int blockNum = gridDim.x; \ + int blockId = blockIdx.x; \ + int globalThdId = blockIdx.x * blockDim.x + threadIdx.x; \ + int globalThdNum = gridDim.x * blockDim.x; \ + int globalWarpId = blockIdx.x * warpNum + warpId; \ + int globalWarpNum = gridDim.x * warpNum; \ + int nullTokenId = config.worldSize * config.MaxNumTokensToRecv(); \ + int myPe = config.rank; \ + int npes = config.worldSize; \ + int myNode = myPe / config.gpuPerNode; \ + int nNodes = npes / config.gpuPerNode; \ + int numExpertPerToken = config.numExpertPerToken; \ + assert(numExpertPerToken < warpSize); \ + size_t hiddenBytes = config.hiddenDim * sizeof(T); \ + size_t indexBytes = config.numExpertPerToken * sizeof(index_t); \ + size_t weightBytes = config.numExpertPerToken * sizeof(float); \ + size_t srcTokenIdBytes = sizeof(index_t); \ + size_t xferBytes = hiddenBytes + indexBytes + weightBytes + srcTokenIdBytes; + +namespace v1 { +template +inline __device__ void DispatchIntraNodeBlock(EpDispatchCombineArgs& args, int tokenId, + int expId, int destPe) { + DEF_COMMON_VARS; + + index_t tokenExpertId = tokenId * args.config.numExpertPerToken + expId; + index_t destTokId = 0; + if (laneId == 0) { + // decide token id in dest pe + destTokId = atomicAdd(args.dispTokOffsetMemObj->template GetAs(destPe), 1); + atomicAdd(args.destPeTokenCounter + destPe, 1); + args.dispDestTokIdMap[tokenExpertId] = destPe * config.MaxNumTokensToRecv() + destTokId; + + core::AtomicStoreRelaxedSystem( + args.dispTokIdToSrcTokIdMemObj->template GetAs(destPe) + destTokId, + config.rank * config.maxNumInpTokenPerRank + tokenId); + } + destTokId = __shfl(destTokId, 0); + size_t srcTokOffset = tokenId * config.hiddenDim; + size_t destTokOffset = destTokId * config.hiddenDim; + + T* remoteTokenPtr = args.shmemOutTokMemObj->template GetAs(destPe); + const T* localTokenPtr = args.inpTokenBuf; + core::WarpCopy(remoteTokenPtr + destTokOffset, localTokenPtr + srcTokOffset, config.hiddenDim); + + index_t* remoteIndexPtr = args.shmemOutIndicesMemObj->template GetAs(destPe); + const index_t* localIndexPtr = args.tokenIndices; + core::WarpCopy(remoteIndexPtr + destTokId * config.numExpertPerToken, + localIndexPtr + tokenId * config.numExpertPerToken, config.numExpertPerToken); + + float* remoteWeightPtr = args.shmemOutWeightsMemObj->template GetAs(destPe); + const float* localWeightPtr = args.weightsBuf; + core::WarpCopy(remoteWeightPtr + destTokId * config.numExpertPerToken, + localWeightPtr + tokenId * config.numExpertPerToken, config.numExpertPerToken); +} + +template +inline __device__ void DispatchIntraNode(EpDispatchCombineArgs& args) { + DEF_COMMON_VARS; + + // Distribute tokens evenly to all blocks + int blockOffset = config.rdmaBlockNum; + int xgmiBlockNum = blockNum - config.rdmaBlockNum; + int tokenPerBlock = (args.curRankNumToken + xgmiBlockNum - 1) / xgmiBlockNum; + int startTokenIdx = (blockId - blockOffset) * tokenPerBlock; + int endTokenIdx = std::min(startTokenIdx + tokenPerBlock, args.curRankNumToken); + + for (int tokenId = startTokenIdx + warpId; tokenId < endTokenIdx; tokenId += warpNum) { + int lanePe = -1, laneNode = -1; + if (laneId < numExpertPerToken) { + lanePe = (args.tokenIndices[tokenId * numExpertPerToken + laneId] / config.numExpertPerRank); + laneNode = lanePe / config.gpuPerNode; + }; + + // Send to other pes in myNode + for (int e = 0; e < config.numExpertPerToken; e++) { + int tokenExpertId = tokenId * config.numExpertPerToken + e; + int destPe = __shfl(lanePe, e); + int destNode = destPe / config.gpuPerNode; + if (destNode == myNode) { + if (__any((laneId < e) && (destPe == lanePe))) { + if (laneId == 0) args.dispDestTokIdMap[tokenExpertId] = nullTokenId; + continue; + } + DispatchIntraNodeBlock(args, tokenId, e, destPe); + } + } + } +} + +template +inline __device__ void DispatchInterNodeSend(EpDispatchCombineArgs& args) { + DEF_COMMON_VARS; + + // Distribute tokens evenly to all blocks + int maxChunkNum = core::CeilDiv(config.maxNumInpTokenPerRank, warpSize); + int totalChunkNum = core::CeilDiv(args.curRankNumToken, warpSize); + int blockChunkNum = core::CeilDiv(totalChunkNum, config.rdmaBlockNum); + + int startTokenIdx = blockChunkNum * blockId * warpSize; + int endTokenIdx = std::min(startTokenIdx + blockChunkNum * warpSize, args.curRankNumToken); + + // First copy to staging buffer + for (int tokenId = startTokenIdx + warpId; tokenId < endTokenIdx; tokenId += warpNum) { + uint8_t* stagingPtr = args.shmemStagingTokMemObj->template GetAs(); + size_t stagingTokOffset = tokenId * xferBytes; + core::WarpCopy(stagingPtr + stagingTokOffset, + reinterpret_cast(args.inpTokenBuf) + tokenId * hiddenBytes, + hiddenBytes); + core::WarpCopy(stagingPtr + stagingTokOffset + hiddenBytes, + reinterpret_cast(args.tokenIndices) + tokenId * indexBytes, + indexBytes); + core::WarpCopy(stagingPtr + stagingTokOffset + hiddenBytes + indexBytes, + reinterpret_cast(args.weightsBuf) + tokenId * weightBytes, + weightBytes); + if (laneId == 0) + reinterpret_cast(stagingPtr + stagingTokOffset + hiddenBytes + indexBytes + + weightBytes)[0] = + tokenId + config.rank * config.maxNumInpTokenPerRank; + } + __syncthreads(); + + // Then send to other nodes + for (int i = warpId; i < nNodes; i += warpNum) { + if (i == myNode) continue; + int proxyPe = i * config.gpuPerNode + (config.rank % config.gpuPerNode); + for (int tokenId = startTokenIdx + laneId; tokenId < endTokenIdx; tokenId += warpSize) { + bool shouldSend = false; + for (int e = 0; e < config.numExpertPerToken; e++) { + int destNode = args.tokenIndices[tokenId * numExpertPerToken + e] / + config.numExpertPerRank / config.gpuPerNode; + if (destNode == i) { + shouldSend |= true; + args.dispDestTokIdMap[tokenId * numExpertPerToken + e] = nullTokenId; + } + } + uint64_t mask = __ballot(shouldSend) & __activemask(); + uint64_t num = __popcll(mask); + + index_t flag = 0; + index_t flagSlotId = 0; + if (laneId == 0) { + flagSlotId = atomicAdd(args.blockFlagCounter, 1); + atomicAdd(args.destNodeTokenCounter + i, num); + flag = num + 1; + } + flag = __shfl(flag, 0); + flagSlotId = __shfl(flagSlotId, 0); + + index_t destTokIdOffset = flagSlotId * warpSize; + + uint64_t warpOffset = 0; + if (laneId > 0) warpOffset = __popcll(mask << (warpSize - laneId)); + index_t destTokId = destTokIdOffset + warpOffset; + + if (shouldSend) { + bool prev = (laneId > 0) ? ((mask >> (laneId - 1)) & 1ULL) : 0; + int count = 0; + if (!prev) { + count = 1; + for (int i = laneId + 1; i < warpSize; i++) { + if ((mask >> i) & 1ULL) { + count++; + } else { + break; + } + } + } + size_t remoteIdx = (myNode * config.MaxNumTokensToRecvPerRank() + destTokId); + if (count > 0) { + size_t stagingTokOffset = tokenId * xferBytes; + shmem::ShmemPutMemNbiThreadKernelImpl( + args.shmemInpTokMemObj, remoteIdx * xferBytes, + args.shmemStagingTokMemObj->GetRdmaMemoryRegion(shmem::GetGlobalGpuStatesPtr()->rank), + stagingTokOffset, count * xferBytes, args.interNodeChunkFlagMemObj, + (myNode * maxChunkNum + flagSlotId) * sizeof(index_t), &flag, sizeof(index_t), + args.sendTokenNumMemObj, proxyPe); + } + args.interNodeDispSendMap[nNodes * tokenId + i] = destTokId; + } + } + } + + int finishedWarp = 0; + if (laneId == 0) finishedWarp = atomicAdd(args.interNodeBlocksBarrier, 1); + finishedWarp = __shfl(finishedWarp, 0); + if ((finishedWarp + 1) == (config.rdmaBlockNum * warpNum)) { + if (laneId < nNodes) { + int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); + index_t numTokenSignal = core::AtomicLoadSeqCstSystem(args.destNodeTokenCounter + laneId) + 1; + shmem::ShmemPutInt32ImmNbiThread(args.nodeRecvTokenNumMemObj, myNode * sizeof(index_t), + numTokenSignal, proxyPe); + } + if (laneId == 0) args.interNodeBlocksBarrier[0] = 0; + } +} + +template +inline __device__ void DispatchInterNodeRecv(EpDispatchCombineArgs& args) { + DEF_COMMON_VARS; + + constexpr int numRecvBlock = 1; + int maxChunkNum = core::CeilDiv(config.maxNumInpTokenPerRank, warpSize); + + index_t* chunkFlag = args.interNodeChunkFlagMemObj->template GetAs(); + index_t* nodeRecvTokenNum = args.nodeRecvTokenNumMemObj->template GetAs(); + uint8_t* stagingPtr = args.shmemInpTokMemObj->template GetAs(); + + for (int k = blockId / numRecvBlock; k < maxChunkNum; k += (config.rdmaBlockNum / numRecvBlock)) { + for (int i = 0; i < nNodes; i++) { + if (i == myNode) continue; + int startTokenIdx = k * warpSize; + + // Poll completion flags + index_t thisChunkTokenNum = 0; + index_t nodeFlag = 0; + if (laneId == 0) { + while (1) { + thisChunkTokenNum = core::AtomicLoadRelaxedSystem(&chunkFlag[i * maxChunkNum + k]); + if (thisChunkTokenNum > 0) break; + + nodeFlag = core::AtomicLoadRelaxedSystem(&nodeRecvTokenNum[i]); + if ((nodeFlag > 0) && (startTokenIdx >= (nodeFlag - 1))) { + thisChunkTokenNum = 1; + break; + } + } + } + thisChunkTokenNum = __shfl(thisChunkTokenNum, 0) - 1; + nodeFlag = __shfl(nodeFlag, 0) - 1; + + int endTokenIdx = startTokenIdx + thisChunkTokenNum; + + for (int j = startTokenIdx + (blockId % numRecvBlock) * warpNum + warpId; j < endTokenIdx; + j += numRecvBlock * warpNum) { + int tokIdx = i * config.MaxNumTokensToRecvPerRank() + j; + index_t* indicies = + reinterpret_cast(stagingPtr + tokIdx * xferBytes + hiddenBytes); + int lanePe = -1; + if (laneId < config.numExpertPerToken) { + lanePe = indicies[laneId] / config.numExpertPerRank; + assert((lanePe < config.worldSize) && (lanePe >= 0)); + } + index_t srcTokId = reinterpret_cast(stagingPtr + tokIdx * xferBytes + + hiddenBytes + indexBytes + weightBytes)[0]; + + for (int e = 0; e < config.numExpertPerToken; e++) { + int destPe = __shfl(lanePe, e); + int destNode = destPe / config.gpuPerNode; + + bool shouldSkip = (destNode != myNode) || __any((laneId < e) && (destPe == lanePe)); + if (shouldSkip) { + if (laneId == 0) + args.interNodeDispDestTokIdMap[tokIdx * config.numExpertPerToken + e] = nullTokenId; + continue; + } + int destTokId = 0; + if (laneId == 0) { + destTokId = atomicAdd(args.dispTokOffsetMemObj->template GetAs(destPe), 1); + atomicAdd(args.destPeTokenCounter + destPe, 1); + args.interNodeDispDestTokIdMap[tokIdx * config.numExpertPerToken + e] = + destPe * config.MaxNumTokensToRecv() + destTokId; + args.dispTokIdToSrcTokIdMemObj->template GetAs(destPe)[destTokId] = srcTokId; + } + destTokId = __shfl(destTokId, 0); + core::WarpCopy( + args.shmemOutTokMemObj->template GetAs(destPe) + destTokId * hiddenBytes, + stagingPtr + tokIdx * xferBytes, hiddenBytes); + core::WarpCopy( + args.shmemOutIndicesMemObj->template GetAs(destPe) + destTokId * indexBytes, + stagingPtr + tokIdx * xferBytes + hiddenBytes, indexBytes); + core::WarpCopy(args.shmemOutWeightsMemObj->template GetAs(destPe) + + destTokId * weightBytes, + stagingPtr + tokIdx * xferBytes + hiddenBytes + indexBytes, weightBytes); + } + } + } + } +} + +template +inline __device__ void DispatchSync(EpDispatchCombineArgs& args) { + DEF_COMMON_VARS; + + int nodePeOffset = myNode * config.gpuPerNode; + int finishedWarp = 0; + if (laneId == 0) finishedWarp = atomicAdd(args.dispatchGridBarrier, 1); + finishedWarp = __shfl(finishedWarp, 0); + if ((finishedWarp + 1) == globalWarpNum) { + if (laneId < config.gpuPerNode) { + int destPe = myNode * config.gpuPerNode + laneId; + index_t numTokenSignal = core::AtomicLoadSeqCstSystem(args.destPeTokenCounter + destPe) + 1; + index_t* signal = args.recvTokenNumMemObj->template GetAs(destPe) + myPe; + core::AtomicStoreSeqCstSystem(signal, numTokenSignal); + } + if (laneId == 0) args.dispatchGridBarrier[0] = 0; + } + + // Each warp wait until sender finished by waiting token number signal + index_t* recvTokenNums = args.recvTokenNumMemObj->template GetAs(); + if (globalWarpId == 0) { + for (int destPe = nodePeOffset + laneId; destPe < (nodePeOffset + config.gpuPerNode); + destPe += warpSize) { + index_t* signal = recvTokenNums + destPe; + index_t recvTokenNum = shmem::ShmemInt32WaitUntilGreaterThan(signal, 0) - 1; + core::AtomicStoreRelaxedSystem(signal, 0); + atomicAdd(args.totalRecvTokenNum, recvTokenNum); + // printf("myPe %d recv token %d from pe %d\n", myPe, recvTokenNum, destPe); + + // reset local counter + args.destPeTokenCounter[destPe] = 0; + recvTokenNums[destPe] = 0; + } + + // reset counter + if (core::WarpLaneId1D() == 0) { + args.dispTokOffsetMemObj->template GetAs()[0] = 0; + } + + if (laneId < nNodes) { + // printf("myPe %d recv token %d from node %d\n", myPe, + // args.nodeRecvTokenNumMemObj->template GetAs()[laneId], laneId); + core::AtomicStoreRelaxedSystem( + args.nodeRecvTokenNumMemObj->template GetAs() + laneId, 0); + core::AtomicStoreRelaxedSystem(args.destNodeTokenCounter + laneId, 0); + } + } + + // if ((globalWarpId < config.worldSize) && (laneId == 0)) shmem::ShmemQuietThread(globalWarpId); +} + +} // namespace v1 + +template +__global__ void EpDispatchInterNodeV1Kernel(EpDispatchCombineArgs args) { + DEF_COMMON_VARS; + if (blockId < config.rdmaBlockNum) { + v1::DispatchInterNodeSend(args); + v1::DispatchInterNodeRecv(args); + } else { + v1::DispatchIntraNode(args); + } + v1::DispatchSync(args); +} + +/* ---------------------------------------------------------------------------------------------- */ +/* EpCombineInterNodeV1Kernel */ +/* ---------------------------------------------------------------------------------------------- */ +namespace v1 { + +template +inline __device__ void CombineSync(EpDispatchCombineArgs& args) { + DEF_COMMON_VARS; + + // Copy input to shmem registered buffer so that other GPUs can access directly + index_t totalRecvTokenNum = args.totalRecvTokenNum[0]; + int tokenPerBlock = core::CeilDiv(totalRecvTokenNum, blockNum); + int startTokenIdx = blockId * tokenPerBlock; + int endTokenIdx = std::min(startTokenIdx + tokenPerBlock, totalRecvTokenNum); + for (int tokenId = startTokenIdx + warpId; tokenId < endTokenIdx; tokenId += warpNum) { + core::WarpCopy(args.shmemInpTokMemObj->template GetAs() + tokenId * config.hiddenDim, + args.inpTokenBuf + tokenId * config.hiddenDim, config.hiddenDim); + } + // After all warps copy done, set barrier flag + int finishedWarp = 0; + if (laneId == 0) finishedWarp = atomicAdd(args.combineGridBarrier, 1); + finishedWarp = __shfl(finishedWarp, 0); + if ((finishedWarp + 1) == (blockNum * warpNum)) { + if (laneId < config.gpuPerNode) { + int destPe = myNode * config.gpuPerNode + laneId; + core::AtomicStoreRelaxedSystem( + args.crossDeviceBarrierMemObj->template GetAs(destPe) + args.config.rank, + args.crossDeviceBarrierFlag); + } + if (laneId == 0) args.combineGridBarrier[0] = 0; + } + // Wait other pes to set flag + uint32_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); + if (laneId < config.gpuPerNode) { + int destPe = myNode * config.gpuPerNode + laneId; + while (core::AtomicLoadRelaxedSystem(localBarrierPtr + destPe) != args.crossDeviceBarrierFlag) { + } + } +} + +template +inline __device__ void CombineIntraNode(EpDispatchCombineArgs& args) { + DEF_COMMON_VARS; + + // Distribute tokens evenly to all blocks + int blockOffset = config.rdmaBlockNum; + int xgmiBlockNum = blockNum - config.rdmaBlockNum; + + extern __shared__ char sharedMem[]; + T** srcPtrs = reinterpret_cast(sharedMem) + warpId * config.numExpertPerToken; + float** srcWeightsPtr = reinterpret_cast(sharedMem) + + warpNum * config.numExpertPerToken + warpId * config.numExpertPerToken; + T* stagingPtr = args.shmemStagingTokMemObj->template GetAs() + + (nNodes + myNode) * config.MaxNumTokensToRecvPerRank() * config.hiddenDim; + + int tokenPerBlock = (args.curRankNumToken + xgmiBlockNum - 1) / xgmiBlockNum; + int startTokenIdx = (blockId - blockOffset) * tokenPerBlock; + int endTokenIdx = std::min(startTokenIdx + tokenPerBlock, args.curRankNumToken); + + for (int tokenId = startTokenIdx + warpId; tokenId < endTokenIdx; tokenId += warpNum) { + if (laneId < config.numExpertPerToken) { + srcPtrs[laneId] = nullptr; + srcWeightsPtr[laneId] = nullptr; + index_t destTokId = args.dispDestTokIdMap[tokenId * config.numExpertPerToken + laneId]; + index_t destPe = destTokId / config.MaxNumTokensToRecv(); + index_t destNode = destPe / config.gpuPerNode; + if (destNode == myNode) { + index_t destLocalTokId = destTokId - destPe * config.MaxNumTokensToRecv(); + srcPtrs[laneId] = + args.shmemInpTokMemObj->template GetAs(destPe) + destLocalTokId * config.hiddenDim; + srcWeightsPtr[laneId] = args.shmemInpWeightsMemObj->template GetAs(destPe) + + destLocalTokId * config.numExpertPerToken; + } + } + core::WarpAccum(stagingPtr + tokenId * config.hiddenDim, srcPtrs, nullptr, + config.numExpertPerToken, config.hiddenDim); + } +} + +template +inline __device__ void CombineInterNode(EpDispatchCombineArgs& args) { + DEF_COMMON_VARS; + + constexpr int numRecvBlock = 4; + int maxChunkNum = core::CeilDiv(config.maxNumInpTokenPerRank, warpSize); + + index_t* chunkFlag = args.interNodeChunkFlagMemObj->template GetAs(); + + extern __shared__ char sharedMem[]; + T** srcPtrs = reinterpret_cast(sharedMem) + warpId * config.numExpertPerToken; + float** srcWeightsPtr = reinterpret_cast(sharedMem) + + warpNum * config.numExpertPerToken + warpId * config.numExpertPerToken; + + for (int k = blockId / numRecvBlock; k < maxChunkNum; k += (config.rdmaBlockNum / numRecvBlock)) { + for (int i = 0; i < nNodes; i++) { + if (i == myNode) continue; + + index_t thisChunkTokenNum = chunkFlag[i * maxChunkNum + k]; + thisChunkTokenNum -= (thisChunkTokenNum > 0) ? 1 : 0; + int startTokenIdx = k * warpSize; + int endTokenIdx = startTokenIdx + thisChunkTokenNum; + + for (int j = startTokenIdx + (blockId % numRecvBlock) * warpNum + warpId; j < endTokenIdx; + j += numRecvBlock * warpNum) { + int tokIdx = i * config.MaxNumTokensToRecvPerRank() + j; + if (laneId < config.numExpertPerToken) { + srcPtrs[laneId] = nullptr; + srcWeightsPtr[laneId] = nullptr; + index_t destTokId = + args.interNodeDispDestTokIdMap[tokIdx * config.numExpertPerToken + laneId]; + index_t destPe = destTokId / config.MaxNumTokensToRecv(); + index_t destNode = destPe / config.gpuPerNode; + if (destNode == myNode) { + index_t destLocalTokId = destTokId - destPe * config.MaxNumTokensToRecv(); + srcPtrs[laneId] = args.shmemInpTokMemObj->template GetAs(destPe) + + destLocalTokId * config.hiddenDim; + srcWeightsPtr[laneId] = args.shmemInpWeightsMemObj->template GetAs(destPe) + + destLocalTokId * config.numExpertPerToken; + } + args.interNodeDispDestTokIdMap[tokIdx * config.numExpertPerToken + laneId] = 0; + } + core::WarpAccum( + args.shmemStagingTokMemObj->template GetAs() + tokIdx * config.hiddenDim, srcPtrs, + nullptr, config.numExpertPerToken, config.hiddenDim); + } + + index_t finished = 0; + if (laneId == 0) + finished = atomicAdd(&args.interNodeChunkFlagCombine[i * maxChunkNum + k], 1); + finished = __shfl(finished, 0); + if ((finished + 1) < (numRecvBlock * warpNum)) continue; + + int proxyPe = i * config.gpuPerNode + (config.rank % config.gpuPerNode); + if (laneId == 0) + shmem::ShmemPutTypeNbiThread( + args.shmemStagingTokMemObj, + ((myNode + nNodes) * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * + config.hiddenDim, + args.shmemStagingTokMemObj, + (i * config.MaxNumTokensToRecvPerRank() + startTokenIdx) * config.hiddenDim, + thisChunkTokenNum * config.hiddenDim, proxyPe); + } + } + int finishedWarp = 0; + if (laneId == 0) finishedWarp = atomicAdd(args.interNodeBlocksBarrier, 1); + finishedWarp = __shfl(finishedWarp, 0); + if ((finishedWarp + 1) == (config.rdmaBlockNum * warpNum)) { + if ((laneId < nNodes) && + (laneId != myNode)) { // avoid setting myNode, it will be set in intra node branch + int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); + shmem::ShmemPutUint32ImmNbiThread(args.crossDeviceBarrierMemObj, + args.config.rank * sizeof(uint32_t), + args.crossDeviceBarrierFlag, proxyPe); + } + if (laneId == 0) args.interNodeBlocksBarrier[0] = 0; + } +} + +template +inline __device__ void CombineAll(EpDispatchCombineArgs& args) { + DEF_COMMON_VARS; + + // Wait all warps + uint32_t finishedWarps = 0; + if (laneId == 0) { + finishedWarps = atomicAdd(&args.combineGridBarrier[1], 1); + shmem::ShmemUint32WaitUntilEquals(&args.combineGridBarrier[1], globalWarpNum); + } + finishedWarps = __shfl(finishedWarps, 0); + // while (core::AtomicLoadRelaxed(&args.combineGridBarrier[1]) != globalWarpNum) { + // } + if (((finishedWarps + 1) == globalWarpNum) && (laneId == 0)) args.combineGridBarrier[1] = 0; + + // Wait other pes to set flag + uint32_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); + if (laneId < nNodes) { + int proxyPe = laneId * config.gpuPerNode + (config.rank % config.gpuPerNode); + while (core::AtomicLoadRelaxedSystem(localBarrierPtr + proxyPe) != + args.crossDeviceBarrierFlag) { + } + } + + extern __shared__ char sharedMem[]; + T** srcPtrs = reinterpret_cast(sharedMem) + warpId * config.numExpertPerToken; + float** srcWeightsPtr = reinterpret_cast(sharedMem) + + warpNum * config.numExpertPerToken + warpId * config.numExpertPerToken; + T* stagingPtr = args.shmemStagingTokMemObj->template GetAs() + + nNodes * config.MaxNumTokensToRecvPerRank() * config.hiddenDim; + + int tokenPerBlock = (args.curRankNumToken + blockNum - 1) / blockNum; + int startTokenIdx = blockId * tokenPerBlock; + int endTokenIdx = std::min(startTokenIdx + tokenPerBlock, args.curRankNumToken); + + for (int tokenId = startTokenIdx + warpId; tokenId < endTokenIdx; tokenId += warpNum) { + int lanePe = -1, laneNode = -1; + if (laneId < config.numExpertPerToken) { + lanePe = (args.tokenIndices[tokenId * numExpertPerToken + laneId] / config.numExpertPerRank); + laneNode = lanePe / config.gpuPerNode; + } + + if (laneId < nNodes) srcPtrs[laneId] = nullptr; + for (int i = 0; i < nNodes; i++) { + if (__any(laneNode == i) && (laneId == 0)) { + int mappedId = (i == myNode) ? tokenId : args.interNodeDispSendMap[nNodes * tokenId + i]; + srcPtrs[i] = + stagingPtr + (i * config.MaxNumTokensToRecvPerRank() + mappedId) * config.hiddenDim; + } + } + if (laneId < nNodes) args.interNodeDispSendMap[nNodes * tokenId + laneId] = 0; + core::WarpAccum(args.shmemOutTokMemObj->template GetAs() + tokenId * config.hiddenDim, + srcPtrs, nullptr, nNodes, config.hiddenDim); + } +} +} // namespace v1 + +template +__global__ void EpCombineInterNodeV1Kernel(EpDispatchCombineArgs args) { + DEF_COMMON_VARS; + + v1::CombineSync(args); + if (blockId < config.rdmaBlockNum) { + v1::CombineInterNode(args); + } else { + v1::CombineIntraNode(args); + } + v1::CombineAll(args); + + // TODO: refactor following state reset code + if (laneId == 0) { + args.totalRecvTokenNum[0] = 0; + args.blockFlagCounter[0] = 0; + // for (int i = 0; i < config.worldSize; i++) shmem::ShmemQuietThread(i); + } + + if (globalThdId < nNodes) + args.nodeRecvTokenNumMemObj->template GetAs()[globalThdId] = 0; + + uint32_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); + if (globalThdId < config.worldSize) { + localBarrierPtr[globalThdId] = 0; + } + + int maxChunkNum = core::CeilDiv(config.maxNumInpTokenPerRank, warpSize); + for (int i = globalThdId; i < (config.maxNumInpTokenPerRank * nNodes); i += globalThdNum) { + args.interNodeChunkFlagMemObj->template GetAs()[i] = 0; + args.interNodeChunkFlagCombine[i] = 0; + // args.interNodeDispSendMap[i] = 0; + } + + for (int i = globalThdId; i < (config.maxNumInpTokenPerRank * nNodes * config.numExpertPerToken); + i += globalThdNum) { + args.interNodeDispDestTokIdMap[i] = 0; + } + // if ((globalWarpId < config.worldSize) && (laneId == 0)) shmem::ShmemQuietThread(globalWarpId); +} + +/* ---------------------------------------------------------------------------------------------- */ +/* Template Specialization */ +/* ---------------------------------------------------------------------------------------------- */ +template __global__ void EpDispatchInterNodeV1Kernel( + EpDispatchCombineArgs args); +template __global__ void EpDispatchInterNodeV1Kernel<__hip_fp8_e4m3_fnuz>( + EpDispatchCombineArgs<__hip_fp8_e4m3_fnuz> args); +template __global__ void EpDispatchInterNodeV1Kernel(EpDispatchCombineArgs args); + +template __global__ void EpCombineInterNodeV1Kernel( + EpDispatchCombineArgs args); +template __global__ void EpCombineInterNodeV1Kernel<__hip_fp8_e4m3_fnuz>( + EpDispatchCombineArgs<__hip_fp8_e4m3_fnuz> args); +template __global__ void EpCombineInterNodeV1Kernel(EpDispatchCombineArgs args); + +} // namespace moe +} // namespace mori diff --git a/include/mori/core/lock.hpp b/src/ops/dispatch_combine/internode_v1.hpp similarity index 77% rename from include/mori/core/lock.hpp rename to src/ops/dispatch_combine/internode_v1.hpp index a261cbad..5edc1256 100644 --- a/include/mori/core/lock.hpp +++ b/src/ops/dispatch_combine/internode_v1.hpp @@ -21,25 +21,18 @@ // SOFTWARE. #pragma once -namespace mori { -namespace core { - -class GpuLock { - public: - __device__ GpuLock(uint32_t* lockMem) : lock(lockMem) {} - __device__ ~GpuLock() = default; +#include "mori/core/core.hpp" +#include "mori/ops/dispatch_combine/dispatch_combine.hpp" +#include "mori/shmem/shmem.hpp" - __device__ void Lock() { - while (!atomicCAS(lock, 0, 1)) { - } - __threadfence_system(); - } +namespace mori { +namespace moe { - __device__ void Unlock() { atomicCAS(lock, 1, 0); } +template +__global__ void EpDispatchInterNodeV1Kernel(EpDispatchCombineArgs args); - private: - uint32_t* lock{nullptr}; -}; +template +__global__ void EpCombineInterNodeV1Kernel(EpDispatchCombineArgs args); -} // namespace core +} // namespace moe } // namespace mori diff --git a/src/ops/dispatch_combine/intranode.hpp b/src/ops/dispatch_combine/intranode.hpp index 924caac8..5443437d 100644 --- a/src/ops/dispatch_combine/intranode.hpp +++ b/src/ops/dispatch_combine/intranode.hpp @@ -81,7 +81,7 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs args) { int myPe = config.rank; int npes = config.worldSize; - size_t maxNumOutTokenPerRank = config.MaxNumTokensToSend(); + size_t MaxNumTokensToSend = config.MaxNumTokensToSend(); if (args.tokenIndices && args.inpTokenBuf) { // Phase1: send token @@ -103,7 +103,7 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs args) { if (__any(condition)) { // Indicate that this token is already sent to the destination PE by setting an overflow // token index - if (laneId == 0) args.dispDestTokIdMap[i] = config.worldSize * maxNumOutTokenPerRank; + if (laneId == 0) args.dispDestTokIdMap[i] = config.worldSize * MaxNumTokensToSend; continue; } @@ -111,7 +111,7 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs args) { // decide token id in dest pe destTokId = atomicAdd(args.dispTokOffsetMemObj->template GetAs(destPe), 1); atomicAdd(args.destPeTokenCounter + destPe, 1); - args.dispDestTokIdMap[i] = destPe * maxNumOutTokenPerRank + destTokId; + args.dispDestTokIdMap[i] = destPe * MaxNumTokensToSend + destTokId; // TODO: use a switch to control the writing of this buffer, should only turn on for testing args.dispTokIdToSrcTokIdMemObj->template GetAs(destPe)[destTokId] = @@ -206,7 +206,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { int myPe = config.rank; int npes = config.worldSize; - size_t maxNumOutTokenPerRank = config.MaxNumTokensToSend(); + size_t MaxNumTokensToSend = config.MaxNumTokensToSend(); // Copy input to shmem registered buffer so that other GPUs can access directly index_t totalRecvTokenNum = args.totalRecvTokenNum[0]; if (args.config.useExternalInpBuffer) { @@ -248,7 +248,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { // Prepare data pointers on different GPUs for (int j = laneId; j < config.numExpertPerToken; j += warpSize) { index_t destTokId = args.dispDestTokIdMap[tokenId * config.numExpertPerToken + j]; - index_t destPe = destTokId / maxNumOutTokenPerRank; + index_t destPe = destTokId / MaxNumTokensToSend; if (destPe < config.worldSize) { index_t destLocalTokId = destTokId - destPe * maxNumOutTokenPerRank; diff --git a/src/pybind/mori.cpp b/src/pybind/mori.cpp index eaadfe9c..3837522e 100644 --- a/src/pybind/mori.cpp +++ b/src/pybind/mori.cpp @@ -242,16 +242,18 @@ void RegisterMoriOps(py::module_& m) { pybind11::enum_(m, "EpDispatchCombineKernelType") .value("IntraNode", mori::moe::KernelType::IntraNode) .value("InterNode", mori::moe::KernelType::InterNode) + .value("InterNodeV1", mori::moe::KernelType::InterNodeV1) .export_values(); pybind11::class_(m, "EpDispatchCombineConfig") - .def(pybind11::init(), + .def(pybind11::init(), py::arg("rank") = 0, py::arg("world_size") = 0, py::arg("hidden_dim") = 0, py::arg("scale_dim") = 0, py::arg("scale_type_size") = 0, py::arg("max_token_type_size") = 0, py::arg("max_num_inp_token_per_rank") = 0, py::arg("num_experts_per_rank") = 0, py::arg("num_experts_per_token") = 0, py::arg("warp_num_per_block") = 0, py::arg("block_num") = 0, - py::arg("use_external_inp_buf") = true) + py::arg("use_external_inp_buf") = true, py::arg("gpu_per_node") = 8, + py::arg("rdma_block_num") = 0) .def_readwrite("rank", &mori::moe::EpDispatchCombineConfig::rank) .def_readwrite("world_size", &mori::moe::EpDispatchCombineConfig::worldSize) .def_readwrite("hidden_dim", &mori::moe::EpDispatchCombineConfig::hiddenDim) @@ -264,7 +266,9 @@ void RegisterMoriOps(py::module_& m) { .def_readwrite("num_experts_per_token", &mori::moe::EpDispatchCombineConfig::numExpertPerToken) .def_readwrite("warp_num_per_block", &mori::moe::EpDispatchCombineConfig::warpNumPerBlock) - .def_readwrite("block_num", &mori::moe::EpDispatchCombineConfig::blockNum); + .def_readwrite("block_num", &mori::moe::EpDispatchCombineConfig::blockNum) + .def_readwrite("gpu_per_node", &mori::moe::EpDispatchCombineConfig::gpuPerNode) + .def_readwrite("rdma_block_num", &mori::moe::EpDispatchCombineConfig::rdmaBlockNum); DeclareEpDispatchCombineHandle(m); } diff --git a/tests/python/ops/bench_dispatch_combine.py b/tests/python/ops/bench_dispatch_combine.py index 94ca1c40..ded034ce 100644 --- a/tests/python/ops/bench_dispatch_combine.py +++ b/tests/python/ops/bench_dispatch_combine.py @@ -92,7 +92,7 @@ def run_once(self, op, test_data, check_result): None, dispatch_indices, block_num=80, - warp_per_block=4, + warp_per_block=16, ) end_event.record() self.sync() @@ -181,9 +181,9 @@ def _bench_dispatch_combine( rank, world_size, port, - max_num_inp_token_per_rank=4096, - data_type=torch.bfloat16, - hidden_dim=4096, + max_num_inp_token_per_rank=128, + data_type=torch.float8_e4m3fnuz, + hidden_dim=7168, scale_dim=0, scale_type_size=0, num_experts_per_rank=16,