Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensuring transform and log files are generated #62

Merged
merged 33 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5a4f5d3
updated check
sarthakpati Apr 2, 2024
51568da
added check for transform file creation
sarthakpati Apr 3, 2024
8072831
apparently a `reload` was needed!
sarthakpati Apr 4, 2024
c68d039
updated test to only remove specific files
sarthakpati Apr 4, 2024
017cb6c
ensure logger pushes to console as well
sarthakpati Apr 4, 2024
386b3ef
added a check to see if log file was created
sarthakpati Apr 4, 2024
42d0e89
add resample function
neuronflow Apr 4, 2024
c49bc70
removed code duplication
sarthakpati Apr 4, 2024
46fd927
add to submodule
sarthakpati Apr 4, 2024
a0bd874
call `resample_function` from tests
sarthakpati Apr 4, 2024
c8a4882
moved the config initialization function under utils
sarthakpati Apr 4, 2024
1415cc0
using the function
sarthakpati Apr 4, 2024
ffef980
initialize `self.parameters` and `self.ssim_score`
sarthakpati Apr 4, 2024
c5db080
separate registration and transform test
neuronflow Apr 4, 2024
cf88264
Autoformat with black
brainless-bot[bot] Apr 4, 2024
b162819
do not clean tf file
neuronflow Apr 4, 2024
293b68c
combined test crashes
neuronflow Apr 4, 2024
2196163
create logger no matter whether output file exists
neuronflow Apr 4, 2024
fe7841b
fix log naming todo to also support non .nii.gz files
neuronflow Apr 4, 2024
2934eff
bump version
neuronflow Apr 4, 2024
8d97b7f
reverting because this was causing pip install issues
sarthakpati Apr 5, 2024
29550a9
we should **not** add dependencies for such trivial tasks
sarthakpati Apr 5, 2024
dfa72d5
added missing return
sarthakpati Apr 5, 2024
9836d01
black .
sarthakpati Apr 5, 2024
4b6bf56
updated tests
sarthakpati Apr 5, 2024
d56b54d
Merge branch 'main' into 61-transform-resampling-does-not-work
neuronflow Apr 8, 2024
91fb002
added check for empty log files
sarthakpati Apr 8, 2024
ccd2aa2
check is fixed
sarthakpati Apr 8, 2024
7c1320f
lint removed
sarthakpati Apr 8, 2024
543a49f
- small rewrite of logger setup
MarcelRosier Apr 16, 2024
71bc3ff
improve: log format
MarcelRosier Apr 16, 2024
6d99009
migrate test to unit tests
MarcelRosier Apr 16, 2024
81cb2a3
fux: add future import to support python <3.10
MarcelRosier Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added data/tcia_aaac_t1ce_transform.mat
Binary file not shown.
2 changes: 1 addition & 1 deletion ereg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TODO do we need this?
from .functional import registration_function
from .functional import registration_function, resample_function
from .registration import RegistrationClass
from .utils.io import read_image_and_cast_to_32bit_float
48 changes: 40 additions & 8 deletions ereg/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import SimpleITK as sitk

from ereg.registration import RegistrationClass
from ereg.utils import initialize_configuration


def registration_function(
Expand All @@ -28,13 +29,7 @@ def registration_function(
Returns:
float: The structural similarity index.
"""
if configuration is not None:
if isinstance(configuration, str):
assert os.path.isfile(configuration), "Config file does not exist."
elif isinstance(configuration, dict):
pass
else:
raise ValueError("Config file must be a string or dictionary.")
configuration = initialize_configuration(configuration)

registration_obj = RegistrationClass(configuration)
registration_obj.register(
Expand All @@ -48,4 +43,41 @@ def registration_function(
return registration_obj.ssim_score


# TODO we also need a transformation/resample function
def resample_function(
target_image: Union[str, sitk.Image],
moving_image: Union[str, sitk.Image],
output_image: str,
transform_file: str,
configuration: Union[str, dict] = None,
log_file: str = None,
**kwargs,
) -> float:
"""
Resample the moving image onto the space of the target image using a given transformation.
Args:
target_image (Union[str, sitk.Image]): The target image onto which the moving image will be resampled.
moving_image (Union[str, sitk.Image]): The image to be resampled.
output_image (str): The filename or path where the resampled image will be saved.
transform_file (str): The file containing the transformation to be applied.
configuration (Union[str, dict], optional): The configuration file or dictionary. Defaults to None.
log_file (str, optional): The file to log progress and details of the resampling process. Defaults to None.
**kwargs: Additional keyword arguments to be passed to the resampling function.
Returns:
float: The structural similarity index (SSIM) between the resampled image and the target image.
"""
configuration = initialize_configuration(configuration)

registration_obj = RegistrationClass(configuration)

registration_obj.resample_image(
target_image=target_image,
moving_image=moving_image,
output_image=output_image,
transform_file=transform_file,
log_file=log_file,
**kwargs,
)

return registration_obj.ssim_score
140 changes: 89 additions & 51 deletions ereg/registration.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import Union

import numpy as np

# from pprint import pprint
import SimpleITK as sitk
import yaml

from ereg.utils.io import read_image_and_cast_to_32bit_float
from ereg.utils.io import initialize_configuration, read_image_and_cast_to_32bit_float
from ereg.utils.metrics import get_ssim

logger = logging.getLogger(__name__)


class RegistrationClass:
def __init__(
Expand All @@ -25,6 +26,9 @@ def __init__(
Args:
config_file (Union[str, dict]): The config file or dictionary.
"""

self._setup_logger()

self.available_metrics = [
"mattes_mutual_information",
"ants_neighborhood_correlation",
Expand Down Expand Up @@ -67,10 +71,11 @@ def __init__(
self.total_attempts = 5
self.transform = None

configuration = initialize_configuration(configuration)
self.parameters = self._generate_default_parameters()
self.ssim_score = None
if configuration is not None:
self.update_parameters(configuration)
else:
self.parameters = self._generate_default_parameters()

def _generate_default_parameters(self) -> dict:
python_file_path = Path(os.path.normpath(os.path.abspath(__file__)))
Expand Down Expand Up @@ -134,24 +139,17 @@ def register(
output_image (str): The output image.
transform_file (str, optional): The transform file. Defaults to None.
"""
log_file = self._get_log_file(output_image, log_file)
self._set_log_file(log_file)

if log_file is None:
# TODO this will create trouble for non ".nii.gz" files
log_file = output_image.replace(".nii.gz", ".log")
logging.basicConfig(
filename=log_file,
format="%(asctime)s,%(name)s,%(levelname)s,%(message)s",
datefmt="%H:%M:%S",
level=logging.DEBUG,
)
self.logger = logging.getLogger("registration")
logger.info(f"{'register: Starting registration':=^100}")

self.logger.info(f"Target image: {target_image}, Moving image: {moving_image}")
logger.info(f"Target image: {target_image}, Moving image: {moving_image}")
target_image = read_image_and_cast_to_32bit_float(target_image)
moving_image = read_image_and_cast_to_32bit_float(moving_image)

if self.parameters.get("bias"):
self.logger.info("Bias correcting images.")
logger.info("Bias correcting images.")
target_image = self._bias_correct_image(target_image)
moving_image = self._bias_correct_image(moving_image)

Expand All @@ -162,15 +160,14 @@ def register(
try:
self.transform = sitk.ReadTransform(transform_file)
compute_transform = False
logger.info("Specified transform file already exists.")
except:
self.logger.info(
logger.warning(
"Could not read transform file. Computing transform."
)
pass
if compute_transform:
self.logger.info(
f"Starting registration with parameters:: {self.parameters}"
)
logger.info(f"Starting registration with parameters:: {self.parameters}")
self.transform = self._register_image_and_get_transform(
target_image=target_image,
moving_image=moving_image,
Expand All @@ -183,15 +180,15 @@ def register(
"composite_transform", None
)
if self.parameters.get("composite_transform"):
self.logger.info("Applying composite transform.")
logger.info("Applying composite transform.")
transform_composite = sitk.ReadTransform(
self.parameters["composite_transform"]
)
self.transform = sitk.CompositeTransform(
transform_composite, self.transform
)

self.logger.info("Applying previous transforms.")
logger.info("Applying previous transforms.")
current_transform = None
for previous_transform in self.parameters["previous_transforms"]:
previous_transform = sitk.ReadTransform(previous_transform)
Expand All @@ -203,15 +200,13 @@ def register(

self.transform = current_transform

# no need for logging since resample_image will log by itself
logging.shutdown()

# resample the moving image to the target image
self.resample_image(
target_image=target_image,
moving_image=moving_image,
output_image=output_image,
transform_file=transform_file,
log_file=log_file,
)

def resample_image(
Expand All @@ -227,34 +222,38 @@ def resample_image(
Resample the moving image to the target image.
Args:
logger (logging.Logger): The logger to use.
target_image (Union[str, sitk.Image]): The target image.
moving_image (Union[str, sitk.Image]): The moving image.
output_image (str): The output image.
transform_file (str, optional): The transform file. Defaults to None.
"""

log_file = self._get_log_file(output_image, log_file)
self._set_log_file(log_file)

logger.info(f"{'resample_image: Starting transformation':=^100}")

# check if output image exists
if not os.path.exists(output_image):
if self.transform is not None:
if log_file is None:
# TODO this will create trouble for non ".nii.gz" file
log_file = output_image.replace(".nii.gz", ".log")
logging.basicConfig(
filename=log_file,
format="%(asctime)s,%(name)s,%(levelname)s,%(message)s",
datefmt="%H:%M:%S",
level=logging.DEBUG,
)
self.logger = logging.getLogger("registration")

self.logger.info(
if transform_file is not None:
assert os.path.isfile(transform_file), "Transform file does not exist."
transform_from_file = None
try:
transform_from_file = sitk.ReadTransform(transform_file)
assert (
transform_from_file is not None
), "Transform could not be read."
except Exception as e:
logger.error(f"Could not read transform file: {e}")
return None

logger.info(
f"Target image: {target_image}, Moving image: {moving_image}, Transform file: {transform_file}"
)
target_image = read_image_and_cast_to_32bit_float(target_image)
moving_image = read_image_and_cast_to_32bit_float(moving_image)

self.logger.info("Resampling image.")
logger.info("Resampling image.")
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(target_image)
interpolator_type = self.interpolator_type.get(
Expand All @@ -263,14 +262,13 @@ def resample_image(
)
resampler.SetInterpolator(interpolator_type)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(self.transform)
resampler.SetTransform(transform_from_file)
output_image_struct = resampler.Execute(moving_image)
sitk.WriteImage(output_image_struct, output_image)
self.ssim_score = get_ssim(target_image, output_image_struct)
self.logger.info(
logger.info(
f"SSIM score of moving against target image: {self.ssim_score}"
)
logging.shutdown()

return self.ssim_score

Expand Down Expand Up @@ -390,7 +388,7 @@ def _register_image_and_get_transform(
for dim in range(dimension):
physical_units *= target_image.GetSpacing()[dim]

self.logger.info("Initializing registration.")
logger.info("Initializing registration.")
registration = sitk.ImageRegistrationMethod()
self.parameters["metric_parameters"] = self.parameters.get(
"metric_parameters", {}
Expand Down Expand Up @@ -761,21 +759,19 @@ def _register_image_and_get_transform(
registration.SetInterpolator(sitk.sitkLinear)

# registration.AddCommand(sitk.sitkIterationEvent, lambda: R)
self.logger.info("Starting registration.")
logger.info("Starting registration.")
output_transform = None
for _ in range(self.parameters["attempts"]):
try:
output_transform = registration.Execute(target_image, moving_image)
break
except RuntimeError as e:
self.logger.warning(
"Registration failed with error: %s. Retrying." % (e)
)
logger.warning("Registration failed with error: %s. Retrying." % (e))
continue

assert output_transform is not None, "Registration failed."

self.logger.info(
logger.info(
f"Final Optimizer Parameters:: convergence={registration.GetOptimizerConvergenceValue()}, iterations={registration.GetOptimizerIteration()}, metric={registration.GetMetricValue()}, stop condition={registration.GetOptimizerStopConditionDescription()}"
)

Expand All @@ -801,3 +797,45 @@ def _register_image_and_get_transform(
tmp.SetCenter(registration_transform_sitk.GetCenter())
registration_transform_sitk = tmp
return registration_transform_sitk

def _setup_logger(self):
logging.basicConfig(
format="[%(levelname)-8s | %(module)-15s | L%(lineno)-5d] | %(asctime)s: %(message)s",
datefmt="%H:%M:%S",
level=logging.DEBUG,
)
self.log_file_handler = None

def _set_log_file(self, log_file: str | Path) -> None:
"""Set the log file for the logger.
Args:
log_file (str | Path): log file path
"""
if self.log_file_handler:
logging.getLogger().removeHandler(self.log_file_handler)

parent_dir = os.path.dirname(log_file)
# create parent dir if the path is more than just a file name
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
self.log_file_handler = logging.FileHandler(log_file)
self.log_file_handler.setFormatter(
logging.Formatter(
"[%(levelname)-8s | %(module)-15s | L%(lineno)-5d] | %(asctime)s: %(message)s",
"%Y-%m-%dT%H:%M:%S%z",
)
)

# Add the file handler to the !root! logger
logging.getLogger().addHandler(self.log_file_handler)

def _get_log_file(self, output_image: str, log_file: str = None) -> str:
if log_file is None:
extensions = Path(output_image).suffixes
output_image_base = output_image
for ext in extensions:
output_image_base = output_image_base.replace(ext, "")
return output_image_base + ".log"
else:
return log_file
1 change: 1 addition & 0 deletions ereg/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .io import initialize_configuration, read_image_and_cast_to_32bit_float
23 changes: 22 additions & 1 deletion ereg/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# import os, tempfile, requestsw
import os
from typing import Union

import SimpleITK as sitk
Expand Down Expand Up @@ -42,3 +42,24 @@ def read_image_and_cast_to_32bit_float(
caster = sitk.CastImageFilter()
caster.SetOutputPixelType(sitk.sitkFloat32)
return caster.Execute(input_image)


def initialize_configuration(configuration: Union[str, dict]) -> dict:
"""
Initialize the configuration dictionary.
Args:
configuration (Union[str, dict]): The configuration file or dictionary.
Returns:
dict: The configuration dictionary.
"""
if configuration is not None:
if isinstance(configuration, str):
assert os.path.isfile(configuration), "Config file does not exist."
elif isinstance(configuration, dict):
pass
else:
raise ValueError("Config file must be a string or dictionary.")

return configuration
Loading
Loading