Skip to content

Commit

Permalink
cache cudafe++ invocations so we restore the .module_id file
Browse files Browse the repository at this point in the history
  • Loading branch information
trxcllnt committed Aug 20, 2024
1 parent 1f5f816 commit a876df8
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 69 deletions.
4 changes: 4 additions & 0 deletions src/compiler/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ pub enum CCompilerKind {
Msvc,
/// NVIDIA CUDA compiler
Nvcc,
/// NVIDIA CUDA compiler front-end
Cudafe,
/// NVIDIA CUDA optimizer and PTX generator
Cicc,
/// NVIDIA CUDA PTX assembler
Expand Down Expand Up @@ -1383,6 +1385,8 @@ impl pkg::ToolchainPackager for CToolchainPackager {
add_named_file(&mut package_builder, "liblto_plugin.so")?;
}

CCompilerKind::Cudafe => {}

CCompilerKind::Cicc => {}

CCompilerKind::Ptxas => {}
Expand Down
94 changes: 61 additions & 33 deletions src/compiler/cicc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl CCompilerImpl for Cicc {
arguments: &[OsString],
cwd: &Path,
) -> CompilerArguments<ParsedArguments> {
parse_arguments(arguments, cwd, Language::Ptx, &ARGS[..])
parse_arguments(arguments, cwd, Language::Ptx, 3, &ARGS[..])
}
#[allow(clippy::too_many_arguments)]
async fn preprocess<T>(
Expand All @@ -74,11 +74,6 @@ impl CCompilerImpl for Cicc {
where
T: CommandCreatorSync,
{
trace!(
"cicc preprocessed input file: cwd={:?} path={:?}",
cwd,
&parsed_args.input
);
preprocess(cwd, parsed_args).await
}
fn generate_compile_commands<T>(
Expand Down Expand Up @@ -109,16 +104,20 @@ pub fn parse_arguments<S>(
arguments: &[OsString],
cwd: &Path,
language: Language,
input_distance_from_end: usize,
arg_info: S,
) -> CompilerArguments<ParsedArguments>
where
S: SearchableArgInfo<ArgData>,
{
let mut args = arguments.to_vec();
let input_loc = arguments.len() - 3;
let input_loc = arguments.len() - input_distance_from_end;
let input = args.splice(input_loc..input_loc + 1, []).next().unwrap();

let mut take_next = false;
let mut gen_module_id_file = false;
let mut module_id_file_name = None;

let mut extra_inputs = vec![];
let mut outputs = HashMap::new();

Expand All @@ -129,46 +128,46 @@ where
match arg {
Ok(arg) => {
let args = match arg.get_data() {
Some(GenModuleIdFileFlag) => {
take_next = false;
gen_module_id_file = true;
&mut common_args
}
Some(ModuleIdFileName(o)) => {
take_next = false;
module_id_file_name = Some(cwd.join(o));
&mut common_args
}

Some(PassThrough(_)) => {
take_next = false;
&mut common_args
}
Some(Output(o)) => {
take_next = false;
let path = cwd.join(o);
outputs.insert(
"obj",
ArtifactDescriptor {
path,
path: cwd.join(o),
optional: false,
},
);
continue;
}
Some(UnhashedInput(o)) => {
take_next = false;
let path = cwd.join(o);
if !path.exists() {
continue;
}
extra_inputs.push(path);
&mut unhashed_args
}
Some(UnhashedOutput(o)) => {
take_next = false;
let path = cwd.join(o);
if let Some(flag) = arg.flag_str() {
outputs.insert(
flag,
ArtifactDescriptor {
path,
path: cwd.join(o),
optional: false,
},
);
}
&mut unhashed_args
}
Some(UnhashedFlag) | Some(Unhashed(_)) => {
Some(Unhashed(_)) => {
take_next = false;
&mut unhashed_args
}
Expand All @@ -195,6 +194,20 @@ where
};
}

if let Some(path) = module_id_file_name {
if gen_module_id_file {
outputs.insert(
"--gen_module_id_file",
ArtifactDescriptor {
path,
optional: language == Language::Ptx,
},
);
} else {
extra_inputs.push(path);
}
}

CompilerArguments::Ok(ParsedArguments {
input: input.into(),
outputs,
Expand All @@ -207,18 +220,17 @@ where
common_args,
arch_args: vec![],
unhashed_args,
extra_dist_files: extra_inputs,
extra_hash_files: vec![],
extra_dist_files: extra_inputs.clone(),
extra_hash_files: extra_inputs.clone(),
msvc_show_includes: false,
profile_generate: false,
color_mode: ColorMode::Off,
color_mode: ColorMode::Auto,
suppress_rewrite_includes_only: false,
too_hard_for_preprocessor_cache_mode: None,
})
}

pub async fn preprocess(cwd: &Path, parsed_args: &ParsedArguments) -> Result<process::Output> {
// cicc and ptxas expect input to be an absolute path
let input = if parsed_args.input.is_absolute() {
parsed_args.input.clone()
} else {
Expand Down Expand Up @@ -250,8 +262,6 @@ pub fn generate_compile_commands(
let _ = path_transformer;
}

trace!("compile");

let lang_str = &parsed_args.language.as_str();
let out_file = match parsed_args.outputs.get("obj") {
Some(obj) => &obj.path,
Expand All @@ -263,17 +273,35 @@ pub fn generate_compile_commands(
arguments.extend_from_slice(&parsed_args.unhashed_args);
arguments.extend(vec![
(&parsed_args.input).into(),
"-o".into(),
out_file.into(),
]);

// hack -- don't add `-o` for cudafe++
if parsed_args.language != Language::Cxx {
arguments.extend(vec![
"-o".into(),
out_file.into(),
]);
}


let command = SingleCompileCommand {
executable: executable.to_owned(),
arguments,
env_vars: env_vars.to_owned(),
cwd: cwd.to_owned(),
};

if log_enabled!(log::Level::Trace) {
trace!("cicc::generate_compile_commands {:?}",
[
vec![command.executable.as_os_str().to_owned()],
command.arguments.to_owned()
]
.concat()
.join(std::ffi::OsStr::new(" "))
);
}

#[cfg(not(feature = "dist-client"))]
let dist_command = None;
#[cfg(feature = "dist-client")]
Expand All @@ -298,10 +326,10 @@ pub fn generate_compile_commands(
}

ArgData! { pub
GenModuleIdFileFlag,
ModuleIdFileName(PathBuf),
Output(PathBuf),
UnhashedInput(PathBuf),
UnhashedOutput(PathBuf),
UnhashedFlag,
PassThrough(OsString),
Unhashed(OsString),
}
Expand All @@ -311,9 +339,9 @@ use self::ArgData::*;
counted_array!(pub static ARGS: [ArgInfo<ArgData>; _] = [
take_arg!("--gen_c_file_name", PathBuf, Separated, UnhashedOutput),
take_arg!("--gen_device_file_name", PathBuf, Separated, UnhashedOutput),
flag!("--gen_module_id_file", UnhashedFlag),
flag!("--gen_module_id_file", GenModuleIdFileFlag),
take_arg!("--include_file_name", OsString, Separated, PassThrough),
take_arg!("--module_id_file_name", PathBuf, Separated, UnhashedInput),
take_arg!("--module_id_file_name", PathBuf, Separated, ModuleIdFileName),
take_arg!("--stub_file_name", PathBuf, Separated, UnhashedOutput),
take_arg!("-o", PathBuf, Separated, Output),
]);
25 changes: 25 additions & 0 deletions src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::compiler::nvcc::Nvcc;
use crate::compiler::nvcc::NvccHostCompiler;
use crate::compiler::nvhpc::Nvhpc;
use crate::compiler::ptxas::Ptxas;
use crate::compiler::cudafe::Cudafe;
use crate::compiler::rust::{Rust, RustupProxy};
use crate::compiler::tasking_vx::TaskingVX;
#[cfg(feature = "dist-client")]
Expand Down Expand Up @@ -300,6 +301,7 @@ impl CompilerKind {
CompilerKind::C(CCompilerKind::Msvc) => textual_lang + " [msvc]",
CompilerKind::C(CCompilerKind::Nvcc) => textual_lang + " [nvcc]",
CompilerKind::C(CCompilerKind::Cicc) => textual_lang + " [cicc]",
CompilerKind::C(CCompilerKind::Cudafe) => textual_lang + " [cudafe]",
CompilerKind::C(CCompilerKind::Ptxas) => textual_lang + " [ptxas]",
CompilerKind::C(CCompilerKind::Nvhpc) => textual_lang + " [nvhpc]",
CompilerKind::C(CCompilerKind::TaskingVX) => textual_lang + " [taskingvx]",
Expand Down Expand Up @@ -1117,6 +1119,17 @@ fn is_rustc_like<P: AsRef<Path>>(p: P) -> bool {
)
}

/// Returns true if the given path looks like cudafe++
fn is_nvidia_cudafe<P: AsRef<Path>>(p: P) -> bool {
matches!(
p.as_ref()
.file_stem()
.map(|s| s.to_string_lossy().to_lowercase())
.as_deref(),
Some("cudafe++")
)
}

/// Returns true if the given path looks like cicc
fn is_nvidia_cicc<P: AsRef<Path>>(p: P) -> bool {
matches!(
Expand Down Expand Up @@ -1200,6 +1213,18 @@ where

let rustc_executable = if let Some(ref rustc_executable) = maybe_rustc_executable {
rustc_executable
} else if is_nvidia_cudafe(executable) {
debug!("Found cudafe++");
return CCompiler::new(
Cudafe {
// TODO: Use nvcc --version
version: Some(String::new()),
},
executable.to_owned(),
&pool,
)
.await
.map(|c| (Box::new(c) as Box<dyn Compiler<T>>, None));
} else if is_nvidia_cicc(executable) {
debug!("Found cicc");
return CCompiler::new(
Expand Down
Loading

0 comments on commit a876df8

Please sign in to comment.