Skip to content

Commit 235c78d

Browse files
committed
Add HybridEP
1 parent abba6ad commit 235c78d

File tree

9 files changed

+4460
-1
lines changed

9 files changed

+4460
-1
lines changed

csrc/hybrid_ep.cu

Lines changed: 845 additions & 0 deletions
Large diffs are not rendered by default.

csrc/hybrid_ep.cuh

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
// SPDX-License-Identifier: MIT
2+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved
3+
#pragma once
4+
#include "kernels/hybrid_ep_backend_configs.hpp"
5+
#include "kernels/hybrid_ep_backend.cuh"
6+
#include <ATen/cuda/CUDAContext.h>
7+
#include <c10/util/Optional.h>
8+
#include <torch/torch.h>
9+
#include <pybind11/functional.h>
10+
#include <pybind11/pybind11.h>
11+
#include <pybind11/stl.h>
12+
#include <vector>
13+
#include <algorithm>
14+
15+
inline std::string type_to_string(TOKEN_DATA_TYPE token_data_type) {
16+
switch (token_data_type) {
17+
case TOKEN_DATA_TYPE::UINT16:
18+
return "uint16_t";
19+
case TOKEN_DATA_TYPE::UINT8:
20+
return "uint8_t";
21+
default:
22+
return "unknown";
23+
}
24+
}
25+
26+
union MemHandleInner{
27+
cudaIpcMemHandle_t cuda_ipc_mem_handle;
28+
CUmemFabricHandle cu_mem_fabric_handle;
29+
};
30+
31+
struct MemHandle{
32+
MemHandleInner inner;
33+
size_t size;
34+
};
35+
36+
// Utility function to get token data type size
37+
inline size_t get_token_data_type_size(TOKEN_DATA_TYPE data_type) {
38+
switch (data_type) {
39+
case TOKEN_DATA_TYPE::UINT8:
40+
return sizeof(uint8_t);
41+
case TOKEN_DATA_TYPE::UINT16:
42+
return sizeof(uint16_t);
43+
default:
44+
throw std::runtime_error("Invalid token data type:" + std::to_string(static_cast<int>(data_type)));
45+
}
46+
}
47+
48+
// Round-up allocation size to fabric granularity.
49+
inline size_t get_size_align_to_granularity(size_t size_raw, size_t granularity){
50+
size_t size = (size_raw + granularity - 1) & ~(granularity - 1);
51+
if(size == 0) size = granularity;
52+
return size;
53+
}
54+
55+
// Device memory allocator, allocate local device memory. Support both normal cudaMalloc and fabric allocator.
56+
inline void device_mem_malloc(void** ptr, size_t size_raw, bool enable_fabric){
57+
if(enable_fabric){
58+
CUdevice device;
59+
CU_CHECK(cuCtxGetDevice(&device));
60+
61+
CUmemAllocationProp prop = {};
62+
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
63+
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
64+
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
65+
prop.location.id = device;
66+
67+
size_t granularity = 0;
68+
CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
69+
70+
size_t size = get_size_align_to_granularity(size_raw, granularity);
71+
72+
CUmemGenericAllocationHandle handle;
73+
CU_CHECK(cuMemCreate(&handle, size, &prop, 0));
74+
75+
CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0));
76+
CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
77+
CUmemAccessDesc access_desc = {};
78+
access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
79+
access_desc.location.id = device;
80+
access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
81+
CU_CHECK(cuMemSetAccess((CUdeviceptr)*ptr, size, &access_desc, 1));
82+
}else{
83+
CUDA_CHECK(cudaMalloc(ptr, size_raw));
84+
}
85+
}
86+
87+
// Get sharable memory handle of local device memory for remote ranks to access. Support both IPC handle and fabric handle.
88+
inline void get_device_mem_handle(MemHandle* mem_handle, void* ptr, bool enable_fabric){
89+
size_t size = 0;
90+
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));
91+
92+
mem_handle->size = size;
93+
94+
if(enable_fabric){
95+
CUmemGenericAllocationHandle handle;
96+
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));
97+
CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
98+
}else{
99+
CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr));
100+
}
101+
}
102+
103+
// Open sharable memory handle from other remote ranks and map it for local device to access. Support both IPC handle and fabric handle.
104+
inline void open_device_mem_handle(void** ptr, MemHandle* mem_handle, bool enable_fabric){
105+
if(enable_fabric){
106+
CUdevice device;
107+
CU_CHECK(cuCtxGetDevice(&device));
108+
size_t size = mem_handle->size;
109+
110+
CUmemGenericAllocationHandle handle;
111+
CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC));
112+
113+
CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0));
114+
CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
115+
CUmemAccessDesc access_desc = {};
116+
access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
117+
access_desc.location.id = device;
118+
access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
119+
CU_CHECK(cuMemSetAccess((CUdeviceptr)*ptr, size, &access_desc, 1));
120+
}else{
121+
CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess));
122+
}
123+
}
124+
125+
// Close and unmap sharable memory handle from other remote ranks. Support both IPC handle and fabric handle.
126+
inline void close_device_mem_handle(void* ptr, bool enable_fabric){
127+
if(enable_fabric){
128+
size_t size = 0;
129+
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));
130+
131+
CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size));
132+
CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size));
133+
}else{
134+
CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
135+
}
136+
}
137+
138+
// Free local device memory allocated by device_mem_malloc.
139+
inline void device_mem_free(void* ptr, bool enable_fabric){
140+
if(enable_fabric){
141+
CUmemGenericAllocationHandle handle;
142+
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));
143+
144+
size_t size = 0;
145+
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));
146+
147+
CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size));
148+
CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size));
149+
CU_CHECK(cuMemRelease(handle));
150+
}else{
151+
CUDA_CHECK(cudaFree(ptr));
152+
}
153+
}
154+
155+
class HybridEpBuffer {
156+
public:
157+
HybridEpBuffer(HybridEpConfigInstance config, int local_rank, int node_rank,
158+
int num_of_ranks_per_node);
159+
~HybridEpBuffer();
160+
161+
// Exchange IPC addresses using C++ distributed communication
162+
void exchange_ipc_address(pybind11::object process_group);
163+
164+
void update_num_of_tokens_per_rank(int num_of_tokens_per_rank);
165+
166+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
167+
torch::Tensor>
168+
metadata_preprocessing(torch::Tensor routing_map, int64_t node_rank,
169+
int64_t local_rank);
170+
171+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
172+
dispatch(torch::Tensor hidden, c10::optional<torch::Tensor> probs,
173+
c10::optional<torch::Tensor> scaling_factor,
174+
torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map,
175+
torch::Tensor attn_to_rdma_map, int64_t num_of_tokens_for_experts,
176+
bool with_probs);
177+
178+
std::tuple<torch::Tensor, torch::Tensor>
179+
combine(torch::Tensor hidden, c10::optional<torch::Tensor> probs,
180+
torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map,
181+
torch::Tensor attn_to_rdma_map, bool with_probs);
182+
183+
private:
184+
void allocate_buffer();
185+
void allocate_buffer_for_preprocessing();
186+
void allocate_buffer_for_dispatch();
187+
void allocate_buffer_for_combine();
188+
void open_handles_from_other_ranks(std::vector<torch::Tensor> dispatch_handles,
189+
std::vector<torch::Tensor> combine_handles);
190+
191+
HybridEpConfigInstance config;
192+
int rank;
193+
int group_size;
194+
int local_rank;
195+
int node_rank;
196+
int num_of_ranks_per_node;
197+
198+
int64_t max_num_of_tokens_for_experts;
199+
200+
hybrid_ep::tmp_state_t *preprocessing_tmp;
201+
202+
struct DispatchBuffers {
203+
TOKEN_DATA_TYPE data_type;
204+
205+
void *expert_output_token;
206+
207+
void **expert_output_token_all_ranks;
208+
209+
float *expert_output_prob;
210+
211+
float **expert_output_prob_all_ranks;
212+
213+
float *expert_output_scaling_factor;
214+
215+
float **expert_output_scaling_factor_all_ranks;
216+
217+
void *rdma_inter_node_group_token;
218+
219+
float *rdma_inter_node_group_prob;
220+
221+
float *rdma_inter_node_group_scaling_factor;
222+
223+
uint64_t *rdma_inter_node_group_flags;
224+
225+
uint32_t *intra_node_write_completion_flags;
226+
227+
uint64_t *expected_rdma_flag_value;
228+
229+
uint32_t *expected_intra_node_flag_value;
230+
231+
} dispatch_buffers;
232+
233+
torch::Tensor
234+
dispatch_memory_handles;
235+
236+
struct CombineBuffers {
237+
238+
uint16_t *expert_input_token;
239+
240+
uint16_t **expert_input_token_all_ranks;
241+
242+
float *expert_input_prob;
243+
244+
float **expert_input_prob_all_ranks;
245+
246+
uint16_t *rdma_intra_node_red_token;
247+
248+
float *rdma_intra_node_red_prob;
249+
250+
uint16_t *rdma_inter_node_group_token;
251+
252+
float
253+
*rdma_inter_node_group_prob;
254+
255+
uint64_t
256+
*rdma_inter_node_group_flags;
257+
258+
uint32_t *intra_node_write_completion_flags;
259+
260+
uint64_t *expected_rdma_flag_value;
261+
262+
uint32_t *expected_intra_node_flag_value;
263+
264+
265+
} combine_buffers;
266+
267+
torch::Tensor
268+
combine_memory_handles;
269+
270+
};

0 commit comments

Comments
 (0)