From 0222393735b07b90350390dfea7c62ab143196c5 Mon Sep 17 00:00:00 2001 From: Ali Khan Date: Sun, 8 Dec 2024 22:35:10 -0500 Subject: [PATCH 1/2] fix for using T1T2w model run_inference was only accepting a single input image, this makes it accept a list of images too, and makes it so --force-nnunet-model T1T2w grabs both T1w and T2w as inputs. --- hippunfold/workflow/rules/nnunet.smk | 71 +++++++++++++++++----------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/hippunfold/workflow/rules/nnunet.smk b/hippunfold/workflow/rules/nnunet.smk index a86a6623..6bb452ff 100644 --- a/hippunfold/workflow/rules/nnunet.smk +++ b/hippunfold/workflow/rules/nnunet.smk @@ -2,32 +2,37 @@ import re def get_nnunet_input(wildcards): - if config["modality"] == "T2w": - nii = ( - bids( - root=work, - datatype="anat", - **config["subj_wildcards"], - suffix="T2w.nii.gz", - space="corobl", - desc="preproc", - hemi="{hemi}", - ), - ) + T1w_nii = bids( + root=work, + datatype="anat", + **config["subj_wildcards"], + suffix="T1w.nii.gz", + space="corobl", + desc="preproc", + hemi="{hemi}", + ) + T2w_nii = bids( + root=work, + datatype="anat", + **config["subj_wildcards"], + suffix="T2w.nii.gz", + space="corobl", + desc="preproc", + hemi="{hemi}", + ) + if ( + config["modality"] == "T1w" + or config["modality"] == "T2w" + and config["force_nnunet_model"] == "T1T2w" + ): + return (T1w_nii, T2w_nii) + + elif config["modality"] == "T2w": + return T2w_nii elif config["modality"] == "T1w": - nii = ( - bids( - root=work, - datatype="anat", - **config["subj_wildcards"], - suffix="T1w.nii.gz", - space="corobl", - desc="preproc", - hemi="{hemi}", - ), - ) + return T1w_nii elif config["modality"] == "hippb500": - nii = bids( + return bids( root=work, datatype="dwi", hemi="{hemi}", @@ -37,7 +42,6 @@ def get_nnunet_input(wildcards): ) else: raise ValueError("modality not supported for nnunet!") - return nii def get_model_tar(): @@ -94,6 +98,19 @@ def parse_trainer_from_tar(wildcards, input): return trainer +def get_cmd_copy_inputs(wildcards, input): + in_img = input.in_img + if isinstance(in_img, str): + # we have one input image + return f"cp {in_img} tempimg/temp_0000.nii.gz" + else: + cmd = [] + # we have multiple input images + for i, img in enumerate(input.in_img): + cmd.append(f"cp {img} tempimg/temp_{i:04d}.nii.gz") + return " && ".join(cmd) + + rule run_inference: """ This rule uses either GPU or CPU . It also runs in an isolated folder (shadow), with symlinks to inputs in that folder, copying over outputs once complete, so temp files are not retained""" @@ -101,7 +118,7 @@ rule run_inference: in_img=get_nnunet_input, model_tar=get_model_tar(), params: - temp_img="tempimg/temp_0000.nii.gz", + cmd_copy_inputs=get_cmd_copy_inputs, temp_lbl="templbl/temp.nii.gz", model_dir="tempmodel", in_folder="tempimg", @@ -148,7 +165,7 @@ rule run_inference: # run inference #copy from temp output folder to final output "mkdir -p {params.model_dir} {params.in_folder} {params.out_folder} && " - "cp {input.in_img} {params.temp_img} && " + "{params.cmd_copy_inputs} && " "tar -xf {input.model_tar} -C {params.model_dir} && " "export RESULTS_FOLDER={params.model_dir} && " "export nnUNet_n_proc_DA={threads} && " From 2d0b10cf34164ca4db530986e154261b4bb1a2bc Mon Sep 17 00:00:00 2001 From: Ali Khan Date: Sun, 8 Dec 2024 22:48:03 -0500 Subject: [PATCH 2/2] fix logical error --- hippunfold/workflow/rules/nnunet.smk | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/hippunfold/workflow/rules/nnunet.smk b/hippunfold/workflow/rules/nnunet.smk index 6bb452ff..ab21a4b0 100644 --- a/hippunfold/workflow/rules/nnunet.smk +++ b/hippunfold/workflow/rules/nnunet.smk @@ -20,11 +20,9 @@ def get_nnunet_input(wildcards): desc="preproc", hemi="{hemi}", ) - if ( - config["modality"] == "T1w" - or config["modality"] == "T2w" - and config["force_nnunet_model"] == "T1T2w" - ): + if (config["modality"] == "T1w" or config["modality"] == "T2w") and config[ + "force_nnunet_model" + ] == "T1T2w": return (T1w_nii, T2w_nii) elif config["modality"] == "T2w":