Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions experiments/LongitudinalRegistration/1-initial_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@
src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii")
segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio")

use_mask_list = [False]
use_labelmap_list = [False]
use_mask_list = [False, False]
use_labelmap_list = [False, True]

# ICON only
use_mass_list = [False]
use_mass_list = [False, False]

methods_list = [
["Greedy"],
["Greedy"],
]
number_of_iterations_ANTS_list = [
[40, 20, 10],
[40, 20, 10],
]
number_of_iterations_greedy_list = [
[60, 20, 10],
]
number_of_iterations_ICON_list = [100]
number_of_iterations_greedy_list = [[80, 40, 5], [40, 40, 10]]
number_of_iterations_ICON_list = [100, 100]

exclude_tokens = ["nop"]
ref_suffix = "_ref"
Expand Down
114 changes: 46 additions & 68 deletions experiments/LongitudinalRegistration/2-finetune_icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
# Frames whose labelmap is missing on disk are dropped from the dataset.
#
# In addition to the original ``gated_nii`` frames, each patient's training
# group is augmented with that patient's ANTS- and Greedy-warped frames
# written by ``1-initial_registration.py`` (warped image + labelmap per gated
# frame, under ``output_dir / <method> / <patient_id>``). Because the warped
# group is augmented with that patient's ANTS- and Greedy-init frames
# written by ``1-initial_registration.py`` (init image + labelmap per gated
# frame, under ``output_dir / <method> / <patient_id>``). Because the init
# frames are merged into the *same* ``subject_id`` group, uniGradICON pairs the
# original gated frames and both backends' pre-registered frames together.

Expand All @@ -28,8 +28,6 @@
from pathlib import Path
from typing import Optional

import itk

from physiomotion4d import WorkflowFineTuneICONRegistration
from physiomotion4d.labelmap_tools import LabelmapTools

Expand All @@ -49,12 +47,14 @@
fine_tune_name = "icon_finetuning"

# Pre-registration augmentation: ``1-initial_registration.py`` warps every gated
# moving frame into reference space with these backends and writes the warped
# moving frame into reference space with these backends and writes the init
# image + labelmap under ``initial_registration_dir / <method>.lower() /
# <patient_id>``. Those warped frames are merged into each patient's training
# <patient_id>``. Those init frames are merged into each patient's training
# group below (section 4b).
initial_registration_dir = output_dir
initial_registration_methods = ["Greedy"]
initial_registration_dirs = [
Path("d:/PhysioMotion4D/duke_data/greedy_registrations/results_l/greedy_40.40.10"),
Path("d:/PhysioMotion4D/duke_data/greedy_registrations/results_raw/greedy_80.40.5"),
]

# Fixed train/test split: sort patients in ``ref_data_dir`` by filename;
# first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies
Expand Down Expand Up @@ -120,49 +120,21 @@


# %%
def load_or_derive_mask(labelmap_path: Path) -> Optional[Path]:
"""Create (or reuse) a loss-function mask next to ``labelmap_path``.

Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm
via :meth:`LabelmapTools.convert_labelmap_to_mask`, writing the result as
``<labelmap_stem>_mask.nii.gz`` in the labelmap's own directory. Handles
both ``.nii.gz`` (original Simpleware labelmaps) and ``.mha``
(pre-registration warped labelmaps). Returns the mask path; existing
masks on disk are reused unmodified.
"""
if not labelmap_path.exists():
return None

name = labelmap_path.name
if name.endswith(".nii.gz"):
stem = name[:-7]
elif name.endswith(".mha"):
stem = name[:-4]
else:
stem = labelmap_path.stem
mask_p = labelmap_path.parent / f"{stem}_mask.nii.gz"
if not mask_p.exists():
mask = labelmap_tools.convert_labelmap_to_mask(
itk.imread(str(labelmap_path)), dilation_in_mm=mask_dilation_mm
)
itk.imwrite(mask, str(mask_p), compression=True)
return mask_p


# %%
def gather_warped_frames(
method_dir: Path,
def gather_init_frames(
initial_registration_dir: Path,
patient_id: str,
) -> tuple[list[Path], list[Optional[Path]], list[Optional[Path]]]:
"""Return ``(warped_image_paths, warped_labelmap_paths, warped_mask_paths)`` for one
``initial_registration_dir / <method> / <patient_id>`` directory.
"""Return ``(init_image_paths, init_labelmap_paths, init_mask_paths)`` for one
Comment on lines +123 to +127
``initial_registration_dir / <patient_id>`` directory.

Enumerates the warped moving images (``<stem>.mha``), excluding the
Enumerates the init moving images (``<stem>.mha``), excluding the
``_labelmap.mha`` and ``_mask.mha``
companions, and pairs each with its ``<stem>_labelmap.mha`` (``None`` when
that labelmap is absent). Returns empty lists when ``method_dir`` does
not exist.
that labelmap is absent). Returns empty lists when
``initial_registration_dir`` does not exist.
"""
if not method_dir.is_dir():
if not initial_registration_dir.is_dir():
print(f" {patient_id}: registration dir {initial_registration_dir} not found")
return [], [], []
Comment on lines +123 to 138
companion_suffixes = (
"_labelmap.mha",
Expand All @@ -171,15 +143,23 @@ def gather_warped_frames(
image_paths: list[Path] = []
labelmap_paths: list[Optional[Path]] = []
mask_paths: list[Optional[Path]] = []
for image in sorted(method_dir.glob("*.mha")):
for image in sorted(initial_registration_dir.glob("*.mha")):
if image.name.endswith(companion_suffixes):
continue
stem = image.name[:-4]
labelmap = method_dir / f"{stem}_labelmap.mha"
mask = method_dir / f"{stem}_mask.mha"
labelmap = initial_registration_dir / f"{stem}_labelmap.mha"
mask = initial_registration_dir / f"{stem}_mask.mha"
if not image.exists() or not labelmap.exists() or not mask.exists():
print(
f" {patient_id}: image {image} or labelmap {labelmap} or mask {mask} not found"
)
continue
image_paths.append(image)
labelmap_paths.append(labelmap if labelmap.exists() else None)
mask_paths.append(mask if mask.exists() else None)
labelmap_paths.append(labelmap)
mask_paths.append(mask)
print(
f" {patient_id}: {len(image_paths)} init frames, {len(labelmap_paths)} with labelmap, {len(mask_paths)} with mask"
)
return image_paths, labelmap_paths, mask_paths


Expand All @@ -205,7 +185,7 @@ def gather_warped_frames(
for f in frame_names:
labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz")
labelmap_paths.append(labelmap if labelmap.exists() else None)
mask = load_or_derive_mask(labelmap)
mask = seg_dir / f.replace(".nii.gz", "_mask.nii.gz")
mask_paths.append(mask)
Comment on lines +188 to 189

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Mask path appended without existence check, unlike labelmap handling.

Line 186 correctly appends None when the labelmap doesn't exist, but line 188 unconditionally appends the mask Path even if the file is missing. When mask_dilation_mm=0 (line 236), the workflow won't derive masks and will fail if it tries to read a non-existent mask file.

         labelmap_paths.append(labelmap if labelmap.exists() else None)
         mask = seg_dir / f.replace(".nii.gz", "_mask.nii.gz")
-        mask_paths.append(mask)
+        mask_paths.append(mask if mask.exists() else None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mask = seg_dir / f.replace(".nii.gz", "_mask.nii.gz")
mask_paths.append(mask)
mask = seg_dir / f.replace(".nii.gz", "_mask.nii.gz")
mask_paths.append(mask if mask.exists() else None)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@experiments/LongitudinalRegistration/2-finetune_icon.py` around lines 187 -
188, The mask Path is appended unconditionally which can lead to attempts to
read missing mask files; modify the code that builds mask_paths to mirror the
labelmap handling: construct mask = seg_dir / f.replace(".nii.gz",
"_mask.nii.gz"), check mask.exists(), and append mask if present or append None
when missing (especially important when mask_dilation_mm == 0) so downstream
code can skip/derive masks safely; update any downstream expectations of
mask_paths to accept None just like the labelmap list.


train_image_files.append(image_paths)
Expand All @@ -218,23 +198,21 @@ def gather_warped_frames(


for subject_index, patient_id in enumerate(valid_train_subjects):
for method_name in initial_registration_methods:
method_dir = initial_registration_dir / method_name.lower() / patient_id
warped_images, warped_labelmaps, warped_masks = gather_warped_frames(method_dir)
if not warped_images:
for initial_registration_dir in initial_registration_dirs:
initial_registration_patient_dir = initial_registration_dir / patient_id

init_images, init_labelmaps, init_masks = gather_init_frames(
initial_registration_patient_dir, patient_id
)
if not init_images:
print(
f" {patient_id}/{method_name}: no initial-registered frames "
f"in {method_dir}"
f" {patient_id}: no initial-registered frames "
f"in {initial_registration_dir}"
)
continue
train_image_files[subject_index].extend(warped_images)
train_labelmap_files[subject_index].extend(warped_labelmaps)
train_mask_files[subject_index].extend(warped_masks)
n_warped = sum(1 for labelmap in warped_labelmaps if labelmap is not None)
print(
f" {patient_id}/{method_name}: +{len(warped_images)} warped frames, "
f"{n_warped} with labelmap"
)
train_image_files[subject_index].extend(init_images)
train_labelmap_files[subject_index].extend(init_labelmaps)
train_mask_files[subject_index].extend(init_masks)

# %%
workflow = WorkflowFineTuneICONRegistration(
Expand All @@ -256,7 +234,7 @@ def gather_warped_frames(
[str(mask_path) if mask_path is not None else None for mask_path in mask_paths]
for mask_paths in train_mask_files
],
mask_dilation_mm=mask_dilation_mm,
mask_dilation_mm=0, # masks are already dilated
unigradicon_src_path=unigradicon_src_path,
epochs=500,
)
Expand Down
Loading
Loading