From 8cb0c15bd1edc146fa1c943c0e861bbaf6ac11d3 Mon Sep 17 00:00:00 2001 From: camillebrianceau Date: Thu, 31 Oct 2024 16:45:44 +0100 Subject: [PATCH 1/2] some ideas --- clinicadl/API_test.py | 39 +- .../pipelines/generate/artifacts/cli.py | 2 +- .../pipelines/generate/hypometabolic/cli.py | 2 +- .../pipelines/generate/random/cli.py | 2 +- .../pipelines/generate/trivial/cli.py | 2 +- .../prepare_data/prepare_data_cli.py | 2 +- .../prepare_data_from_bids_cli.py | 2 +- clinicadl/dataset/caps_dataset.py | 612 ++++++------------ clinicadl/dataset/caps_dataset_config.py | 127 ---- clinicadl/dataset/caps_dataset_utils.py | 193 ------ clinicadl/dataset/caps_reader.py | 426 +++++++++++- clinicadl/dataset/concat.py | 129 +++- clinicadl/dataset/config/data.py | 75 +++ clinicadl/dataset/config/extraction.py | 474 +++++++++++++- clinicadl/dataset/config/preprocessing.py | 172 ++++- clinicadl/dataset/config/utils.py | 87 +++ clinicadl/dataset/data_config.py | 164 ----- clinicadl/dataset/dataloader_config.py | 18 - clinicadl/dataset/prepare_data/__init__.py | 0 .../dataset/prepare_data/prepare_data.py | 230 ------- .../prepare_data/prepare_data_utils.py | 442 ------------- clinicadl/dataset/utils.py | 162 +---- clinicadl/experiment_manager/maps_manager.py | 2 +- clinicadl/generate/generate_utils.py | 6 +- clinicadl/hugging_face/hugging_face.py | 2 +- clinicadl/model/clinicadl_model.py | 2 +- clinicadl/quality_check/t1_linear/utils.py | 4 +- clinicadl/trainer/old_trainer.py | 2 +- clinicadl/trainer/trainer.py | 2 +- clinicadl/transforms/transforms.py | 28 +- clinicadl/utils/enum.py | 26 + clinicadl/utils/exceptions.py | 4 + clinicadl/utils/iotools/__init__.py | 42 -- clinicadl/utils/iotools/clinica_utils.py | 7 +- clinicadl/utils/iotools/train_utils.py | 4 +- clinicadl/utils/iotools/utils.py | 66 +- 36 files changed, 1645 insertions(+), 1914 deletions(-) delete mode 100644 clinicadl/dataset/caps_dataset_config.py delete mode 100644 clinicadl/dataset/caps_dataset_utils.py create mode 100644 clinicadl/dataset/config/data.py create mode 100644 clinicadl/dataset/config/utils.py delete mode 100644 clinicadl/dataset/data_config.py delete mode 100644 clinicadl/dataset/dataloader_config.py delete mode 100644 clinicadl/dataset/prepare_data/__init__.py delete mode 100644 clinicadl/dataset/prepare_data/prepare_data.py delete mode 100644 clinicadl/dataset/prepare_data/prepare_data_utils.py diff --git a/clinicadl/API_test.py b/clinicadl/API_test.py index a6eb9fa72..a18440f93 100644 --- a/clinicadl/API_test.py +++ b/clinicadl/API_test.py @@ -1,6 +1,6 @@ from pathlib import Path -import torchio +import torchio.transforms as transforms from clinicadl.dataset.caps_dataset import ( CapsDatasetPatch, @@ -30,26 +30,32 @@ from clinicadl.splitter.kfold import KFolder from clinicadl.splitter.split import get_single_split, split_tsv from clinicadl.trainer.trainer import Trainer -from clinicadl.transforms.transforms import Transforms +from clinicadl.transforms.config import TransformsConfig # Create the Maps Manager / Read/write manager / maps_path = Path("/") -manager = ExperimentManager(maps_path, overwrite=False) +manager = ExperimentManager( + maps_path, overwrite=False +) # a ajouter dans le manager: mlflow/ profiler/ etc ... caps_directory = Path("caps_directory") # output of clinica pipelines caps_reader = CapsReader(caps_directory, manager=manager) preprocessing_1 = caps_reader.get_preprocessing("t1-linear") -extraction_1 = caps_reader.extract_slice(preprocessing=preprocessing_1, arg_slice=2) -transforms_1 = Transforms( - data_augmentation=[torchio.t1, torchio.t2], - image_transforms=[torchio.t1, torchio.t2], - object_transforms=[torchio.t1, torchio.t2], +caps_reader.prepare_data( + preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2 +) # don't return anything -> just extract the image tensor and compute some information for each images + + +transforms_1 = TransformsConfig( + data_augmentation=[transforms.Crop, transforms.Transform], + image_transforms=[transforms.Blur, transforms.Ghosting], + object_transforms=[transforms.BiasField, transforms.Motion], ) # not mandatory preprocessing_2 = caps_reader.get_preprocessing("pet-linear") extraction_2 = caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch=2) -transforms_2 = Transforms( +transforms_2 = TransformsConfig( data_augmentation=[torchio.t2], image_transforms=[torchio.t1], object_transforms=[torchio.t1, torchio.t2], @@ -151,16 +157,23 @@ caps_directory = Path("caps_directory") # output of clinica pipelines caps_reader = CapsReader(caps_directory, manager=manager) -extraction_1 = caps_reader.extract_image(preprocessing=T1PreprocessingConfig()) -transforms_1 = Transforms( - data_augmentation=[torchio.transforms.RandomMotion] +caps_reader.prepare_data( + preprocessing=T1PreprocessingConfig(), + data_tsv=Path(""), + n_proc=2, + use_uncropped_images=False, +) +transforms_1 = TransformsConfig( + data_augmentation=[transforms.RandomMotion], # default = no transforms + image_transforms=[transforms.Noise], # default = MiniMax + extraction=ExtractionMethod.PATCH, # default = Image + objects_transforms=[transforms.BiasField], # default = none ) # not mandatory sub_ses_tsv = Path("") split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv dataset_t1_image = caps_reader.get_dataset( - extraction=extraction_1, preprocessing=T1PreprocessingConfig(), sub_ses_tsv=split_dir / "train.tsv", transforms=transforms_1, diff --git a/clinicadl/commandline/pipelines/generate/artifacts/cli.py b/clinicadl/commandline/pipelines/generate/artifacts/cli.py index 68d1ec869..5cf5387fa 100644 --- a/clinicadl/commandline/pipelines/generate/artifacts/cli.py +++ b/clinicadl/commandline/pipelines/generate/artifacts/cli.py @@ -13,8 +13,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.artifacts import options as artifacts +from clinicadl.dataset.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateArtifactsConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py index 82c4e5cb3..9fe878578 100644 --- a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py +++ b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py @@ -12,8 +12,8 @@ from clinicadl.commandline.pipelines.generate.hypometabolic import ( options as hypometabolic, ) +from clinicadl.dataset.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateHypometabolicConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/random/cli.py b/clinicadl/commandline/pipelines/generate/random/cli.py index 8ea26a5d0..f14ccb6bd 100644 --- a/clinicadl/commandline/pipelines/generate/random/cli.py +++ b/clinicadl/commandline/pipelines/generate/random/cli.py @@ -14,8 +14,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.random import options as random +from clinicadl.dataset.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateRandomConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/trivial/cli.py b/clinicadl/commandline/pipelines/generate/trivial/cli.py index d683865f2..3cdc16267 100644 --- a/clinicadl/commandline/pipelines/generate/trivial/cli.py +++ b/clinicadl/commandline/pipelines/generate/trivial/cli.py @@ -13,8 +13,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.trivial import options as trivial +from clinicadl.dataset.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateTrivialConfig from clinicadl.generate.generate_utils import ( im_loss_roi_gaussian_distribution, diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py index d162dcf97..17b418249 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py @@ -8,7 +8,7 @@ preprocessing, ) from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.dataset.caps_reader.prepare_data import DeepLearningPrepareData from clinicadl.utils.enum import ExtractionMethod diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py index f4f888a71..a71cc1ca2 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py @@ -8,7 +8,7 @@ preprocessing, ) from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.dataset.caps_reader.prepare_data import DeepLearningPrepareData from clinicadl.utils.enum import ExtractionMethod diff --git a/clinicadl/dataset/caps_dataset.py b/clinicadl/dataset/caps_dataset.py index d45dc5aa6..a10794f03 100644 --- a/clinicadl/dataset/caps_dataset.py +++ b/clinicadl/dataset/caps_dataset.py @@ -5,76 +5,65 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union +import nibabel as nib import numpy as np import pandas as pd import torch from torch.utils.data import Dataset -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.dataset.config.extraction import ( + ExtractionConfig, ExtractionImageConfig, ExtractionPatchConfig, ExtractionROIConfig, ExtractionSliceConfig, ) -from clinicadl.dataset.prepare_data.prepare_data_utils import ( - compute_discarded_slices, - extract_patch_path, - extract_patch_tensor, - extract_roi_path, - extract_roi_tensor, - extract_slice_path, - extract_slice_tensor, - find_mask_path, -) +from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.utils import CapsDatasetOutput from clinicadl.transforms.config import TransformsConfig from clinicadl.utils.enum import ( + STR, + ExtractionMethod, Pattern, Preprocessing, SliceDirection, SliceMode, + SubFolder, + Suffix, Template, ) from clinicadl.utils.exceptions import ( ClinicaDLCAPSError, ClinicaDLTSVError, ) +from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader logger = getLogger("clinicadl") -################################# -# Datasets loaders -################################# class CapsDataset(Dataset): """Abstract class for all derived CapsDatasets.""" def __init__( self, - config: CapsDatasetConfig, - label_presence: bool, - preprocessing_dict: Dict[str, Any], + caps_directory: Path, + data_df: pd.DataFrame, + extraction: ExtractionConfig, + preprocessing: PreprocessingConfig, + transforms: TransformsConfig, ): - self.label_presence = label_presence - self.eval_mode = False - self.config = config - self.preprocessing_dict = preprocessing_dict - - if not hasattr(self, "elem_index"): - raise AttributeError( - "Child class of CapsDataset must set elem_index attribute." - ) - if not hasattr(self, "mode"): - raise AttributeError("Child class of CapsDataset, must set mode attribute.") + self.caps_directory = caps_directory + self.subjects_directory = caps_directory / STR.SUBJECTS.value + self.extraction = extraction + self.preprocessing = preprocessing + self.transforms = transforms - self.df = self.config.data.data_df + self.df = data_df mandatory_col = { - "participant_id", - "session_id", - "cohort", + STR.PARTICIPANT_ID.value, + STR.SESSION_ID.value, + STR.COHORT.value, } - if label_presence and self.config.data.label is not None: - mandatory_col.add(self.config.data.label) if not mandatory_col.issubset(set(self.df.columns.values)): raise ClinicaDLTSVError( @@ -82,40 +71,13 @@ def __init__( f"Columns should include {mandatory_col}" ) self.elem_per_image = self.num_elem_per_image() - self.size = self[0]["image"].size() + self.size = self[0].image.size() @property @abc.abstractmethod def elem_index(self): pass - def label_fn(self, target: Union[str, float, int]) -> Union[float, int, None]: - """ - Returns the label value usable in criterion. - - Args: - target: value of the target. - Returns: - label: value of the label usable in criterion. - """ - # Reconstruction case (no label) - if self.config.data.label is None: - return None - # Regression case (no label code) - elif self.config.data.label_code is None: - return np.float32([target]) - # Classification case (label + label_code dict) - else: - return self.config.data.label_code[str(target)] - - def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]: - """ - Returns the label value usable in criterion. - - """ - domain_code = {"t1": 0, "flair": 1} - return domain_code[str(target)] - def __len__(self) -> int: return len(self.df) * self.elem_per_image @@ -130,50 +92,48 @@ def _get_image_path(self, participant: str, session: str, cohort: str) -> Path: Returns: image_path: path to the tensor containing the whole image. """ - from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader # Try to find .nii.gz file try: - folder, file_type = self.config.compute_folder_and_file_type() - results = clinicadl_file_reader( [participant], [session], - self.config.data.caps_dict[cohort], - file_type.model_dump(), + self.caps_directory, + self.preprocessing.file_type.model_dump(), ) logger.debug(f"clinicadl_file_reader output: {results}") filepath = Path(results[0][0]) - image_filename = filepath.name.replace(".nii.gz", ".pt") + image_filename = filepath.name.replace(Suffix.PT.value, Suffix.NII_GZ.value) image_dir = ( - self.config.data.caps_dict[cohort] - / "subjects" + self.caps_directory + / STR.SUBJECTS.value / participant / session - / "deeplearning_prepare_data" - / "image_based" - / folder + / STR.DEEP_L_P_DATA.value + / SubFolder.IMAGE.value + / self.preprocessing.compute_folder() ) image_path = image_dir / image_filename # Try to find .pt file except ClinicaDLCAPSError: - folder, file_type = self.config.compute_folder_and_file_type() - file_type.pattern = file_type.pattern.replace(".nii.gz", ".pt") + self.preprocessing.file_type.pattern = ( + self.preprocessing.file_type.pattern.replace( + Suffix.NII_GZ.value, Suffix.PT.value + ) + ) results = clinicadl_file_reader( [participant], [session], - self.config.data.caps_dict[cohort], - file_type.model_dump(), + self.caps_directory, + self.preprocessing.file_type.model_dump(), ) filepath = results[0] image_path = Path(filepath[0]) return image_path - def _get_meta_data( - self, idx: int - ) -> Tuple[str, str, str, Union[float, int, None], int]: + def _get_meta_data(self, idx: int) -> Tuple[str, str, str, Union[float, int, None]]: """ Gets all meta data necessary to compute the path with _get_image_path @@ -187,26 +147,16 @@ def _get_meta_data( label (str or float or int): value of the label to be used in criterion. """ image_idx = idx // self.elem_per_image - participant = self.df.at[image_idx, "participant_id"] - session = self.df.at[image_idx, "session_id"] - cohort = self.df.at[image_idx, "cohort"] + participant = self.df.at[image_idx, STR.PARTICIPANT_ID.value] + session = self.df.at[image_idx, STR.SESSION_ID.value] + cohort = self.df.at[image_idx, STR.COHORT.value] if self.elem_index is None: elem_idx = idx % self.elem_per_image else: elem_idx = self.elem_index - if self.label_presence and self.config.data.label is not None: - target = self.df.at[image_idx, self.config.data.label] - label = self.label_fn(target) - else: - label = -1 - if "domain" in self.df.columns: - domain = self.df.at[image_idx, "domain"] - domain = self.domain_fn(domain) - else: - domain = "" # TO MODIFY - return participant, session, cohort, elem_idx, label, domain + return participant, session, cohort, elem_idx def _get_full_image(self) -> torch.Tensor: """ @@ -216,24 +166,20 @@ def _get_full_image(self) -> torch.Tensor: Returns: image tensor of the full image first image. """ - import nibabel as nib - from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader - - participant_id = self.df.loc[0, "participant_id"] - session_id = self.df.loc[0, "session_id"] - cohort = self.df.loc[0, "cohort"] + participant_id = self.df.at[0, STR.PARTICIPANT_ID.value] + session_id = self.df.at[0, STR.SESSION_ID.value] + cohort = self.df.at[0, STR.COHORT.value] try: image_path = self._get_image_path(participant_id, session_id, cohort) image = torch.load(image_path, weights_only=True) except IndexError: - file_type = self.config.extraction.file_type results = clinicadl_file_reader( [participant_id], [session_id], - self.config.data.caps_dict[cohort], - file_type.model_dump(), + self.caps_directory, + self.preprocessing.file_type.model_dump(), ) image_nii = nib.loadsave.load(results[0]) image_np = image_nii.get_fdata() @@ -242,7 +188,7 @@ def _get_full_image(self) -> torch.Tensor: return image @abc.abstractmethod - def __getitem__(self, idx: int) -> Dict[str, Any]: + def __getitem__(self, idx: int) -> CapsDatasetOutput: """ Gets the sample containing all the information needed for training and testing tasks. @@ -252,8 +198,8 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: dictionary with following items: - "image" (torch.Tensor): the input given to the model, - "label" (int or float): the label used in criterion, - - "participant_id" (str): ID of the participant, - - "session_id" (str): ID of the session, + - PARTICIPANT_ID (str): ID of the participant, + - SESSION_ID (str): ID of the session, - f"{self.mode}_id" (int): number of the element, - "image_path": path to the image loaded in CAPS. @@ -265,15 +211,15 @@ def num_elem_per_image(self) -> int: """Computes the number of elements per image based on the full image.""" pass - def eval(self): - """Put the dataset on evaluation mode (data augmentation is not performed).""" - self.eval_mode = True - return self + # def eval(self): + # """Put the dataset on evaluation mode (data augmentation is not performed).""" + # self.eval_mode = True + # return self - def train(self): - """Put the dataset on training mode (data augmentation is performed).""" - self.eval_mode = False - return self + # def train(self): + # """Put the dataset on training mode (data augmentation is performed).""" + # self.eval_mode = False + # return self class CapsDatasetImage(CapsDataset): @@ -281,9 +227,11 @@ class CapsDatasetImage(CapsDataset): def __init__( self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - label_presence: bool = True, + caps_directory: Path, + data_df: pd.DataFrame, + extraction: ExtractionImageConfig, + preprocessing: PreprocessingConfig, + transforms: TransformsConfig, ): """ Args: @@ -298,14 +246,12 @@ def __init__( multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. """ - - self.mode = "image" - self.config = config - self.label_presence = label_presence super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, + caps_directory=caps_directory, + data_df=data_df, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms, ) @property @@ -313,26 +259,26 @@ def elem_index(self): return None def __getitem__(self, idx): - participant, session, cohort, _, label, domain = self._get_meta_data(idx) + participant, session, cohort, _ = self._get_meta_data(idx) image_path = self._get_image_path(participant, session, cohort) image = torch.load(image_path, weights_only=True) - train_trf, trf = self.config.transforms.get_transforms() + train_trf, trf = self.transforms.get_transforms() image = trf(image) - if self.config.transforms.train_transformations and not self.eval_mode: + if self.transforms.train_transformations and not self.eval_mode: # train_mode image = train_trf(image) - sample = { - "image": image, - "label": label, - "participant_id": participant, - "session_id": session, - "image_id": 0, - "image_path": image_path.as_posix(), - "domain": domain, - } + sample = CapsDatasetOutput( + image=image, + label=label, + participant_id=participant, + session_id=session, + image_id=0, + image_path=image_path, + mode=ExtractionMethod.IMAGE, + ) return sample @@ -343,10 +289,12 @@ def num_elem_per_image(self): class CapsDatasetPatch(CapsDataset): def __init__( self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], + caps_directory: Path, + data_df: pd.DataFrame, + extraction: ExtractionPatchConfig, + preprocessing: PreprocessingConfig, + transforms: TransformsConfig, patch_index: Optional[int] = None, - label_presence: bool = True, ): """ caps_directory: Directory of all the images. @@ -354,35 +302,30 @@ def __init__( preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. train_transformations: Optional transform to be applied only on training mode. """ - self.patch_index = patch_index - self.mode = "patch" - self.config = config - self.label_presence = label_presence - + # self.patch_index = patch_index + self.extraction = extraction super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, + caps_directory=caps_directory, + data_df=data_df, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms, ) - @property - def elem_index(self): - return self.patch_index + # @property + # def elem_index(self): + # return self.patch_index def __getitem__(self, idx): - participant, session, cohort, patch_idx, label, domain = self._get_meta_data( - idx - ) + participant, session, cohort, patch_idx = self._get_meta_data(idx) image_path = self._get_image_path(participant, session, cohort) - if self.config.extraction.save_features: + if self.extraction.save_features: patch_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" + SubFolder.IMAGE.value, SubFolder.PATCH.value ) - patch_filename = extract_patch_path( + patch_filename = self.extraction.extract_patch_path( image_path, - self.config.extraction.patch_size, - self.config.extraction.stride_size, patch_idx, ) patch_tensor = torch.load( @@ -391,70 +334,38 @@ def __getitem__(self, idx): else: image = torch.load(image_path, weights_only=True) - patch_tensor = extract_patch_tensor( + patch_tensor = self.extraction.extract_patch_tensor( image, - self.config.extraction.patch_size, - self.config.extraction.stride_size, patch_idx, ) - train_trf, trf = self.config.transforms.get_transforms() + train_trf, trf = self.transforms.get_transforms() patch_tensor = trf(patch_tensor) - if self.config.transforms.train_transformations and not self.eval_mode: + if self.transforms.train_transformations and not self.eval_mode: patch_tensor = train_trf(patch_tensor) - sample = { - "image": patch_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "patch_id": patch_idx, - } + sample = CapsDatasetOutput( + image=patch_tensor, + label=label, + participant_id=participant, + session_id=session, + image_id=patch_idx, + mode=ExtractionMethod.PATCH, + ) return sample - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - - image = self._get_full_image() - - patches_tensor = ( - image.unfold( - 1, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .unfold( - 2, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .unfold( - 3, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .contiguous() - ) - patches_tensor = patches_tensor.view( - -1, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - ) - num_patches = patches_tensor.shape[0] - return num_patches - class CapsDatasetRoi(CapsDataset): def __init__( self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], + caps_directory: Path, + data_df: pd.DataFrame, + extraction: ExtractionROIConfig, + preprocessing: PreprocessingConfig, + transforms: TransformsConfig, roi_index: Optional[int] = None, - label_presence: bool = True, ): """ Args: @@ -472,126 +383,83 @@ def __init__( """ self.roi_index = roi_index - self.mode = "roi" - self.config = config - self.label_presence = label_presence - self.mask_paths, self.mask_arrays = self._get_mask_paths_and_tensors( - self.config.data.caps_directory, preprocessing_dict - ) + self.extraction = extraction super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, + caps_directory=caps_directory, + data_df=data_df, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms, ) + self.mask_paths, self.mask_arrays = self._get_mask_paths_and_tensors() + @property def elem_index(self): return self.roi_index def __getitem__(self, idx): - participant, session, cohort, roi_idx, label, domain = self._get_meta_data(idx) + participant, session, cohort, roi_idx = self._get_meta_data(idx) image_path = self._get_image_path(participant, session, cohort) - if self.config.extraction.roi_list is None: + if self.extraction.roi_list is None: raise NotImplementedError( "Default regions are not available anymore in ClinicaDL. " "Please define appropriate masks and give a roi_list." ) - if self.config.extraction.save_features: + if self.extraction.save_features: mask_path = self.mask_paths[roi_idx] roi_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - roi_filename = extract_roi_path( - image_path, mask_path, self.config.extraction.roi_uncrop_output + SubFolder.IMAGE.value, SubFolder.ROI.value ) + roi_filename = self.extraction.extract_roi_path(image_path, mask_path) roi_tensor = torch.load(Path(roi_dir) / roi_filename, weights_only=True) else: image = torch.load(image_path, weights_only=True) mask_array = self.mask_arrays[roi_idx] - roi_tensor = extract_roi_tensor( - image, mask_array, self.config.extraction.uncropped_roi - ) + roi_tensor = self.extraction.extract_roi_tensor(image, mask_array) - train_trf, trf = self.config.transforms.get_transforms() + train_trf, trf = self.transforms.get_transforms() roi_tensor = trf(roi_tensor) - if self.config.transforms.train_transformations and not self.eval_mode: + if self.transforms.train_transformations and not self.eval_mode: roi_tensor = train_trf(roi_tensor) - sample = { - "image": roi_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "roi_id": roi_idx, - } + sample = CapsDatasetOutput( + image=roi_tensor, + label=label, + participant_id=participant, + session_id=session, + image_id=roi_idx, + mode=ExtractionMethod.ROI, + ) return sample def num_elem_per_image(self): if self.elem_index is not None: return 1 - if self.config.extraction.roi_list is None: + if self.extraction.roi_list is None: return 2 else: - return len(self.config.extraction.roi_list) + return len(self.extraction.roi_list) def _get_mask_paths_and_tensors( self, - caps_directory: Path, - preprocessing_dict: Dict[str, Any], ) -> Tuple[List[str], List]: """Loads the masks necessary to regions extraction""" - import nibabel as nib - - caps_dict = self.config.data.caps_dict - if len(caps_dict) > 1: - caps_directory = caps_dict[next(iter(caps_dict))] - logger.warning( - f"The equality of masks is not assessed for multi-cohort training. " - f"The masks stored in {caps_directory} will be used." - ) - try: - preprocessing_ = Preprocessing(preprocessing_dict["preprocessing"]) - except NotImplementedError: - print( - f"Template of preprocessing {preprocessing_dict['preprocessing']} " - f"is not defined." - ) - # Find template name and pattern - if preprocessing_.value == "custom": - template_name = preprocessing_dict["roi_custom_template"] - if template_name is None: - raise ValueError( - "Please provide a name for the template when preprocessing is `custom`." - ) - - pattern = preprocessing_dict["roi_custom_mask_pattern"] - if pattern is None: - raise ValueError( - "Please provide a pattern for the masks when preprocessing is `custom`." - ) - - else: - for template_ in Template: - if preprocessing_.name == template_.name: - template_name = template_ - - for pattern_ in Pattern: - if preprocessing_.name == pattern_.name: - pattern = pattern_ - - mask_location = caps_directory / "masks" / f"tpl-{template_name}" + mask_location = ( + self.caps_directory / "masks" / f"tpl-{self.extraction.roi_template}" + ) mask_paths, mask_arrays = list(), list() - for roi in self.config.extraction.roi_list: + for roi in self.extraction.roi_list: logger.info(f"Find mask for roi {roi}.") - mask_path, desc = find_mask_path(mask_location, roi, pattern, True) + mask_path, desc = self.extraction.find_mask_path(mask_location, roi) if mask_path is None: raise FileNotFoundError(desc) mask_nii = nib.loadsave.load(mask_path) @@ -604,10 +472,12 @@ def _get_mask_paths_and_tensors( class CapsDatasetSlice(CapsDataset): def __init__( self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], + caps_directory: Path, + data_df: pd.DataFrame, + extraction: ExtractionSliceConfig, + preprocessing: PreprocessingConfig, + transforms: TransformsConfig, slice_index: Optional[int] = None, - label_presence: bool = True, ): """ Args: @@ -624,14 +494,13 @@ def __init__( multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. """ self.slice_index = slice_index - self.mode = "slice" - self.config = config - self.label_presence = label_presence - self.preprocessing_dict = preprocessing_dict + self.extraction = extraction super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, + caps_directory=caps_directory, + data_df=data_df, + extraction=extraction, + preprocessing=preprocessing, + transforms=transforms, ) @property @@ -639,20 +508,16 @@ def elem_index(self): return self.slice_index def __getitem__(self, idx): - participant, session, cohort, slice_idx, label, domain = self._get_meta_data( - idx - ) - slice_idx = slice_idx + self.config.extraction.discarded_slices[0] + participant, session, cohort, slice_idx = self._get_meta_data(idx) + slice_idx = slice_idx + self.extraction.discarded_slices[0] image_path = self._get_image_path(participant, session, cohort) - if self.config.extraction.save_features: + if self.extraction.save_features: slice_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" + SubFolder.IMAGE.value, SubFolder.SLICE.value ) - slice_filename = extract_slice_path( + slice_filename = self.extraction.extract_slice_path( image_path, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, slice_idx, ) slice_tensor = torch.load( @@ -662,27 +527,26 @@ def __getitem__(self, idx): else: image_path = self._get_image_path(participant, session, cohort) image = torch.load(image_path, weights_only=True) - slice_tensor = extract_slice_tensor( + slice_tensor = self.extraction.extract_slice_tensor( image, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, slice_idx, ) - train_trf, trf = self.config.transforms.get_transforms() + train_trf, trf = self.transforms.get_transforms() slice_tensor = trf(slice_tensor) - if self.config.transforms.train_transformations and not self.eval_mode: + if self.transforms.train_transformations and not self.eval_mode: slice_tensor = train_trf(slice_tensor) - sample = { - "image": slice_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "slice_id": slice_idx, - } + sample = CapsDatasetOutput( + image=slice_tensor, + label=label, + participant_id=participant, + session_id=session, + image_id=slice_idx, + mode=ExtractionMethod.SLICE, + ) return sample @@ -690,128 +554,12 @@ def num_elem_per_image(self): if self.elem_index is not None: return 1 - if self.config.extraction.num_slices is not None: - return self.config.extraction.num_slices + if self.extraction.num_slices is not None: + return self.extraction.num_slices image = self._get_full_image() return ( - image.size(int(self.config.extraction.slice_direction) + 1) - - self.config.extraction.discarded_slices[0] - - self.config.extraction.discarded_slices[1] - ) - - -def return_dataset( - input_dir: Path, - data_df: pd.DataFrame, - preprocessing_dict: Dict[str, Any], - transforms_config: TransformsConfig, - label: Optional[str] = None, - label_code: Optional[Dict[str, int]] = None, - cnn_index: Optional[int] = None, - label_presence: bool = True, - multi_cohort: bool = False, -) -> CapsDataset: - """ - Return appropriate Dataset according to given options. - Args: - input_dir: path to a directory containing a CAPS structure. - data_df: List subjects, sessions and diagnoses. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied during training only. - all_transformations: Optional transform to be applied during training and evaluation. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - cnn_index: Index of the CNN in a multi-CNN paradigm (optional). - label_presence: If True the diagnosis will be extracted from the given DataFrame. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - Returns: - the corresponding dataset. - """ - if cnn_index is not None and preprocessing_dict["mode"] == "image": - raise NotImplementedError( - f"Multi-CNN is not implemented for {preprocessing_dict['mode']} mode." - ) - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - preprocessing_type=preprocessing_dict["preprocessing"], - preprocessing=preprocessing_dict["preprocessing"], - extraction=preprocessing_dict["mode"], - caps_directory=input_dir, - data_df=data_df, - label=label, - label_code=label_code, - multi_cohort=multi_cohort, - ) - config.transforms = transforms_config - - if preprocessing_dict["mode"] == "image": - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetImage( - config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "patch": - assert isinstance(config.extraction, ExtractionPatchConfig) - config.extraction.patch_size = preprocessing_dict["patch_size"] - config.extraction.stride_size = preprocessing_dict["stride_size"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetPatch( - config, - patch_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "roi": - assert isinstance(config.extraction, ExtractionROIConfig) - config.extraction.roi_list = preprocessing_dict["roi_list"] - config.extraction.roi_uncrop_output = preprocessing_dict["uncropped_roi"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetRoi( - config, - roi_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "slice": - assert isinstance(config.extraction, ExtractionSliceConfig) - config.extraction.slice_direction = SliceDirection( - str(preprocessing_dict["slice_direction"]) - ) - config.extraction.slice_mode = SliceMode(preprocessing_dict["slice_mode"]) - config.extraction.discarded_slices = compute_discarded_slices( - preprocessing_dict["discarded_slices"] - ) - config.extraction.num_slices = ( - None - if "num_slices" not in preprocessing_dict - else preprocessing_dict["num_slices"] - ) - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetSlice( - config, - slice_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - else: - raise NotImplementedError( - f"Mode {preprocessing_dict['mode']} is not implemented." + image.size(int(self.extraction.slice_direction) + 1) + - self.extraction.discarded_slices[0] + - self.extraction.discarded_slices[1] ) diff --git a/clinicadl/dataset/caps_dataset_config.py b/clinicadl/dataset/caps_dataset_config.py deleted file mode 100644 index 0eac3ffd3..000000000 --- a/clinicadl/dataset/caps_dataset_config.py +++ /dev/null @@ -1,127 +0,0 @@ -from pathlib import Path -from typing import Optional, Tuple, Union - -from pydantic import BaseModel, ConfigDict - -from clinicadl.dataset.config import extraction -from clinicadl.dataset.config.preprocessing import ( - CustomPreprocessingConfig, - DTIPreprocessingConfig, - FlairPreprocessingConfig, - PETPreprocessingConfig, - PreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.data_config import DataConfig -from clinicadl.dataset.dataloader_config import DataLoaderConfig -from clinicadl.dataset.utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, -) -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils.enum import ExtractionMethod, Preprocessing -from clinicadl.utils.iotools.clinica_utils import FileType - - -def get_extraction(extract_method: ExtractionMethod): - if extract_method == ExtractionMethod.ROI: - return extraction.ExtractionROIConfig - elif extract_method == ExtractionMethod.SLICE: - return extraction.ExtractionSliceConfig - elif extract_method == ExtractionMethod.IMAGE: - return extraction.ExtractionImageConfig - elif extract_method == ExtractionMethod.PATCH: - return extraction.ExtractionPatchConfig - else: - raise ValueError(f"Preprocessing {extract_method.value} is not implemented.") - - -def get_preprocessing(preprocessing_type: Preprocessing): - if preprocessing_type == Preprocessing.T1_LINEAR: - return T1PreprocessingConfig - elif preprocessing_type == Preprocessing.PET_LINEAR: - return PETPreprocessingConfig - elif preprocessing_type == Preprocessing.FLAIR_LINEAR: - return FlairPreprocessingConfig - elif preprocessing_type == Preprocessing.CUSTOM: - return CustomPreprocessingConfig - elif preprocessing_type == Preprocessing.DWI_DTI: - return DTIPreprocessingConfig - else: - raise ValueError( - f"Preprocessing {preprocessing_type.value} is not implemented." - ) - - -class CapsDatasetConfig(BaseModel): - """Config class for CapsDataset object. - - caps_directory, preprocessing_json, extract_method, preprocessing - are arguments that must be passed by the user. - - transforms isn't optional because there is always at least one transform (NanRemoval) - """ - - data: DataConfig - dataloader: DataLoaderConfig - extraction: extraction.ExtractionConfig - preprocessing: PreprocessingConfig - transforms: TransformsConfig - - # pydantic config - model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) - - @classmethod - def from_preprocessing_and_extraction_method( - cls, - preprocessing_type: Union[str, Preprocessing], - extraction: Union[str, ExtractionMethod], - **kwargs, - ): - return cls( - data=DataConfig(**kwargs), - dataloader=DataLoaderConfig(**kwargs), - preprocessing=get_preprocessing(Preprocessing(preprocessing_type))( - **kwargs - ), - extraction=get_extraction(ExtractionMethod(extraction))(**kwargs), - transforms=TransformsConfig(**kwargs), - ) - - def compute_folder_and_file_type( - self, from_bids: Optional[Path] = None - ) -> Tuple[str, FileType]: - preprocessing = self.preprocessing.preprocessing - if from_bids is not None: - if isinstance(self.preprocessing, CustomPreprocessingConfig): - mod_subfolder = Preprocessing.CUSTOM.value - file_type = FileType( - pattern=f"*{self.preprocessing.custom_suffix}", - description="Custom suffix", - ) - else: - mod_subfolder = preprocessing - file_type = bids_nii(self.preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if isinstance(self.preprocessing, T1PreprocessingConfig) or isinstance( - self.preprocessing, FlairPreprocessingConfig - ): - file_type = linear_nii(self.preprocessing) - elif isinstance(self.preprocessing, PETPreprocessingConfig): - file_type = pet_linear_nii(self.preprocessing) - elif isinstance(self.preprocessing, DTIPreprocessingConfig): - file_type = dwi_dti(self.preprocessing) - elif isinstance(self.preprocessing, CustomPreprocessingConfig): - file_type = FileType( - pattern=f"*{self.preprocessing.custom_suffix}", - description="Custom suffix", - ) - return mod_subfolder, file_type diff --git a/clinicadl/dataset/caps_dataset_utils.py b/clinicadl/dataset/caps_dataset_utils.py deleted file mode 100644 index b54ba373d..000000000 --- a/clinicadl/dataset/caps_dataset_utils.py +++ /dev/null @@ -1,193 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.config.preprocessing import ( - CustomPreprocessingConfig, - DTIPreprocessingConfig, - FlairPreprocessingConfig, - PETPreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, -) -from clinicadl.utils.enum import Preprocessing -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.iotools.clinica_utils import FileType - - -def compute_folder_and_file_type( - config: CapsDatasetConfig, from_bids: Optional[Path] = None -) -> Tuple[str, FileType]: - preprocessing = config.preprocessing.preprocessing - if from_bids is not None: - if isinstance(config.preprocessing, CustomPreprocessingConfig): - mod_subfolder = Preprocessing.CUSTOM.value - file_type = FileType( - pattern=f"*{config.preprocessing.custom_suffix}", - description="Custom suffix", - ) - else: - mod_subfolder = preprocessing - file_type = bids_nii(config.preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if isinstance(config.preprocessing, T1PreprocessingConfig) or isinstance( - config.preprocessing, FlairPreprocessingConfig - ): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - file_type = pet_linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, DTIPreprocessingConfig): - file_type = dwi_dti(config.preprocessing) - elif isinstance(config.preprocessing, CustomPreprocessingConfig): - file_type = FileType( - pattern=f"*{config.preprocessing.custom_suffix}", - description="Custom suffix", - ) - return mod_subfolder, file_type - - -def find_file_type(config: CapsDatasetConfig) -> FileType: - if isinstance(config.preprocessing, T1PreprocessingConfig): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - if ( - config.preprocessing.tracer is None - or config.preprocessing.suvr_reference_region is None - ): - raise ClinicaDLArgumentError( - "`tracer` and `suvr_reference_region` must be defined " - "when using `pet-linear` preprocessing." - ) - file_type = pet_linear_nii(config.preprocessing) - else: - raise NotImplementedError( - f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.preprocessing.value}" - ) - - return file_type - - -def read_json(json_path: Path) -> Dict[str, Any]: - """ - Ensures retro-compatibility between the different versions of ClinicaDL. - - Parameters - ---------- - json_path: Path - path to the JSON file summing the parameters of a MAPS. - - Returns - ------- - A dictionary of training parameters. - """ - from clinicadl.utils.iotools.utils import path_decoder - - with json_path.open(mode="r") as f: - parameters = json.load(f, object_hook=path_decoder) - # Types of retro-compatibility - # Change arg name: ex network --> model - # Change arg value: ex for preprocessing: mni --> t1-extensive - # New arg with default hard-coded value --> discarded_slice --> 20 - retro_change_name = { - "model": "architecture", - "multi": "multi_network", - "minmaxnormalization": "normalize", - "num_workers": "n_proc", - "mode": "extract_method", - } - - retro_add = { - "optimizer": "Adam", - "loss": None, - } - - for old_name, new_name in retro_change_name.items(): - if old_name in parameters: - parameters[new_name] = parameters[old_name] - del parameters[old_name] - - for name, value in retro_add.items(): - if name not in parameters: - parameters[name] = value - - if "extract_method" in parameters: - parameters["mode"] = parameters["extract_method"] - # Value changes - if "use_cpu" in parameters: - parameters["gpu"] = not parameters["use_cpu"] - del parameters["use_cpu"] - if "nondeterministic" in parameters: - parameters["deterministic"] = not parameters["nondeterministic"] - del parameters["nondeterministic"] - - # Build preprocessing_dict - if "preprocessing_dict" not in parameters: - parameters["preprocessing_dict"] = {"mode": parameters["mode"]} - preprocessing_options = [ - "preprocessing", - "use_uncropped_image", - "prepare_dl", - "custom_suffix", - "tracer", - "suvr_reference_region", - "patch_size", - "stride_size", - "slice_direction", - "slice_mode", - "discarded_slices", - "roi_list", - "uncropped_roi", - "roi_custom_suffix", - "roi_custom_template", - "roi_custom_mask_pattern", - ] - for preprocessing_var in preprocessing_options: - if preprocessing_var in parameters: - parameters["preprocessing_dict"][preprocessing_var] = parameters[ - preprocessing_var - ] - del parameters[preprocessing_var] - - # Add missing parameters in previous version of extract - if "use_uncropped_image" not in parameters["preprocessing_dict"]: - parameters["preprocessing_dict"]["use_uncropped_image"] = False - - if ( - "prepare_dl" not in parameters["preprocessing_dict"] - and parameters["mode"] != "image" - ): - parameters["preprocessing_dict"]["prepare_dl"] = False - - if ( - parameters["mode"] == "slice" - and "slice_mode" not in parameters["preprocessing_dict"] - ): - parameters["preprocessing_dict"]["slice_mode"] = "rgb" - - if "preprocessing" not in parameters: - parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"] - - from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=parameters["mode"], - preprocessing_type=parameters["preprocessing"], - **parameters, - ) - if "file_type" not in parameters["preprocessing_dict"]: - _, file_type = compute_folder_and_file_type(config) - parameters["preprocessing_dict"]["file_type"] = file_type.model_dump() - - return parameters diff --git a/clinicadl/dataset/caps_reader.py b/clinicadl/dataset/caps_reader.py index 14199616e..b11a8c553 100644 --- a/clinicadl/dataset/caps_reader.py +++ b/clinicadl/dataset/caps_reader.py @@ -1,7 +1,22 @@ +import json +from enum import Enum +from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Optional, Union -from clinicadl.dataset.caps_dataset import CapsDataset +import nibabel as nib +import pandas as pd +import torch +from joblib import Parallel, delayed +from torch import save as save_tensor + +from clinicadl.dataset.caps_dataset import ( + CapsDataset, + CapsDatasetImage, + CapsDatasetPatch, + CapsDatasetRoi, + CapsDatasetSlice, +) from clinicadl.dataset.config.extraction import ( ExtractionConfig, ExtractionImageConfig, @@ -9,54 +24,419 @@ ExtractionROIConfig, ExtractionSliceConfig, ) -from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.config.preprocessing import ( + CustomPreprocessingConfig, + DTIPreprocessingConfig, + FlairPreprocessingConfig, + PETPreprocessingConfig, + PreprocessingConfig, + T1PreprocessingConfig, + T2PreprocessingConfig, +) +from clinicadl.dataset.config.utils import get_preprocessing from clinicadl.experiment_manager.experiment_manager import ExperimentManager from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.enum import ( + DTIMeasure, + DTISpace, + Preprocessing, + SliceDirection, + SliceMode, + SubFolder, + Suffix, + SUVRReferenceRegions, + Tracer, +) +from clinicadl.utils.exceptions import ClinicaDLArgumentError +from clinicadl.utils.iotools.clinica_utils import ( + check_caps_folder, + clinicadl_file_reader, + container_from_filename, + create_subs_sess_list, + determine_caps_or_bids, + get_subject_session_list, +) +from clinicadl.utils.iotools.utils import path_encoder + +logger = getLogger("clinicadl.caps_reader") class CapsReader: - def __init__(self, caps_directory: Path, manager: ExperimentManager): + def __init__( + self, + caps_directory: Path, + manager: ExperimentManager, + from_bids: Optional[Path] = None, + ): + """TO COMPLETE""" + + self.manager = manager + self.get_input_directory(caps_directory, from_bids) + + def create_caps_json(self): + caps_json = self.input_directory / "caps.json" + if caps_json.is_file: + with open(caps_json, "a") as f: + caps_data = json.load(f) + return caps_data + + else: + with open(caps_json, "w") as f: + f.write() + return caps_data + + def get_input_directory( + self, caps_directory: Path, from_bids: Optional[Path] = None + ): + # Get subject and session list + if from_bids is not None: + try: + self.input_directory = Path(from_bids) + except ClinicaDLArgumentError: + logger.warning("Your BIDS directory doesn't exist.") + logger.debug(f"BIDS directory: {self.input_directory}.") + self.bids = True + else: + self.input_directory = caps_directory + check_caps_folder(self.input_directory) + logger.debug(f"CAPS directory: {self.input_directory}.") + self.bids = False + + def prepare_extraction( + self, preprocessing: PreprocessingConfig, data_tsv: Optional[Path] = None + ): + subjects, sessions = get_subject_session_list( + self.input_directory, data_tsv, self.bids, False, None + ) + logger.debug(f"List of subjects: \n{subjects}.") + logger.debug(f"List of sessions: \n{sessions}.") + + file_type = preprocessing.get_filetype() + + input_files = clinicadl_file_reader( + subjects, sessions, self.input_directory, file_type.model_dump() + )[0] + logger.debug(f"Selected image file name list: {input_files}.") + + return input_files + + def prepare_data( + self, + preprocessing: PreprocessingConfig, + data_tsv: Optional[Path] = None, + n_proc: int = 2, + use_uncropped_images: bool = False, + ) -> ExtractionImageConfig: + """TO COMPLETE""" + + # extraction = ExtractionImageConfig(use_uncropped_image = use_uncropped_images) + input_files = self.prepare_extraction(preprocessing, data_tsv=data_tsv) + + def prepare_image(file: Path): + output_file_dir = ( + self.input_directory + / container_from_filename(file) + / "deeplearning_prepare_data" + / SubFolder.IMAGE.value + / preprocessing.compute_folder(self.bids) + ) + + output_file_dir.mkdir(parents=True, exist_ok=True) + output_file = output_file_dir / file.name.replace( + Suffix.NII_GZ.value, Suffix.PT.value + ) + + logger.debug(f"Processing of {file}.") + image_array = nib.loadsave.load(file).get_fdata(dtype="float32") + + # get some important infos about the image + info_df = pd.DataFrame(columns=["mean", "std", "max", "min"]) + info_df.loc[0] = [ + image_array.mean(), + image_array.std(), + image_array.max(), + image_array.min(), + ] + info_df.to_csv("image_info.tsv", sep="\t") + + # extract and save the image tensor + image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() + save_tensor(image_tensor.clone(), output_file) + logger.debug(f"Output tensor saved at {output_file}") + + Parallel(n_jobs=n_proc)(delayed(prepare_image)(file) for file in input_files) + + return extraction + + def write_output_imgs( + self, + output_mode: list, + file: Path, + subfolder: SubFolder, + preprocessing: PreprocessingConfig, + ): + # Write the extracted tensor on a .pt file + container = container_from_filename(file) + mod_subfolder = preprocessing.compute_folder(self.bids) + + for filename, tensor in output_mode: + output_file_dir = ( + self.input_directory + / container + / "deeplearning_prepare_data" + / subfolder.value + / mod_subfolder + ) + output_file_dir.mkdir(parents=True, exist_ok=True) + output_file = output_file_dir / filename + save_tensor(tensor, output_file) + logger.debug(f"Output tensor saved at {output_file}") + + def write_preprocessing( + self, preprocessing: PreprocessingConfig, extraction: ExtractionConfig + ) -> Path: + extract_dir = self.input_directory / "tensor_extraction" + extract_dir.mkdir(parents=True, exist_ok=True) + + json_path = extract_dir / extraction.extract_json + + if json_path.is_file(): + raise FileExistsError( + f"JSON file at {json_path} already exists. " + f"Please choose another name for your preprocessing file." + ) + + preprocessing_dict = preprocessing.model_dump() + preprocessing_dict.update(extraction.model_dump()) + + with json_path.open(mode="w") as json_file: + json.dump(preprocessing_dict, json_file, default=path_encoder) + return json_path + + def get_preprocessing( + self, preprocessing: Union[str, Preprocessing] + ) -> PreprocessingConfig: """TO COMPLETE""" - pass + + preprocessing_ = Preprocessing(preprocessing) + print(preprocessing_) + subjects, sessions = get_subject_session_list( + input_dir=self.input_directory, is_bids_dir=self.bids + ) + if ( + self.input_directory + / "subjects" + / subjects[0] + / sessions[0] + / (preprocessing_.value).replace("-", "_") + ).is_dir(): + preprocessing_config = get_preprocessing(preprocessing_)() + preprocessing_config.from_bids = self.bids + pattern = preprocessing_config.file_type.pattern + + def get_value(enum, pattern: str): + for value in enum: + if value.value in pattern: + return value + raise ValueError( + f"We can't find a value matching {[e.value for e in enum]} in the pattern {pattern}" + ) + + if isinstance(preprocessing_config, PETPreprocessingConfig): + preprocessing_config.tracer = get_value(Tracer, pattern) + preprocessing_config.suvr_reference_region = get_value( + SUVRReferenceRegions, pattern + ) + + elif isinstance(preprocessing_config, DTIPreprocessingConfig): + preprocessing_config.dti_measure = get_value(DTIMeasure, pattern) + preprocessing_config.dti_space = get_value(DTISpace, pattern) + + elif isinstance(preprocessing_config, CustomPreprocessingConfig): + # TODO: add something to find the custom pattern + pass + else: + raise FileNotFoundError( + f"The preprocessing folder {preprocessing} does not exist." + ) + return preprocessing_config def get_dataset( self, - extraction: ExtractionConfig, preprocessing: PreprocessingConfig, - sub_ses_tsv: Path, - transforms: TransformsConfig, + sub_ses_tsv: Optional[Path] = None, + transforms: Optional[TransformsConfig] = None, ) -> CapsDataset: - return CapsDataset(extraction, preprocessing, sub_ses_tsv, transforms) + if sub_ses_tsv is None: + sub_ses_tsv = create_subs_sess_list( + self.input_directory, output_dir=self.input_directory + ) + elif not sub_ses_tsv.is_file(): + raise FileNotFoundError( + f"The provided sub_ses_tsv file {sub_ses_tsv} does not exist." + ) - def get_preprocessing(self, preprocessing: str) -> PreprocessingConfig: - """TO COMPLETE""" + data_df = pd.read_csv(sub_ses_tsv, sep="\t") + + if transforms is None: + logger.info( + "No transforms was provided. We will use the default transforms. Check the documentation for more information" + ) + transforms = TransformsConfig() + + if isinstance(extraction, ExtractionImageConfig): + return CapsDatasetImage( + caps_directory=self.input_directory, + extraction=extraction, + preprocessing=preprocessing, + data_df=data_df, + transforms=transforms, + ) + + elif isinstance(extraction, ExtractionSliceConfig): + return CapsDatasetSlice( + caps_directory=self.input_directory, + extraction=extraction, + preprocessing=preprocessing, + data_df=data_df, + transforms=transforms, + ) + + elif isinstance(extraction, ExtractionPatchConfig): + return CapsDatasetPatch( + caps_directory=self.input_directory, + extraction=extraction, + preprocessing=preprocessing, + data_df=data_df, + transforms=transforms, + ) + + elif isinstance(extraction, ExtractionROIConfig): + return CapsDatasetRoi( + caps_directory=self.input_directory, + extraction=extraction, + preprocessing=preprocessing, + data_df=data_df, + transforms=transforms, + ) - return PreprocessingConfig() + else: + raise NotImplementedError( + f"Mode {extraction.extract_method.value} is not implemented." + ) def extract_slice( - self, preprocessing: PreprocessingConfig, arg_slice: Optional[int] = None + self, + preprocessing: PreprocessingConfig, + data_tsv: Optional[Path] = None, + n_proc: int = 2, + extract_json: Optional[str] = None, + slice_direction: Optional[SliceDirection] = None, + slice_mode: Optional[SliceMode] = None, + discarded_slices: Optional[Union[int, tuple]] = None, ) -> ExtractionSliceConfig: """TO COMPLETE""" - return ExtractionSliceConfig() + input_files = self.prepare_extraction(preprocessing, data_tsv=data_tsv) + extraction = ExtractionSliceConfig( + extract_json=extract_json, + slice_direction=slice_direction, + slice_mode=slice_mode, + discarded_slices=discarded_slices, + ) + + def prepare_slice(file): + logger.debug(f" Processing of {file}.") + output_mode = extraction.extract_slices(file) + logger.debug(f"{len(output_mode)} slices extracted.") + + self.write_output_imgs( + output_mode=output_mode, + file=file, + subfolder=SubFolder.SLICE, + preprocessing=preprocessing, + ) + + Parallel(n_jobs=n_proc)(delayed(prepare_slice)(file) for file in input_files) + + return extraction def extract_patch( - self, preprocessing: PreprocessingConfig, arg_patch: Optional[int] = None + self, + preprocessing: PreprocessingConfig, + data_tsv: Optional[Path] = None, + n_proc: int = 2, + extract_json: Optional[str] = None, + patch_size: Optional[int] = None, + stride_size: Optional[int] = None, ) -> ExtractionPatchConfig: """TO COMPLETE""" - return ExtractionPatchConfig() + input_files = self.prepare_extraction(preprocessing, data_tsv=data_tsv) + extraction = ExtractionPatchConfig( + extract_json=extract_json, patch_size=patch_size, stride_size=stride_size + ) + + def prepare_patch(file): + logger.debug(f" Processing of {file}.") + output_mode = extraction.extract_patches(file) + logger.debug(f"{len(output_mode)} patches extracted.") + self.write_output_imgs( + output_mode=output_mode, + file=file, + subfolder=SubFolder.PATCH, + preprocessing=preprocessing, + ) + + Parallel(n_jobs=n_proc)(delayed(prepare_patch)(file) for file in input_files) + + return extraction def extract_roi( - self, preprocessing: PreprocessingConfig, arg_roi: Optional[int] = None + self, + preprocessing: PreprocessingConfig, + data_tsv: Optional[Path] = None, + n_proc: int = 2, + extract_json: Optional[str] = None, + roi_list: Optional[list[str]] = None, + roi_crop_input: Optional[bool] = None, + roi_crop_output: Optional[bool] = None, + roi_custom_template: Optional[str] = None, + roi_custom_pattern: str = None, + roi_custom_suffix: Optional[str] = None, + roi_custom_mask_pattern: Optional[str] = None, + roi_background_value: Optional[int] = None, ) -> ExtractionROIConfig: """TO COMPLETE""" - return ExtractionROIConfig() + input_files = self.prepare_extraction(preprocessing, data_tsv=data_tsv) + extraction = ExtractionROIConfig( + extract_json=extract_json, + roi_list=roi_list, + roi_crop_input=roi_crop_input, + roi_crop_output=roi_crop_output, + roi_custom_template=roi_custom_template, + roi_custom_mask_pattern=roi_custom_pattern, + roi_custom_suffix=roi_custom_suffix, + roi_background_value=roi_background_value, + ) + extraction.check_with_preprocessing(preprocessing=preprocessing.preprocessing) - def extract_image( - self, preprocessing: PreprocessingConfig, arg_image: Optional[int] = None - ) -> ExtractionImageConfig: - """TO COMPLETE""" + def prepare_roi(file): + logger.debug(f" Processing of {file}.") + masks_location = ( + self.input_directory / "masks" / f"tpl-{extraction.roi_template}" + ) + extraction.check_mask_list(masks_location=masks_location) + output_mode = extraction.extract_roi(file, masks_location=masks_location) + logger.debug(f"{len(output_mode)} patches extracted.") + self.write_output_imgs( + output_mode=output_mode, + file=file, + subfolder=SubFolder.ROI, + preprocessing=preprocessing, + ) + + Parallel(n_jobs=n_proc)(delayed(prepare_roi)(file) for file in input_files) - return ExtractionImageConfig() + return extraction diff --git a/clinicadl/dataset/concat.py b/clinicadl/dataset/concat.py index f0b420dfe..884d317f2 100644 --- a/clinicadl/dataset/concat.py +++ b/clinicadl/dataset/concat.py @@ -1,6 +1,131 @@ +# coding: utf8 +# TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? +import abc +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from pydantic import BaseModel +from torch.utils.data import Dataset + from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.dataset.config.extraction import ( + ExtractionConfig, + ExtractionImageConfig, + ExtractionPatchConfig, + ExtractionROIConfig, + ExtractionSliceConfig, +) +from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.config.utils import ( + get_preprocessing_and_mode_from_json, +) +from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.enum import ( + ExtractionMethod, + Pattern, + Preprocessing, + SliceDirection, + SliceMode, + Template, +) +from clinicadl.utils.exceptions import ( + ClinicaDLCAPSError, + ClinicaDLConcatError, + ClinicaDLTSVError, +) +from clinicadl.utils.iotools.clinica_utils import check_caps_folder +from clinicadl.utils.iotools.utils import path_decoder, read_json + +logger = getLogger("clinicadl") class ConcatDataset(CapsDataset): - def __init__(self, list_: list[CapsDataset]): - """TO COMPLETE""" + def __init__(self, datasets: List[CapsDataset]): + self._datasets = datasets + self._len = sum(len(dataset) for dataset in datasets) + self._indexes = [] + + # Calculate distribution of indexes in all datasets + cumulative_index = 0 + for idx, dataset in enumerate(datasets): + next_cumulative_index = cumulative_index + len(dataset) + self._indexes.append((cumulative_index, next_cumulative_index, idx)) + cumulative_index = next_cumulative_index + + logger.debug(f"Datasets summary length: {self._len}") + logger.debug(f"Datasets indexes: {self._indexes}") + + self.caps_dict = self.compute_caps_dict() + self.check_configs() + + self.eval_mode = False + + def __getitem__(self, index: int) -> Tuple[List[int], List[int]]: + for start, stop, dataset_index in self._indexes: + if start <= index < stop: + dataset = self._datasets[dataset_index] + return dataset[index - start] + + def __len__(self) -> int: + return self._len + + def check_configs(self): + extraction = self._datasets[len(self._datasets) - 1].extraction + preprocessing = self._datasets[len(self._datasets) - 1].preprocessing + transforms = self._datasets[len(self._datasets) - 1].transforms + size = self._datasets[len(self._datasets) - 1].size + elem_per_image = self._datasets[len(self._datasets) - 1].elem_per_image + + for idx in range(len(self._datasets) - 1): + if self._datasets[idx].extraction != extraction: + raise ClinicaDLConcatError( + f"Different extraction modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].extraction}, " + f"Dataset {len(self._datasets)}: {extraction}" + ) + + if self._datasets[idx].preprocessing != preprocessing: + raise ClinicaDLConcatError( + f"Different preprocessing modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].preprocessing}, " + f"Dataset {len(self._datasets)}: {preprocessing}" + ) + + if self._datasets[idx].transforms != transforms: + raise ClinicaDLConcatError( + f"Different transforms modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].transforms}, " + f"Dataset {len(self._datasets)}: {transforms}" + ) + if self._datasets[idx].size != size: + raise ClinicaDLConcatError( + f"Different size modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].size}, " + f"Dataset {len(self._datasets)}: {size}" + ) + if self._datasets[idx].elem_per_image != elem_per_image: + raise ClinicaDLConcatError( + f"Different elem_per_image modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].elem_per_image}, " + f"Dataset {len(self._datasets)}: {elem_per_image}" + ) + + self.extraction = extraction + self.preprocessing = preprocessing + self.transforms = transforms + self.size = size + self.elem_per_image = elem_per_image + + def compute_caps_dict(self) -> Dict[str, Path]: + caps_dict = dict() + for idx in range(len(self._datasets)): + cohort = idx + caps_path = self._datasets[idx].caps_dict["caps_directory"] + check_caps_folder(caps_path) + caps_dict[cohort] = caps_path + + return caps_dict diff --git a/clinicadl/dataset/config/data.py b/clinicadl/dataset/config/data.py new file mode 100644 index 000000000..62ec340ad --- /dev/null +++ b/clinicadl/dataset/config/data.py @@ -0,0 +1,75 @@ +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import pandas as pd +from pydantic import BaseModel, ConfigDict, field_validator + +from clinicadl.utils.exceptions import ( + ClinicaDLArgumentError, + ClinicaDLTSVError, +) +from clinicadl.utils.iotools.data_utils import load_data_test + +logger = getLogger("clinicadl.data_config") + + +class DataConfig(BaseModel): # TODO : put in data module + """Config class to specify the data. + + caps_directory and preprocessing_json are arguments + that must be passed by the user. + """ + + caps_directory: Optional[Path] = None + baseline: bool = False + mask_path: Optional[Path] = None + data_tsv: Optional[Path] = None + n_subjects: int = 300 + # pydantic config + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + @field_validator("diagnoses", mode="before") + def validator_diagnoses(cls, v): + """Transforms a list to a tuple.""" + if isinstance(v, list): + return tuple(v) + return v # TODO : check if columns are in tsv + + def create_groupe_df(self): + group_df = None + if self.data_tsv is not None and self.data_tsv.is_file(): + group_df = load_data_test( + self.data_tsv, + self.diagnoses, + multi_cohort=self.multi_cohort, + ) + return group_df + + def is_given_label_code(self, _label: str, _label_code: Union[str, Dict[str, int]]): + return ( + self.label is not None + and self.label != "" + and self.label != _label + and _label_code == "default" + ) + + def check_label(self, _label: str): + if not self.label: + self.label = _label + + @field_validator("data_tsv", mode="before") + @classmethod + def check_data_tsv(cls, v) -> Path: + if v is not None: + if not isinstance(v, Path): + v = Path(v) + if not v.is_file(): + raise ClinicaDLTSVError( + "The participants_list you gave is not a file. Please give an existing file." + ) + if v.stat().st_size == 0: + raise ClinicaDLTSVError( + "The participants_list you gave is empty. Please give a non-empty file." + ) + return v diff --git a/clinicadl/dataset/config/extraction.py b/clinicadl/dataset/config/extraction.py index f3619590f..7fdc321f8 100644 --- a/clinicadl/dataset/config/extraction.py +++ b/clinicadl/dataset/config/extraction.py @@ -1,15 +1,24 @@ from logging import getLogger +from pathlib import Path from time import time -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union +import nibabel as nib +import numpy as np +import torch from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt from clinicadl.utils.enum import ( ExtractionMethod, + Pattern, + Preprocessing, SliceDirection, SliceMode, + Suffix, + Template, ) +from clinicadl.utils.exceptions import ClinicaDLArgumentError from clinicadl.utils.iotools.clinica_utils import FileType logger = getLogger("clinicadl.preprocessing_config") @@ -21,47 +30,484 @@ class ExtractionConfig(BaseModel): """ extract_method: ExtractionMethod - file_type: Optional[FileType] = None + extract_json: str = f"extract_{int(time())}.json" + use_uncropped_image: bool = True save_features: bool = False - extract_json: Optional[str] = None # pydantic config model_config = ConfigDict(validate_assignment=True) @field_validator("extract_json", mode="before") def compute_extract_json(cls, v: str): - if v is None: - return f"extract_{int(time())}.json" + if isinstance(v, Path): + v = str(v) elif not v.endswith(".json"): - return f"{v}.json" - else: - return v + v = f"{v}.json" + return v class ExtractionImageConfig(ExtractionConfig): extract_method: ExtractionMethod = ExtractionMethod.IMAGE + def extract_images(self, input_img: Path) -> list[Tuple[Path, torch.Tensor]]: + """Extract the images + This function convert nifti image to tensor (.pt) version of the image. + Tensor version is saved at the same location than input_img. + Args: + input_img: path to the NifTi input image. + Returns: + filename (str): single tensor file saved on the disk. Same location than input file. + """ + + image_array = nib.loadsave.load(input_img).get_fdata(dtype="float32") + image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() + # make sure the tensor type is torch.float32 + output_file = ( + Path(input_img.name.replace(Suffix.NII_GZ.value, Suffix.PT.value)), + image_tensor.clone(), + ) + + return [output_file] + class ExtractionPatchConfig(ExtractionConfig): patch_size: int = 50 stride_size: int = 50 extract_method: ExtractionMethod = ExtractionMethod.PATCH + def num_elem_per_image(self, image: torch.Tensor): + # if self.elem_index is not None: + # return 1 + patches_tensor = self.create_patches(image) + self.num_patches = patches_tensor.shape[0] + return self.num_patches + + def extract_patches( + self, + nii_path: Path, + ) -> List[Tuple[Path, torch.Tensor]]: + """Extracts the patches + This function extracts patches form the preprocessed nifti image. Patch size + if provided as input and also the stride size. If stride size is smaller + than the patch size an overlap exist between consecutive patches. If stride + size is equal to path size there is no overlap. Otherwise, unprocessed + zones can exits. + Args: + nii_path: path to the NifTi input image. + self.patch_size: size of a single patch. + self.stride_size: size of the stride leading to next patch. + Returns: + list of tuples containing the path to the extracted patch + and the tensor of the corresponding patch. + """ + + image_array = nib.loadsave.load(nii_path).get_fdata(dtype="float32") + image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() + + patches_tensor = self.create_patches(image_tensor) + + patch_list = [] + for patch_index in range(patches_tensor.shape[0]): + patch_tensor = self.extract_patch_tensor( + image_tensor, patch_index, patches_tensor + ) + patch_path = self.extract_patch_path(nii_path, patch_index) + + patch_list.append((patch_path, patch_tensor)) + + return patch_list + + def create_patches(self, image_tensor: torch.Tensor) -> torch.Tensor: + patches_tensor = ( + image_tensor.unfold(1, self.patch_size, self.stride_size) + .unfold(2, self.patch_size, self.stride_size) + .unfold(3, self.patch_size, self.stride_size) + .contiguous() + ) + # the dimension of patches_tensor is [1, patch_num1, patch_num2, patch_num3, self.patch_size1, self.patch_size2, self.patch_size3] + return patches_tensor.view( + -1, self.patch_size, self.patch_size, self.patch_size + ) + + def extract_patch_tensor( + self, + image_tensor: torch.Tensor, + patch_index: int, + patches_tensor: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Extracts a single patch from image_tensor""" + + if patches_tensor is None: + patches_tensor = self.create_patches(image_tensor) + + return patches_tensor[patch_index, ...].unsqueeze_(0).clone() + + def extract_patch_path(self, img_path: Path, patch_index: int) -> Path: + input_img_filename = img_path.name + txt_idx = input_img_filename.rfind("_") + it_filename_prefix = input_img_filename[0:txt_idx] + it_filename_suffix = input_img_filename[txt_idx:] + it_filename_suffix = it_filename_suffix.replace( + Suffix.NII_GZ.value, Suffix.PT.value + ) + + return Path( + f"{it_filename_prefix}_patchsize-{self.patch_size}_stride-{self.stride_size}_patch-{patch_index}{it_filename_suffix}" + ) + class ExtractionSliceConfig(ExtractionConfig): slice_direction: SliceDirection = SliceDirection.SAGITTAL slice_mode: SliceMode = SliceMode.RGB - num_slices: Optional[NonNegativeInt] = None - discarded_slices: Tuple[NonNegativeInt, NonNegativeInt] = (0, 0) + # num_slices: Optional[NonNegativeInt] = None # not sure it is needed + discarded_slices: Union[int, Tuple] = (0,) extract_method: ExtractionMethod = ExtractionMethod.SLICE + @field_validator("slice_direction", mode="before") + def check_slice_direction(cls, v: str): + if isinstance(v, int): + return SliceDirection(str(v)) + + # @field_validator("discarded_slices", mode="before") + # def compute_discarded_slice(cls, v: Union[int, Tuple]) -> Tuple[int, int]: + # return compute_discarded_slices(v) + # DONE in extraction + + def compute_discarded_slices(self) -> Tuple[int, int]: + if isinstance(self.discarded_slices, int): + begin_discard, end_discard = self.discarded_slices, self.discarded_slices + elif len(self.discarded_slices) == 1: + begin_discard, end_discard = ( + self.discarded_slices[0], + self.discarded_slices[0], + ) + elif len(self.discarded_slices) == 2: + begin_discard, end_discard = ( + self.discarded_slices[0], + self.discarded_slices[1], + ) + else: + raise IndexError( + f"Maximum two number of discarded slices can be defined. " + f"You gave discarded slices = {self.discarded_slices}." + ) + return begin_discard, end_discard + + def extract_slices( + self, + nii_path: Path, + ) -> List[Tuple[str, torch.Tensor]]: + """Extracts the slices from three directions + This function extracts slices form the preprocessed nifti image. + + The direction of extraction can be defined either on sagittal direction (0), + coronal direction (1) or axial direction (other). + + The output slices can be stored following two modes: + single (1 channel) or rgb (3 channels, all the same). + + Args: + nii_path: path to the NifTi input image. + slice_direction: along which axis slices are extracted. + slice_mode: 'single' or 'rgb'. + discarded_slices: Number of slices to discard at the beginning and the end of the image. + Will be a tuple of two integers if the number of slices to discard at the beginning + and at the end differ. + Returns: + list of tuples containing the path to the extracted slice + and the tensor of the corresponding slice. + """ + + image_array = nib.loadsave.load(nii_path).get_fdata(dtype="float32") + image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() + + begin_discard, end_discard = self.compute_discarded_slices() + index_list = range( + begin_discard, + image_tensor.shape[int(self.slice_direction.value) + 1] - end_discard, + ) + + slice_list = [] + for slice_index in index_list: + slice_tensor = self.extract_slice_tensor(image_tensor, slice_index) + slice_path = self.extract_slice_path(nii_path, slice_index) + + slice_list.append((slice_path, slice_tensor)) + + return slice_list + + def extract_slice_tensor( + self, + image_tensor: torch.Tensor, + slice_index: int, + ) -> torch.Tensor: + # Allow to select the slice `slice_index` in dimension `slice_direction` + idx_tuple = tuple( + [slice(None)] * (int(self.slice_direction.value) + 1) + + [slice_index] + + [slice(None)] * (2 - int(self.slice_direction.value)) + ) + slice_tensor = image_tensor[idx_tuple] # shape is 1 * W * L + + if self.slice_mode == "rgb": + slice_tensor = torch.cat( + (slice_tensor, slice_tensor, slice_tensor) + ) # shape is 3 * W * L + + return slice_tensor.clone() + + def extract_slice_path( + self, + img_path: Path, + slice_index: int, + ) -> str: + slice_dict = {0: "sag", 1: "cor", 2: "axi"} + input_img_filename = img_path.name + txt_idx = input_img_filename.rfind("_") + it_filename_prefix = input_img_filename[0:txt_idx] + it_filename_suffix = input_img_filename[txt_idx:] + it_filename_suffix = it_filename_suffix.replace( + Suffix.NII_GZ.value, Suffix.PT.value + ) + return ( + f"{it_filename_prefix}_axis-{slice_dict[int(self.slice_direction.value)]}" + f"_channel-{self.slice_mode.value}_slice-{slice_index}{it_filename_suffix}" + ) + class ExtractionROIConfig(ExtractionConfig): roi_list: List[str] = [] - roi_uncrop_output: bool = False + roi_crop_input: bool = True + roi_crop_output: bool = True + roi_template: str = "" + roi_pattern: str = "" + roi_background_value: int = 0 + roi_custom_template: str = "" - roi_custom_pattern: str = "" - roi_custom_suffix: str = "" roi_custom_mask_pattern: str = "" - roi_background_value: int = 0 + roi_custom_suffix: str = "" extract_method: ExtractionMethod = ExtractionMethod.ROI + + def check_with_preprocessing(self, preprocessing: Preprocessing): + if preprocessing == Preprocessing.CUSTOM: + if not self.roi_template: + raise ClinicaDLArgumentError( + "A custom template must be defined when the modality is set to custom." + ) + self.roi_template = self.roi_custom_template + self.roi_mask_pattern = self.roi_custom_mask_pattern + else: + if preprocessing == Preprocessing.T1_LINEAR: + self.roi_template = Template.T1_LINEAR + self.roi_mask_pattern = Pattern.T1_LINEAR + elif preprocessing == Preprocessing.PET_LINEAR: + self.roi_template = Template.PET_LINEAR + self.roi_mask_pattern = Pattern.PET_LINEAR + elif preprocessing == Preprocessing.FLAIR_LINEAR: + self.roi_template = Template.FLAIR_LINEAR + self.roi_mask_pattern = Pattern.FLAIR_LINEAR + + def check_mask_list( + self, + masks_location: Path, + ) -> None: + if len(self.roi_list) == 0: + raise ClinicaDLArgumentError("A list of regions of interest must be given.") + + for roi in self.roi_list: + roi_path, desc = self.find_mask_path(masks_location, roi) + if roi_path is None: + raise FileNotFoundError( + f"The ROI '{roi}' does not correspond to a mask in the CAPS directory. {desc}" + ) + roi_mask = nib.loadsave.load(roi_path).get_fdata() + mask_values = set(np.unique(roi_mask)) + if mask_values != {0, 1}: + raise ValueError( + "The ROI masks used should be binary (composed of 0 and 1 only)." + ) + + def find_mask_path( + self, + masks_location: Path, + roi: str, + ) -> Tuple[Union[None, Path], str]: + """ + Finds masks corresponding to the pattern asked and containing the adequate self.roi_crop_input description + + Parameters + ---------- + masks_location: Path + Directory containing the masks. + roi: str + Name of the region. + mask_pattern: str + Pattern which should be found in the filename of the mask. + self.roi_crop_input: bool + If True the original image should contain the substring 'desc-Crop'. + + Returns + ------- + path of the mask or None if nothing was found. + a human-friendly description of the pattern looked for. + """ + + # Check that pattern begins and ends with _ to avoid mixing keys + if self.roi_mask_pattern is None: + mask_pattern = "" + + candidates_pattern = f"*{mask_pattern}*_roi-{roi}_mask.nii*" + + desc = f"The mask should follow the pattern {candidates_pattern}. " + candidates = [e for e in masks_location.glob(candidates_pattern)] + if self.roi_crop_input is None: + # pass + candidates2 = candidates + elif self.roi_crop_input: + candidates2 = [mask for mask in candidates if "_desc-Crop_" in mask.name] + desc += "and contain '_desc-Crop_' string." + else: + candidates2 = [ + mask for mask in candidates if "_desc-Crop_" not in mask.name + ] + desc += "and not contain '_desc-Crop_' string." + + if len(candidates2) == 0: + return None, desc + else: + return min(candidates2), desc + + def compute_output_pattern(self, mask_path: Path): + """ + Computes the output pattern of the region cropped (without the source file prefix) + Parameters + ---------- + mask_path: Path + Path to the masks + self.roi_crop_output: bool + If True the output is cropped, and the descriptor CropRoi must exist + + Returns + ------- + the output pattern + """ + + mask_filename = mask_path.name + template_id = mask_filename.split("_")[0].split("-")[1] + mask_descriptors = mask_filename.split("_")[1:-2:] + roi_id = mask_filename.split("_")[-2].split("-")[1] + if "desc-Crop" not in mask_descriptors and not self.roi_crop_output: + mask_descriptors = ["desc-CropRoi"] + mask_descriptors + elif "desc-Crop" in mask_descriptors: + mask_descriptors = [ + descriptor + for descriptor in mask_descriptors + if descriptor != "desc-Crop" + ] + if self.roi_crop_output: + mask_descriptors = ["desc-CropRoi"] + mask_descriptors + else: + mask_descriptors = ["desc-CropImage"] + mask_descriptors + + mask_pattern = "_".join(mask_descriptors) + + if mask_pattern == "": + output_pattern = f"space-{template_id}_roi-{roi_id}" + else: + output_pattern = f"space-{template_id}_{mask_pattern}_roi-{roi_id}" + + return output_pattern + + def extract_roi( + self, + nii_path: Path, + masks_location: Path, + ) -> List[Tuple[str, torch.Tensor]]: + """Extracts regions of interest defined by masks + This function extracts regions of interest from preprocessed nifti images. + The regions are defined using binary masks that must be located in the CAPS + at `masks/tpl-