Skip to content

Commit

Permalink
inline envvars for new CUDA 12.6 behavior where cicc subcommand is pr…
Browse files Browse the repository at this point in the history
…efixed with $CICC_PATH
  • Loading branch information
trxcllnt committed Aug 12, 2024
1 parent b12d8d9 commit 4b5d35c
Showing 1 changed file with 43 additions and 24 deletions.
67 changes: 43 additions & 24 deletions src/compiler/nvcc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,12 @@ pub fn generate_compile_commands(
.context("Missing object file output").unwrap()
.path.clone();

let output = if output.is_absolute() {
output
} else {
cwd.join(output)
};

arguments.extend(vec![
parsed_args.compilation_flag.clone(),
"-o".into(),
Expand Down Expand Up @@ -427,7 +433,7 @@ pub struct NvccCompileCommand {

#[derive(Clone, Debug)]
pub struct NvccGeneratedCommand {
pub exe: String,
pub exe: PathBuf,
pub args: Vec<String>,
pub cwd: PathBuf,
pub cacheable: Cacheable,
Expand Down Expand Up @@ -501,18 +507,13 @@ impl CompileCommandImpl for NvccCompileCommand {
let grouped_subcommands = get_nvcc_subcommand_groups(
creator,
executable,
&arguments,
arguments,
cwd,
temp.as_path(),
keep.clone(),
&mut env_vars
).await?;

let env_path = env_vars.iter()
.find(|(k, _)| k == "PATH")
.map(|(_, p)| p.to_owned())
.unwrap();

let mut output = process::Output {
status: process::ExitStatus::default(),
stdout: vec![],
Expand All @@ -535,7 +536,7 @@ impl CompileCommandImpl for NvccCompileCommand {

let results = futures::future::join_all(
command_groups.iter().map(|commands|
run_nvcc_subcommands(service, creator, cwd, &env_path, &env_vars, commands)
run_nvcc_subcommands(service, creator, cwd, &env_vars, commands)
)
)
.await;
Expand Down Expand Up @@ -603,8 +604,11 @@ where

fn select_nvcc_env_vars(
env_vars: &mut Vec<(OsString, OsString)>,
line: &str
) -> Option<String> {
line: &mut String,
cwd: &Path,
) -> Option<(PathBuf, Vec<String>)> {

// Intercept the environment variable lines and add them to the env_vars list
if let Some(var) = Regex::new(r"^([_A-Z]+)=(.*)$").unwrap().captures(line) {
let (_, [var, val]) = var.extract();

Expand All @@ -629,7 +633,30 @@ where
env_vars.splice(loc, [pair]);
return None;
}
Some(line.to_string())

// The rest of the lines are subcommands, so parse into a vec of [cmd, args..]

// Expand envvars in nvcc subcommands, i.e. "$CICC_PATH/cicc ..."
if let Some(env_vars) = dist::osstring_tuples_to_strings(env_vars) {
for (key, val) in env_vars {
let var = "$".to_owned() + &key;
*line = line.replace(&var, &val);
}
}

let args = shlex::split(line)?;
let (exe, args) = args.split_first()?;

let env_path = env_vars.iter()
.find(|(k, _)| k == "PATH")
.map(|(_, p)| p.to_owned())
.unwrap();

if let Ok(exe) = which_in(exe, env_path.into(), cwd) {
return Some((exe.clone(), args.to_vec()))
}

None
}

let nvcc_subcommand_groups = nvcc_dryrun_cmd.take_stderr();
Expand All @@ -654,13 +681,8 @@ where
// Select lines that match the `#$ ` prefix from nvcc --dryrun
.map(|line| line.and_then(select_lines_with_hash_dollar_space))
// Intercept the environment variable lines and add them to the env_vars list
.filter_map_ok(|line: String| select_nvcc_env_vars(env_vars, &line))
// The rest of the lines are subcommands, so parse into a vec of [cmd, args..]
.filter_map_ok(|line| {
let args = shlex::split(&line)?;
let (exe, args) = args.split_first()?;
Some((exe.clone(), args.to_vec()))
})
.filter_map_ok(|mut line| select_nvcc_env_vars(env_vars, &mut line, cwd))
.fold_ok((
vec![],
HashMap::<String, String>::new(),
Expand Down Expand Up @@ -714,12 +736,12 @@ where

let tmp_name = |name: &String| tmp.join(name).into_os_string().into_string().ok();

let (dir, cacheable) = match exe.as_str() {
let (dir, cacheable) = match exe.file_name().and_then(|s| s.to_str()) {
// cicc, ptxas are cacheable
"cicc" | "ptxas" => (tmp, Cacheable::Yes),
Some("cicc") | Some("ptxas") => (tmp, Cacheable::Yes),
// cudafe++ and fatbinary are not cacheable
"cudafe++" => (tmp, Cacheable::No),
"fatbinary" => {
Some("cudafe++") => (tmp, Cacheable::No),
Some("fatbinary") => {
// The fatbinary command represents the start of the last group
command_groups.push(vec![]);
(tmp, Cacheable::No)
Expand Down Expand Up @@ -789,7 +811,6 @@ async fn run_nvcc_subcommands<T>(
service: &server::SccacheService<T>,
creator: &T,
cwd: &Path,
env_path: &OsStr,
env_vars: &[(OsString, OsString)],
commands: &[NvccGeneratedCommand],
) -> Result<process::Output>
Expand All @@ -810,8 +831,6 @@ where
cacheable,
} = cmd;

let exe = which_in(exe, env_path.into(), cwd)?;

if log_enabled!(log::Level::Trace) {
debug!("run_commands_sequential cwd={:?}, cmd={:?}", cwd, [
vec![exe.clone().into_os_string().into_string().unwrap()],
Expand Down

0 comments on commit 4b5d35c

Please sign in to comment.