Skip to content

Commit

Permalink
Merge pull request #62 from BrainLesion/61-transform-resampling-does-…
Browse files Browse the repository at this point in the history
…not-work

Ensuring transform and log files are generated
  • Loading branch information
neuronflow authored Apr 16, 2024
2 parents 760f5e7 + 81cb2a3 commit 5133c97
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 174 deletions.
Binary file added data/tcia_aaac_t1ce_transform.mat
Binary file not shown.
2 changes: 1 addition & 1 deletion ereg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TODO do we need this?
from .functional import registration_function
from .functional import registration_function, resample_function
from .registration import RegistrationClass
from .utils.io import read_image_and_cast_to_32bit_float
48 changes: 40 additions & 8 deletions ereg/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import SimpleITK as sitk

from ereg.registration import RegistrationClass
from ereg.utils import initialize_configuration


def registration_function(
Expand All @@ -28,13 +29,7 @@ def registration_function(
Returns:
float: The structural similarity index.
"""
if configuration is not None:
if isinstance(configuration, str):
assert os.path.isfile(configuration), "Config file does not exist."
elif isinstance(configuration, dict):
pass
else:
raise ValueError("Config file must be a string or dictionary.")
configuration = initialize_configuration(configuration)

registration_obj = RegistrationClass(configuration)
registration_obj.register(
Expand All @@ -48,4 +43,41 @@ def registration_function(
return registration_obj.ssim_score


# TODO we also need a transformation/resample function
def resample_function(
target_image: Union[str, sitk.Image],
moving_image: Union[str, sitk.Image],
output_image: str,
transform_file: str,
configuration: Union[str, dict] = None,
log_file: str = None,
**kwargs,
) -> float:
"""
Resample the moving image onto the space of the target image using a given transformation.
Args:
target_image (Union[str, sitk.Image]): The target image onto which the moving image will be resampled.
moving_image (Union[str, sitk.Image]): The image to be resampled.
output_image (str): The filename or path where the resampled image will be saved.
transform_file (str): The file containing the transformation to be applied.
configuration (Union[str, dict], optional): The configuration file or dictionary. Defaults to None.
log_file (str, optional): The file to log progress and details of the resampling process. Defaults to None.
**kwargs: Additional keyword arguments to be passed to the resampling function.
Returns:
float: The structural similarity index (SSIM) between the resampled image and the target image.
"""
configuration = initialize_configuration(configuration)

registration_obj = RegistrationClass(configuration)

registration_obj.resample_image(
target_image=target_image,
moving_image=moving_image,
output_image=output_image,
transform_file=transform_file,
log_file=log_file,
**kwargs,
)

return registration_obj.ssim_score
140 changes: 89 additions & 51 deletions ereg/registration.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import Union

import numpy as np

# from pprint import pprint
import SimpleITK as sitk
import yaml

from ereg.utils.io import read_image_and_cast_to_32bit_float
from ereg.utils.io import initialize_configuration, read_image_and_cast_to_32bit_float
from ereg.utils.metrics import get_ssim

logger = logging.getLogger(__name__)


class RegistrationClass:
def __init__(
Expand All @@ -25,6 +26,9 @@ def __init__(
Args:
config_file (Union[str, dict]): The config file or dictionary.
"""

self._setup_logger()

self.available_metrics = [
"mattes_mutual_information",
"ants_neighborhood_correlation",
Expand Down Expand Up @@ -67,10 +71,11 @@ def __init__(
self.total_attempts = 5
self.transform = None

configuration = initialize_configuration(configuration)
self.parameters = self._generate_default_parameters()
self.ssim_score = None
if configuration is not None:
self.update_parameters(configuration)
else:
self.parameters = self._generate_default_parameters()

def _generate_default_parameters(self) -> dict:
python_file_path = Path(os.path.normpath(os.path.abspath(__file__)))
Expand Down Expand Up @@ -134,24 +139,17 @@ def register(
output_image (str): The output image.
transform_file (str, optional): The transform file. Defaults to None.
"""
log_file = self._get_log_file(output_image, log_file)
self._set_log_file(log_file)

if log_file is None:
# TODO this will create trouble for non ".nii.gz" files
log_file = output_image.replace(".nii.gz", ".log")
logging.basicConfig(
filename=log_file,
format="%(asctime)s,%(name)s,%(levelname)s,%(message)s",
datefmt="%H:%M:%S",
level=logging.DEBUG,
)
self.logger = logging.getLogger("registration")
logger.info(f"{'register: Starting registration':=^100}")

self.logger.info(f"Target image: {target_image}, Moving image: {moving_image}")
logger.info(f"Target image: {target_image}, Moving image: {moving_image}")
target_image = read_image_and_cast_to_32bit_float(target_image)
moving_image = read_image_and_cast_to_32bit_float(moving_image)

if self.parameters.get("bias"):
self.logger.info("Bias correcting images.")
logger.info("Bias correcting images.")
target_image = self._bias_correct_image(target_image)
moving_image = self._bias_correct_image(moving_image)

Expand All @@ -162,15 +160,14 @@ def register(
try:
self.transform = sitk.ReadTransform(transform_file)
compute_transform = False
logger.info("Specified transform file already exists.")
except:
self.logger.info(
logger.warning(
"Could not read transform file. Computing transform."
)
pass
if compute_transform:
self.logger.info(
f"Starting registration with parameters:: {self.parameters}"
)
logger.info(f"Starting registration with parameters:: {self.parameters}")
self.transform = self._register_image_and_get_transform(
target_image=target_image,
moving_image=moving_image,
Expand All @@ -183,15 +180,15 @@ def register(
"composite_transform", None
)
if self.parameters.get("composite_transform"):
self.logger.info("Applying composite transform.")
logger.info("Applying composite transform.")
transform_composite = sitk.ReadTransform(
self.parameters["composite_transform"]
)
self.transform = sitk.CompositeTransform(
transform_composite, self.transform
)

self.logger.info("Applying previous transforms.")
logger.info("Applying previous transforms.")
current_transform = None
for previous_transform in self.parameters["previous_transforms"]:
previous_transform = sitk.ReadTransform(previous_transform)
Expand All @@ -203,15 +200,13 @@ def register(

self.transform = current_transform

# no need for logging since resample_image will log by itself
logging.shutdown()

# resample the moving image to the target image
self.resample_image(
target_image=target_image,
moving_image=moving_image,
output_image=output_image,
transform_file=transform_file,
log_file=log_file,
)

def resample_image(
Expand All @@ -227,34 +222,38 @@ def resample_image(
Resample the moving image to the target image.
Args:
logger (logging.Logger): The logger to use.
target_image (Union[str, sitk.Image]): The target image.
moving_image (Union[str, sitk.Image]): The moving image.
output_image (str): The output image.
transform_file (str, optional): The transform file. Defaults to None.
"""

log_file = self._get_log_file(output_image, log_file)
self._set_log_file(log_file)

logger.info(f"{'resample_image: Starting transformation':=^100}")

# check if output image exists
if not os.path.exists(output_image):
if self.transform is not None:
if log_file is None:
# TODO this will create trouble for non ".nii.gz" file
log_file = output_image.replace(".nii.gz", ".log")
logging.basicConfig(
filename=log_file,
format="%(asctime)s,%(name)s,%(levelname)s,%(message)s",
datefmt="%H:%M:%S",
level=logging.DEBUG,
)
self.logger = logging.getLogger("registration")

self.logger.info(
if transform_file is not None:
assert os.path.isfile(transform_file), "Transform file does not exist."
transform_from_file = None
try:
transform_from_file = sitk.ReadTransform(transform_file)
assert (
transform_from_file is not None
), "Transform could not be read."
except Exception as e:
logger.error(f"Could not read transform file: {e}")
return None

logger.info(
f"Target image: {target_image}, Moving image: {moving_image}, Transform file: {transform_file}"
)
target_image = read_image_and_cast_to_32bit_float(target_image)
moving_image = read_image_and_cast_to_32bit_float(moving_image)

self.logger.info("Resampling image.")
logger.info("Resampling image.")
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(target_image)
interpolator_type = self.interpolator_type.get(
Expand All @@ -263,14 +262,13 @@ def resample_image(
)
resampler.SetInterpolator(interpolator_type)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(self.transform)
resampler.SetTransform(transform_from_file)
output_image_struct = resampler.Execute(moving_image)
sitk.WriteImage(output_image_struct, output_image)
self.ssim_score = get_ssim(target_image, output_image_struct)
self.logger.info(
logger.info(
f"SSIM score of moving against target image: {self.ssim_score}"
)
logging.shutdown()

return self.ssim_score

Expand Down Expand Up @@ -390,7 +388,7 @@ def _register_image_and_get_transform(
for dim in range(dimension):
physical_units *= target_image.GetSpacing()[dim]

self.logger.info("Initializing registration.")
logger.info("Initializing registration.")
registration = sitk.ImageRegistrationMethod()
self.parameters["metric_parameters"] = self.parameters.get(
"metric_parameters", {}
Expand Down Expand Up @@ -761,21 +759,19 @@ def _register_image_and_get_transform(
registration.SetInterpolator(sitk.sitkLinear)

# registration.AddCommand(sitk.sitkIterationEvent, lambda: R)
self.logger.info("Starting registration.")
logger.info("Starting registration.")
output_transform = None
for _ in range(self.parameters["attempts"]):
try:
output_transform = registration.Execute(target_image, moving_image)
break
except RuntimeError as e:
self.logger.warning(
"Registration failed with error: %s. Retrying." % (e)
)
logger.warning("Registration failed with error: %s. Retrying." % (e))
continue

assert output_transform is not None, "Registration failed."

self.logger.info(
logger.info(
f"Final Optimizer Parameters:: convergence={registration.GetOptimizerConvergenceValue()}, iterations={registration.GetOptimizerIteration()}, metric={registration.GetMetricValue()}, stop condition={registration.GetOptimizerStopConditionDescription()}"
)

Expand All @@ -801,3 +797,45 @@ def _register_image_and_get_transform(
tmp.SetCenter(registration_transform_sitk.GetCenter())
registration_transform_sitk = tmp
return registration_transform_sitk

def _setup_logger(self):
logging.basicConfig(
format="[%(levelname)-8s | %(module)-15s | L%(lineno)-5d] | %(asctime)s: %(message)s",
datefmt="%H:%M:%S",
level=logging.DEBUG,
)
self.log_file_handler = None

def _set_log_file(self, log_file: str | Path) -> None:
"""Set the log file for the logger.
Args:
log_file (str | Path): log file path
"""
if self.log_file_handler:
logging.getLogger().removeHandler(self.log_file_handler)

parent_dir = os.path.dirname(log_file)
# create parent dir if the path is more than just a file name
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
self.log_file_handler = logging.FileHandler(log_file)
self.log_file_handler.setFormatter(
logging.Formatter(
"[%(levelname)-8s | %(module)-15s | L%(lineno)-5d] | %(asctime)s: %(message)s",
"%Y-%m-%dT%H:%M:%S%z",
)
)

# Add the file handler to the !root! logger
logging.getLogger().addHandler(self.log_file_handler)

def _get_log_file(self, output_image: str, log_file: str = None) -> str:
if log_file is None:
extensions = Path(output_image).suffixes
output_image_base = output_image
for ext in extensions:
output_image_base = output_image_base.replace(ext, "")
return output_image_base + ".log"
else:
return log_file
1 change: 1 addition & 0 deletions ereg/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .io import initialize_configuration, read_image_and_cast_to_32bit_float
23 changes: 22 additions & 1 deletion ereg/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# import os, tempfile, requestsw
import os
from typing import Union

import SimpleITK as sitk
Expand Down Expand Up @@ -42,3 +42,24 @@ def read_image_and_cast_to_32bit_float(
caster = sitk.CastImageFilter()
caster.SetOutputPixelType(sitk.sitkFloat32)
return caster.Execute(input_image)


def initialize_configuration(configuration: Union[str, dict]) -> dict:
"""
Initialize the configuration dictionary.
Args:
configuration (Union[str, dict]): The configuration file or dictionary.
Returns:
dict: The configuration dictionary.
"""
if configuration is not None:
if isinstance(configuration, str):
assert os.path.isfile(configuration), "Config file does not exist."
elif isinstance(configuration, dict):
pass
else:
raise ValueError("Config file must be a string or dictionary.")

return configuration
Loading

0 comments on commit 5133c97

Please sign in to comment.