1
+ // SPDX-License-Identifier: MIT
2
+ // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved
3
+ #pragma once
4
+ #include " kernels/hybrid_ep_backend_configs.hpp"
5
+ #include " kernels/hybrid_ep_backend.cuh"
6
+ #include < ATen/cuda/CUDAContext.h>
7
+ #include < c10/util/Optional.h>
8
+ #include < torch/torch.h>
9
+ #include < pybind11/functional.h>
10
+ #include < pybind11/pybind11.h>
11
+ #include < pybind11/stl.h>
12
+ #include < vector>
13
+ #include < algorithm>
14
+
15
+ inline std::string type_to_string (TOKEN_DATA_TYPE token_data_type) {
16
+ switch (token_data_type) {
17
+ case TOKEN_DATA_TYPE::UINT16:
18
+ return " uint16_t" ;
19
+ case TOKEN_DATA_TYPE::UINT8:
20
+ return " uint8_t" ;
21
+ default :
22
+ return " unknown" ;
23
+ }
24
+ }
25
+
26
+ union MemHandleInner{
27
+ cudaIpcMemHandle_t cuda_ipc_mem_handle;
28
+ CUmemFabricHandle cu_mem_fabric_handle;
29
+ };
30
+
31
+ struct MemHandle {
32
+ MemHandleInner inner;
33
+ size_t size;
34
+ };
35
+
36
+ // Utility function to get token data type size
37
+ inline size_t get_token_data_type_size (TOKEN_DATA_TYPE data_type) {
38
+ switch (data_type) {
39
+ case TOKEN_DATA_TYPE::UINT8:
40
+ return sizeof (uint8_t );
41
+ case TOKEN_DATA_TYPE::UINT16:
42
+ return sizeof (uint16_t );
43
+ default :
44
+ throw std::runtime_error (" Invalid token data type:" + std::to_string (static_cast <int >(data_type)));
45
+ }
46
+ }
47
+
48
+ // Round-up allocation size to fabric granularity.
49
+ inline size_t get_size_align_to_granularity (size_t size_raw, size_t granularity){
50
+ size_t size = (size_raw + granularity - 1 ) & ~(granularity - 1 );
51
+ if (size == 0 ) size = granularity;
52
+ return size;
53
+ }
54
+
55
+ // Device memory allocator, allocate local device memory. Support both normal cudaMalloc and fabric allocator.
56
+ inline void device_mem_malloc (void ** ptr, size_t size_raw, bool enable_fabric){
57
+ if (enable_fabric){
58
+ CUdevice device;
59
+ CU_CHECK (cuCtxGetDevice (&device));
60
+
61
+ CUmemAllocationProp prop = {};
62
+ prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
63
+ prop.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
64
+ prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
65
+ prop.location .id = device;
66
+
67
+ size_t granularity = 0 ;
68
+ CU_CHECK (cuMemGetAllocationGranularity (&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
69
+
70
+ size_t size = get_size_align_to_granularity (size_raw, granularity);
71
+
72
+ CUmemGenericAllocationHandle handle;
73
+ CU_CHECK (cuMemCreate (&handle, size, &prop, 0 ));
74
+
75
+ CU_CHECK (cuMemAddressReserve ((CUdeviceptr*)ptr, size, granularity, 0 , 0 ));
76
+ CU_CHECK (cuMemMap ((CUdeviceptr)*ptr, size, 0 , handle, 0 ));
77
+ CUmemAccessDesc access_desc = {};
78
+ access_desc.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
79
+ access_desc.location .id = device;
80
+ access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
81
+ CU_CHECK (cuMemSetAccess ((CUdeviceptr)*ptr, size, &access_desc, 1 ));
82
+ }else {
83
+ CUDA_CHECK (cudaMalloc (ptr, size_raw));
84
+ }
85
+ }
86
+
87
+ // Get sharable memory handle of local device memory for remote ranks to access. Support both IPC handle and fabric handle.
88
+ inline void get_device_mem_handle (MemHandle* mem_handle, void * ptr, bool enable_fabric){
89
+ size_t size = 0 ;
90
+ CU_CHECK (cuMemGetAddressRange (NULL , &size, (CUdeviceptr)ptr));
91
+
92
+ mem_handle->size = size;
93
+
94
+ if (enable_fabric){
95
+ CUmemGenericAllocationHandle handle;
96
+ CU_CHECK (cuMemRetainAllocationHandle (&handle, ptr));
97
+ CU_CHECK (cuMemExportToShareableHandle (&mem_handle->inner .cu_mem_fabric_handle , handle, CU_MEM_HANDLE_TYPE_FABRIC, 0 ));
98
+ }else {
99
+ CUDA_CHECK (cudaIpcGetMemHandle (&mem_handle->inner .cuda_ipc_mem_handle , ptr));
100
+ }
101
+ }
102
+
103
+ // Open sharable memory handle from other remote ranks and map it for local device to access. Support both IPC handle and fabric handle.
104
+ inline void open_device_mem_handle (void ** ptr, MemHandle* mem_handle, bool enable_fabric){
105
+ if (enable_fabric){
106
+ CUdevice device;
107
+ CU_CHECK (cuCtxGetDevice (&device));
108
+ size_t size = mem_handle->size ;
109
+
110
+ CUmemGenericAllocationHandle handle;
111
+ CU_CHECK (cuMemImportFromShareableHandle (&handle, &mem_handle->inner .cu_mem_fabric_handle , CU_MEM_HANDLE_TYPE_FABRIC));
112
+
113
+ CU_CHECK (cuMemAddressReserve ((CUdeviceptr*)ptr, size, 0 , 0 , 0 ));
114
+ CU_CHECK (cuMemMap ((CUdeviceptr)*ptr, size, 0 , handle, 0 ));
115
+ CUmemAccessDesc access_desc = {};
116
+ access_desc.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
117
+ access_desc.location .id = device;
118
+ access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
119
+ CU_CHECK (cuMemSetAccess ((CUdeviceptr)*ptr, size, &access_desc, 1 ));
120
+ }else {
121
+ CUDA_CHECK (cudaIpcOpenMemHandle (ptr, mem_handle->inner .cuda_ipc_mem_handle , cudaIpcMemLazyEnablePeerAccess));
122
+ }
123
+ }
124
+
125
+ // Close and unmap sharable memory handle from other remote ranks. Support both IPC handle and fabric handle.
126
+ inline void close_device_mem_handle (void * ptr, bool enable_fabric){
127
+ if (enable_fabric){
128
+ size_t size = 0 ;
129
+ CU_CHECK (cuMemGetAddressRange (NULL , &size, (CUdeviceptr)ptr));
130
+
131
+ CU_CHECK (cuMemUnmap ((CUdeviceptr)ptr, size));
132
+ CU_CHECK (cuMemAddressFree ((CUdeviceptr)ptr, size));
133
+ }else {
134
+ CUDA_CHECK (cudaIpcCloseMemHandle (ptr));
135
+ }
136
+ }
137
+
138
+ // Free local device memory allocated by device_mem_malloc.
139
+ inline void device_mem_free (void * ptr, bool enable_fabric){
140
+ if (enable_fabric){
141
+ CUmemGenericAllocationHandle handle;
142
+ CU_CHECK (cuMemRetainAllocationHandle (&handle, ptr));
143
+
144
+ size_t size = 0 ;
145
+ CU_CHECK (cuMemGetAddressRange (NULL , &size, (CUdeviceptr)ptr));
146
+
147
+ CU_CHECK (cuMemUnmap ((CUdeviceptr)ptr, size));
148
+ CU_CHECK (cuMemAddressFree ((CUdeviceptr)ptr, size));
149
+ CU_CHECK (cuMemRelease (handle));
150
+ }else {
151
+ CUDA_CHECK (cudaFree (ptr));
152
+ }
153
+ }
154
+
155
+ class HybridEpBuffer {
156
+ public:
157
+ HybridEpBuffer (HybridEpConfigInstance config, int local_rank, int node_rank,
158
+ int num_of_ranks_per_node);
159
+ ~HybridEpBuffer ();
160
+
161
+ // Exchange IPC addresses using C++ distributed communication
162
+ void exchange_ipc_address (pybind11::object process_group);
163
+
164
+ void update_num_of_tokens_per_rank (int num_of_tokens_per_rank);
165
+
166
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
167
+ torch::Tensor>
168
+ metadata_preprocessing (torch::Tensor routing_map, int64_t node_rank,
169
+ int64_t local_rank);
170
+
171
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
172
+ dispatch (torch::Tensor hidden, c10::optional<torch::Tensor> probs,
173
+ c10::optional<torch::Tensor> scaling_factor,
174
+ torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map,
175
+ torch::Tensor attn_to_rdma_map, int64_t num_of_tokens_for_experts,
176
+ bool with_probs);
177
+
178
+ std::tuple<torch::Tensor, torch::Tensor>
179
+ combine (torch::Tensor hidden, c10::optional<torch::Tensor> probs,
180
+ torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map,
181
+ torch::Tensor attn_to_rdma_map, bool with_probs);
182
+
183
+ private:
184
+ void allocate_buffer ();
185
+ void allocate_buffer_for_preprocessing ();
186
+ void allocate_buffer_for_dispatch ();
187
+ void allocate_buffer_for_combine ();
188
+ void open_handles_from_other_ranks (std::vector<torch::Tensor> dispatch_handles,
189
+ std::vector<torch::Tensor> combine_handles);
190
+
191
+ HybridEpConfigInstance config;
192
+ int rank;
193
+ int group_size;
194
+ int local_rank;
195
+ int node_rank;
196
+ int num_of_ranks_per_node;
197
+
198
+ int64_t max_num_of_tokens_for_experts;
199
+
200
+ hybrid_ep::tmp_state_t *preprocessing_tmp;
201
+
202
+ struct DispatchBuffers {
203
+ TOKEN_DATA_TYPE data_type;
204
+
205
+ void *expert_output_token;
206
+
207
+ void **expert_output_token_all_ranks;
208
+
209
+ float *expert_output_prob;
210
+
211
+ float **expert_output_prob_all_ranks;
212
+
213
+ float *expert_output_scaling_factor;
214
+
215
+ float **expert_output_scaling_factor_all_ranks;
216
+
217
+ void *rdma_inter_node_group_token;
218
+
219
+ float *rdma_inter_node_group_prob;
220
+
221
+ float *rdma_inter_node_group_scaling_factor;
222
+
223
+ uint64_t *rdma_inter_node_group_flags;
224
+
225
+ uint32_t *intra_node_write_completion_flags;
226
+
227
+ uint64_t *expected_rdma_flag_value;
228
+
229
+ uint32_t *expected_intra_node_flag_value;
230
+
231
+ } dispatch_buffers;
232
+
233
+ torch::Tensor
234
+ dispatch_memory_handles;
235
+
236
+ struct CombineBuffers {
237
+
238
+ uint16_t *expert_input_token;
239
+
240
+ uint16_t **expert_input_token_all_ranks;
241
+
242
+ float *expert_input_prob;
243
+
244
+ float **expert_input_prob_all_ranks;
245
+
246
+ uint16_t *rdma_intra_node_red_token;
247
+
248
+ float *rdma_intra_node_red_prob;
249
+
250
+ uint16_t *rdma_inter_node_group_token;
251
+
252
+ float
253
+ *rdma_inter_node_group_prob;
254
+
255
+ uint64_t
256
+ *rdma_inter_node_group_flags;
257
+
258
+ uint32_t *intra_node_write_completion_flags;
259
+
260
+ uint64_t *expected_rdma_flag_value;
261
+
262
+ uint32_t *expected_intra_node_flag_value;
263
+
264
+
265
+ } combine_buffers;
266
+
267
+ torch::Tensor
268
+ combine_memory_handles;
269
+
270
+ };
0 commit comments