From a0c1635a95ec00dc5057a2877d365da6c2f6f2f5 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Wed, 10 Jun 2026 07:21:48 -0400 Subject: [PATCH 1/4] ENH: Replace ANTS with Greedy registration; standardize Greedy+ICON for time-series MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove RegisterImagesANTS from RegisterTimeSeriesImages and all dependent workflows/CLIs, replacing every ANTS_ICON and greedy_ICON path with the unified Greedy_ICON pipeline. Normalize method-name casing throughout (greedy → Greedy, greedy_ICON → Greedy_ICON). In the longitudinal registration experiment scripts: - 2-finetune_icon.py: support multiple pre-registration directories, require masks already on disk (drop load_or_derive_mask), rename "warped" to "init" throughout. - 3-recon_4d_icon_eval.py → 3-eval_icon.py: add warped-reference re-segmentation via SegmentHeartSimpleware and a second landmark-error table comparing warped-reference landmarks against time-point landmarks. - Add composite_time_series_mid_slice.py: stacks middle Z slices from a directory of MHA images into a composite volume for visual QA, with adjacent-slice RMSE reporting. --- .../0-cardiacGatedCT_segment_and_landmark.py | 2 +- .../1-initial_registration.py | 14 +- .../2-finetune_icon.py | 113 +++++------ ...3-recon_4d_icon_eval.py => 3-eval_icon.py} | 172 ++++++++++++++-- .../composite_time_series_mid_slice.py | 159 +++++++++++++++ .../cli/reconstruct_highres_4d_ct.py | 22 +- .../register_time_series_images.py | 164 +++------------ .../workflow_convert_image_to_usd.py | 20 +- .../workflow_fine_tune_icon_registration.py | 188 +++++++++++------- ...rkflow_fit_statistical_model_to_patient.py | 16 +- .../workflow_reconstruct_highres_4d_ct.py | 40 ++-- .../tutorial_06_reconstruct_highres_4d_ct.py | 8 +- 12 files changed, 564 insertions(+), 354 deletions(-) rename experiments/LongitudinalRegistration/{3-recon_4d_icon_eval.py => 3-eval_icon.py} (61%) create mode 100644 experiments/LongitudinalRegistration/composite_time_series_mid_slice.py diff --git a/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py b/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py index 834df73..a45098f 100644 --- a/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py +++ b/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py @@ -32,7 +32,7 @@ ref_data_dir = "d:/PhysioMotion4D/duke_data/ref_images" src_data_dir_base = "d:/PhysioMotion4D/duke_data/gated_nii" -segmentation_dir_base = "d:/PhysioMotion4D/duke_data/simple_ascardio" +segmentation_dir_base = "d:/PhysioMotion4D/duke_data/ascardio" ref_files = [ os.path.join(ref_data_dir, f) diff --git a/experiments/LongitudinalRegistration/1-initial_registration.py b/experiments/LongitudinalRegistration/1-initial_registration.py index 8775814..c718cf3 100644 --- a/experiments/LongitudinalRegistration/1-initial_registration.py +++ b/experiments/LongitudinalRegistration/1-initial_registration.py @@ -27,22 +27,22 @@ src_data_dir_base = Path("d:/PhysioMotion4D/duke_data/gated_nii") segmentation_dir_base = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") -use_mask_list = [False] -use_labelmap_list = [False] +use_mask_list = [False, False] +use_labelmap_list = [False, True] # ICON only -use_mass_list = [False] +use_mass_list = [False, False] methods_list = [ ["Greedy"], + ["Greedy"], ] number_of_iterations_ANTS_list = [ [40, 20, 10], + [40, 20, 10], ] -number_of_iterations_greedy_list = [ - [60, 20, 10], -] -number_of_iterations_ICON_list = [100] +number_of_iterations_greedy_list = [[80, 40, 5], [40, 40, 10]] +number_of_iterations_ICON_list = [100, 100] exclude_tokens = ["nop"] ref_suffix = "_ref" diff --git a/experiments/LongitudinalRegistration/2-finetune_icon.py b/experiments/LongitudinalRegistration/2-finetune_icon.py index 05e55db..7cc3900 100644 --- a/experiments/LongitudinalRegistration/2-finetune_icon.py +++ b/experiments/LongitudinalRegistration/2-finetune_icon.py @@ -17,9 +17,9 @@ # Frames whose labelmap is missing on disk are dropped from the dataset. # # In addition to the original ``gated_nii`` frames, each patient's training -# group is augmented with that patient's ANTS- and Greedy-warped frames -# written by ``1-initial_registration.py`` (warped image + labelmap per gated -# frame, under ``output_dir / / ``). Because the warped +# group is augmented with that patient's ANTS- and Greedy-init frames +# written by ``1-initial_registration.py`` (init image + labelmap per gated +# frame, under ``output_dir / / ``). Because the init # frames are merged into the *same* ``subject_id`` group, uniGradICON pairs the # original gated frames and both backends' pre-registered frames together. @@ -28,8 +28,6 @@ from pathlib import Path from typing import Optional -import itk - from physiomotion4d import WorkflowFineTuneICONRegistration from physiomotion4d.labelmap_tools import LabelmapTools @@ -49,12 +47,14 @@ fine_tune_name = "icon_finetuning" # Pre-registration augmentation: ``1-initial_registration.py`` warps every gated -# moving frame into reference space with these backends and writes the warped +# moving frame into reference space with these backends and writes the init # image + labelmap under ``initial_registration_dir / .lower() / -# ``. Those warped frames are merged into each patient's training +# ``. Those init frames are merged into each patient's training # group below (section 4b). -initial_registration_dir = output_dir -initial_registration_methods = ["Greedy"] +initial_registration_dirs = [ + Path("d:/PhysioMotion4D/duke_data/greedy_registrations/results_l/greedy_40.40.10"), + Path("d:/PhysioMotion4D/duke_data/greedy_registrations/results_raw/greedy_80.40.5"), +] # Fixed train/test split: sort patients in ``ref_data_dir`` by filename; # first 80% are train, last 20% are test. ``2-recon_4d_icon_eval.py`` applies @@ -120,49 +120,20 @@ # %% -def load_or_derive_mask(labelmap_path: Path) -> Optional[Path]: - """Create (or reuse) a loss-function mask next to ``labelmap_path``. - - Thresholds the labelmap at ``>0`` and dilates by ``mask_dilation_mm`` mm - via :meth:`LabelmapTools.convert_labelmap_to_mask`, writing the result as - ``_mask.nii.gz`` in the labelmap's own directory. Handles - both ``.nii.gz`` (original Simpleware labelmaps) and ``.mha`` - (pre-registration warped labelmaps). Returns the mask path; existing - masks on disk are reused unmodified. - """ - if not labelmap_path.exists(): - return None - - name = labelmap_path.name - if name.endswith(".nii.gz"): - stem = name[:-7] - elif name.endswith(".mha"): - stem = name[:-4] - else: - stem = labelmap_path.stem - mask_p = labelmap_path.parent / f"{stem}_mask.nii.gz" - if not mask_p.exists(): - mask = labelmap_tools.convert_labelmap_to_mask( - itk.imread(str(labelmap_path)), dilation_in_mm=mask_dilation_mm - ) - itk.imwrite(mask, str(mask_p), compression=True) - return mask_p - - -# %% -def gather_warped_frames( - method_dir: Path, +def gather_init_frames( + initial_registration_dir: Path, ) -> tuple[list[Path], list[Optional[Path]], list[Optional[Path]]]: - """Return ``(warped_image_paths, warped_labelmap_paths, warped_mask_paths)`` for one - ``initial_registration_dir / / `` directory. + """Return ``(init_image_paths, init_labelmap_paths, init_mask_paths)`` for one + ``initial_registration_dir / `` directory. - Enumerates the warped moving images (``.mha``), excluding the + Enumerates the init moving images (``.mha``), excluding the ``_labelmap.mha`` and ``_mask.mha`` companions, and pairs each with its ``_labelmap.mha`` (``None`` when - that labelmap is absent). Returns empty lists when ``method_dir`` does - not exist. + that labelmap is absent). Returns empty lists when + ``initial_registration_dir`` does not exist. """ - if not method_dir.is_dir(): + if not initial_registration_dir.is_dir(): + print(f" {patient_id}: registration dir {initial_registration_dir} not found") return [], [], [] companion_suffixes = ( "_labelmap.mha", @@ -171,15 +142,23 @@ def gather_warped_frames( image_paths: list[Path] = [] labelmap_paths: list[Optional[Path]] = [] mask_paths: list[Optional[Path]] = [] - for image in sorted(method_dir.glob("*.mha")): + for image in sorted(initial_registration_dir.glob("*.mha")): if image.name.endswith(companion_suffixes): continue stem = image.name[:-4] - labelmap = method_dir / f"{stem}_labelmap.mha" - mask = method_dir / f"{stem}_mask.mha" + labelmap = initial_registration_dir / f"{stem}_labelmap.mha" + mask = initial_registration_dir / f"{stem}_mask.mha" + if not image.exists() or not labelmap.exists() or not mask.exists(): + print( + f" {patient_id}: image {image} or labelmap {labelmap} or mask {mask} not found" + ) + continue image_paths.append(image) - labelmap_paths.append(labelmap if labelmap.exists() else None) - mask_paths.append(mask if mask.exists() else None) + labelmap_paths.append(labelmap) + mask_paths.append(mask) + print( + f" {patient_id}: {len(image_paths)} init frames, {len(labelmap_paths)} with labelmap, {len(mask_paths)} with mask" + ) return image_paths, labelmap_paths, mask_paths @@ -205,7 +184,7 @@ def gather_warped_frames( for f in frame_names: labelmap = seg_dir / f.replace(".nii.gz", "_labelmap.nii.gz") labelmap_paths.append(labelmap if labelmap.exists() else None) - mask = load_or_derive_mask(labelmap) + mask = seg_dir / f.replace(".nii.gz", "_mask.nii.gz") mask_paths.append(mask) train_image_files.append(image_paths) @@ -218,23 +197,21 @@ def gather_warped_frames( for subject_index, patient_id in enumerate(valid_train_subjects): - for method_name in initial_registration_methods: - method_dir = initial_registration_dir / method_name.lower() / patient_id - warped_images, warped_labelmaps, warped_masks = gather_warped_frames(method_dir) - if not warped_images: + for initial_registration_dir in initial_registration_dirs: + initial_registration_patient_dir = initial_registration_dir / patient_id + + init_images, init_labelmaps, init_masks = gather_init_frames( + initial_registration_patient_dir + ) + if not init_images: print( - f" {patient_id}/{method_name}: no initial-registered frames " - f"in {method_dir}" + f" {patient_id}: no initial-registered frames " + f"in {initial_registration_dir}" ) continue - train_image_files[subject_index].extend(warped_images) - train_labelmap_files[subject_index].extend(warped_labelmaps) - train_mask_files[subject_index].extend(warped_masks) - n_warped = sum(1 for labelmap in warped_labelmaps if labelmap is not None) - print( - f" {patient_id}/{method_name}: +{len(warped_images)} warped frames, " - f"{n_warped} with labelmap" - ) + train_image_files[subject_index].extend(init_images) + train_labelmap_files[subject_index].extend(init_labelmaps) + train_mask_files[subject_index].extend(init_masks) # %% workflow = WorkflowFineTuneICONRegistration( @@ -256,7 +233,7 @@ def gather_warped_frames( [str(mask_path) if mask_path is not None else None for mask_path in mask_paths] for mask_paths in train_mask_files ], - mask_dilation_mm=mask_dilation_mm, + mask_dilation_mm=0, # masks are already dilated unigradicon_src_path=unigradicon_src_path, epochs=500, ) diff --git a/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py b/experiments/LongitudinalRegistration/3-eval_icon.py similarity index 61% rename from experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py rename to experiments/LongitudinalRegistration/3-eval_icon.py index 6b61734..b67e553 100644 --- a/experiments/LongitudinalRegistration/3-recon_4d_icon_eval.py +++ b/experiments/LongitudinalRegistration/3-eval_icon.py @@ -24,9 +24,10 @@ import itk import numpy as np -from physiomotion4d import RegisterTimeSeriesImages +from physiomotion4d import RegisterTimeSeriesImages, SegmentHeartSimpleware from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.landmark_tools import LandmarkTools +from physiomotion4d.transform_tools import TransformTools # %% [markdown] # ## 1. Hard-coded paths and configuration @@ -37,11 +38,12 @@ segmentation_base_dir = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") _HERE = Path(__file__).parent -output_dir = _HERE / "results" +output_dir = _HERE / "results_icon_eval" finetuned_weights_path = ( - output_dir - / "icon_finetuned" - / "icon_finetuned_model" + _HERE + / "results_finetuning" + / "icon_finetuning" + / "icon_finetuning_model-2" / "checkpoints" / "network_weights_final.trch" ) @@ -60,15 +62,11 @@ output_dir.mkdir(parents=True, exist_ok=True) detail_file = output_dir / "landmark_errors_by_point.csv" summary_file = output_dir / "registration_summary.csv" +warped_ref_detail_file = output_dir / "warped_ref_landmark_errors_by_point.csv" if detail_file.exists(): detail_file.unlink() - -# %% [markdown] -# ## 2. Derive the held-out test cohort -# -# The fixed split is: sort ``ref_data_dir`` by filename, take the *first* -# 80% of patients as train, the *last* 20% as test. ``1-finetune_icon.py`` -# applies the same rule so the two scripts agree without any cached record. +if warped_ref_detail_file.exists(): + warped_ref_detail_file.unlink() # %% ref_files = sorted( @@ -100,6 +98,9 @@ # %% landmark_tools = LandmarkTools() labelmap_tools = LabelmapTools() +transform_tools = TransformTools() +segmenter = SegmentHeartSimpleware() +segmenter.set_trim_branches(False) # %% [markdown] @@ -186,27 +187,21 @@ method_dir.mkdir(parents=True, exist_ok=True) for index, image_file in enumerate(image_files): - if index == reference_index: - continue timepoint = timepoints[index] - timepoint_dir = method_dir / timepoint - timepoint_dir.mkdir(parents=True, exist_ok=True) inverse_transform = result["inverse_transforms"][index] + itk.transformwrite( result["forward_transforms"][index], - str(timepoint_dir / "time_to_reference.hdf"), + str(method_dir / f"{subject_id}_g{timepoint}_forward_tfm.hdf"), compression=True, ) itk.transformwrite( inverse_transform, - str(timepoint_dir / "reference_to_time.hdf"), + str(method_dir / f"{subject_id}_g{timepoint}_inverse_tfm.hdf"), compression=True, ) - # inverse_transform follows the ITK resampler convention — it maps - # moving-grid points back to reference-grid points, which is what - # we need to warp time-point landmarks into reference space. timepoint_landmarks = moving_landmarks[index] shared = sorted(timepoint_landmarks.keys() & reference_landmarks.keys()) errors: list[tuple[str, float]] = [] @@ -244,6 +239,95 @@ } ) + # ------------------------------------------------------------------ + # Warp the reference image back onto each time-point's grid, re- + # segment with SegmentHeartSimpleware, and compare the resulting + # landmarks with the time-point's own precomputed landmarks. + # + # Per transform_conventions.rst: + # - warp fixed image -> moving grid => inverse_transform + + # TransformTools.transform_image (pull-back) + # - warp moving points -> fixed space => inverse_transform + + # .TransformPoint() (push-forward) + # Both use inverse_transform, but for opposite purposes. + # Here we use inverse_transform for image warping (row 3 of the + # table), placing the reference image in time-point space so + # Simpleware sees anatomy at the correct cardiac phase. + # + # Skip the reference frame — warping it to itself is trivial and + # its own landmarks are already the "ground truth" reference. + # ------------------------------------------------------------------ + if index == reference_index: + continue + + warped_ref = transform_tools.transform_image( + fixed_image, + inverse_transform, + moving_images[index], + interpolation_method="linear", + ) + itk.imwrite( + warped_ref, + method_dir / f"{subject_id}_g{timepoint}_warped_ref.mha", + compression=True, + ) + + seg_result = segmenter.segment(warped_ref, contrast_enhanced_study=False) + warped_ref_labelmap = seg_result["labelmap"] + warped_ref_landmarks = segmenter.get_landmarks() + + itk.imwrite( + warped_ref_labelmap, + str(method_dir / f"{subject_id}_g{timepoint}_warped_ref_labelmap.mha"), + compression=True, + ) + landmark_tools.write_landmarks_3dslicer( + warped_ref_landmarks, + str( + method_dir + / f"{subject_id}_g{timepoint}_warped_ref_landmarks.mrk.json" + ), + ) + + # Both warped_ref_landmarks and timepoint_landmarks are in the + # time-point (moving) image space — compare directly. + tp_landmarks = moving_landmarks[index] + shared_warp = sorted(warped_ref_landmarks.keys() & tp_landmarks.keys()) + warp_errors: list[tuple[str, float]] = [] + for name in shared_warp: + err = float( + np.linalg.norm( + np.asarray(warped_ref_landmarks[name], dtype=np.float64) + - np.asarray(tp_landmarks[name], dtype=np.float64) + ) + ) + warp_errors.append((name, err)) + print(f" Warped-ref landmark {name}: {err:.3f} mm") + + with warped_ref_detail_file.open("a", newline="", encoding="utf-8") as fh: + writer = csv.writer(fh) + if fh.tell() == 0: + writer.writerow( + [ + "subject_id", + "method", + "timepoint", + "name", + "error_mm", + ] + ) + for name, err in warp_errors: + writer.writerow([subject_id, method_name, timepoint, name, err]) + + warp_vals = np.asarray([e for _, e in warp_errors], dtype=np.float64) + if warp_vals.size: + print( + f" Warped-ref landmark errors ({timepoint}): " + f"mean={float(np.mean(warp_vals)):.3f} mm " + f"median={float(np.median(warp_vals)):.3f} mm " + f"max={float(np.max(warp_vals)):.3f} mm" + ) + # %% [markdown] # ## 5. Write the wide-form per-timepoint summary CSV @@ -291,3 +375,49 @@ f"{float(np.max(arr)):>12.3f}" ) print("=" * len(header)) + +# %% [markdown] +# ## 7. Per-method aggregate table: warped-reference landmark errors +# +# Compares landmarks extracted from the reference image warped back to each +# time-point's grid (via ``inverse_transform``) against that time-point's own +# precomputed landmarks. Both sets are in the moving (time-point) image space, +# so errors are Euclidean distances without any additional transform. + +# %% +if warped_ref_detail_file.exists(): + warp_groups: dict[str, list[float]] = {} + with warped_ref_detail_file.open(newline="", encoding="utf-8") as fh: + for row in csv.DictReader(fh): + warp_groups.setdefault(row["method"], []).append(float(row["error_mm"])) + + warp_header = ( + f"{'Method':<18}{'N':>8}{'Mean (mm)':>12}" + f"{'Median (mm)':>14}{'P95 (mm)':>12}{'Max (mm)':>12}" + ) + print() + print("=" * len(warp_header)) + print( + f"Warped-reference landmark error summary ({len(test_subjects)} test subjects)" + ) + print("=" * len(warp_header)) + print(warp_header) + print("-" * len(warp_header)) + for method_name, _ in methods: + arr = np.asarray(warp_groups.get(method_name, []), dtype=np.float64) + if arr.size == 0: + print(f"{method_name:<18}{0:>8}{'':>12}{'':>14}{'':>12}{'':>12}") + continue + print( + f"{method_name:<18}" + f"{arr.size:>8}" + f"{float(np.mean(arr)):>12.3f}" + f"{float(np.median(arr)):>14.3f}" + f"{float(np.percentile(arr, 95)):>12.3f}" + f"{float(np.max(arr)):>12.3f}" + ) + print("=" * len(warp_header)) +else: + print( + "No warped-reference landmark errors written (all frames were reference frames)." + ) diff --git a/experiments/LongitudinalRegistration/composite_time_series_mid_slice.py b/experiments/LongitudinalRegistration/composite_time_series_mid_slice.py new file mode 100644 index 0000000..f84dc8f --- /dev/null +++ b/experiments/LongitudinalRegistration/composite_time_series_mid_slice.py @@ -0,0 +1,159 @@ +"""Create a mid-slice composite volume from a directory of MHA images.""" + +import argparse +import re +import tkinter as tk +from pathlib import Path +from tkinter import filedialog +from typing import Optional + +import itk +import numpy as np + + +DEFAULT_IMAGE_REGEX = r"^pm00.*_init\.mha$" +OUTPUT_FILENAME = "composite.mha" + + +def select_directory() -> Optional[Path]: + """Open a directory chooser and return the selected directory.""" + root = tk.Tk() + root.withdraw() + root.update() + selected_dir = filedialog.askdirectory( + title="Select directory containing time-series MHA images" + ) + root.destroy() + if not selected_dir: + return None + return Path(selected_dir) + + +def find_image_files(input_dir: Path, image_regex: str) -> list[Path]: + """Return sorted image paths whose filename matches ``image_regex``.""" + pattern = re.compile(image_regex) + return sorted( + path + for path in input_dir.iterdir() + if path.is_file() and pattern.fullmatch(path.name) + ) + + +def extract_middle_slice(image_path: Path) -> tuple[np.ndarray, itk.Image]: + """Read ``image_path`` and return its middle Z slice and ITK image.""" + image = itk.imread(str(image_path)) + image_array = itk.array_from_image(image) + if image_array.ndim != 3: + raise ValueError( + f"Expected 3D image at {image_path}, got array shape {image_array.shape}" + ) + middle_slice_index = image_array.shape[0] // 2 + return image_array[middle_slice_index, :, :], image + + +def create_composite_volume(image_files: list[Path]) -> itk.Image: + """Stack middle Z slices from 3D ITK images into a composite volume.""" + if not image_files: + raise ValueError("No input images were provided") + + middle_slices: list[np.ndarray] = [] + first_image: Optional[itk.Image] = None + expected_shape: Optional[tuple[int, ...]] = None + for image_path in image_files: + middle_slice, image = extract_middle_slice(image_path) + if first_image is None: + first_image = image + expected_shape = middle_slice.shape + elif middle_slice.shape != expected_shape: + raise ValueError( + f"Middle slice shape mismatch for {image_path}: " + f"expected {expected_shape}, got {middle_slice.shape}" + ) + middle_slices.append(middle_slice) + + composite_array = np.stack(middle_slices, axis=0) + composite_image = itk.image_from_array(composite_array) + if first_image is None: + raise ValueError("No input images were read") + + input_spacing = first_image.GetSpacing() + composite_image.SetSpacing((float(input_spacing[0]), float(input_spacing[1]), 1.0)) + return composite_image + + +def adjacent_slice_rmse(composite_array: np.ndarray) -> list[float]: + """Return RMSE values between each adjacent slice in ``composite_array``.""" + if composite_array.ndim != 3: + raise ValueError( + f"Expected 3D composite array, got shape {composite_array.shape}" + ) + + rmse_values: list[float] = [] + for slice_index in range(composite_array.shape[0] - 1): + difference = composite_array[slice_index + 1].astype( + np.float64 + ) - composite_array[slice_index].astype(np.float64) + rmse_values.append(float(np.sqrt(np.mean(difference**2)))) + return rmse_values + + +def print_adjacent_slice_rmse(composite_array: np.ndarray) -> None: + """Print per-pair and total adjacent-slice RMSE values.""" + rmse_values = adjacent_slice_rmse(composite_array) + for slice_index, rmse_value in enumerate(rmse_values): + print(f"RMSE slice {slice_index} to {slice_index + 1}: {rmse_value:.6g}") + print(f"Total adjacent-slice RMSE: {sum(rmse_values):.6g}") + + +def write_composite(input_dir: Path, image_regex: str) -> Path: + """Create ``composite.mha`` in ``input_dir`` from matching image files.""" + if not input_dir.is_dir(): + raise NotADirectoryError(f"Input directory does not exist: {input_dir}") + + image_files = find_image_files(input_dir, image_regex) + if not image_files: + raise FileNotFoundError( + f"No image files in {input_dir} matched regex {image_regex!r}" + ) + + composite_image = create_composite_volume(image_files) + print_adjacent_slice_rmse(itk.array_from_image(composite_image)) + output_path = input_dir / OUTPUT_FILENAME + itk.imwrite(composite_image, str(output_path), compression=True) + return output_path + + +def main(argv: Optional[list[str]] = None) -> int: + """Run the mid-slice composite volume command.""" + parser = argparse.ArgumentParser( + description=( + "Create composite.mha from middle Z slices of images matching a regex." + ) + ) + parser.add_argument( + "directory", + nargs="?", + type=Path, + help="Directory containing input images. Opens a dialog when omitted.", + ) + parser.add_argument( + "--regex", + default=DEFAULT_IMAGE_REGEX, + help=f"Filename regex to match input images. Default: {DEFAULT_IMAGE_REGEX}", + ) + args = parser.parse_args(argv) + + input_dir = args.directory + if input_dir is None: + input_dir = select_directory() + if input_dir is None: + print("No directory selected") + return 1 + + output_path = write_composite(input_dir, args.regex) + print(f"Wrote {output_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/physiomotion4d/cli/reconstruct_highres_4d_ct.py b/src/physiomotion4d/cli/reconstruct_highres_4d_ct.py index 295de5e..13ee832 100644 --- a/src/physiomotion4d/cli/reconstruct_highres_4d_ct.py +++ b/src/physiomotion4d/cli/reconstruct_highres_4d_ct.py @@ -4,7 +4,7 @@ This script provides a CLI to reconstruct high-resolution 4D CT time series from lower-resolution time-series images and a single high-resolution reference image -using combined ANTS+ICON registration. +using combined Greedy+ICON registration. """ import argparse @@ -47,8 +47,8 @@ def main() -> int: %(prog)s \\ --time-series-images frame_*.mha \\ --fixed-image highres.mha \\ - --registration-method ANTS_ICON \\ - --ANTS-iterations 30 15 7 3 \\ + --registration-method Greedy_ICON \\ + --Greedy-iterations 30 15 7 3 \\ --ICON-iterations 20 \\ --output-dir ./results @@ -81,9 +81,9 @@ def main() -> int: # Registration configuration parser.add_argument( "--registration-method", - choices=["ANTS", "ICON", "ANTS_ICON"], - default="ANTS_ICON", - help="Registration method to use (default: ANTS_ICON)", + choices=["Greedy", "ICON", "Greedy_ICON"], + default="Greedy_ICON", + help="Registration method to use (default: Greedy_ICON)", ) parser.add_argument( "--reference-frame", @@ -106,10 +106,10 @@ def main() -> int: # Registration iterations parser.add_argument( - "--ANTS-iterations", + "--Greedy-iterations", nargs="+", type=int, - help="ANTs multi-resolution iterations (e.g., 30 15 7 3). Default: [30, 15, 7, 3]", + help="Greedy multi-resolution iterations (e.g., 30 15 7 3). Default: [30, 15, 7, 3]", ) parser.add_argument( "--ICON-iterations", @@ -291,10 +291,10 @@ def main() -> int: workflow.set_moving_masks(moving_masks) # Set number of iterations based on registration method and CLI arguments - if args.ants_iterations: - workflow.set_number_of_iterations_ANTS(args.ants_iterations) + if args.Greedy_iterations: + workflow.set_number_of_iterations_Greedy(args.Greedy_iterations) else: - workflow.set_number_of_iterations_ANTS([30, 15, 7, 3]) + workflow.set_number_of_iterations_Greedy([30, 15, 7, 3]) if args.icon_iterations: workflow.set_number_of_iterations_ICON(args.icon_iterations) diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 6ced80c..194cd63 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -1,8 +1,8 @@ """Time series image registration implementation. This module provides the RegisterTimeSeriesImages class for registering an ordered -sequence of images (time series) to a fixed image. It supports ANTs, Greedy, -ICON, and combined ANTs/Greedy initialization followed by ICON refinement. +sequence of images (time series) to a fixed image. It supports Greedy, ICON, and +combined Greedy initialization followed by ICON refinement. The class is particularly useful for 4D medical imaging applications such as cardiac CT where sequential frames need to be registered to a common frame. @@ -13,18 +13,15 @@ import itk -from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_base import RegisterImagesBase from physiomotion4d.register_images_greedy import RegisterImagesGreedy from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.transform_tools import TransformTools REGISTRATION_METHODS: list[str] = [ - "ANTS", - "greedy", + "Greedy", "ICON", - "ANTS_ICON", - "greedy_ICON", + "Greedy_ICON", ] @@ -32,8 +29,8 @@ class RegisterTimeSeriesImages(RegisterImagesBase): """Register a time series of images to a fixed image. This class extends RegisterImagesBase to provide sequential registration of - multiple images (time series) to a fixed image. It supports ANTs, Greedy, - ICON, and combined ANTs/Greedy initialization followed by ICON refinement. + multiple images (time series) to a fixed image. It supports Greedy, ICON, and + combined Greedy initialization followed by ICON refinement. It can propagate information from prior registrations to initialize subsequent ones. @@ -46,26 +43,25 @@ class RegisterTimeSeriesImages(RegisterImagesBase): Key features: - Sequential registration of ordered image lists - - Support for ANTs, Greedy, ICON, ANTs+ICON, and Greedy+ICON backends + - Support for Greedy, ICON, and Greedy+ICON backends - Optional use of prior transforms to initialize next registration - Configurable starting point in the time series - Returns all transforms and loss values for the entire series Attributes: - registration_method_name (str): Registration method in use ('ANTS', - 'greedy', 'ICON', 'ANTS_ICON', or 'greedy_ICON'). - registrar_ANTS (RegisterImagesANTS): Internal ANTs registrar. + registration_method_name (str): Registration method in use ('Greedy', + 'ICON', or 'Greedy_ICON'). registrar_greedy (RegisterImagesGreedy): Internal Greedy registrar. registrar_ICON (RegisterImagesICON): Internal ICON registrar (also used - as the refinement stage for 'ANTS_ICON' and 'greedy_ICON'). + as the refinement stage for 'Greedy_ICON'). transform_tools (TransformTools): Utility for transform operations. Example: >>> # Register a cardiac CT time series - >>> registrar = RegisterTimeSeriesImages(registration_method='ANTS') + >>> registrar = RegisterTimeSeriesImages(registration_method='Greedy') >>> registrar.set_modality('ct') >>> registrar.set_fixed_image(fixed_image) - >>> registrar.set_number_of_iterations_ANTS([40, 20, 10]) + >>> registrar.set_number_of_iterations_greedy([40, 20, 10]) >>> >>> # Register all time points to fixed image >>> result = registrar.register_time_series( @@ -88,14 +84,13 @@ class RegisterTimeSeriesImages(RegisterImagesBase): """ def __init__( - self, registration_method: str = "ANTS", log_level: int | str = logging.INFO + self, registration_method: str = "Greedy", log_level: int | str = logging.INFO ) -> None: """Initialize the time series image registration class. Args: registration_method (str): Registration method to use. - Options: 'ANTS', 'greedy', 'ICON', 'ANTS_ICON', or - 'greedy_ICON'. Default: 'ANTS' + Options: 'Greedy', 'ICON', or 'Greedy_ICON'. Default: 'Greedy' log_level: Logging level (default: logging.INFO) Raises: @@ -105,12 +100,10 @@ def __init__( self.registration_method_name: str = registration_method - self.registrar_ANTS = RegisterImagesANTS(log_level=log_level) self.registrar_greedy = RegisterImagesGreedy(log_level=log_level) self.registrar_ICON = RegisterImagesICON(log_level=log_level) # Set default iterations based on registration method - self.number_of_iterations_ANTS: list[int] = [40, 20, 10] self.number_of_iterations_greedy: list[int] = [40, 20, 10] self.number_of_iterations_ICON: int = 50 @@ -124,17 +117,6 @@ def __init__( self.smooth_prior_transform_sigma: float = 0.5 - def set_number_of_iterations_ANTS( - self, number_of_iterations_ANTS: list[int] - ) -> None: - """Set the number of iterations for ANTs registration. - - Args: - number_of_iterations_ANTS: List of iterations for ANTs multi-resolution - (e.g., [40, 20, 10] for three resolution levels) - """ - self.number_of_iterations_ANTS = number_of_iterations_ANTS - def set_number_of_iterations_ICON(self, number_of_iterations_ICON: int) -> None: """Set the number of iterations for ICON registration. @@ -282,10 +264,10 @@ def register_time_series( calling this method. Example: - >>> registrar = RegisterTimeSeriesImages(registration_method='ANTS') + >>> registrar = RegisterTimeSeriesImages(registration_method='Greedy') >>> registrar.set_fixed_image(fixed_image) >>> registrar.set_fixed_mask(fixed_mask) # Optional - >>> registrar.set_number_of_iterations_ANTS([30, 15, 5]) + >>> registrar.set_number_of_iterations_greedy([30, 15, 5]) >>> >>> # Use new intuitive parameter names >>> result = registrar.register_time_series( @@ -309,14 +291,7 @@ def register_time_series( if self.fixed_image is None: raise ValueError("Fixed image must be set before registering time series") - if self.registration_method_name == "ANTS": - self.registrar_ANTS.set_fixed_image(self.fixed_image) - self.registrar_ANTS.set_modality(self.modality) - self.registrar_ANTS.set_mask_dilation(self.mask_dilation_mm) - self.registrar_ANTS.set_number_of_iterations(self.number_of_iterations_ANTS) - self.registrar_ANTS.set_fixed_mask(self.fixed_mask) - self.registrar_ANTS.set_fixed_labelmap(self.fixed_labelmap) - elif self.registration_method_name == "greedy": + if self.registration_method_name in ["Greedy", "Greedy_ICON"]: self.registrar_greedy.set_fixed_image(self.fixed_image) self.registrar_greedy.set_modality(self.modality) self.registrar_greedy.set_mask_dilation(self.mask_dilation_mm) @@ -325,32 +300,7 @@ def register_time_series( ) self.registrar_greedy.set_fixed_mask(self.fixed_mask) self.registrar_greedy.set_fixed_labelmap(self.fixed_labelmap) - elif self.registration_method_name == "ICON": - self.registrar_ICON.set_fixed_image(self.fixed_image) - self.registrar_ICON.set_modality(self.modality) - self.registrar_ICON.set_mask_dilation(self.mask_dilation_mm) - self.registrar_ICON.set_number_of_iterations(self.number_of_iterations_ICON) - self.registrar_ICON.set_fixed_mask(self.fixed_mask) - self.registrar_ICON.set_fixed_labelmap(self.fixed_labelmap) - elif self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: - if self.registration_method_name == "ANTS_ICON": - self.registrar_ANTS.set_fixed_image(self.fixed_image) - self.registrar_ANTS.set_modality(self.modality) - self.registrar_ANTS.set_mask_dilation(self.mask_dilation_mm) - self.registrar_ANTS.set_number_of_iterations( - self.number_of_iterations_ANTS - ) - self.registrar_ANTS.set_fixed_mask(self.fixed_mask) - self.registrar_ANTS.set_fixed_labelmap(self.fixed_labelmap) - else: - self.registrar_greedy.set_fixed_image(self.fixed_image) - self.registrar_greedy.set_modality(self.modality) - self.registrar_greedy.set_mask_dilation(self.mask_dilation_mm) - self.registrar_greedy.set_number_of_iterations( - self.number_of_iterations_greedy - ) - self.registrar_greedy.set_fixed_mask(self.fixed_mask) - self.registrar_greedy.set_fixed_labelmap(self.fixed_labelmap) + if self.registration_method_name in ["ICON", "Greedy_ICON"]: self.registrar_ICON.set_fixed_image(self.fixed_image) self.registrar_ICON.set_modality(self.modality) self.registrar_ICON.set_mask_dilation(self.mask_dilation_mm) @@ -403,13 +353,7 @@ def register_time_series( if moving_labelmaps is not None else None ) - if self.registration_method_name == "ANTS": - result = self.registrar_ANTS.register( - moving_images[reference_frame], - moving_mask=reference_mask, - moving_labelmap=reference_labelmap, - ) - elif self.registration_method_name == "greedy": + if self.registration_method_name == "Greedy": result = self.registrar_greedy.register( moving_images[reference_frame], moving_mask=reference_mask, @@ -421,13 +365,8 @@ def register_time_series( moving_mask=reference_mask, moving_labelmap=reference_labelmap, ) - elif self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: - registrar_initial = ( - self.registrar_ANTS - if self.registration_method_name == "ANTS_ICON" - else self.registrar_greedy - ) - result = registrar_initial.register( + elif self.registration_method_name == "Greedy_ICON": + result = self.registrar_greedy.register( moving_images[reference_frame], moving_mask=reference_mask, moving_labelmap=reference_labelmap, @@ -489,13 +428,7 @@ def register_time_series( ) # Try registration with identity initialization - if self.registration_method_name == "ANTS": - result_init_identity = self.registrar_ANTS.register( - moving_image=moving_image, - moving_mask=moving_mask, - moving_labelmap=moving_labelmap, - ) - elif self.registration_method_name == "greedy": + if self.registration_method_name == "Greedy": result_init_identity = self.registrar_greedy.register( moving_image=moving_image, moving_mask=moving_mask, @@ -507,13 +440,8 @@ def register_time_series( moving_mask=moving_mask, moving_labelmap=moving_labelmap, ) - elif self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: - registrar_initial = ( - self.registrar_ANTS - if self.registration_method_name == "ANTS_ICON" - else self.registrar_greedy - ) - result_init_identity = registrar_initial.register( + elif self.registration_method_name == "Greedy_ICON": + result_init_identity = self.registrar_greedy.register( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, @@ -536,14 +464,7 @@ def register_time_series( # Select best result based on prior usage if prior_weight > 0.0: # Try with prior transform initialization - if self.registration_method_name == "ANTS": - result_init_prior = self.registrar_ANTS.register( - moving_image=moving_image, - moving_mask=moving_mask, - moving_labelmap=moving_labelmap, - initial_forward_transform=prior_forward, - ) - elif self.registration_method_name == "greedy": + if self.registration_method_name == "Greedy": result_init_prior = self.registrar_greedy.register( moving_image=moving_image, moving_mask=moving_mask, @@ -557,13 +478,8 @@ def register_time_series( moving_labelmap=moving_labelmap, initial_forward_transform=prior_forward, ) - elif self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: - registrar_initial = ( - self.registrar_ANTS - if self.registration_method_name == "ANTS_ICON" - else self.registrar_greedy - ) - result_init_prior = registrar_initial.register( + elif self.registration_method_name == "Greedy_ICON": + result_init_prior = self.registrar_greedy.register( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, @@ -660,7 +576,7 @@ def reconstruct_time_series( ValueError: If lengths of moving_images and inverse_transforms don't match Example: - >>> registrar = RegisterTimeSeriesImages(registration_method='ANTS') + >>> registrar = RegisterTimeSeriesImages(registration_method='Greedy') >>> registrar.set_fixed_image(fixed_image) >>> >>> result = registrar.register_time_series( @@ -782,20 +698,7 @@ def registration_method( Returns: dict: Registration result with forward_transform, inverse_transform, and loss """ - if self.registration_method_name == "ANTS": - res = self.registrar_ANTS.registration_method( - moving_image=moving_image, - moving_mask=moving_mask, - moving_labelmap=moving_labelmap, - moving_image_pre=moving_image_pre, - initial_forward_transform=initial_forward_transform, - ) - return { - "forward_transform": cast(itk.Transform, res["forward_transform"]), - "inverse_transform": cast(itk.Transform, res["inverse_transform"]), - "loss": float(cast(float, res["loss"])), - } - if self.registration_method_name == "greedy": + if self.registration_method_name == "Greedy": res = self.registrar_greedy.registration_method( moving_image=moving_image, moving_mask=moving_mask, @@ -821,13 +724,8 @@ def registration_method( "inverse_transform": cast(itk.Transform, res["inverse_transform"]), "loss": float(cast(float, res["loss"])), } - if self.registration_method_name in ["ANTS_ICON", "greedy_ICON"]: - registrar_initial = ( - self.registrar_ANTS - if self.registration_method_name == "ANTS_ICON" - else self.registrar_greedy - ) - initial_res = registrar_initial.registration_method( + if self.registration_method_name == "Greedy_ICON": + initial_res = self.registrar_greedy.registration_method( moving_image=moving_image, moving_mask=moving_mask, moving_labelmap=moving_labelmap, diff --git a/src/physiomotion4d/workflow_convert_image_to_usd.py b/src/physiomotion4d/workflow_convert_image_to_usd.py index 640cf68..554eb69 100644 --- a/src/physiomotion4d/workflow_convert_image_to_usd.py +++ b/src/physiomotion4d/workflow_convert_image_to_usd.py @@ -19,8 +19,8 @@ from physiomotion4d.contour_tools import ContourTools from physiomotion4d.convert_image_4d_to_3d import ConvertImage4DTo3D from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.register_images_ants import RegisterImagesANTS from physiomotion4d.register_images_base import RegisterImagesBase +from physiomotion4d.register_images_greedy import RegisterImagesGreedy from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.segment_anatomy_base import SegmentAnatomyBase from physiomotion4d.segment_chest_total_segmentator import SegmentChestTotalSegmentator @@ -36,7 +36,7 @@ ) #: Supported registration backend identifiers. -REGISTRATION_METHODS: tuple[str, ...] = ("ANTS", "ICON") +REGISTRATION_METHODS: tuple[str, ...] = ("Greedy", "ICON") class WorkflowConvertImageToUSD(PhysioMotion4DBase): @@ -95,7 +95,7 @@ def __init__( or ``'HeartSimplewareTrimmedBranches'`` (HeartSimpleware with pulmonary/great-vessel branches trimmed to the cardiac region). registration_method (str): Registration method to use: - ``'ANTS'`` or ``'ICON'`` (default: ``'ICON'``). + ``'Greedy'`` or ``'ICON'`` (default: ``'ICON'``). times_per_second: Frames per second for animated USD time series. Defaults to 24.0, matching the underlying VTK-to-USD converter. log_level: Logging level (default: logging.INFO) @@ -159,23 +159,23 @@ def __init__( # Initialize registration method self.registrar: RegisterImagesBase - if self.registration_method == "ANTS": - self.log_info("Initializing ANTs registration...") - ants_registrar = RegisterImagesANTS(log_level=log_level) - ants_registrar.set_modality("ct") - ants_registrar.set_transform_type("Deformable") + if self.registration_method == "Greedy": + self.log_info("Initializing Greedy registration...") + greedy_registrar = RegisterImagesGreedy(log_level=log_level) + greedy_registrar.set_modality("ct") + greedy_registrar.set_transform_type("Deformable") if ( number_of_registration_iterations is not None and number_of_registration_iterations > 0 ): - ants_registrar.set_number_of_iterations( + greedy_registrar.set_number_of_iterations( [ number_of_registration_iterations, number_of_registration_iterations // 2, 0, ] ) - self.registrar = ants_registrar + self.registrar = greedy_registrar else: # ICON (default) self.log_info("Initializing ICON registration...") icon_registrar = RegisterImagesICON(log_level=log_level) diff --git a/src/physiomotion4d/workflow_fine_tune_icon_registration.py b/src/physiomotion4d/workflow_fine_tune_icon_registration.py index c16e9f9..c6ee539 100644 --- a/src/physiomotion4d/workflow_fine_tune_icon_registration.py +++ b/src/physiomotion4d/workflow_fine_tune_icon_registration.py @@ -5,7 +5,7 @@ from ``experiments/LongitudinalRegistration``: 1. **Fine-tuning**: build a paired dataset JSON and YAML config from per-subject - lists of image files (with optional segmentation labelmaps and landmark CSVs) + lists of image files (with optional labelmaps and landmark CSVs) and launch ``unigradicon.finetuning.finetune`` as a subprocess. Mirrors ``experiments/LongitudinalRegistration/1-finetune_icon.py``. 2. **Apply**: load a fine-tuned uniGradICON checkpoint and register a list of @@ -26,7 +26,7 @@ transform that maps moving-grid points back to reference-grid points; ``forward_transform`` is the inverse direction (reference grid → moving grid). Landmarks are warped using ``TransformPoint`` and - images/segmentations are resampled via + images/labelmaps are resampled via :meth:`TransformTools.transform_image`. """ @@ -57,7 +57,7 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): **Stage 1: Fine-tuning** (file-based) Build a paired dataset JSON and YAML config from per-subject lists of - image, segmentation, and landmark files, then launch + image, labelmap, and landmark files, then launch ``unigradicon.finetuning.finetune`` as a subprocess. Each subject's time-point images form one paired group (they share a ``subject_id``). @@ -65,8 +65,8 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): Register a list of moving images to a single reference image using the fine-tuned ICON weights and return both directions of the warp: - - moving images / segmentations / landmarks warped into reference space - - the reference image / segmentation / landmarks warped into each + - moving images / labelmaps / landmarks warped into reference space + - the reference image / labelmap / landmarks warped into each moving-image space Attributes: @@ -80,9 +80,9 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): identifiers). Written into the dataset JSON's ``subject_id`` field; falls back to synthetic ``subject_NNNN`` when ``None``. subject_labelmap_files (Optional[list[list[Optional[str]]]]): - Per-subject multi-label segmentation/labelmap paths aligned with + Per-subject multi-label labelmap paths aligned with ``subject_image_files``. ``None`` (or per-image ``None``) means no - segmentation for that image. If supplied for at least one image, + labelmap for that image. If supplied for at least one image, paired-with-seg training is enabled. subject_mask_files (Optional[list[list[Optional[str]]]]): Per-subject binary mask paths aligned with ``subject_image_files``. @@ -104,7 +104,7 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): registrar (RegisterTimeSeriesImages): ICON-backend registrar used in :meth:`apply_registration`. transform_tools (TransformTools): Utility for resampling images and - segmentations. + labelmaps. Example: >>> # Stage 1: fine-tune @@ -127,8 +127,8 @@ class WorkflowFineTuneICONRegistration(PhysioMotion4DBase): ... reference_image=ref_image, ... moving_images=moving_images, ... weights_path=weights_path, - ... reference_segmentation=ref_seg, - ... moving_segmentations=moving_segs, + ... reference_labelmap=ref_seg, + ... moving_labelmaps=moving_segs, ... ) >>> warped_to_ref = result['moving_to_reference_images'] >>> warped_to_moving = result['reference_to_moving_images'] @@ -143,7 +143,7 @@ def __init__( subject_labelmap_files: Optional[list[list[Optional[str]]]] = None, subject_mask_files: Optional[list[list[Optional[str]]]] = None, subject_landmark_files: Optional[list[list[Optional[str]]]] = None, - epochs: int = 2000, + epochs: int = 500, batch_size: int = 4, learning_rate: float = 5e-5, input_shape: tuple[int, int, int] = (175, 175, 175), @@ -262,7 +262,7 @@ def __init__( self.subject_mask_files = subject_mask_files self.subject_landmark_files = subject_landmark_files - self.use_segmentations: bool = subject_labelmap_files is not None + self.use_labelmaps: bool = subject_labelmap_files is not None self.use_masks: bool = ( subject_mask_files is not None or subject_labelmap_files is not None ) @@ -295,7 +295,7 @@ def __init__( self.labelmap_tools = LabelmapTools(log_level=log_level) self.registrar: Optional[RegisterTimeSeriesImages] = None - self._use_segmentations: bool = self.use_segmentations + self._use_labelmaps: bool = self.use_labelmaps self._use_masks: bool = self.use_masks self._dataset_json_path: Optional[Path] = None @@ -374,7 +374,7 @@ def _derive_mask( def prepare_dataset( self, - use_segmentations: Optional[bool] = None, + use_labelmaps: Optional[bool] = None, use_masks: Optional[bool] = None, ) -> Path: """Write the uniGradICON dataset JSON from the configured file lists. @@ -399,12 +399,12 @@ def prepare_dataset( """ self.experiment_dir.mkdir(parents=True, exist_ok=True) - if use_segmentations is None: - use_segmentations = self.use_segmentations + if use_labelmaps is None: + use_labelmaps = self.use_labelmaps if use_masks is None: use_masks = self.use_masks - self._use_segmentations = use_segmentations + self._use_labelmaps = use_labelmaps self._use_masks = use_masks dataset_entries: list[dict[str, str]] = [] @@ -415,7 +415,7 @@ def prepare_dataset( else f"subject_{subject_index:04d}" ) seg_list: list[Optional[str]] - if not use_segmentations: + if not use_labelmaps: seg_list = [None] * len(image_files) else: seg_list = ( @@ -450,7 +450,7 @@ def prepare_dataset( "subject_id": subject_id, } - if use_segmentations: + if use_labelmaps: if seg_file is None or not Path(seg_file).exists(): self.log_warning( "Skipping %s: segmentation missing for paired-with-seg " @@ -536,7 +536,7 @@ def prepare_config(self, dataset_json_path: Optional[Path] = None) -> Path: "dice_loss_weight": self.dice_loss_weight, "lncc_sigma": self.lncc_sigma, "loss_function_masking": self._use_masks, - "use_label": self._use_segmentations, + "use_label": False, "roi_masking": False, }, "datasets": [ @@ -637,11 +637,13 @@ def apply_registration( reference_image: itk.Image, moving_images: list[itk.Image], weights_path: Optional[Union[str, Path]] = None, - reference_segmentation: Optional[itk.Image] = None, + reference_labelmap: Optional[itk.Image] = None, + reference_mask: Optional[itk.Image] = None, reference_landmarks: Optional[Landmarks] = None, - moving_segmentations: Optional[list[Optional[itk.Image]]] = None, + moving_labelmaps: Optional[list[Optional[itk.Image]]] = None, + moving_masks: Optional[list[Optional[itk.Image]]] = None, moving_landmarks: Optional[list[Optional[Landmarks]]] = None, - number_of_iterations: int = 20, + number_of_iterations: int = 100, modality: str = "ct", ) -> dict[str, Any]: """Register each moving image to the reference using fine-tuned ICON weights. @@ -654,12 +656,12 @@ def apply_registration( ROI; the same is done for the reference segmentation (used as the fixed mask). - Warps the moving image, segmentation, and landmarks into reference - space using ``forward_transform``. Segmentations use nearest-neighbor + space using ``forward_transform``. Labelmaps use nearest-neighbor interpolation. Landmarks use ``inverse_transform.TransformPoint`` (resampler-convention transform: maps moving-grid points back to reference-grid points). - Warps the reference image, segmentation, and landmarks into each - moving-image space using ``inverse_transform`` for image/segmentation + moving-image space using ``inverse_transform`` for image/labelmap resampling and ``forward_transform.TransformPoint`` for landmarks. Args: @@ -669,15 +671,22 @@ def apply_registration( weights_path: Path to a uniGradICON checkpoint (e.g. ``Finetune_multi_final.trch``). ``None`` uses the default pretrained uniGradICON weights. - reference_segmentation: Optional multi-label labelmap aligned with + reference_labelmap: Optional multi-label labelmap aligned with ``reference_image``. Used to derive the fixed-image mask and returned warped into each moving-image space. + reference_mask: Optional binary mask aligned with ``reference_image``. + Used to derive the fixed-image mask and returned warped into each + moving-image space. reference_landmarks: Optional ``{name: (x, y, z)}`` landmark dict in LPS that will be warped into each moving-image space. - moving_segmentations: Optional per-moving multi-label labelmaps + moving_labelmaps: Optional per-moving multi-label labelmaps aligned with ``moving_images``. Used to derive per-moving masks and returned warped into reference space. Per-image ``None`` entries are allowed. + moving_masks: Optional per-moving binary mask paths aligned with + ``moving_images``. Used to derive per-moving masks and returned + warped into reference space. Per-image ``None`` entries are + allowed. moving_landmarks: Optional per-moving landmark dicts in LPS. Each set is warped into reference space. Per-image ``None`` entries are allowed. @@ -697,8 +706,8 @@ def apply_registration( - ``losses`` (``list[float]``): per-moving registration loss. - ``moving_to_reference_images`` (``list[itk.Image]``): each moving image resampled onto the reference grid. - - ``moving_to_reference_segmentations`` (``list[Optional[itk.Image]]``): - each moving segmentation resampled onto the reference grid + - ``moving_to_reference_labelmaps`` (``list[Optional[itk.Image]]``): + each moving labelmap resampled onto the reference grid with nearest-neighbor interpolation. ``None`` when the input was ``None``. - ``moving_to_reference_landmarks`` (``list[Optional[Landmarks]]``): @@ -706,27 +715,32 @@ def apply_registration( ``None`` when the input was ``None``. - ``reference_to_moving_images`` (``list[itk.Image]``): the reference image resampled onto each moving grid. - - ``reference_to_moving_segmentations`` (``list[Optional[itk.Image]]``): + - ``reference_to_moving_labelmaps`` (``list[Optional[itk.Image]]``): the reference segmentation resampled onto each moving grid with nearest-neighbor interpolation. ``None`` for every - entry when ``reference_segmentation`` was ``None``. + entry when ``reference_labelmap`` was ``None``. - ``reference_to_moving_landmarks`` (``list[Optional[Landmarks]]``): reference landmarks warped into each moving space. ``None`` for every entry when ``reference_landmarks`` was ``None``. Raises: ValueError: If ``moving_images`` is empty. - ValueError: If ``moving_segmentations`` or ``moving_landmarks`` is + ValueError: If ``moving_labelmaps`` or ``moving_landmarks`` is supplied with a length that does not match ``moving_images``. """ if not moving_images: raise ValueError("moving_images must not be empty") num_moving = len(moving_images) - if moving_segmentations is not None and len(moving_segmentations) != num_moving: + if moving_labelmaps is not None and len(moving_labelmaps) != num_moving: raise ValueError( - f"moving_segmentations length ({len(moving_segmentations)}) must " + f"moving_labelmaps length ({len(moving_labelmaps)}) must " f"match moving_images length ({num_moving})" ) + if moving_masks is not None and len(moving_masks) != num_moving: + raise ValueError( + f"moving_masks length ({len(moving_masks)}) must match " + f"moving_images length ({num_moving})" + ) if moving_landmarks is not None and len(moving_landmarks) != num_moving: raise ValueError( f"moving_landmarks length ({len(moving_landmarks)}) must match " @@ -740,40 +754,43 @@ def apply_registration( else: self.log_info("ICON weights: %s", weights_path) - fixed_mask = ( - self.labelmap_tools.convert_labelmap_to_mask( - reference_segmentation, dilation_in_mm=self.mask_dilation_mm + if reference_mask is None: + reference_mask = ( + self.labelmap_tools.convert_labelmap_to_mask( + reference_labelmap, dilation_in_mm=self.mask_dilation_mm + ) + if reference_labelmap is not None + else None ) - if reference_segmentation is not None - else None - ) - moving_masks: Optional[list[Optional[itk.Image]]] = None - if moving_segmentations is not None: - moving_masks = [ - ( - self.labelmap_tools.convert_labelmap_to_mask( - seg, dilation_in_mm=self.mask_dilation_mm + + if moving_masks is None: + if moving_labelmaps is not None: + moving_masks = [ + ( + self.labelmap_tools.convert_labelmap_to_mask( + labelmap, dilation_in_mm=self.mask_dilation_mm + ) + if labelmap is not None + else None ) - if seg is not None - else None - ) - for seg in moving_segmentations - ] + for labelmap in moving_labelmaps + ] self.registrar = RegisterTimeSeriesImages( registration_method="ICON", log_level=self.log_level ) self.registrar.set_modality(modality) self.registrar.set_fixed_image(reference_image) - self.registrar.set_fixed_mask(fixed_mask) + self.registrar.set_fixed_mask(reference_mask) self.registrar.set_number_of_iterations_ICON(number_of_iterations) if weights_path is not None: self.registrar.registrar_ICON.set_weights_path(str(weights_path)) + # TODO: set reference frame and register reference result = self.registrar.register_time_series( moving_images=moving_images, moving_masks=moving_masks, - moving_labelmaps=None, + moving_labelmaps=moving_labelmaps, reference_frame=0, register_reference=True, prior_weight=0.0, @@ -783,10 +800,12 @@ def apply_registration( losses = result["losses"] moving_to_reference_images: list[itk.Image] = [] - moving_to_reference_segmentations: list[Optional[itk.Image]] = [] + moving_to_reference_labelmaps: list[Optional[itk.Image]] = [] + moving_to_reference_masks: list[Optional[itk.Image]] = [] moving_to_reference_landmarks: list[Optional[Landmarks]] = [] reference_to_moving_images: list[itk.Image] = [] - reference_to_moving_segmentations: list[Optional[itk.Image]] = [] + reference_to_moving_labelmaps: list[Optional[itk.Image]] = [] + reference_to_moving_masks: list[Optional[itk.Image]] = [] reference_to_moving_landmarks: list[Optional[Landmarks]] = [] for index in range(num_moving): @@ -805,41 +824,64 @@ def apply_registration( ) ) - moving_seg = ( - moving_segmentations[index] - if moving_segmentations is not None - else None + moving_labelmap = ( + moving_labelmaps[index] if moving_labelmaps is not None else None ) - if moving_seg is not None: - moving_to_reference_segmentations.append( + if moving_labelmap is not None: + moving_to_reference_labelmaps.append( + self.transform_tools.transform_image( + moving_labelmap, + forward_tfm, + reference_image, + interpolation_method="nearest", + ) + ) + else: + moving_to_reference_labelmaps.append(None) + + if reference_labelmap is not None: + reference_to_moving_labelmaps.append( + self.transform_tools.transform_image( + reference_labelmap, + inverse_tfm, + moving_image, + interpolation_method="nearest", + ) + ) + else: + reference_to_moving_labelmaps.append(None) + + moving_mask = moving_masks[index] if moving_masks is not None else None + if moving_mask is not None: + moving_to_reference_masks.append( self.transform_tools.transform_image( - moving_seg, + moving_mask, forward_tfm, reference_image, interpolation_method="nearest", ) ) else: - moving_to_reference_segmentations.append(None) + moving_to_reference_masks.append(None) - if reference_segmentation is not None: - reference_to_moving_segmentations.append( + if reference_mask is not None: + reference_to_moving_masks.append( self.transform_tools.transform_image( - reference_segmentation, + reference_mask, inverse_tfm, moving_image, interpolation_method="nearest", ) ) else: - reference_to_moving_segmentations.append(None) + reference_to_moving_masks.append(None) - moving_lms = ( + moving_lndmrks = ( moving_landmarks[index] if moving_landmarks is not None else None ) - if moving_lms is not None: + if moving_lndmrks is not None: moving_to_reference_landmarks.append( - self._transform_landmarks(moving_lms, inverse_tfm) + self._transform_landmarks(moving_lndmrks, inverse_tfm) ) else: moving_to_reference_landmarks.append(None) @@ -863,9 +905,11 @@ def apply_registration( "inverse_transforms": inverse_transforms, "losses": losses, "moving_to_reference_images": moving_to_reference_images, - "moving_to_reference_segmentations": moving_to_reference_segmentations, + "moving_to_reference_labelmaps": moving_to_reference_labelmaps, + "moving_to_reference_masks": moving_to_reference_masks, "moving_to_reference_landmarks": moving_to_reference_landmarks, "reference_to_moving_images": reference_to_moving_images, - "reference_to_moving_segmentations": reference_to_moving_segmentations, + "reference_to_moving_labelmaps": reference_to_moving_labelmaps, + "reference_to_moving_masks": reference_to_moving_masks, "reference_to_moving_landmarks": reference_to_moving_landmarks, } diff --git a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py index 0e2c384..0764c99 100644 --- a/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py +++ b/src/physiomotion4d/workflow_fit_statistical_model_to_patient.py @@ -32,7 +32,7 @@ from physiomotion4d.contour_tools import ContourTools from physiomotion4d.labelmap_tools import LabelmapTools from physiomotion4d.physiomotion4d_base import PhysioMotion4DBase -from physiomotion4d.register_images_ants import RegisterImagesANTS +from physiomotion4d.register_images_greedy import RegisterImagesGreedy from physiomotion4d.register_images_icon import RegisterImagesICON from physiomotion4d.register_models_distance_maps import RegisterModelsDistanceMaps from physiomotion4d.register_models_icp import RegisterModelsICP @@ -62,7 +62,7 @@ class WorkflowFitStatisticalModelToPatient(PhysioMotion4DBase): (e.g., cardiac models) to patient-specific surface models and images. The registration pipeline combines: - Initial model alignment using RegisterModelsICP (centroid + affine ICP) - - Mask-based deformable registration using RegisterModelsDistanceMaps (ANTs/ICON) + - Mask-based deformable registration using RegisterModelsDistanceMaps (Greedy/ICON) - Optional final mask-to-image refinement using Icon registration **Registration Pipeline:** @@ -94,7 +94,7 @@ class WorkflowFitStatisticalModelToPatient(PhysioMotion4DBase): roi_dilation_mm (float): Dilation for ROI mask transform_tools (TransformTools): Transform utilities registrar_ICON (RegisterImagesICON): ICON registration instance - registrar_ANTS (RegisterImagesANTS): ANTs registration instance + registrar_Greedy (RegisterImagesGreedy): Greedy registration instance use_pca_registration (bool): Whether PCA registration is enabled (set via set_use_pca_registration) pca_model (dict): PCA model dict when PCA enabled; same structure as WorkflowCreateStatisticalModel output pca_number_of_modes (int): Number of PCA modes when PCA enabled @@ -218,8 +218,8 @@ def __init__( ptype=itk.F, ) - self.registrar_ANTS = RegisterImagesANTS() - self.registrar_ANTS.set_number_of_iterations([5, 2, 5]) + self.registrar_Greedy = RegisterImagesGreedy() + self.registrar_Greedy.set_number_of_iterations([5, 2, 5]) # Icon registration for final mask-to-image step self.registrar_ICON = RegisterImagesICON() self.registrar_ICON.set_modality("ct") @@ -812,10 +812,10 @@ def register_labelmap_to_image( ) patient_roi = self._auto_generate_roi_mask(patient_mask) - self.registrar_ANTS.set_fixed_image(self.patient_image) - self.registrar_ANTS.set_fixed_mask(patient_roi) + self.registrar_Greedy.set_fixed_image(self.patient_image) + self.registrar_Greedy.set_fixed_mask(patient_roi) - result = self.registrar_ANTS.register( + result = self.registrar_Greedy.register( moving_image=labelmap, moving_mask=labelmap_roi ) self.m2i_inverse_transform = result["inverse_transform"] diff --git a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py index b3105c5..8b12136 100644 --- a/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py +++ b/src/physiomotion4d/workflow_reconstruct_highres_4d_ct.py @@ -4,7 +4,7 @@ a high-resolution 4D CT time series from lower-resolution time-series images and a single high-resolution reference image. -The workflow uses ANTS+ICON combined registration to: +The workflow uses Greedy+ICON combined registration to: 1. Register each time-series image to the high-resolution reference 2. Apply inverse transforms to reconstruct high-resolution time series 3. Optionally upsample to the reference image resolution @@ -16,7 +16,7 @@ Key Features: - Sequential time-series registration using RegisterTimeSeriesImages - - Combined ANTS+ICON registration for optimal results + - Combined Greedy+ICON registration for optimal results - Bidirectional registration from reference frame - Optional temporal smoothing with prior transforms - High-resolution reconstruction with optional upsampling @@ -37,11 +37,11 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): This class implements a workflow for reconstructing high-resolution dynamic CT images by registering low-resolution time-series images to a high-resolution - reference image using combined ANTS+ICON registration. + reference image using combined Greedy+ICON registration. **Registration Pipeline:** 1. **Time Series Registration**: Register each time-series image to the - high-resolution reference using RegisterTimeSeriesImages with ANTS_ICON method + high-resolution reference using RegisterTimeSeriesImages with Greedy_ICON method 2. **Reconstruction**: Apply inverse transforms to reconstruct high-resolution time series 3. **Optional Upsampling**: Resample to isotropic high resolution @@ -58,7 +58,7 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): register_reference (bool): Whether to register reference frame prior_weight (float): Weight for temporal smoothing (0.0-1.0) upsample_to_fixed_resolution (bool): Whether to upsample reconstruction - registration_method (str): Registration method ('ANTS', 'ICON', or 'ANTS_ICON') + registration_method (str): Registration method ('Greedy', 'ICON', or 'Greedy_ICON') number_of_iterations: Iterations for registration registrar (RegisterTimeSeriesImages): Internal registration object forward_transforms (list[itk.Transform]): one per frame; each warps its @@ -74,11 +74,11 @@ class WorkflowReconstructHighres4DCT(PhysioMotion4DBase): ... time_series_images=lowres_images, ... fixed_image=highres_reference, ... reference_frame=3, - ... registration_method='ANTS_ICON', + ... registration_method='Greedy_ICON', ... ) >>> >>> # Configure registration parameters - >>> workflow.set_number_of_iterations_ANTS([30, 15, 7]) + >>> workflow.set_number_of_iterations_Greedy([30, 15, 7]) >>> workflow.set_number_of_iterations_ICON(20) >>> workflow.set_prior_weight(0.5) >>> @@ -97,7 +97,7 @@ def __init__( fixed_image: itk.Image, reference_frame: int = 0, register_reference: bool = False, - registration_method: str = "ANTS_ICON", + registration_method: str = "Greedy_ICON", log_level: int | str = logging.INFO, ): """Initialize the high-resolution 4D CT reconstruction workflow. @@ -113,7 +113,7 @@ def __init__( to the fixed image. If False, use identity transform for reference. Default: False registration_method (str, optional): Registration method to use. - Options: 'ANTS', 'ICON', or 'ANTS_ICON'. Default: 'ANTS_ICON' + Options: 'Greedy', 'ICON', or 'Greedy_ICON'. Default: 'Greedy_ICON' log_level: Logging level (logging.DEBUG, logging.INFO, etc.). Default: logging.INFO @@ -137,9 +137,9 @@ def __init__( f"[0, {len(time_series_images) - 1}]" ) - if registration_method not in ["ANTS", "ICON", "ANTS_ICON"]: + if registration_method not in ["Greedy", "ICON", "Greedy_ICON"]: raise ValueError( - f"registration_method must be 'ANTS', 'ICON', or 'ANTS_ICON', " + f"registration_method must be 'Greedy', 'ICON', or 'Greedy_ICON', " f"got '{registration_method}'" ) @@ -159,7 +159,7 @@ def __init__( self.moving_masks: Optional[list[Optional[itk.Image]]] = None # Set default number of iterations based on registration method - self.number_of_iterations_ANTS: list[int] = [30, 15, 7, 3] + self.number_of_iterations_Greedy: list[int] = [30, 15, 7, 3] self.number_of_iterations_ICON: int = 20 # Initialize registrar @@ -173,16 +173,16 @@ def __init__( self.losses: Optional[list[float]] = None self.reconstructed_images: Optional[list[itk.Image]] = None - def set_number_of_iterations_ANTS( - self, number_of_iterations_ANTS: list[int] + def set_number_of_iterations_Greedy( + self, number_of_iterations_Greedy: list[int] ) -> None: - """Set the number of iterations for ANTs registration. + """Set the number of iterations for Greedy registration. Args: - number_of_iterations_ANTS: List of iterations for ANTs multi-resolution + number_of_iterations_Greedy: List of iterations for Greedy multi-resolution (e.g., [30, 15, 7, 3] for four resolution levels) """ - self.number_of_iterations_ANTS = number_of_iterations_ANTS + self.number_of_iterations_Greedy = number_of_iterations_Greedy def set_number_of_iterations_ICON(self, number_of_iterations_ICON: int) -> None: """Set the number of iterations for ICON registration. @@ -280,7 +280,7 @@ def register_time_series(self) -> dict: self.registrar.set_fixed_image(self.fixed_image) self.registrar.set_modality(self.modality) self.registrar.set_mask_dilation(self.mask_dilation_mm) - self.registrar.set_number_of_iterations_ANTS(self.number_of_iterations_ANTS) + self.registrar.set_number_of_iterations_greedy(self.number_of_iterations_Greedy) self.registrar.set_number_of_iterations_ICON(self.number_of_iterations_ICON) self.registrar.set_fixed_mask(self.fixed_mask) @@ -289,7 +289,9 @@ def register_time_series(self) -> dict: self.log_info(f"Reference frame: {self.reference_frame}") self.log_info(f"Register reference: {self.register_reference}") self.log_info(f"Prior weight: {self.prior_weight}") - self.log_info(f"Number of iterations (ANTs): {self.number_of_iterations_ANTS}") + self.log_info( + f"Number of iterations (Greedy): {self.number_of_iterations_Greedy}" + ) self.log_info(f"Number of iterations (ICON): {self.number_of_iterations_ICON}") # Perform registration diff --git a/tutorials/tutorial_06_reconstruct_highres_4d_ct.py b/tutorials/tutorial_06_reconstruct_highres_4d_ct.py index 09ddc09..386ad9e 100644 --- a/tutorials/tutorial_06_reconstruct_highres_4d_ct.py +++ b/tutorials/tutorial_06_reconstruct_highres_4d_ct.py @@ -45,7 +45,7 @@ OUTPUT_DIR = TUTORIALS_DIR / "output" / "tutorial_06" BASELINES_DIR = REPO_ROOT / "tests" / "baselines" MAX_FRAMES = 4 - REGISTRATION_METHOD = "ANTS" + REGISTRATION_METHOD = "Greedy_ICON" LOG_LEVEL = logging.INFO # %% @@ -59,10 +59,10 @@ if test_mode: max_frames = min(MAX_FRAMES, 3) - number_of_iterations_ANTS = [1, 0] + number_of_iterations_Greedy = [1, 0] else: max_frames = MAX_FRAMES - number_of_iterations_ANTS = [30, 15, 7, 3] + number_of_iterations_Greedy = [30, 15, 7, 3] output_dir.mkdir(parents=True, exist_ok=True) @@ -87,7 +87,7 @@ log_level=log_level, ) workflow.set_modality("ct") - workflow.set_number_of_iterations_ANTS(number_of_iterations_ANTS) + workflow.set_number_of_iterations_Greedy(number_of_iterations_Greedy) # %% # Workflow execution From 0a98f20007e422706aaf198967bcd79680d53281 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Wed, 10 Jun 2026 10:23:37 -0400 Subject: [PATCH 2/4] ENH: Coderabbit comments --- .../0-cardiacGatedCT_segment_and_landmark.py | 2 +- .../composite_time_series_mid_slice.py | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py b/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py index a45098f..834df73 100644 --- a/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py +++ b/experiments/LongitudinalRegistration/0-cardiacGatedCT_segment_and_landmark.py @@ -32,7 +32,7 @@ ref_data_dir = "d:/PhysioMotion4D/duke_data/ref_images" src_data_dir_base = "d:/PhysioMotion4D/duke_data/gated_nii" -segmentation_dir_base = "d:/PhysioMotion4D/duke_data/ascardio" +segmentation_dir_base = "d:/PhysioMotion4D/duke_data/simple_ascardio" ref_files = [ os.path.join(ref_data_dir, f) diff --git a/experiments/LongitudinalRegistration/composite_time_series_mid_slice.py b/experiments/LongitudinalRegistration/composite_time_series_mid_slice.py index f84dc8f..975bc66 100644 --- a/experiments/LongitudinalRegistration/composite_time_series_mid_slice.py +++ b/experiments/LongitudinalRegistration/composite_time_series_mid_slice.py @@ -10,20 +10,24 @@ import itk import numpy as np - DEFAULT_IMAGE_REGEX = r"^pm00.*_init\.mha$" OUTPUT_FILENAME = "composite.mha" def select_directory() -> Optional[Path]: """Open a directory chooser and return the selected directory.""" - root = tk.Tk() - root.withdraw() - root.update() - selected_dir = filedialog.askdirectory( - title="Select directory containing time-series MHA images" - ) - root.destroy() + try: + root = tk.Tk() + except tk.TclError: + return None + try: + root.withdraw() + root.update() + selected_dir = filedialog.askdirectory( + title="Select directory containing time-series MHA images" + ) + finally: + root.destroy() if not selected_dir: return None return Path(selected_dir) From f942eb2573ab9764a22c1e500553244d62d79af9 Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Thu, 11 Jun 2026 06:14:36 -0400 Subject: [PATCH 3/4] ENH: Simplify identification of reference images in finetuning eval --- .../LongitudinalRegistration/3-eval_icon.py | 244 ++++++++++-------- src/physiomotion4d/register_images_icon.py | 4 +- .../register_time_series_images.py | 7 +- ...st_workflow_fine_tune_icon_registration.py | 18 +- 4 files changed, 154 insertions(+), 119 deletions(-) diff --git a/experiments/LongitudinalRegistration/3-eval_icon.py b/experiments/LongitudinalRegistration/3-eval_icon.py index b67e553..d06277f 100644 --- a/experiments/LongitudinalRegistration/3-eval_icon.py +++ b/experiments/LongitudinalRegistration/3-eval_icon.py @@ -1,17 +1,16 @@ # %% [markdown] # # Evaluate ICON default vs finetuned weights on held-out longitudinal CT # -# Enumerates the Duke patient cohort by sorting ``ref_images/`` and uses the -# *last 20%* of patients as the held-out test set — the same fixed split -# applied by ``2-finetune_icon.py`` (first 80% train, last 20% test). For -# each test subject the 70th-percentile gated frame is selected as the -# reference and every other frame is registered to it twice with -# ``RegisterTimeSeriesImages``: once with the default uniGradICON weights and -# once with the finetuned checkpoint from ``2-finetune_icon.py``. The -# resampler-convention inverse transform (which maps moving-grid points back -# to reference-grid points) is applied to each time-point's precomputed -# landmarks to land them in reference space, and the Euclidean error against -# the reference landmarks is recorded. +# Enumerates the Duke patient cohort from ``timepoint_base_dir`` subdirectories +# and uses the *last 20%* of patients as the held-out test set — the same fixed +# split applied by ``2-finetune_icon.py`` (first 80% train, last 20% test). +# Each subject directory must contain exactly one file whose name ends with +# ``ref.nii.gz``; that file is the fixed image for all registration methods. +# All gated frames (``_g[0-9]{3}.nii.gz``) are registered to it as moving +# images. The resampler-convention inverse transform (which maps moving-grid +# points back to reference-grid points) is applied to each time-point's +# precomputed landmarks to land them in reference space, and the Euclidean +# error against the reference landmarks is recorded. # # Run interactively cell-by-cell; all paths are hard-coded. @@ -19,7 +18,6 @@ import csv import re from pathlib import Path -from typing import Optional import itk import numpy as np @@ -33,7 +31,6 @@ # ## 1. Hard-coded paths and configuration # %% -ref_data_dir = Path("d:/PhysioMotion4D/duke_data/ref_images") timepoint_base_dir = Path("d:/PhysioMotion4D/duke_data/gated_nii") segmentation_base_dir = Path("d:/PhysioMotion4D/duke_data/simple_ascardio") @@ -50,13 +47,40 @@ train_fraction = 0.8 icon_iterations = None -reference_percentile = 0.70 exclude_tokens = ["nop"] timepoint_re = re.compile(r"_g(?P[0-9]{3})") -methods: list[tuple[str, Optional[Path]]] = [ - ("icon_default", None), - ("icon_finetuned", finetuned_weights_path), +# Each entry: name, reg_method, weights_path, use_mask, greedy_iters +# All methods register gated frames to the 70th-percentile gated frame. +all_methods = [ + { + "name": "icon_default", + "reg_method": "ICON", + "weights_path": None, + "use_mask": True, + "greedy_iters": None, + }, + { + "name": "icon_finetuned", + "reg_method": "ICON", + "weights_path": finetuned_weights_path, + "use_mask": True, + "greedy_iters": None, + }, + { + "name": "Greedy", + "reg_method": "Greedy", + "weights_path": None, + "use_mask": True, + "greedy_iters": [80, 40, 5], + }, + { + "name": "Greedy_ICON", + "reg_method": "Greedy_ICON", + "weights_path": finetuned_weights_path, + "use_mask": True, + "greedy_iters": [80, 40, 5], + }, ] output_dir.mkdir(parents=True, exist_ok=True) @@ -69,12 +93,11 @@ warped_ref_detail_file.unlink() # %% -ref_files = sorted( - p - for p in ref_data_dir.iterdir() - if p.name.startswith("pm00") and p.suffixes[-2:] == [".nii", ".gz"] +all_patient_ids = sorted( + p.name + for p in timepoint_base_dir.iterdir() + if p.is_dir() and p.name.startswith("pm00") ) -all_patient_ids = [p.name[:6] for p in ref_files] n_train = max( 1, min(len(all_patient_ids) - 1, round(train_fraction * len(all_patient_ids))) ) @@ -85,6 +108,25 @@ ) print(f"Held-out test subjects: {test_subjects}") +# %% [markdown] +# ## 2. Validate that every test subject has exactly one reference file + +# %% +missing = [] +for subject_id in test_subjects: + ref_candidates = list((timepoint_base_dir / subject_id).glob("*ref.nii.gz")) + if len(ref_candidates) != 1: + missing.append( + f"{subject_id}: found {len(ref_candidates)} ref file(s)" + + (f" {ref_candidates}" if ref_candidates else "") + ) +if missing: + raise FileNotFoundError( + "Missing or ambiguous ref.nii.gz for test subjects:\n" + + "\n".join(f" {m}" for m in missing) + ) +print("All test subjects have exactly one ref.nii.gz") + # %% [markdown] # ## 3. Reader instance used in the per-frame inner loop # @@ -104,48 +146,52 @@ # %% [markdown] -# ## 4. Register and score every test subject under both ICON methods +# ## 4. Register and score every test subject under all methods +# +# All methods register each gated frame to the 70th-percentile gated frame as +# the fixed image. The per-frame metric pipeline is shared: landmark error and +# warped-reference re-segmentation. # %% summary_rows: list[dict[str, object]] = [] for subject_id in test_subjects: - source_dir = timepoint_base_dir / subject_id - print(f"Source directory: {source_dir}") - seg_dir = segmentation_base_dir / subject_id - print(f"Segmentation directory: {seg_dir}") + source_dir = timepoint_base_dir / subject_id + print(f"\nSubject {subject_id}") + + # --- Per-subject reference scan (fixed image for all methods) --- + ref_file = next(source_dir.glob("*ref.nii.gz")) + ref_stem = ref_file.name[:-7] + fixed_image = itk.imread(str(ref_file), pixel_type=itk.F) + fixed_mask_path = seg_dir / f"{ref_stem}_labelmap_mask.nii.gz" + fixed_labelmap_path = seg_dir / f"{ref_stem}_labelmap.nii.gz" + if fixed_mask_path.exists(): + fixed_mask = itk.imread(str(fixed_mask_path)) + else: + fixed_labelmap = itk.imread(str(fixed_labelmap_path)) + fixed_mask = labelmap_tools.convert_labelmap_to_mask( + fixed_labelmap, dilation_in_mm=3.0 + ) + itk.imwrite(fixed_mask, str(fixed_mask_path), compression=True) + fixed_landmarks = landmark_tools.read_landmarks_3dslicer( + str(seg_dir / f"{ref_stem}_landmark.mrk.json") + ) + print(f" Fixed: {ref_file.name}") + # --- Gated frames (moving images, shared across all methods) --- image_files = [ p for p in sorted(source_dir.glob("*.nii.gz")) if not any(t in p.name for t in exclude_tokens) + and timepoint_re.search(p.name) is not None ] - print(f"Found {len(image_files)} image files") stems = [p.name[:-7] for p in image_files] + timepoints = [timepoint_re.search(p.name).group("timepoint") for p in image_files] labelmap_files = [seg_dir / f"{s}_labelmap.nii.gz" for s in stems] mask_files = [seg_dir / f"{s}_labelmap_mask.nii.gz" for s in stems] landmark_files = [seg_dir / f"{s}_landmark.mrk.json" for s in stems] - timepoints = [timepoint_re.search(p.name).group("timepoint") for p in image_files] - - reference_index = int(round(reference_percentile * (len(image_files) - 1))) - print( - f"\nSubject {subject_id}: {len(image_files)} time points, " - f"reference index {reference_index} (g{timepoints[reference_index]})" - ) - - fixed_image = itk.imread(str(image_files[reference_index]), pixel_type=itk.F) - fixed_labelmap = itk.imread(str(labelmap_files[reference_index])) - if mask_files[reference_index].exists(): - fixed_mask = itk.imread(str(mask_files[reference_index])) - else: - fixed_mask = labelmap_tools.convert_labelmap_to_mask( - fixed_labelmap, dilation_in_mm=5.0 - ) - itk.imwrite(fixed_mask, str(mask_files[reference_index]), compression=True) - reference_landmarks = landmark_tools.read_landmarks_3dslicer( - landmark_files[reference_index] - ) + print(f" {len(image_files)} gated frames") moving_images = [itk.imread(str(p), pixel_type=itk.F) for p in image_files] moving_labelmaps = [itk.imread(str(p)) for p in labelmap_files] @@ -153,46 +199,57 @@ landmark_tools.read_landmarks_3dslicer(str(p)) for p in landmark_files ] moving_masks = [] - for index, p in enumerate(mask_files): + for i, p in enumerate(mask_files): if not p.exists(): mask = labelmap_tools.convert_labelmap_to_mask( - moving_labelmaps[index], dilation_in_mm=5.0 + moving_labelmaps[i], dilation_in_mm=3.0 ) itk.imwrite(mask, str(p), compression=True) moving_masks.append(mask) else: - mask = itk.imread(str(p)) - moving_masks.append(mask) + moving_masks.append(itk.imread(str(p))) + + # --- Per-method registration and scoring --- + for method_cfg in all_methods: + method_name = str(method_cfg["name"]) + reg_method = str(method_cfg["reg_method"]) + weights_path = method_cfg["weights_path"] + use_mask = bool(method_cfg["use_mask"]) + greedy_iters = method_cfg["greedy_iters"] - for method_name, weights_path in methods: print(f" Method: {method_name}") - registrar = RegisterTimeSeriesImages(registration_method="ICON") + registrar = RegisterTimeSeriesImages(registration_method=reg_method) registrar.set_modality("ct") registrar.set_fixed_image(fixed_image) - registrar.set_fixed_mask(fixed_mask) - registrar.set_number_of_iterations_ICON(icon_iterations) - if weights_path is not None: - registrar.registrar_ICON.set_weights_path(str(weights_path)) + if use_mask: + registrar.set_fixed_mask(fixed_mask) + if greedy_iters is not None: + registrar.set_number_of_iterations_greedy(greedy_iters) + if reg_method in ("ICON", "Greedy_ICON"): + registrar.set_number_of_iterations_ICON(icon_iterations) + if weights_path is not None: + registrar.registrar_ICON.set_weights_path(str(weights_path)) result = registrar.register_time_series( moving_images=moving_images, - moving_masks=moving_masks, - moving_labelmaps=moving_labelmaps, - reference_frame=reference_index, - register_reference=False, - prior_weight=0.0, + moving_masks=moving_masks if use_mask else None, + moving_labelmaps=moving_labelmaps if use_mask else None, + register_reference=True, + reference_frame=0, # Not used + prior_weight=0.0, # Not used ) method_dir = output_dir / method_name / subject_id method_dir.mkdir(parents=True, exist_ok=True) - for index, image_file in enumerate(image_files): + for index in range(len(image_files)): timepoint = timepoints[index] - + forward_transform = result["forward_transforms"][index] inverse_transform = result["inverse_transforms"][index] + loss = float(result["losses"][index]) itk.transformwrite( - result["forward_transforms"][index], + forward_transform, str(method_dir / f"{subject_id}_g{timepoint}_forward_tfm.hdf"), compression=True, ) @@ -202,15 +259,16 @@ compression=True, ) + # Landmark error: warp gated landmarks into fixed-image space. timepoint_landmarks = moving_landmarks[index] - shared = sorted(timepoint_landmarks.keys() & reference_landmarks.keys()) - errors: list[tuple[str, float]] = [] + shared = sorted(timepoint_landmarks.keys() & fixed_landmarks.keys()) + errors = [] for name in shared: warped = inverse_transform.TransformPoint(timepoint_landmarks[name]) err = float( np.linalg.norm( np.asarray(warped, dtype=np.float64) - - np.asarray(reference_landmarks[name], dtype=np.float64) + - np.asarray(fixed_landmarks[name], dtype=np.float64) ) ) errors.append((name, err)) @@ -229,9 +287,9 @@ { "subject_id": subject_id, "method": method_name, - "reference_timepoint": timepoints[reference_index], + "reference_timepoint": ref_stem, "timepoint": timepoint, - "loss": float(result["losses"][index]), + "loss": loss, "n_landmarks": int(values.size), "mean_mm": float(np.mean(values)) if values.size else "", "median_mm": float(np.median(values)) if values.size else "", @@ -239,27 +297,9 @@ } ) - # ------------------------------------------------------------------ - # Warp the reference image back onto each time-point's grid, re- - # segment with SegmentHeartSimpleware, and compare the resulting - # landmarks with the time-point's own precomputed landmarks. - # - # Per transform_conventions.rst: - # - warp fixed image -> moving grid => inverse_transform + - # TransformTools.transform_image (pull-back) - # - warp moving points -> fixed space => inverse_transform + - # .TransformPoint() (push-forward) - # Both use inverse_transform, but for opposite purposes. - # Here we use inverse_transform for image warping (row 3 of the - # table), placing the reference image in time-point space so - # Simpleware sees anatomy at the correct cardiac phase. - # - # Skip the reference frame — warping it to itself is trivial and - # its own landmarks are already the "ground truth" reference. - # ------------------------------------------------------------------ - if index == reference_index: - continue - + # Warp fixed image back onto the gated frame's grid, re-segment, and + # compare the resulting landmarks against the gated frame's own + # precomputed landmarks. warped_ref = transform_tools.transform_image( fixed_image, inverse_transform, @@ -268,7 +308,7 @@ ) itk.imwrite( warped_ref, - method_dir / f"{subject_id}_g{timepoint}_warped_ref.mha", + str(method_dir / f"{subject_id}_g{timepoint}_warped_ref.mha"), compression=True, ) @@ -289,11 +329,9 @@ ), ) - # Both warped_ref_landmarks and timepoint_landmarks are in the - # time-point (moving) image space — compare directly. tp_landmarks = moving_landmarks[index] shared_warp = sorted(warped_ref_landmarks.keys() & tp_landmarks.keys()) - warp_errors: list[tuple[str, float]] = [] + warp_errors = [] for name in shared_warp: err = float( np.linalg.norm( @@ -308,13 +346,7 @@ writer = csv.writer(fh) if fh.tell() == 0: writer.writerow( - [ - "subject_id", - "method", - "timepoint", - "name", - "error_mm", - ] + ["subject_id", "method", "timepoint", "name", "error_mm"] ) for name, err in warp_errors: writer.writerow([subject_id, method_name, timepoint, name, err]) @@ -361,7 +393,8 @@ print("=" * len(header)) print(header) print("-" * len(header)) -for method_name, _ in methods: +for method_cfg in all_methods: + method_name = str(method_cfg["name"]) arr = np.asarray(groups.get(method_name, []), dtype=np.float64) if arr.size == 0: print(f"{method_name:<18}{0:>8}{'':>12}{'':>14}{'':>12}{'':>12}") @@ -403,7 +436,8 @@ print("=" * len(warp_header)) print(warp_header) print("-" * len(warp_header)) - for method_name, _ in methods: + for method_cfg in all_methods: + method_name = str(method_cfg["name"]) arr = np.asarray(warp_groups.get(method_name, []), dtype=np.float64) if arr.size == 0: print(f"{method_name:<18}{0:>8}{'':>12}{'':>14}{'':>12}{'':>12}") diff --git a/src/physiomotion4d/register_images_icon.py b/src/physiomotion4d/register_images_icon.py index 65c527e..e157b70 100644 --- a/src/physiomotion4d/register_images_icon.py +++ b/src/physiomotion4d/register_images_icon.py @@ -76,7 +76,7 @@ def __init__(self, log_level: int | str = logging.INFO) -> None: super().__init__(log_level=log_level) self.net = None - self.number_of_iterations: int = 50 + self.number_of_iterations: Optional[int] = 50 self.use_multi_modality: bool = False self.use_mass_preservation: bool = False self.weights_path: Optional[str] = None @@ -99,7 +99,7 @@ def set_weights_path(self, weights_path: str) -> None: self.weights_path = weights_path self.net = None # force reload on next register() call - def set_number_of_iterations(self, number_of_iterations: int) -> None: + def set_number_of_iterations(self, number_of_iterations: Optional[int]) -> None: """Set the number of iterations for ICON registration. Args: diff --git a/src/physiomotion4d/register_time_series_images.py b/src/physiomotion4d/register_time_series_images.py index 194cd63..092ec00 100644 --- a/src/physiomotion4d/register_time_series_images.py +++ b/src/physiomotion4d/register_time_series_images.py @@ -105,7 +105,7 @@ def __init__( # Set default iterations based on registration method self.number_of_iterations_greedy: list[int] = [40, 20, 10] - self.number_of_iterations_ICON: int = 50 + self.number_of_iterations_ICON: Optional[int] = 50 if self.registration_method_name not in REGISTRATION_METHODS: raise ValueError( @@ -117,7 +117,9 @@ def __init__( self.smooth_prior_transform_sigma: float = 0.5 - def set_number_of_iterations_ICON(self, number_of_iterations_ICON: int) -> None: + def set_number_of_iterations_ICON( + self, number_of_iterations_ICON: Optional[int] + ) -> None: """Set the number of iterations for ICON registration. Args: @@ -236,7 +238,6 @@ def register_time_series( is used (each registration starts from identity). Higher values provide more temporal smoothness but may propagate errors. Default: 0.0 - Returns: dict: Dictionary containing results: - "forward_transforms" (list[itk.Transform]): one per image; diff --git a/tests/test_workflow_fine_tune_icon_registration.py b/tests/test_workflow_fine_tune_icon_registration.py index 311fd96..902a201 100644 --- a/tests/test_workflow_fine_tune_icon_registration.py +++ b/tests/test_workflow_fine_tune_icon_registration.py @@ -127,7 +127,7 @@ def test_init_rejects_mismatched_subject_ids_length(tmp_path: Path) -> None: ) -def test_use_segmentations_and_use_masks_flags(tmp_path: Path) -> None: +def test_use_labelmaps_and_use_masks_flags(tmp_path: Path) -> None: """The two helper flags reflect supplied companions independently.""" base: dict[str, Any] = { "subject_image_files": [["a"]], @@ -135,19 +135,19 @@ def test_use_segmentations_and_use_masks_flags(tmp_path: Path) -> None: "fine_tune_name": "x", } none_wf = WorkflowFineTuneICONRegistration(**base) - assert not none_wf.use_segmentations + assert not none_wf.use_labelmaps assert not none_wf.use_masks seg_only = WorkflowFineTuneICONRegistration( **base, subject_labelmap_files=[["seg.nii.gz"]] ) - assert seg_only.use_segmentations + assert seg_only.use_labelmaps assert seg_only.use_masks # derived from segs mask_only = WorkflowFineTuneICONRegistration( **base, subject_mask_files=[["mask.nii.gz"]] ) - assert not mask_only.use_segmentations + assert not mask_only.use_labelmaps assert mask_only.use_masks @@ -349,8 +349,8 @@ def test_prepare_config_emits_uniGradICON_yaml( assert training["learning_rate"] == 1e-4 assert training["input_shape"] == [64, 64, 64] assert training["gpus"] == [1] - # Driven by data availability. - assert training["use_label"] is True + # loss_function_masking is driven by data availability; use_label is always False. + assert training["use_label"] is False assert training["loss_function_masking"] is True assert training["roi_masking"] is False @@ -505,7 +505,7 @@ def test_apply_registration_rejects_empty_moving(tmp_path: Path) -> None: def test_apply_registration_rejects_mismatched_companions(tmp_path: Path) -> None: - """moving_segmentations / moving_landmarks length must match moving_images.""" + """moving_labelmaps / moving_landmarks length must match moving_images.""" workflow = WorkflowFineTuneICONRegistration( subject_image_files=[["a"]], output_dir=tmp_path, @@ -514,11 +514,11 @@ def test_apply_registration_rejects_mismatched_companions(tmp_path: Path) -> Non ) ref = itk.image_from_array(np.zeros((3, 3, 3), dtype=np.float32)) mov = itk.image_from_array(np.zeros((3, 3, 3), dtype=np.float32)) - with pytest.raises(ValueError, match="moving_segmentations length"): + with pytest.raises(ValueError, match="moving_labelmaps length"): workflow.apply_registration( reference_image=ref, moving_images=[mov], - moving_segmentations=[], + moving_labelmaps=[], ) with pytest.raises(ValueError, match="moving_landmarks length"): workflow.apply_registration( From 837e2182ac4f164da13330c3e188c8d37242521f Mon Sep 17 00:00:00 2001 From: Stephen Aylward Date: Thu, 11 Jun 2026 06:47:54 -0400 Subject: [PATCH 4/4] ENH: coderabbit tweaks --- experiments/LongitudinalRegistration/2-finetune_icon.py | 3 ++- experiments/LongitudinalRegistration/3-eval_icon.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/experiments/LongitudinalRegistration/2-finetune_icon.py b/experiments/LongitudinalRegistration/2-finetune_icon.py index 7cc3900..a9e5073 100644 --- a/experiments/LongitudinalRegistration/2-finetune_icon.py +++ b/experiments/LongitudinalRegistration/2-finetune_icon.py @@ -122,6 +122,7 @@ # %% def gather_init_frames( initial_registration_dir: Path, + patient_id: str, ) -> tuple[list[Path], list[Optional[Path]], list[Optional[Path]]]: """Return ``(init_image_paths, init_labelmap_paths, init_mask_paths)`` for one ``initial_registration_dir / `` directory. @@ -201,7 +202,7 @@ def gather_init_frames( initial_registration_patient_dir = initial_registration_dir / patient_id init_images, init_labelmaps, init_masks = gather_init_frames( - initial_registration_patient_dir + initial_registration_patient_dir, patient_id ) if not init_images: print( diff --git a/experiments/LongitudinalRegistration/3-eval_icon.py b/experiments/LongitudinalRegistration/3-eval_icon.py index d06277f..ef6ffb6 100644 --- a/experiments/LongitudinalRegistration/3-eval_icon.py +++ b/experiments/LongitudinalRegistration/3-eval_icon.py @@ -51,7 +51,7 @@ timepoint_re = re.compile(r"_g(?P[0-9]{3})") # Each entry: name, reg_method, weights_path, use_mask, greedy_iters -# All methods register gated frames to the 70th-percentile gated frame. +# All methods register gated frames to the per-subject reference scan (*ref.nii.gz). all_methods = [ { "name": "icon_default",