diff --git a/src/main.rs b/src/main.rs index a6868ac..b607152 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ struct Config { mainClass: String, vmArgs: Vec, useZgcIfSupportedOs: bool, + useMainAsContextClassLoader: bool, } // Picks discrete GPU on Windows, if possible @@ -41,9 +42,10 @@ const JVM_LOCATION: [&str; 3] = ["jdk", "lib", "server"]; fn start_jvm( jvm_location: &Path, class_path: Vec, - main_class: &str, + main_class_name: &str, vm_args: Vec, use_zgc_if_supported: bool, + use_main_as_context_class_loader: bool, args: Vec, ) { let mut args_builder = InitArgsBuilder::new() @@ -75,6 +77,44 @@ fn start_jvm( .attach_current_thread() .expect("Failed to attach the current thread"); + if use_main_as_context_class_loader { + // Class mainClass = MainClass.class; + let main_class = env + .find_class(main_class_name) + .expect("Failed to get main class"); + + // ClassLoader loader = mainClass.getClassLoader() + let class_loader = env + .call_method( + main_class, + "getClassLoader", + "()Ljava/lang/ClassLoader;", + &[], + ) + .and_then(|it| it.l()) + .expect("Failed to get class loader from main class"); + + // Thread thread = Thread.currentThread() + let current_thread = env + .call_static_method( + "java/lang/Thread", + "currentThread", + "()Ljava/lang/Thread;", + &[], + ) + .and_then(|it| it.l()) + .expect("Failed to get current thread"); + + // thread.setContextClassLoader(loader) + env.call_method( + current_thread, + "setContextClassLoader", + "(Ljava/lang/ClassLoader;)V", + &[(&class_loader).into()], + ) + .expect("Failed to set class loader"); + } + let jstrings: Vec = args .iter() .map(|s| env.new_string(s)) // Convert to JString (maybe) @@ -92,7 +132,7 @@ fn start_jvm( i = i + 1; } env.call_static_method( - main_class, + main_class_name, "main", "([Ljava/lang/String;)V", &[(&method_args).into()], @@ -168,13 +208,24 @@ fn main() { let config_file_path = current_location.join("config.json"); let data = fs::read_to_string(config_file_path).expect("Unable to read config file"); let config: Config = serde_json::from_str(&data).expect("Invalid config json"); - let class_path: Vec = config.classPath.into_iter().map(|it| current_location.join(it).into_os_string().into_string().unwrap()).collect(); + let class_path: Vec = config + .classPath + .into_iter() + .map(|it| { + current_location + .join(it) + .into_os_string() + .into_string() + .unwrap() + }) + .collect(); start_jvm( &jvm_location, class_path, &config.mainClass.replace(".", "/"), config.vmArgs, config.useZgcIfSupportedOs, + config.useMainAsContextClassLoader, args, ); }