diff --git a/hippunfold/workflow/rules/nnunet.smk b/hippunfold/workflow/rules/nnunet.smk index a86a6623..ab21a4b0 100644 --- a/hippunfold/workflow/rules/nnunet.smk +++ b/hippunfold/workflow/rules/nnunet.smk @@ -2,32 +2,35 @@ import re def get_nnunet_input(wildcards): - if config["modality"] == "T2w": - nii = ( - bids( - root=work, - datatype="anat", - **config["subj_wildcards"], - suffix="T2w.nii.gz", - space="corobl", - desc="preproc", - hemi="{hemi}", - ), - ) + T1w_nii = bids( + root=work, + datatype="anat", + **config["subj_wildcards"], + suffix="T1w.nii.gz", + space="corobl", + desc="preproc", + hemi="{hemi}", + ) + T2w_nii = bids( + root=work, + datatype="anat", + **config["subj_wildcards"], + suffix="T2w.nii.gz", + space="corobl", + desc="preproc", + hemi="{hemi}", + ) + if (config["modality"] == "T1w" or config["modality"] == "T2w") and config[ + "force_nnunet_model" + ] == "T1T2w": + return (T1w_nii, T2w_nii) + + elif config["modality"] == "T2w": + return T2w_nii elif config["modality"] == "T1w": - nii = ( - bids( - root=work, - datatype="anat", - **config["subj_wildcards"], - suffix="T1w.nii.gz", - space="corobl", - desc="preproc", - hemi="{hemi}", - ), - ) + return T1w_nii elif config["modality"] == "hippb500": - nii = bids( + return bids( root=work, datatype="dwi", hemi="{hemi}", @@ -37,7 +40,6 @@ def get_nnunet_input(wildcards): ) else: raise ValueError("modality not supported for nnunet!") - return nii def get_model_tar(): @@ -94,6 +96,19 @@ def parse_trainer_from_tar(wildcards, input): return trainer +def get_cmd_copy_inputs(wildcards, input): + in_img = input.in_img + if isinstance(in_img, str): + # we have one input image + return f"cp {in_img} tempimg/temp_0000.nii.gz" + else: + cmd = [] + # we have multiple input images + for i, img in enumerate(input.in_img): + cmd.append(f"cp {img} tempimg/temp_{i:04d}.nii.gz") + return " && ".join(cmd) + + rule run_inference: """ This rule uses either GPU or CPU . It also runs in an isolated folder (shadow), with symlinks to inputs in that folder, copying over outputs once complete, so temp files are not retained""" @@ -101,7 +116,7 @@ rule run_inference: in_img=get_nnunet_input, model_tar=get_model_tar(), params: - temp_img="tempimg/temp_0000.nii.gz", + cmd_copy_inputs=get_cmd_copy_inputs, temp_lbl="templbl/temp.nii.gz", model_dir="tempmodel", in_folder="tempimg", @@ -148,7 +163,7 @@ rule run_inference: # run inference #copy from temp output folder to final output "mkdir -p {params.model_dir} {params.in_folder} {params.out_folder} && " - "cp {input.in_img} {params.temp_img} && " + "{params.cmd_copy_inputs} && " "tar -xf {input.model_tar} -C {params.model_dir} && " "export RESULTS_FOLDER={params.model_dir} && " "export nnUNet_n_proc_DA={threads} && "