Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 208 additions & 61 deletions examples/ops/dispatch_combine/test_dispatch_combine_internode.py

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions include/mori/application/bootstrap/torch_bootstrap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
// SOFTWARE.
#pragma once

#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>

#include "mori/application/bootstrap/base_bootstrap.hpp"

namespace mori {
Expand All @@ -41,7 +39,7 @@ class TorchBootstrapNetwork : public BootstrapNetwork {
void Barrier();

private:
c10::intrusive_ptr<c10d::ProcessGroup> group;
std::string groupName;
};

} // namespace application
Expand Down
1 change: 0 additions & 1 deletion include/mori/core/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
// SOFTWARE.
#pragma once

#include "mori/core/lock.hpp"
#include "mori/core/transport/p2p/p2p.hpp"
#include "mori/core/transport/rdma/rdma.hpp"
#include "mori/core/utils.hpp"
5 changes: 3 additions & 2 deletions include/mori/core/transport/p2p/device_primitives.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ inline __device__ void ThreadCopy(T* dst, T* src, size_t nelems) {
}

template <typename T, int Unroll>
inline __device__ void WarpCopyImpl(T* dst, const T* src, size_t& offset, size_t nelems) {
inline __device__ void WarpCopyImpl(T* __restrict__ dst, const T* __restrict__ src, size_t& offset,
size_t nelems) {
constexpr int VecBytes = 16;
constexpr int vecSize = VecBytes / sizeof(T);
int laneId = threadIdx.x & (warpSize - 1);
Expand All @@ -230,7 +231,7 @@ inline __device__ void WarpCopyImpl(T* dst, const T* src, size_t& offset, size_t
}

template <typename T, int Unroll = 1>
inline __device__ void WarpCopy(T* dst, const T* src, size_t nelems) {
inline __device__ void WarpCopy(T* __restrict__ dst, const T* __restrict__ src, size_t nelems) {
int laneId = threadIdx.x & (warpSize - 1);

size_t offset = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ inline __device__ uint64_t Mlx5PostWriteInline(WorkQueueHandle& wq, uint32_t cur
if (bytes == 4) {
AtomicStoreRelaxed(reinterpret_cast<uint32_t*>(wqeDataPtr),
reinterpret_cast<uint32_t*>(val)[0]);
} else if (bytes == 8) {
AtomicStoreRelaxed(reinterpret_cast<uint64_t*>(wqeDataPtr),
reinterpret_cast<uint64_t*>(val)[0]);
} else {
for (int i = 0; i < bytes; i++) {
AtomicStoreRelaxed(reinterpret_cast<uint8_t*>(wqeDataPtr) + i,
Expand Down
68 changes: 31 additions & 37 deletions include/mori/core/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,28 @@ inline __device__ int FlatBlockWarpId() { return FlatBlockThreadId() / DeviceWar

inline __device__ int WarpLaneId() { return FlatBlockThreadId() & (DeviceWarpSize() - 1); }

inline __device__ bool IsThreadZeroInBlock() {
return (FlatBlockThreadId() % DeviceWarpSize()) == 0;
}

inline __device__ uint64_t GetActiveLaneMask() {
return __ballot(true);
}

inline __device__ unsigned int GetActiveLaneCount(uint64_t activeLaneMask) {
return __popcll(activeLaneMask);
}

inline __device__ unsigned int GetActiveLaneCount() {
return GetActiveLaneCount(GetActiveLaneMask());
}

inline __device__ unsigned int GetActiveLaneNum(uint64_t activeLaneMask) {
return __popcll(activeLaneMask & __lanemask_lt());
}

inline __device__ unsigned int GetActiveLaneNum() {
return GetActiveLaneNum(GetActiveLaneMask());
inline __device__ int WarpLaneId1D() { return threadIdx.x & (warpSize - 1); }

inline __device__ bool IsThreadZeroInBlock() {
return (FlatBlockThreadId() % DeviceWarpSize()) == 0;
}

inline __device__ uint64_t GetActiveLaneMask() { return __ballot(true); }

inline __device__ unsigned int GetActiveLaneCount(uint64_t activeLaneMask) {
return __popcll(activeLaneMask);
}

inline __device__ unsigned int GetActiveLaneCount() {
return GetActiveLaneCount(GetActiveLaneMask());
}

inline __device__ unsigned int GetActiveLaneNum(uint64_t activeLaneMask) {
return __popcll(activeLaneMask & __lanemask_lt());
}

inline __device__ unsigned int GetActiveLaneNum() { return GetActiveLaneNum(GetActiveLaneMask()); }

inline __device__ int GetFirstActiveLaneID(uint64_t activeLaneMask) {
return activeLaneMask ? __ffsll((unsigned long long int)activeLaneMask) - 1 : -1;
}
Expand All @@ -92,21 +90,17 @@ inline __device__ int GetLastActiveLaneID(uint64_t activeLaneMask) {

inline __device__ int GetLastActiveLaneID() { return GetLastActiveLaneID(GetActiveLaneMask()); }

inline __device__ bool IsFirstActiveLane(uint64_t activeLaneMask) {
return GetActiveLaneNum(activeLaneMask) == 0;
}

inline __device__ bool IsFirstActiveLane() {
return IsFirstActiveLane(GetActiveLaneMask());
}

inline __device__ bool IsLastActiveLane(uint64_t activeLaneMask) {
return GetActiveLaneNum(activeLaneMask) == GetActiveLaneCount(activeLaneMask) - 1;
}

inline __device__ bool IsLastActiveLane() {
return IsLastActiveLane(GetActiveLaneMask());
}
inline __device__ bool IsFirstActiveLane(uint64_t activeLaneMask) {
return GetActiveLaneNum(activeLaneMask) == 0;
}

inline __device__ bool IsFirstActiveLane() { return IsFirstActiveLane(GetActiveLaneMask()); }

inline __device__ bool IsLastActiveLane(uint64_t activeLaneMask) {
return GetActiveLaneNum(activeLaneMask) == GetActiveLaneCount(activeLaneMask) - 1;
}

inline __device__ bool IsLastActiveLane() { return IsLastActiveLane(GetActiveLaneMask()); }

/* ---------------------------------------------------------------------------------------------- */
/* Atomic Operations */
Expand Down
35 changes: 35 additions & 0 deletions include/mori/ops/dispatch_combine/dispatch_combine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace moe {
enum KernelType {
IntraNode = 0,
InterNode = 1,
InterNodeV1 = 2,
};

inline const char* HipDataTypeToString(hipDataType dtype) {
Expand Down Expand Up @@ -84,6 +85,8 @@ struct EpDispatchCombineConfig {
// If true, use external buffer which incurs extra copy overhead; otherwise, the kernel assumes
// the provided buffer is shmemInpTokMemObj
bool useExternalInpBuffer{true};
int gpuPerNode{8};
int rdmaBlockNum{1};

inline __host__ __device__ int MaxNumTokensToSendPerRank() const { return maxNumInpTokenPerRank; }

Expand Down Expand Up @@ -217,6 +220,22 @@ class EpDispatchCombineHandle {
index_t* totalRecvTokenNum{nullptr};
mori::application::SymmMemObjPtr crossDeviceBarrierMemObj;
uint32_t* crossDeviceBarrierFlag{nullptr};

// Inter-node v1 kernel parameters
// Signal the completion of inter-node token transfer
mori::application::SymmMemObjPtr interNodeChunkFlagMemObj;
// Signal the number of tokens transfered from other nodes
mori::application::SymmMemObjPtr nodeRecvTokenNumMemObj;
// Count the number of tokens sent to other nodes
index_t* destNodeTokenCounter{nullptr};
// Counter that is used to sort the ordering of inter-node token chunk transfer
index_t* blockFlagCounter{nullptr};
//
uint32_t* interNodeBlocksBarrier{nullptr};

index_t* interNodeDispDestTokIdMap{nullptr};
index_t* interNodeChunkFlagCombine{nullptr};
index_t* interNodeDispSendMap{nullptr};
};

template <typename T>
Expand Down Expand Up @@ -256,6 +275,14 @@ struct EpDispatchCombineArgs {
index_t* totalRecvTokenNum{nullptr};
mori::application::SymmMemObjPtr crossDeviceBarrierMemObj;
uint32_t* crossDeviceBarrierFlag{nullptr};
mori::application::SymmMemObjPtr interNodeChunkFlagMemObj;
index_t* destNodeTokenCounter{nullptr};
mori::application::SymmMemObjPtr nodeRecvTokenNumMemObj;
index_t* blockFlagCounter{nullptr};
uint32_t* interNodeBlocksBarrier{nullptr};
index_t* interNodeDispDestTokIdMap{nullptr};
index_t* interNodeChunkFlagCombine{nullptr};
index_t* interNodeDispSendMap{nullptr};
};

using EpDispatchCombineArgsVariant =
Expand Down Expand Up @@ -299,6 +326,14 @@ EpDispatchCombineArgs<T> GetEpDispatchCombineArgs(const EpDispatchCombineHandle&
args.totalRecvTokenNum = handle.totalRecvTokenNum;
args.crossDeviceBarrierMemObj = handle.crossDeviceBarrierMemObj;
args.crossDeviceBarrierFlag = handle.crossDeviceBarrierFlag;
args.interNodeChunkFlagMemObj = handle.interNodeChunkFlagMemObj;
args.destNodeTokenCounter = handle.destNodeTokenCounter;
args.nodeRecvTokenNumMemObj = handle.nodeRecvTokenNumMemObj;
args.blockFlagCounter = handle.blockFlagCounter;
args.interNodeBlocksBarrier = handle.interNodeBlocksBarrier;
args.interNodeDispDestTokIdMap = handle.interNodeDispDestTokIdMap;
args.interNodeChunkFlagCombine = handle.interNodeChunkFlagCombine;
args.interNodeDispSendMap = handle.interNodeDispSendMap;
return args;
}

Expand Down
133 changes: 133 additions & 0 deletions include/mori/shmem/shmem_ibgda_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,139 @@ inline __device__ void ShmemPutMemNbiThreadKernelImpl(const application::SymmMem
// __threadfence_system();
}

template <core::ProviderType PrvdType>
inline __device__ void ShmemPutMemNbiThreadKernelImpl(
const application::SymmMemObjPtr dest, size_t destOffset,
const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes,
const application::SymmMemObjPtr signal, size_t signalOffset, void* signalVal,
size_t signalBytes, const application::SymmMemObjPtr atomic, int pe) {
if (bytes == 0) return;
uintptr_t laddr = source.addr + sourceOffset;
uintptr_t raddr = dest->peerPtrs[pe] + destOffset;
uintptr_t rkey = dest->peerRkeys[pe];

GpuStates* globalGpuStates = GetGlobalGpuStatesPtr();
application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints;
core::WorkQueueHandle* wq = &ep[pe].wqHandle;
core::CompletionQueueHandle* cq = &ep[pe].cqHandle;

uint64_t activemask = core::GetActiveLaneMask();
uint8_t num_active_lanes = core::GetActiveLaneCount(activemask);
uint8_t my_logical_lane_id = core::GetActiveLaneNum(activemask);
bool is_leader{my_logical_lane_id == num_active_lanes - 1};
const uint64_t leader_phys_lane_id = core::GetLastActiveLaneID(activemask);
uint8_t num_wqes{num_active_lanes};
num_wqes += 1;
uint32_t warp_sq_counter{0};
uint32_t warp_msntbl_counter{0}, warp_psn_counter{0};
uint32_t my_sq_counter{0}, my_msntbl_counter{0}, my_psn_counter{0};

if (is_leader) {
if constexpr (PrvdType == core::ProviderType::MLX5) {
warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_wqes, __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_AGENT);
} else if constexpr (PrvdType == core::ProviderType::BNXT) {
uint32_t psnCnt = (bytes + wq->mtuSize - 1) / wq->mtuSize;
atomic_add_packed_msn_and_psn(&wq->msnPack, num_wqes, psnCnt * num_wqes, &warp_msntbl_counter,
&warp_psn_counter);
// TODO: if warp_msntbl_counter overflow 32bit, sq_slot's calculation will be wrong
warp_sq_counter = warp_msntbl_counter * BNXT_RE_NUM_SLOT_PER_WQE;
__hip_atomic_fetch_max(&wq->postIdx,
(warp_msntbl_counter + num_wqes) * BNXT_RE_NUM_SLOT_PER_WQE,
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
}
warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id);
if constexpr (PrvdType == core::ProviderType::MLX5) {
my_sq_counter = warp_sq_counter + my_logical_lane_id;
} else if constexpr (PrvdType == core::ProviderType::BNXT) {
my_sq_counter = warp_sq_counter + my_logical_lane_id * BNXT_RE_NUM_SLOT_PER_WQE;
warp_msntbl_counter = __shfl(warp_msntbl_counter, leader_phys_lane_id);
warp_psn_counter = __shfl(warp_psn_counter, leader_phys_lane_id);
my_msntbl_counter = warp_msntbl_counter + my_logical_lane_id;
my_psn_counter = warp_psn_counter + my_logical_lane_id;
} else {
assert(false);
}

while (true) {
uint64_t db_touched =
__hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
uint64_t db_done = __hip_atomic_load(&wq->doneIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
uint64_t num_active_sq_entries = db_touched - db_done;
uint64_t num_free_entries = wq->sqWqeNum - num_active_sq_entries;
uint64_t num_entries_until_warp_last_entry;
if constexpr (PrvdType == core::ProviderType::MLX5) {
num_entries_until_warp_last_entry = warp_sq_counter + num_wqes - db_touched;
} else if constexpr (PrvdType == core::ProviderType::BNXT) {
num_entries_until_warp_last_entry =
warp_sq_counter + num_active_lanes * BNXT_RE_NUM_SLOT_PER_WQE - db_touched;
} else {
assert(false);
}
if (num_free_entries > num_entries_until_warp_last_entry) {
break;
}
ShmemQuietThreadKernelImpl<PrvdType>(pe);
}
uint64_t dbr_val;
if constexpr (PrvdType == core::ProviderType::MLX5) {
wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter;
dbr_val = core::PostWrite<PrvdType>(*wq, my_sq_counter, my_sq_counter, my_sq_counter, false,
ep[pe].handle.qpn, laddr, source.lkey, raddr, rkey, bytes);
if (is_leader) {
// int rank = GetGlobalGpuStatesPtr()->rank;
// uint64_t atomic_sq_counter = my_sq_counter + 1;
// uintptr_t atomic_raddr = atomic->peerPtrs[pe];
// uintptr_t atomic_rkey = atomic->peerRkeys[pe];
// wq->outstandingWqe[atomic_sq_counter % OUTSTANDING_TABLE_SIZE] = atomic_sq_counter;
// uint64_t val = 0;
// dbr_val = core::PostAtomic<PrvdType>(*wq, atomic_sq_counter, atomic_sq_counter,
// atomic_sq_counter, false, ep[pe].handle.qpn,
// atomic->GetRdmaMemoryRegion(rank).addr,
// atomic->GetRdmaMemoryRegion(rank).lkey, atomic_raddr,
// atomic_rkey, &val, &val, 8, core::AMO_ADD);
// uint64_t signal_sq_counter = warp_sq_counter + num_wqes - 1;
uint64_t signal_sq_counter = my_sq_counter + 1;
uintptr_t signal_raddr = signal->peerPtrs[pe] + signalOffset;
uintptr_t signal_rkey = signal->peerRkeys[pe];
wq->outstandingWqe[signal_sq_counter % OUTSTANDING_TABLE_SIZE] = signal_sq_counter;
dbr_val = core::PostWriteInline<PrvdType>(*wq, signal_sq_counter, signal_sq_counter,
signal_sq_counter, is_leader, ep[pe].handle.qpn,
signalVal, signal_raddr, signal_rkey, signalBytes);
}
} else if constexpr (PrvdType == core::ProviderType::BNXT) {
wq->outstandingWqe[my_sq_counter % wq->sqWqeNum] = my_sq_counter;
dbr_val =
core::PostWrite<PrvdType>(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, is_leader,
ep[pe].handle.qpn, laddr, source.lkey, raddr, rkey, bytes);
} else {
assert(false);
}
__threadfence_system();
if (is_leader) {
uint64_t db_touched{0};
do {
db_touched = __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
} while (db_touched != warp_sq_counter);

core::UpdateSendDbrRecord<PrvdType>(wq->dbrRecAddr, warp_sq_counter + num_wqes);
// __threadfence_system();
core::RingDoorbell<PrvdType>(wq->dbrAddr, dbr_val);
__threadfence_system();

__hip_atomic_fetch_add(&cq->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
if constexpr (PrvdType == core::ProviderType::MLX5) {
__hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_AGENT);
} else if constexpr (PrvdType == core::ProviderType::BNXT) {
__hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes * BNXT_RE_NUM_SLOT_PER_WQE,
__ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
}
}
// __threadfence_system();
}

template <>
inline __device__ void ShmemPutMemNbiThreadKernel<application::TransportType::RDMA>(
const application::SymmMemObjPtr dest, size_t destOffset,
Expand Down
9 changes: 8 additions & 1 deletion python/mori/ops/dispatch_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class EpDispatchCombineConfig:
block_num: int = 80
use_external_inp_buf: bool = True
kernel_type: EpDispatchCombineKernelType = EpDispatchCombineKernelType.IntraNode
gpu_per_node: int = 8
rdma_block_num: int = 0


def _cpp_dispatch_combine_factory(entity_name):
Expand All @@ -72,6 +74,8 @@ def __init__(self, config):
warp_num_per_block=config.warp_num_per_block,
block_num=config.block_num,
use_external_inp_buf=config.use_external_inp_buf,
gpu_per_node=config.gpu_per_node,
rdma_block_num=config.rdma_block_num,
)
)

Expand Down Expand Up @@ -176,7 +180,10 @@ def _allgather_with_token_num_padding(self, input, max_token_num):
def get_dispatch_src_token_pos(self):
torch.cuda.synchronize()

if self.config.kernel_type.value == EpDispatchCombineKernelType.IntraNode.value:
if self.config.kernel_type.value in (
EpDispatchCombineKernelType.IntraNode.value,
EpDispatchCombineKernelType.InterNodeV1.value,
):
return self._get_dispatch_src_token_pos_func(self._handle)

dispatch_sender_token_id_map = self._get_dispatch_sender_token_idx_map_func(
Expand Down
Loading