Skip to content

Commit

Permalink
Merge pull request #320 from khanlab/bugfix-T1T2w
Browse files Browse the repository at this point in the history
FIX: input images when using T1T2w model
  • Loading branch information
akhanf authored Dec 9, 2024
2 parents 8230c03 + 2d0b10c commit 95f9f33
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions hippunfold/workflow/rules/nnunet.smk
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,35 @@ import re


def get_nnunet_input(wildcards):
if config["modality"] == "T2w":
nii = (
bids(
root=work,
datatype="anat",
**config["subj_wildcards"],
suffix="T2w.nii.gz",
space="corobl",
desc="preproc",
hemi="{hemi}",
),
)
T1w_nii = bids(
root=work,
datatype="anat",
**config["subj_wildcards"],
suffix="T1w.nii.gz",
space="corobl",
desc="preproc",
hemi="{hemi}",
)
T2w_nii = bids(
root=work,
datatype="anat",
**config["subj_wildcards"],
suffix="T2w.nii.gz",
space="corobl",
desc="preproc",
hemi="{hemi}",
)
if (config["modality"] == "T1w" or config["modality"] == "T2w") and config[
"force_nnunet_model"
] == "T1T2w":
return (T1w_nii, T2w_nii)

elif config["modality"] == "T2w":
return T2w_nii
elif config["modality"] == "T1w":
nii = (
bids(
root=work,
datatype="anat",
**config["subj_wildcards"],
suffix="T1w.nii.gz",
space="corobl",
desc="preproc",
hemi="{hemi}",
),
)
return T1w_nii
elif config["modality"] == "hippb500":
nii = bids(
return bids(
root=work,
datatype="dwi",
hemi="{hemi}",
Expand All @@ -37,7 +40,6 @@ def get_nnunet_input(wildcards):
)
else:
raise ValueError("modality not supported for nnunet!")
return nii


def get_model_tar():
Expand Down Expand Up @@ -94,14 +96,27 @@ def parse_trainer_from_tar(wildcards, input):
return trainer


def get_cmd_copy_inputs(wildcards, input):
in_img = input.in_img
if isinstance(in_img, str):
# we have one input image
return f"cp {in_img} tempimg/temp_0000.nii.gz"
else:
cmd = []
# we have multiple input images
for i, img in enumerate(input.in_img):
cmd.append(f"cp {img} tempimg/temp_{i:04d}.nii.gz")
return " && ".join(cmd)


rule run_inference:
""" This rule uses either GPU or CPU .
It also runs in an isolated folder (shadow), with symlinks to inputs in that folder, copying over outputs once complete, so temp files are not retained"""
input:
in_img=get_nnunet_input,
model_tar=get_model_tar(),
params:
temp_img="tempimg/temp_0000.nii.gz",
cmd_copy_inputs=get_cmd_copy_inputs,
temp_lbl="templbl/temp.nii.gz",
model_dir="tempmodel",
in_folder="tempimg",
Expand Down Expand Up @@ -148,7 +163,7 @@ rule run_inference:
# run inference
#copy from temp output folder to final output
"mkdir -p {params.model_dir} {params.in_folder} {params.out_folder} && "
"cp {input.in_img} {params.temp_img} && "
"{params.cmd_copy_inputs} && "
"tar -xf {input.model_tar} -C {params.model_dir} && "
"export RESULTS_FOLDER={params.model_dir} && "
"export nnUNet_n_proc_DA={threads} && "
Expand Down

0 comments on commit 95f9f33

Please sign in to comment.