Skip to content

Commit 3f601f7

Browse files
authored
Improve Jit (#457)
1 parent 1242eb4 commit 3f601f7

File tree

11 files changed

+70
-39
lines changed

11 files changed

+70
-39
lines changed

csrc/hybrid_ep/executor/executor.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include "executor.cuh"
55

6-
Executor::Executor(int local_rank, int node_rank) : local_rank(local_rank), node_rank(node_rank), kernel_cache(local_rank) {}
6+
Executor::Executor(int local_rank, int node_rank, std::string base_path) : local_rank(local_rank), node_rank(node_rank), kernel_cache(local_rank, base_path) {}
77

88
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
99
Executor::metadata_preprocess_core(

csrc/hybrid_ep/executor/executor.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
#include <torch/torch.h>
88

99
#include "utils.cuh"
10-
#include "backend/hybrid_ep_backend.cuh"
10+
#include "hybrid_ep_backend.cuh"
1111
#include "jit/compiler.cuh"
1212
#include "extension/permute.cuh"
1313

1414
class Executor {
1515
public:
16-
Executor(int local_rank, int node_rank);
16+
Executor(int local_rank, int node_rank, std::string base_path);
1717

1818
struct DispatchArgs {
1919
// Input tensors

csrc/hybrid_ep/hybrid_ep.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved
33
#include "hybrid_ep.cuh"
44

5-
HybridEPBuffer::HybridEPBuffer(BufferConfig config, int local_rank, int node_rank, int group_size)
5+
HybridEPBuffer::HybridEPBuffer(BufferConfig config, int local_rank, int node_rank, int group_size, std::string base_path)
66
: buffer_config(config), local_rank(local_rank), node_rank(node_rank), group_size(group_size),
7-
executor(local_rank, node_rank) {
7+
executor(local_rank, node_rank, base_path) {
88
if(group_size <= buffer_config.num_of_ranks_per_node) {
99
// If used on only intra-node communication, the dispatch/combine can share same buffers.
1010
use_shared_buffer = true;
@@ -145,6 +145,7 @@ void HybridEPBuffer::allocate_buffer_for_dispatch() {
145145
// Allocate and initialize synchronization buffers
146146
if (local_rank == 0) {
147147
remote_allocator.allocate((void**)&dispatch_buffers.intra_node_write_completion_flags, sizeof(uint32_t));
148+
CUDA_CHECK(cudaMemset(dispatch_buffers.intra_node_write_completion_flags, 0, sizeof(uint32_t)));
148149
}
149150

150151
CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.expected_rdma_flag_value, sizeof(uint64_t)));
@@ -207,6 +208,7 @@ void HybridEPBuffer::allocate_buffer_for_combine() {
207208
// Allocate and initialize synchronization buffers
208209
if (local_rank == 0) {
209210
remote_allocator.allocate((void**)&combine_buffers.intra_node_write_completion_flags, sizeof(uint32_t));
211+
CUDA_CHECK(cudaMemset(combine_buffers.intra_node_write_completion_flags, 0, sizeof(uint32_t)));
210212
}
211213

212214
CUDA_CHECK(cudaMalloc((void**)&combine_buffers.expected_rdma_flag_value, sizeof(uint64_t)));

csrc/hybrid_ep/hybrid_ep.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved
33
#pragma once
44
#include "config.cuh"
5-
#include "backend/hybrid_ep_backend.cuh"
5+
#include "hybrid_ep_backend.cuh"
66
#include "allocator/allocator.cuh"
77
#include "utils.cuh"
88
#include "executor/executor.cuh"
@@ -11,10 +11,11 @@
1111
#include <torch/torch.h>
1212
#include <vector>
1313
#include <algorithm>
14+
#include <string>
1415

1516
class HybridEPBuffer {
1617
public:
17-
HybridEPBuffer(BufferConfig config, int local_rank, int node_rank, int group_size);
18+
HybridEPBuffer(BufferConfig config, int local_rank, int node_rank, int group_size, std::string base_path);
1819
~HybridEPBuffer();
1920
bool update_buffer(HybridEpConfigInstance config); // True means the buffer is reallocated.
2021

csrc/hybrid_ep/jit/compiler.cu

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ inline std::string get_env(std::string name) {
1111
return std::string(env);
1212
}
1313

14-
NVCCCompiler::NVCCCompiler() {
14+
NVCCCompiler::NVCCCompiler(std::string base_path): base_path(base_path) {
1515
nvcc_path = get_env("CUDA_HOME") + "/bin/nvcc";
1616

1717
// Init the flags to compiler
@@ -21,10 +21,7 @@ NVCCCompiler::NVCCCompiler() {
2121
" -Xcompiler -fPIC -shared";
2222

2323
// Add the include path of the hybrid-ep library
24-
std::string base_path = BASE_PATH;
25-
26-
// Add the include path of the hybrid-ep library
27-
include = " -I" + base_path + "/csrc/hybrid_ep" + " -I" + get_env("CUDA_HOME") + "/include";
24+
include = " -I" + base_path + "/backend" + " -I" + get_env("CUDA_HOME") + "/include";
2825

2926
// Add the library path of the hybrid-ep library
3027
library = "-L" + get_env("CUDA_HOME") + "/lib64 -lcudart";
@@ -40,8 +37,6 @@ NVCCCompiler::NVCCCompiler() {
4037

4138

4239
std::string NVCCCompiler::build(std::string code, std::string signature, int local_rank) {
43-
std::string base_path = BASE_PATH;
44-
4540
// Create the source directory
4641
std::string jit_dir = base_path + "/build/jit";
4742
std::filesystem::create_directories(jit_dir);
@@ -80,7 +75,7 @@ std::string NVCCCompiler::build(std::string code, std::string signature, int loc
8075
return output_path;
8176
}
8277

83-
std::any NVCCCompiler::get_instance(std::string library_path) {
78+
std::any NVCCCompiler::get_instance(std::string library_path, std::string kernel_key) {
8479
// Open the compiled library with RTLD_GLOBAL for symbol visibility
8580
void* handle = dlopen(library_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
8681
if (handle == nullptr) {
@@ -98,8 +93,15 @@ std::any NVCCCompiler::get_instance(std::string library_path) {
9893
library_path);
9994
}
10095

101-
// After using the library, clear the built library
102-
remove(library_path.c_str());
96+
// Unique the compiled lib from different rank
97+
std::string unique_library_path = base_path + "/build/jit/" + kernel_key + ".so";
98+
std::string unique_command = "mv " + library_path + " " + unique_library_path;
99+
if(library_path != unique_library_path) {
100+
auto ret = std::system(unique_command.c_str());
101+
if (ret != 0) {
102+
throw std::runtime_error("Failed to unique the library: " + unique_command);
103+
}
104+
}
103105

104106
// Run the get_function_ptr, then we get the compiled template
105107
std::any func_ptr = get_ptr();
@@ -109,7 +111,7 @@ std::any NVCCCompiler::get_instance(std::string library_path) {
109111

110112
std::string NVCCCompiler::get_metadata_preprocessing_code(HybridEpConfigInstance config) {
111113
return R"(
112-
#include "backend/hybrid_ep_backend.cuh"
114+
#include "hybrid_ep_backend.cuh"
113115
#include <any>
114116
115117
extern "C" {
@@ -130,7 +132,7 @@ std::string NVCCCompiler::get_dispatch_code(HybridEpConfigInstance config) {
130132
(config.token_data_type == TOKEN_DATA_TYPE::UINT8) ? "uint8_t" : "uint16_t";
131133

132134
return R"(
133-
#include "backend/hybrid_ep_backend.cuh"
135+
#include "hybrid_ep_backend.cuh"
134136
#include <any>
135137
136138
extern "C" {
@@ -150,7 +152,7 @@ std::string NVCCCompiler::get_dispatch_code(HybridEpConfigInstance config) {
150152

151153
std::string NVCCCompiler::get_combine_code(HybridEpConfigInstance config) {
152154
return R"(
153-
#include "backend/hybrid_ep_backend.cuh"
155+
#include "hybrid_ep_backend.cuh"
154156
#include <any>
155157
156158
extern "C" {
@@ -171,7 +173,18 @@ std::string NVCCCompiler::get_combine_code(HybridEpConfigInstance config) {
171173
)";
172174
}
173175

174-
176+
KernelCache::KernelCache(int local_rank, std::string base_path):
177+
local_rank(local_rank), base_path(base_path), nvcc_compiler(base_path) {
178+
// Load all cached kernels from the cache directory
179+
std::string cache_dir = base_path + "/build/jit";
180+
std::filesystem::create_directories(cache_dir);
181+
for (const auto& entry : std::filesystem::directory_iterator(cache_dir)) {
182+
if (entry.path().extension() == ".so") {
183+
std::string kernel_key = entry.path().stem().string();
184+
kernel_cache[kernel_key] = nvcc_compiler.get_instance(entry.path().string(), kernel_key);
185+
}
186+
}
187+
}
175188

176189
void KernelCache::run_proprecess_kernel(
177190
HybridEpConfigInstance config,
@@ -202,7 +215,7 @@ void KernelCache::run_proprecess_kernel(
202215
if (it == kernel_cache.end()) {
203216
auto preprocessing_code = nvcc_compiler.get_metadata_preprocessing_code(config);
204217
auto preprocessing_path = nvcc_compiler.build(preprocessing_code, preprocess_kernel_key, local_rank);
205-
kernel_cache[preprocess_kernel_key] = nvcc_compiler.get_instance(preprocessing_path);
218+
kernel_cache[preprocess_kernel_key] = nvcc_compiler.get_instance(preprocessing_path, preprocess_kernel_key);
206219
}
207220
auto preprocessing_instance = kernel_cache[preprocess_kernel_key];
208221

@@ -255,7 +268,7 @@ void KernelCache::run_dispatch_kernel(
255268
// JIT Compile the kernel
256269
auto dispatch_code = nvcc_compiler.get_dispatch_code(config);
257270
auto dispatch_path = nvcc_compiler.build(dispatch_code, dispatch_kernel_key, local_rank);
258-
kernel_cache[dispatch_kernel_key] = nvcc_compiler.get_instance(dispatch_path);
271+
kernel_cache[dispatch_kernel_key] = nvcc_compiler.get_instance(dispatch_path, dispatch_kernel_key);
259272
}
260273
auto dispatch_instance = kernel_cache[dispatch_kernel_key];
261274

@@ -295,7 +308,7 @@ void KernelCache::run_combine_kernel(
295308
// JIT Compile the kernel
296309
auto combine_code = nvcc_compiler.get_combine_code(config);
297310
auto combine_path = nvcc_compiler.build(combine_code, combine_kernel_key, local_rank);
298-
kernel_cache[combine_kernel_key] = nvcc_compiler.get_instance(combine_path);
311+
kernel_cache[combine_kernel_key] = nvcc_compiler.get_instance(combine_path, combine_kernel_key);
299312
}
300313
auto combine_instance = kernel_cache[combine_kernel_key];
301314

csrc/hybrid_ep/jit/compiler.cuh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
#include <iostream>
1818

1919
#include "config.cuh"
20-
#include "backend/hybrid_ep_backend.cuh"
20+
#include "hybrid_ep_backend.cuh"
2121
#include "utils.cuh"
2222

2323
class NVCCCompiler{
2424
public:
2525
// Init the flags required by nvcc compiler
26-
NVCCCompiler();
26+
NVCCCompiler(std::string base_path);
2727

2828
// Generate the code for jit compile
2929
std::string get_metadata_preprocessing_code(HybridEpConfigInstance config);
@@ -45,12 +45,14 @@ public:
4545
* @brief Get the compiled function pointer from the compiled .so file
4646
*
4747
* @param library_path The path of the compiled .so file
48+
* @param kernel_key The key of the kernel, used to cache the compiled function pointer
4849
* @return std::any The function pointer
4950
*/
50-
std::any get_instance(std::string library_path);
51+
std::any get_instance(std::string library_path, std::string kernel_key);
5152

5253

5354
private:
55+
std::string base_path; // The path of the installed package
5456
std::string flags; // The flags required by nvcc compiler, which contains the
5557
// base flags(-O3, -arch...), include files, library files
5658
std::string nvcc_path; // The path of the nvcc compiler
@@ -60,8 +62,7 @@ private:
6062

6163
class KernelCache{
6264
public:
63-
KernelCache() = default;
64-
KernelCache(int local_rank) : local_rank(local_rank) {}
65+
KernelCache(int local_rank, std::string base_path);
6566

6667
void run_proprecess_kernel(
6768
HybridEpConfigInstance config,
@@ -94,5 +95,6 @@ public:
9495
private:
9596
NVCCCompiler nvcc_compiler;
9697
std::unordered_map<std::string, std::any> kernel_cache;
98+
std::string base_path; // The path of the installed package
9799
int local_rank; // Used to generate the unique signature for each rank
98100
};

csrc/hybrid_ep/pybind_hybrid_ep.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
101101
});
102102

103103
pybind11::class_<HybridEPBuffer>(m, "HybridEPBuffer")
104-
.def(py::init<BufferConfig, int, int, int>())
104+
.def(py::init<BufferConfig, int, int, int, std::string>())
105105
.def("update_buffer", &HybridEPBuffer::update_buffer, py::arg("config"))
106106
.def("exchange_ipc_address", &HybridEPBuffer::exchange_ipc_address)
107107
.def("metadata_preprocessing", &HybridEPBuffer::metadata_preprocessing,
File renamed without changes.

deep_ep/hybrid_ep_buffer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,19 @@ def __init__(
111111

112112
# Create C++ buffer - this will allocate all buffers during construction
113113
self.runtime = hybrid_ep_cpp.HybridEPBuffer(
114-
self.config, self.local_rank, self.node_rank, self.group_size
114+
self.config, self.local_rank, self.node_rank, self.group_size, os.path.dirname(os.path.abspath(__file__))
115115
)
116116
# Exchange IPC addresses using C++ distributed communication
117117
self.runtime.exchange_ipc_address(self.group)
118118

119+
def empty_jit_cache(self):
120+
'''
121+
Clean the cached kernel files.
122+
'''
123+
jit_cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "build", "jit")
124+
if os.path.exists(jit_cache_path):
125+
shutil.rmtree(jit_cache_path)
126+
119127
def update_template_config(
120128
self,
121129
hidden_dim: int = None,

0 commit comments

Comments
 (0)