@@ -11,7 +11,7 @@ inline std::string get_env(std::string name) {
1111 return std::string (env);
1212}
1313
14- NVCCCompiler::NVCCCompiler () {
14+ NVCCCompiler::NVCCCompiler (std::string base_path): base_path(base_path ) {
1515 nvcc_path = get_env (" CUDA_HOME" ) + " /bin/nvcc" ;
1616
1717 // Init the flags to compiler
@@ -21,10 +21,7 @@ NVCCCompiler::NVCCCompiler() {
2121 " -Xcompiler -fPIC -shared" ;
2222
2323 // Add the include path of the hybrid-ep library
24- std::string base_path = BASE_PATH;
25-
26- // Add the include path of the hybrid-ep library
27- include = " -I" + base_path + " /csrc/hybrid_ep" + " -I" + get_env (" CUDA_HOME" ) + " /include" ;
24+ include = " -I" + base_path + " /backend" + " -I" + get_env (" CUDA_HOME" ) + " /include" ;
2825
2926 // Add the library path of the hybrid-ep library
3027 library = " -L" + get_env (" CUDA_HOME" ) + " /lib64 -lcudart" ;
@@ -40,8 +37,6 @@ NVCCCompiler::NVCCCompiler() {
4037
4138
4239std::string NVCCCompiler::build (std::string code, std::string signature, int local_rank) {
43- std::string base_path = BASE_PATH;
44-
4540 // Create the source directory
4641 std::string jit_dir = base_path + " /build/jit" ;
4742 std::filesystem::create_directories (jit_dir);
@@ -80,7 +75,7 @@ std::string NVCCCompiler::build(std::string code, std::string signature, int loc
8075 return output_path;
8176}
8277
83- std::any NVCCCompiler::get_instance (std::string library_path) {
78+ std::any NVCCCompiler::get_instance (std::string library_path, std::string kernel_key ) {
8479 // Open the compiled library with RTLD_GLOBAL for symbol visibility
8580 void * handle = dlopen (library_path.c_str (), RTLD_NOW | RTLD_GLOBAL);
8681 if (handle == nullptr ) {
@@ -98,8 +93,15 @@ std::any NVCCCompiler::get_instance(std::string library_path) {
9893 library_path);
9994 }
10095
101- // After using the library, clear the built library
102- remove (library_path.c_str ());
96+ // Unique the compiled lib from different rank
97+ std::string unique_library_path = base_path + " /build/jit/" + kernel_key + " .so" ;
98+ std::string unique_command = " mv " + library_path + " " + unique_library_path;
99+ if (library_path != unique_library_path) {
100+ auto ret = std::system (unique_command.c_str ());
101+ if (ret != 0 ) {
102+ throw std::runtime_error (" Failed to unique the library: " + unique_command);
103+ }
104+ }
103105
104106 // Run the get_function_ptr, then we get the compiled template
105107 std::any func_ptr = get_ptr ();
@@ -109,7 +111,7 @@ std::any NVCCCompiler::get_instance(std::string library_path) {
109111
110112std::string NVCCCompiler::get_metadata_preprocessing_code (HybridEpConfigInstance config) {
111113 return R"(
112- #include "backend/ hybrid_ep_backend.cuh"
114+ #include "hybrid_ep_backend.cuh"
113115 #include <any>
114116
115117 extern "C" {
@@ -130,7 +132,7 @@ std::string NVCCCompiler::get_dispatch_code(HybridEpConfigInstance config) {
130132 (config.token_data_type == TOKEN_DATA_TYPE::UINT8) ? " uint8_t" : " uint16_t" ;
131133
132134 return R"(
133- #include "backend/ hybrid_ep_backend.cuh"
135+ #include "hybrid_ep_backend.cuh"
134136 #include <any>
135137
136138 extern "C" {
@@ -150,7 +152,7 @@ std::string NVCCCompiler::get_dispatch_code(HybridEpConfigInstance config) {
150152
151153std::string NVCCCompiler::get_combine_code (HybridEpConfigInstance config) {
152154 return R"(
153- #include "backend/ hybrid_ep_backend.cuh"
155+ #include "hybrid_ep_backend.cuh"
154156 #include <any>
155157
156158 extern "C" {
@@ -171,7 +173,18 @@ std::string NVCCCompiler::get_combine_code(HybridEpConfigInstance config) {
171173 )" ;
172174}
173175
174-
176+ KernelCache::KernelCache (int local_rank, std::string base_path):
177+ local_rank(local_rank), base_path(base_path), nvcc_compiler(base_path) {
178+ // Load all cached kernels from the cache directory
179+ std::string cache_dir = base_path + " /build/jit" ;
180+ std::filesystem::create_directories (cache_dir);
181+ for (const auto & entry : std::filesystem::directory_iterator (cache_dir)) {
182+ if (entry.path ().extension () == " .so" ) {
183+ std::string kernel_key = entry.path ().stem ().string ();
184+ kernel_cache[kernel_key] = nvcc_compiler.get_instance (entry.path ().string (), kernel_key);
185+ }
186+ }
187+ }
175188
176189void KernelCache::run_proprecess_kernel (
177190 HybridEpConfigInstance config,
@@ -202,7 +215,7 @@ void KernelCache::run_proprecess_kernel(
202215 if (it == kernel_cache.end ()) {
203216 auto preprocessing_code = nvcc_compiler.get_metadata_preprocessing_code (config);
204217 auto preprocessing_path = nvcc_compiler.build (preprocessing_code, preprocess_kernel_key, local_rank);
205- kernel_cache[preprocess_kernel_key] = nvcc_compiler.get_instance (preprocessing_path);
218+ kernel_cache[preprocess_kernel_key] = nvcc_compiler.get_instance (preprocessing_path, preprocess_kernel_key );
206219 }
207220 auto preprocessing_instance = kernel_cache[preprocess_kernel_key];
208221
@@ -255,7 +268,7 @@ void KernelCache::run_dispatch_kernel(
255268 // JIT Compile the kernel
256269 auto dispatch_code = nvcc_compiler.get_dispatch_code (config);
257270 auto dispatch_path = nvcc_compiler.build (dispatch_code, dispatch_kernel_key, local_rank);
258- kernel_cache[dispatch_kernel_key] = nvcc_compiler.get_instance (dispatch_path);
271+ kernel_cache[dispatch_kernel_key] = nvcc_compiler.get_instance (dispatch_path, dispatch_kernel_key );
259272 }
260273 auto dispatch_instance = kernel_cache[dispatch_kernel_key];
261274
@@ -295,7 +308,7 @@ void KernelCache::run_combine_kernel(
295308 // JIT Compile the kernel
296309 auto combine_code = nvcc_compiler.get_combine_code (config);
297310 auto combine_path = nvcc_compiler.build (combine_code, combine_kernel_key, local_rank);
298- kernel_cache[combine_kernel_key] = nvcc_compiler.get_instance (combine_path);
311+ kernel_cache[combine_kernel_key] = nvcc_compiler.get_instance (combine_path, combine_kernel_key );
299312 }
300313 auto combine_instance = kernel_cache[combine_kernel_key];
301314
0 commit comments