diff --git a/clinicadl/API_test.py b/clinicadl/API/API_test.py similarity index 70% rename from clinicadl/API_test.py rename to clinicadl/API/API_test.py index a6eb9fa72..6f50708fb 100644 --- a/clinicadl/API_test.py +++ b/clinicadl/API/API_test.py @@ -1,6 +1,6 @@ from pathlib import Path -import torchio +import torchio.transforms as transforms from clinicadl.dataset.caps_dataset import ( CapsDatasetPatch, @@ -9,7 +9,11 @@ ) from clinicadl.dataset.caps_reader import CapsReader from clinicadl.dataset.concat import ConcatDataset -from clinicadl.dataset.config.extraction import ExtractionConfig +from clinicadl.dataset.config.extraction import ( + ExtractionConfig, + ExtractionImageConfig, + ExtractionPatchConfig, +) from clinicadl.dataset.config.preprocessing import ( PreprocessingConfig, T1PreprocessingConfig, @@ -17,7 +21,7 @@ from clinicadl.experiment_manager.experiment_manager import ExperimentManager from clinicadl.losses.config import CrossEntropyLossConfig from clinicadl.losses.factory import get_loss_function -from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.model.clinicadl_model import ClinicaDLModel, ClinicaDLModelClassif from clinicadl.networks.config import ImplementedNetworks from clinicadl.networks.factory import ( ConvEncoderOptions, @@ -30,42 +34,50 @@ from clinicadl.splitter.kfold import KFolder from clinicadl.splitter.split import get_single_split, split_tsv from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.config import TransformsConfig from clinicadl.transforms.transforms import Transforms # Create the Maps Manager / Read/write manager / maps_path = Path("/") -manager = ExperimentManager(maps_path, overwrite=False) +manager = ExperimentManager( + maps_path, overwrite=False +) # a ajouter dans le manager: mlflow/ profiler/ etc ... caps_directory = Path("caps_directory") # output of clinica pipelines -caps_reader = CapsReader(caps_directory, manager=manager) +caps_reader = CapsReader(caps_directory) preprocessing_1 = caps_reader.get_preprocessing("t1-linear") -extraction_1 = caps_reader.extract_slice(preprocessing=preprocessing_1, arg_slice=2) +caps_reader.prepare_data( + preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2 +) # don't return anything -> just extract the image tensor and compute some information for each images + +extraction_1 = ExtractionImageConfig() transforms_1 = Transforms( - data_augmentation=[torchio.t1, torchio.t2], - image_transforms=[torchio.t1, torchio.t2], - object_transforms=[torchio.t1, torchio.t2], + extraction=extraction_1, + image_augmentation=[transforms.Crop(), transforms.Transform()], + image_transforms=[transforms.Blur(), transforms.Ghosting()], + object_transforms=[transforms.BiasField(), transforms.Motion()], ) # not mandatory preprocessing_2 = caps_reader.get_preprocessing("pet-linear") -extraction_2 = caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch=2) +caps_reader.prepare_data(preprocessing=preprocessing_2) +extraction_2 = ExtractionPatchConfig(patch_size=20) transforms_2 = Transforms( - data_augmentation=[torchio.t2], - image_transforms=[torchio.t1], - object_transforms=[torchio.t1, torchio.t2], + extraction=extraction_2, + object_augmentation=[transforms.BiasField()], + image_transforms=[transforms.Motion()], + object_transforms=[transforms.Crop()], ) sub_ses_tsv = Path("") split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv dataset_t1_roi = caps_reader.get_dataset( - extraction=extraction_1, preprocessing=preprocessing_1, sub_ses_tsv=split_dir / "train.tsv", transforms=transforms_1, ) # do we give config or object for transforms ? dataset_pet_patch = caps_reader.get_dataset( - extraction=extraction_2, preprocessing=preprocessing_2, sub_ses_tsv=split_dir / "train.tsv", transforms=transforms_2, @@ -94,7 +106,21 @@ ) network, _ = get_network_from_config(network_config) optimizer, _ = get_optimizer(network, AdamConfig()) - model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) + model = ClinicaDLModelClassif(network=network, loss=loss, optimizer=optimizer) + + # from config + _, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], + num_outputs=1, + conv_args=ConvEncoderOptions(channels=[3, 2, 2]), + ) + + model = ClinicaDLModelClassif.from_config( + network_config=network_config, + loss_config=loss_config, + optimizer_config=AdamConfig(), + ) trainer.train(model, split) # le trainer va instancier un predictor/valdiator dans le train ou dans le init @@ -113,7 +139,7 @@ ) network, _ = get_network_from_config(network_config) optimizer, _ = get_optimizer(network, AdamConfig()) -model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) +model = ClinicaDLModelClassif(network=network, loss=loss, optimizer=optimizer) trainer.train(model, split) # le trainer va instancier un predictor/valdiator dans le train ou dans le init @@ -121,24 +147,22 @@ # TEST -preprocessing_test: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear") -extraction_test: ExtractionConfig = caps_reader.extract_patch( - preprocessing=preprocessing_2, arg_patch=2 -) +preprocessing_test = caps_reader.get_preprocessing("pet-linear") +caps_reader.prepare_data(preprocessing=preprocessing_2) transforms_test = Transforms( - data_augmentation=[torchio.t2], - image_transforms=[torchio.t1], - object_transforms=[torchio.t1, torchio.t2], + extraction=extraction_2, + object_augmentation=[transforms.BiasField()], + image_transforms=[transforms.Motion()], + object_transforms=[transforms.Crop()], ) dataset_test = caps_reader.get_dataset( - extraction=extraction_test, preprocessing=preprocessing_test, sub_ses_tsv=split_dir / "test.tsv", transforms=transforms_test, ) -predictor = Predictor(manager=manager) +predictor = Predictor(model=model, manager=manager) predictor.predict(dataset_test=dataset_test, split=2) @@ -149,18 +173,25 @@ manager = ExperimentManager(maps_path, overwrite=False) caps_directory = Path("caps_directory") # output of clinica pipelines -caps_reader = CapsReader(caps_directory, manager=manager) +caps_reader = CapsReader(caps_directory) -extraction_1 = caps_reader.extract_image(preprocessing=T1PreprocessingConfig()) +preprocessing_config = caps_reader.prepare_data( + preprocessing=T1PreprocessingConfig(), + data_tsv=Path(""), + n_proc=2, + use_uncropped_images=False, +) transforms_1 = Transforms( - data_augmentation=[torchio.transforms.RandomMotion] + object_augmentation=[transforms.RandomMotion], # default = no transforms + image_transforms=[transforms.Noise], # default = MiniMax + extraction=ExtractionPatchConfig(), # default = Image + object_transforms=[transforms.BiasField], # default = none ) # not mandatory sub_ses_tsv = Path("") split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv dataset_t1_image = caps_reader.get_dataset( - extraction=extraction_1, preprocessing=T1PreprocessingConfig(), sub_ses_tsv=split_dir / "train.tsv", transforms=transforms_1, @@ -198,10 +229,7 @@ # sub_ses_tsv = Path("") # split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv -dataset_t1_image = CapsDatasetPatch.from_json( - extraction=extract_json, - sub_ses_tsv=split_dir / "train.tsv", -) +dataset_ = caps_reader.get_dataset_from_json(json_path=Path("")) config_file = Path("config_file") trainer = Trainer.from_json(config_file=config_file, manager=manager) diff --git a/clinicadl/API/complicated_case.py b/clinicadl/API/complicated_case.py new file mode 100644 index 000000000..2d94e9f6c --- /dev/null +++ b/clinicadl/API/complicated_case.py @@ -0,0 +1,128 @@ +from pathlib import Path + +import torchio.transforms as transforms + +from clinicadl.dataset.caps_dataset import ( + CapsDatasetPatch, + CapsDatasetRoi, + CapsDatasetSlice, +) +from clinicadl.dataset.caps_reader import CapsReader +from clinicadl.dataset.concat import ConcatDataset +from clinicadl.dataset.config.extraction import ExtractionConfig +from clinicadl.dataset.config.preprocessing import ( + PreprocessingConfig, + T1PreprocessingConfig, +) +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.config import CrossEntropyLossConfig +from clinicadl.losses.factory import get_loss_function +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.networks.config import ImplementedNetworks +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig +from clinicadl.optimization.optimizer.factory import get_optimizer +from clinicadl.predictor.predictor import Predictor +from clinicadl.splitter.kfold import KFolder +from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.config import TransformsConfig + +# Create the Maps Manager / Read/write manager / +maps_path = Path("/") +manager = ExperimentManager( + maps_path, overwrite=False +) # a ajouter dans le manager: mlflow/ profiler/ etc ... + +caps_directory = Path("caps_directory") # output of clinica pipelines +caps_reader = CapsReader(caps_directory, manager=manager) + +preprocessing_1 = caps_reader.get_preprocessing("t1-linear") +caps_reader.prepare_data( + preprocessing=preprocessing_1, data_tsv=Path(""), n_proc=2 +) # don't return anything -> just extract the image tensor and compute some information for each images + + +transforms_1 = TransformsConfig( + data_augmentation=[transforms.Crop, transforms.Transform], + image_transforms=[transforms.Blur, transforms.Ghosting], + object_transforms=[transforms.BiasField, transforms.Motion], +) # not mandatory + +preprocessing_2 = caps_reader.get_preprocessing("pet-linear") +extraction_2 = caps_reader.extract_patch(preprocessing=preprocessing_2, arg_patch=2) +transforms_2 = TransformsConfig( + data_augmentation=[torchio.t2], + image_transforms=[torchio.t1], + object_transforms=[torchio.t1, torchio.t2], +) + +sub_ses_tsv = Path("") +split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv + +dataset_t1_roi = caps_reader.get_dataset( + extraction=extraction_1, + preprocessing=preprocessing_1, + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_1, +) # do we give config or object for transforms ? +dataset_pet_patch = caps_reader.get_dataset( + extraction=extraction_2, + preprocessing=preprocessing_2, + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_2, +) + +dataset_multi_modality_multi_extract = ConcatDataset( + [dataset_t1_roi, dataset_pet_patch] +) # 2 train.tsv en entrée qu'il faut concat et pareil pour les transforms à faire attention + +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + +# CAS CROSS-VALIDATION +splitter = KFolder( + n_splits=3, caps_dataset=dataset_multi_modality_multi_extract, manager=manager +) + +for split in splitter.split_iterator(split_list=[0, 1]): + # bien définir ce qu'il y a dans l'objet split + + loss, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], + num_outputs=1, + conv_args=ConvEncoderOptions(channels=[3, 2, 2]), + ) + network, _ = get_network_from_config(network_config) + optimizer, _ = get_optimizer(network, AdamConfig()) + model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) + + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init + +# TEST + +preprocessing_test: PreprocessingConfig = caps_reader.get_preprocessing("pet-linear") +extraction_test: ExtractionConfig = caps_reader.extract_patch( + preprocessing=preprocessing_2, arg_patch=2 +) +transforms_test = Transforms( + data_augmentation=[torchio.t2], + image_transforms=[torchio.t1], + object_transforms=[torchio.t1, torchio.t2], +) + +dataset_test = caps_reader.get_dataset( + extraction=extraction_test, + preprocessing=preprocessing_test, + sub_ses_tsv=split_dir / "test.tsv", + transforms=transforms_test, +) + +predictor = Predictor(manager=manager) +predictor.predict(dataset_test=dataset_test, split=2) diff --git a/clinicadl/API/cross_val.py b/clinicadl/API/cross_val.py new file mode 100644 index 000000000..09d9c1aca --- /dev/null +++ b/clinicadl/API/cross_val.py @@ -0,0 +1,67 @@ +from pathlib import Path + +import torchio.transforms as transforms + +from clinicadl.dataset.caps_dataset import ( + CapsDatasetPatch, + CapsDatasetRoi, + CapsDatasetSlice, +) +from clinicadl.dataset.caps_reader import CapsReader +from clinicadl.dataset.concat import ConcatDataset +from clinicadl.dataset.config.extraction import ExtractionConfig +from clinicadl.dataset.config.preprocessing import ( + PreprocessingConfig, + T1PreprocessingConfig, +) +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.config import CrossEntropyLossConfig +from clinicadl.losses.factory import get_loss_function +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.networks.config import ImplementedNetworks +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig +from clinicadl.optimization.optimizer.factory import get_optimizer +from clinicadl.predictor.predictor import Predictor +from clinicadl.splitter.kfold import KFolder +from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.config import TransformsConfig + +# SIMPLE EXPERIMENT WITH A CAPS ALREADY EXISTING + +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite=False) + +# sub_ses_tsv = Path("") +# split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv + +dataset_t1_image = CapsDatasetPatch.from_json( + extraction=extract_json, + sub_ses_tsv=split_dir / "train.tsv", +) +config_file = Path("config_file") +trainer = Trainer.from_json(config_file=config_file, manager=manager) + +# CAS CROSS-VALIDATION +splitter = KFolder(n_splits=3, caps_dataset=dataset_t1_image, manager=manager) + +for split in splitter.split_iterator(split_list=[0, 1]): + # bien définir ce qu'il y a dans l'objet split + + loss, loss_config = get_loss_function(CrossEntropyLossConfig()) + network_config = create_network_config(ImplementedNetworks.CNN)( + in_shape=[2, 2, 2], + num_outputs=1, + conv_args=ConvEncoderOptions(channels=[3, 2, 2]), + ) + network, _ = get_network_from_config(network_config) + optimizer, _ = get_optimizer(network, AdamConfig()) + model = ClinicaDLModel(network=network, loss=loss, optimizer=optimizer) + + trainer.train(model, split) + # le trainer va instancier un predictor/valdiator dans le train ou dans le init diff --git a/clinicadl/API/single_split.py b/clinicadl/API/single_split.py new file mode 100644 index 000000000..91d2b5001 --- /dev/null +++ b/clinicadl/API/single_split.py @@ -0,0 +1,69 @@ +from pathlib import Path + +import torchio.transforms as transforms + +from clinicadl.dataset.caps_dataset import ( + CapsDatasetPatch, + CapsDatasetRoi, + CapsDatasetSlice, +) +from clinicadl.dataset.caps_reader import CapsReader +from clinicadl.dataset.concat import ConcatDataset +from clinicadl.dataset.config.extraction import ExtractionConfig, ExtractionPatchConfig +from clinicadl.dataset.config.preprocessing import ( + PreprocessingConfig, + T1PreprocessingConfig, +) +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.config import CrossEntropyLossConfig +from clinicadl.losses.factory import get_loss_function +from clinicadl.model.clinicadl_model import ClinicaDLModel +from clinicadl.networks.config import ImplementedNetworks +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig +from clinicadl.optimization.optimizer.factory import get_optimizer +from clinicadl.predictor.predictor import Predictor +from clinicadl.splitter.kfold import KFolder +from clinicadl.splitter.split import get_single_split, split_tsv +from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.config import TransformsConfig +from clinicadl.transforms.transforms import Transforms +from clinicadl.utils.enum import ExtractionMethod + +# SIMPLE EXPERIMENT + + +maps_path = Path("/") +manager = ExperimentManager(maps_path, overwrite=False) + +caps_directory = Path("caps_directory") # output of clinica pipelines +caps_reader = CapsReader( + caps_directory, manager=manager +) # un peu bizarre de passer un maps_path a cet endroit via le manager pq on veut pas forcmeent faire un entrainement ?? + +preprocessing_t1 = caps_reader.get_preprocessing("t1-linear") +caps_reader.prepare_data( + preprocessing=preprocessing_t1, + data_tsv=Path(""), + n_proc=2, + use_uncropped_images=False, +) +transforms_1 = Transforms( + object_augmentation=[transforms.RandomMotion()], # default = no transforms + image_transforms=[transforms.Noise(0.2, 0.5, 3)], # default = MiniMax + extraction=ExtractionPatchConfig(patch_size=30, stride_size=20), # default = Image + object_transforms=[transforms.Blur((0.4, 0.5, 0.6))], # default = none +) # not mandatory + +sub_ses_tsv = Path("") +split_dir = split_tsv(sub_ses_tsv) # -> creer un test.tsv et un train.tsv + +dataset_t1_image = caps_reader.get_dataset( + preprocessing=preprocessing_t1, + sub_ses_tsv=split_dir / "train.tsv", + transforms=transforms_1, +) # do we give config or ob diff --git a/clinicadl/commandline/pipelines/generate/artifacts/cli.py b/clinicadl/commandline/pipelines/generate/artifacts/cli.py index 68d1ec869..5cf5387fa 100644 --- a/clinicadl/commandline/pipelines/generate/artifacts/cli.py +++ b/clinicadl/commandline/pipelines/generate/artifacts/cli.py @@ -13,8 +13,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.artifacts import options as artifacts +from clinicadl.dataset.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateArtifactsConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py index 82c4e5cb3..9fe878578 100644 --- a/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py +++ b/clinicadl/commandline/pipelines/generate/hypometabolic/cli.py @@ -12,8 +12,8 @@ from clinicadl.commandline.pipelines.generate.hypometabolic import ( options as hypometabolic, ) +from clinicadl.dataset.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateHypometabolicConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/random/cli.py b/clinicadl/commandline/pipelines/generate/random/cli.py index 8ea26a5d0..f14ccb6bd 100644 --- a/clinicadl/commandline/pipelines/generate/random/cli.py +++ b/clinicadl/commandline/pipelines/generate/random/cli.py @@ -14,8 +14,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.random import options as random +from clinicadl.dataset.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateRandomConfig from clinicadl.generate.generate_utils import ( load_and_check_tsv, diff --git a/clinicadl/commandline/pipelines/generate/trivial/cli.py b/clinicadl/commandline/pipelines/generate/trivial/cli.py index d683865f2..3cdc16267 100644 --- a/clinicadl/commandline/pipelines/generate/trivial/cli.py +++ b/clinicadl/commandline/pipelines/generate/trivial/cli.py @@ -13,8 +13,8 @@ preprocessing, ) from clinicadl.commandline.pipelines.generate.trivial import options as trivial +from clinicadl.dataset.caps_dataset.caps_dataset_utils import find_file_type from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.caps_dataset_utils import find_file_type from clinicadl.generate.generate_config import GenerateTrivialConfig from clinicadl.generate.generate_utils import ( im_loss_roi_gaussian_distribution, diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py index d162dcf97..17b418249 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_cli.py @@ -8,7 +8,7 @@ preprocessing, ) from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.dataset.caps_reader.prepare_data import DeepLearningPrepareData from clinicadl.utils.enum import ExtractionMethod diff --git a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py index f4f888a71..a71cc1ca2 100644 --- a/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py +++ b/clinicadl/commandline/pipelines/prepare_data/prepare_data_from_bids_cli.py @@ -8,7 +8,7 @@ preprocessing, ) from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.prepare_data.prepare_data import DeepLearningPrepareData +from clinicadl.dataset.caps_reader.prepare_data import DeepLearningPrepareData from clinicadl.utils.enum import ExtractionMethod diff --git a/clinicadl/commandline/pipelines/train/classification/cli.py b/clinicadl/commandline/pipelines/train/classification/cli.py index 21d57f365..0d20f7821 100644 --- a/clinicadl/commandline/pipelines/train/classification/cli.py +++ b/clinicadl/commandline/pipelines/train/classification/cli.py @@ -1,4 +1,8 @@ +from pathlib import Path + import click +import pandas as pd +import torchio as tio from clinicadl.commandline import arguments from clinicadl.commandline.modules_options import ( @@ -22,9 +26,27 @@ from clinicadl.commandline.pipelines.transfer_learning import ( options as transfer_learning, ) +from clinicadl.dataset.caps_reader import CapsMultiReader, CapsReader +from clinicadl.experiment_manager.experiment_manager import ExperimentManager +from clinicadl.losses.enum import ClassificationLoss +from clinicadl.losses.factory import create_loss_config, get_loss_function +from clinicadl.model.clinicadl_model import ClinicaDLModelClassif +from clinicadl.networks.config import ImplementedNetworks +from clinicadl.networks.factory import ( + ConvEncoderOptions, + create_network_config, + get_network_from_config, +) +from clinicadl.optimization.optimizer.config import AdamConfig, OptimizerConfig +from clinicadl.optimization.optimizer.factory import ( + create_optimizer_config, + get_optimizer, +) +from clinicadl.splitter.kfold import KFolder, Splitter from clinicadl.trainer.config.classification import ClassificationConfig -from clinicadl.trainer.old_trainer import Trainer -from clinicadl.utils.enum import Task +from clinicadl.trainer.trainer import Trainer +from clinicadl.transforms.transforms import Transforms +from clinicadl.utils.enum import ClassificationLoss, ExtractionMethod, Task from clinicadl.utils.iotools.train_utils import merge_cli_and_config_file_options @@ -104,9 +126,44 @@ def cli(**kwargs): https://clinicadl.readthedocs.io/en/stable/Train/Introduction/#configuration-file """ - options = merge_cli_and_config_file_options(Task.CLASSIFICATION, **kwargs) + + manager = ExperimentManager( + maps_path=options["output_maps_directory"], overwrite=False + ) + + if options["multi-cohort"]: + caps_reader = CapsMultiReader(caps_directory=options["caps_directory"]) + else: + caps_reader = CapsReader( + caps_directory=options["caps_directory"] + ) # un peu bizarre de passer un maps_path a cet endroit via le manager pq on veut pas forcmeent faire un entrainement ?? + + preprocessing, extraction = caps_reader.get_preprocessing_and_extraction_from_json( + preprocessing_json=options["preprocessing_json"] + ) + transforms = Transforms(extraction=extraction, **options) # not mandatory + dataset = caps_reader.get_dataset( + preprocessing=preprocessing, + transforms=transforms, + ) + + # CAS CROSS-VALIDATION + splitter = Splitter( + n_splits=options["n_splits"], caps_dataset=dataset, manager=manager + ) + config = ClassificationConfig(**options) - trainer = Trainer(config) + trainer = Trainer(config=config, manager=manager) + + for split in splitter.split_iterator(split_list=options["split"]): + loss_config = create_loss_config(options["loss"])(**options) + network_config = create_network_config(options["architecture"])(**options) + optimizer_config = create_optimizer_config(options["optimizer"])(**options) + model = ClinicaDLModelClassif.from_config( + network_config=network_config, + loss_config=loss_config, + optimizer_config=optimizer_config, + ) - trainer.train(split_list=config.split.split, overwrite=True) + trainer.train(model, split) diff --git a/clinicadl/dataset/caps_dataset.py b/clinicadl/dataset/caps_dataset.py index d45dc5aa6..5699e07a0 100644 --- a/clinicadl/dataset/caps_dataset.py +++ b/clinicadl/dataset/caps_dataset.py @@ -5,116 +5,77 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union +import nibabel as nib import numpy as np import pandas as pd import torch from torch.utils.data import Dataset +from torchvision.transforms import ToTensor -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig from clinicadl.dataset.config.extraction import ( + ExtractionConfig, ExtractionImageConfig, ExtractionPatchConfig, ExtractionROIConfig, ExtractionSliceConfig, ) -from clinicadl.dataset.prepare_data.prepare_data_utils import ( - compute_discarded_slices, - extract_patch_path, - extract_patch_tensor, - extract_roi_path, - extract_roi_tensor, - extract_slice_path, - extract_slice_tensor, - find_mask_path, -) -from clinicadl.transforms.config import TransformsConfig +from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.utils import CapsDatasetOutput +from clinicadl.transforms.transforms import Transforms from clinicadl.utils.enum import ( + STR, + ExtractionMethod, Pattern, Preprocessing, SliceDirection, SliceMode, + SubFolder, + Suffix, Template, ) from clinicadl.utils.exceptions import ( ClinicaDLCAPSError, ClinicaDLTSVError, ) +from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader logger = getLogger("clinicadl") -################################# -# Datasets loaders -################################# class CapsDataset(Dataset): """Abstract class for all derived CapsDatasets.""" def __init__( self, - config: CapsDatasetConfig, - label_presence: bool, - preprocessing_dict: Dict[str, Any], + caps_directory: Path, + data_df: pd.DataFrame, + preprocessing: PreprocessingConfig, + transforms: Transforms, + index: Optional[int] = None, ): - self.label_presence = label_presence - self.eval_mode = False - self.config = config - self.preprocessing_dict = preprocessing_dict - - if not hasattr(self, "elem_index"): - raise AttributeError( - "Child class of CapsDataset must set elem_index attribute." - ) - if not hasattr(self, "mode"): - raise AttributeError("Child class of CapsDataset, must set mode attribute.") - - self.df = self.config.data.data_df + self.caps_directory = caps_directory + self.subjects_directory = caps_directory / STR.SUBJECTS.value + self.preprocessing = preprocessing + self.transforms = transforms + self.extraction = transforms.extraction + self.image_0 = self._get_full_image() + self.df = data_df mandatory_col = { - "participant_id", - "session_id", - "cohort", + STR.PARTICIPANT_ID.value, + STR.SESSION_ID.value, + STR.COHORT.value, } - if label_presence and self.config.data.label is not None: - mandatory_col.add(self.config.data.label) if not mandatory_col.issubset(set(self.df.columns.values)): raise ClinicaDLTSVError( f"the data file is not in the correct format." f"Columns should include {mandatory_col}" ) - self.elem_per_image = self.num_elem_per_image() - self.size = self[0]["image"].size() - - @property - @abc.abstractmethod - def elem_index(self): - pass - - def label_fn(self, target: Union[str, float, int]) -> Union[float, int, None]: - """ - Returns the label value usable in criterion. - - Args: - target: value of the target. - Returns: - label: value of the label usable in criterion. - """ - # Reconstruction case (no label) - if self.config.data.label is None: - return None - # Regression case (no label code) - elif self.config.data.label_code is None: - return np.float32([target]) - # Classification case (label + label_code dict) - else: - return self.config.data.label_code[str(target)] - - def domain_fn(self, target: Union[str, float, int]) -> Union[float, int]: - """ - Returns the label value usable in criterion. - - """ - domain_code = {"t1": 0, "flair": 1} - return domain_code[str(target)] + self.elem_index = index + self.elem_per_image = self.extraction.num_elem_per_image( + elem_index=self.elem_index, image=self.image_0 + ) + self.size = self[0].image.size() def __len__(self) -> int: return len(self.df) * self.elem_per_image @@ -130,50 +91,48 @@ def _get_image_path(self, participant: str, session: str, cohort: str) -> Path: Returns: image_path: path to the tensor containing the whole image. """ - from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader # Try to find .nii.gz file try: - folder, file_type = self.config.compute_folder_and_file_type() - results = clinicadl_file_reader( [participant], [session], - self.config.data.caps_dict[cohort], - file_type.model_dump(), + self.caps_directory, + self.preprocessing.file_type.model_dump(), ) logger.debug(f"clinicadl_file_reader output: {results}") filepath = Path(results[0][0]) - image_filename = filepath.name.replace(".nii.gz", ".pt") + image_filename = filepath.name.replace(Suffix.PT.value, Suffix.NII_GZ.value) image_dir = ( - self.config.data.caps_dict[cohort] - / "subjects" + self.caps_directory + / STR.SUBJECTS.value / participant / session - / "deeplearning_prepare_data" - / "image_based" - / folder + / STR.DEEP_L_P_DATA.value + / SubFolder.IMAGE.value + / self.preprocessing.compute_folder() ) image_path = image_dir / image_filename # Try to find .pt file except ClinicaDLCAPSError: - folder, file_type = self.config.compute_folder_and_file_type() - file_type.pattern = file_type.pattern.replace(".nii.gz", ".pt") + self.preprocessing.file_type.pattern = ( + self.preprocessing.file_type.pattern.replace( + Suffix.NII_GZ.value, Suffix.PT.value + ) + ) results = clinicadl_file_reader( [participant], [session], - self.config.data.caps_dict[cohort], - file_type.model_dump(), + self.caps_directory, + self.preprocessing.file_type.model_dump(), ) filepath = results[0] image_path = Path(filepath[0]) return image_path - def _get_meta_data( - self, idx: int - ) -> Tuple[str, str, str, Union[float, int, None], int]: + def _get_meta_data(self, idx: int) -> Tuple[str, str, str, int]: """ Gets all meta data necessary to compute the path with _get_image_path @@ -187,26 +146,16 @@ def _get_meta_data( label (str or float or int): value of the label to be used in criterion. """ image_idx = idx // self.elem_per_image - participant = self.df.at[image_idx, "participant_id"] - session = self.df.at[image_idx, "session_id"] - cohort = self.df.at[image_idx, "cohort"] + participant = self.df.at[image_idx, STR.PARTICIPANT_ID.value] + session = self.df.at[image_idx, STR.SESSION_ID.value] + cohort = self.df.at[image_idx, STR.COHORT.value] if self.elem_index is None: elem_idx = idx % self.elem_per_image else: elem_idx = self.elem_index - if self.label_presence and self.config.data.label is not None: - target = self.df.at[image_idx, self.config.data.label] - label = self.label_fn(target) - else: - label = -1 - if "domain" in self.df.columns: - domain = self.df.at[image_idx, "domain"] - domain = self.domain_fn(domain) - else: - domain = "" # TO MODIFY - return participant, session, cohort, elem_idx, label, domain + return participant, session, cohort, elem_idx def _get_full_image(self) -> torch.Tensor: """ @@ -216,33 +165,28 @@ def _get_full_image(self) -> torch.Tensor: Returns: image tensor of the full image first image. """ - import nibabel as nib - - from clinicadl.utils.iotools.clinica_utils import clinicadl_file_reader - participant_id = self.df.loc[0, "participant_id"] - session_id = self.df.loc[0, "session_id"] - cohort = self.df.loc[0, "cohort"] + participant_id = self.df.at[0, STR.PARTICIPANT_ID.value] + session_id = self.df.at[0, STR.SESSION_ID.value] + cohort = self.df.at[0, STR.COHORT.value] try: image_path = self._get_image_path(participant_id, session_id, cohort) image = torch.load(image_path, weights_only=True) except IndexError: - file_type = self.config.extraction.file_type results = clinicadl_file_reader( [participant_id], [session_id], - self.config.data.caps_dict[cohort], - file_type.model_dump(), + self.caps_directory, + self.preprocessing.file_type.model_dump(), ) - image_nii = nib.loadsave.load(results[0]) - image_np = image_nii.get_fdata() + image_nii = nib.loadsave.load((results[0])) # type: ignore + image_np = image_nii.get_fdata() # type: ignore image = ToTensor()(image_np) return image - @abc.abstractmethod - def __getitem__(self, idx: int) -> Dict[str, Any]: + def __getitem__(self, idx: int) -> CapsDatasetOutput: """ Gets the sample containing all the information needed for training and testing tasks. @@ -252,566 +196,70 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: dictionary with following items: - "image" (torch.Tensor): the input given to the model, - "label" (int or float): the label used in criterion, - - "participant_id" (str): ID of the participant, - - "session_id" (str): ID of the session, + - PARTICIPANT_ID (str): ID of the participant, + - SESSION_ID (str): ID of the session, - f"{self.mode}_id" (int): number of the element, - "image_path": path to the image loaded in CAPS. """ - pass - - @abc.abstractmethod - def num_elem_per_image(self) -> int: - """Computes the number of elements per image based on the full image.""" - pass - - def eval(self): - """Put the dataset on evaluation mode (data augmentation is not performed).""" - self.eval_mode = True - return self - - def train(self): - """Put the dataset on training mode (data augmentation is performed).""" - self.eval_mode = False - return self - - -class CapsDatasetImage(CapsDataset): - """Dataset of MRI organized in a CAPS folder.""" - - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - """ - - self.mode = "image" - self.config = config - self.label_presence = label_presence - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return None - - def __getitem__(self, idx): - participant, session, cohort, _, label, domain = self._get_meta_data(idx) + participant, session, cohort, index = self._get_meta_data(idx) image_path = self._get_image_path(participant, session, cohort) image = torch.load(image_path, weights_only=True) - train_trf, trf = self.config.transforms.get_transforms() - - image = trf(image) - if self.config.transforms.train_transformations and not self.eval_mode: - image = train_trf(image) - - sample = { - "image": image, - "label": label, - "participant_id": participant, - "session_id": session, - "image_id": 0, - "image_path": image_path.as_posix(), - "domain": domain, - } - - return sample - - def num_elem_per_image(self): - return 1 + ( + image_trf, + object_trf, + image_augmentation, + object_augmentation, + ) = self.transforms.get_transforms() + image = image_trf(image) -class CapsDatasetPatch(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - patch_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied only on training mode. - """ - self.patch_index = patch_index - self.mode = "patch" - self.config = config - self.label_presence = label_presence - - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.patch_index - - def __getitem__(self, idx): - participant, session, cohort, patch_idx, label, domain = self._get_meta_data( - idx - ) - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.save_features: - patch_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - patch_filename = extract_patch_path( - image_path, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - patch_idx, - ) - patch_tensor = torch.load( - Path(patch_dir).resolve() / patch_filename, weights_only=True - ) + if image_augmentation and not self.eval_mode: + image = image_augmentation(image) - else: - image = torch.load(image_path, weights_only=True) - patch_tensor = extract_patch_tensor( + if not isinstance(self.extraction, ExtractionImageConfig): + tensor = self.transforms.extraction.extract_tensor( image, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - patch_idx, + index, ) + if object_trf: + tensor = object_trf(tensor) - train_trf, trf = self.config.transforms.get_transforms() - patch_tensor = trf(patch_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - patch_tensor = train_trf(patch_tensor) + if object_augmentation and not self.eval_mode: + tensor = object_augmentation(tensor) - sample = { - "image": patch_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "patch_id": patch_idx, - } - - return sample + out = tensor + index = 0 - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - - image = self._get_full_image() - - patches_tensor = ( - image.unfold( - 1, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .unfold( - 2, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .unfold( - 3, - self.config.extraction.patch_size, - self.config.extraction.stride_size, - ) - .contiguous() - ) - patches_tensor = patches_tensor.view( - -1, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - self.config.extraction.patch_size, - ) - num_patches = patches_tensor.shape[0] - return num_patches - - -class CapsDatasetRoi(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - roi_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - roi_index: If a value is given the same region will be extracted for each image. - else the dataset will load all the regions possible for one image. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - """ - self.roi_index = roi_index - self.mode = "roi" - self.config = config - self.label_presence = label_presence - self.mask_paths, self.mask_arrays = self._get_mask_paths_and_tensors( - self.config.data.caps_directory, preprocessing_dict - ) - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.roi_index - - def __getitem__(self, idx): - participant, session, cohort, roi_idx, label, domain = self._get_meta_data(idx) - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.roi_list is None: - raise NotImplementedError( - "Default regions are not available anymore in ClinicaDL. " - "Please define appropriate masks and give a roi_list." - ) - - if self.config.extraction.save_features: - mask_path = self.mask_paths[roi_idx] - roi_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - roi_filename = extract_roi_path( - image_path, mask_path, self.config.extraction.roi_uncrop_output - ) - roi_tensor = torch.load(Path(roi_dir) / roi_filename, weights_only=True) - - else: - image = torch.load(image_path, weights_only=True) - mask_array = self.mask_arrays[roi_idx] - roi_tensor = extract_roi_tensor( - image, mask_array, self.config.extraction.uncropped_roi - ) - - train_trf, trf = self.config.transforms.get_transforms() - - roi_tensor = trf(roi_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - roi_tensor = train_trf(roi_tensor) - - sample = { - "image": roi_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "roi_id": roi_idx, - } - - return sample - - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - if self.config.extraction.roi_list is None: - return 2 else: - return len(self.config.extraction.roi_list) - - def _get_mask_paths_and_tensors( - self, - caps_directory: Path, - preprocessing_dict: Dict[str, Any], - ) -> Tuple[List[str], List]: - """Loads the masks necessary to regions extraction""" - import nibabel as nib - - caps_dict = self.config.data.caps_dict - if len(caps_dict) > 1: - caps_directory = caps_dict[next(iter(caps_dict))] - logger.warning( - f"The equality of masks is not assessed for multi-cohort training. " - f"The masks stored in {caps_directory} will be used." - ) - - try: - preprocessing_ = Preprocessing(preprocessing_dict["preprocessing"]) - except NotImplementedError: - print( - f"Template of preprocessing {preprocessing_dict['preprocessing']} " - f"is not defined." - ) - # Find template name and pattern - if preprocessing_.value == "custom": - template_name = preprocessing_dict["roi_custom_template"] - if template_name is None: - raise ValueError( - "Please provide a name for the template when preprocessing is `custom`." - ) - - pattern = preprocessing_dict["roi_custom_mask_pattern"] - if pattern is None: - raise ValueError( - "Please provide a pattern for the masks when preprocessing is `custom`." - ) - - else: - for template_ in Template: - if preprocessing_.name == template_.name: - template_name = template_ - - for pattern_ in Pattern: - if preprocessing_.name == pattern_.name: - pattern = pattern_ - - mask_location = caps_directory / "masks" / f"tpl-{template_name}" - - mask_paths, mask_arrays = list(), list() - for roi in self.config.extraction.roi_list: - logger.info(f"Find mask for roi {roi}.") - mask_path, desc = find_mask_path(mask_location, roi, pattern, True) - if mask_path is None: - raise FileNotFoundError(desc) - mask_nii = nib.loadsave.load(mask_path) - mask_paths.append(Path(mask_path)) - mask_arrays.append(mask_nii.get_fdata()) - - return mask_paths, mask_arrays - - -class CapsDatasetSlice(CapsDataset): - def __init__( - self, - config: CapsDatasetConfig, - preprocessing_dict: Dict[str, Any], - slice_index: Optional[int] = None, - label_presence: bool = True, - ): - """ - Args: - caps_directory: Directory of all the images. - data_file: Path to the tsv file or DataFrame containing the subject/session list. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - slice_index: If a value is given the same slice will be extracted for each image. - else the dataset will load all the slices possible for one image. - train_transformations: Optional transform to be applied only on training mode. - label_presence: If True the diagnosis will be extracted from the given DataFrame. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - all_transformations: Optional transform to be applied during training and evaluation. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - """ - self.slice_index = slice_index - self.mode = "slice" - self.config = config - self.label_presence = label_presence - self.preprocessing_dict = preprocessing_dict - super().__init__( - config=config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - @property - def elem_index(self): - return self.slice_index - - def __getitem__(self, idx): - participant, session, cohort, slice_idx, label, domain = self._get_meta_data( - idx + out = image + + sample = CapsDatasetOutput( + image=out, + # label=label, + participant_id=participant, + session_id=session, + image_id=index, + image_path=image_path, + mode=self.extraction.extract_method, ) - slice_idx = slice_idx + self.config.extraction.discarded_slices[0] - image_path = self._get_image_path(participant, session, cohort) - - if self.config.extraction.save_features: - slice_dir = image_path.parent.as_posix().replace( - "image_based", f"{self.mode}_based" - ) - slice_filename = extract_slice_path( - image_path, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, - slice_idx, - ) - slice_tensor = torch.load( - Path(slice_dir) / slice_filename, weights_only=True - ) - - else: - image_path = self._get_image_path(participant, session, cohort) - image = torch.load(image_path, weights_only=True) - slice_tensor = extract_slice_tensor( - image, - self.config.extraction.slice_direction, - self.config.extraction.slice_mode, - slice_idx, - ) - - train_trf, trf = self.config.transforms.get_transforms() - - slice_tensor = trf(slice_tensor) - - if self.config.transforms.train_transformations and not self.eval_mode: - slice_tensor = train_trf(slice_tensor) - - sample = { - "image": slice_tensor, - "label": label, - "participant_id": participant, - "session_id": session, - "slice_id": slice_idx, - } return sample - def num_elem_per_image(self): - if self.elem_index is not None: - return 1 - - if self.config.extraction.num_slices is not None: - return self.config.extraction.num_slices - - image = self._get_full_image() - return ( - image.size(int(self.config.extraction.slice_direction) + 1) - - self.config.extraction.discarded_slices[0] - - self.config.extraction.discarded_slices[1] - ) - - -def return_dataset( - input_dir: Path, - data_df: pd.DataFrame, - preprocessing_dict: Dict[str, Any], - transforms_config: TransformsConfig, - label: Optional[str] = None, - label_code: Optional[Dict[str, int]] = None, - cnn_index: Optional[int] = None, - label_presence: bool = True, - multi_cohort: bool = False, -) -> CapsDataset: - """ - Return appropriate Dataset according to given options. - Args: - input_dir: path to a directory containing a CAPS structure. - data_df: List subjects, sessions and diagnoses. - preprocessing_dict: preprocessing dict contained in the JSON file of prepare_data. - train_transformations: Optional transform to be applied during training only. - all_transformations: Optional transform to be applied during training and evaluation. - label: Name of the column in data_df containing the label. - label_code: label code that links the output node number to label value. - cnn_index: Index of the CNN in a multi-CNN paradigm (optional). - label_presence: If True the diagnosis will be extracted from the given DataFrame. - multi_cohort: If True caps_directory is the path to a TSV file linking cohort names and paths. - - Returns: - the corresponding dataset. - """ - if cnn_index is not None and preprocessing_dict["mode"] == "image": - raise NotImplementedError( - f"Multi-CNN is not implemented for {preprocessing_dict['mode']} mode." - ) - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - preprocessing_type=preprocessing_dict["preprocessing"], - preprocessing=preprocessing_dict["preprocessing"], - extraction=preprocessing_dict["mode"], - caps_directory=input_dir, - data_df=data_df, - label=label, - label_code=label_code, - multi_cohort=multi_cohort, - ) - config.transforms = transforms_config - - if preprocessing_dict["mode"] == "image": - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetImage( - config, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - - elif preprocessing_dict["mode"] == "patch": - assert isinstance(config.extraction, ExtractionPatchConfig) - config.extraction.patch_size = preprocessing_dict["patch_size"] - config.extraction.stride_size = preprocessing_dict["stride_size"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetPatch( - config, - patch_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, + def num_elem_per_image(self) -> int: + """Computes the number of elements per image based on the full image.""" + return self.extraction.num_elem_per_image( + elem_index=self.elem_index, image=self.image_0 ) - elif preprocessing_dict["mode"] == "roi": - assert isinstance(config.extraction, ExtractionROIConfig) - config.extraction.roi_list = preprocessing_dict["roi_list"] - config.extraction.roi_uncrop_output = preprocessing_dict["uncropped_roi"] - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetRoi( - config, - roi_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) + def eval(self): + """Put the dataset on evaluation mode (data augmentation is not performed).""" + self.eval_mode = True + return self - elif preprocessing_dict["mode"] == "slice": - assert isinstance(config.extraction, ExtractionSliceConfig) - config.extraction.slice_direction = SliceDirection( - str(preprocessing_dict["slice_direction"]) - ) - config.extraction.slice_mode = SliceMode(preprocessing_dict["slice_mode"]) - config.extraction.discarded_slices = compute_discarded_slices( - preprocessing_dict["discarded_slices"] - ) - config.extraction.num_slices = ( - None - if "num_slices" not in preprocessing_dict - else preprocessing_dict["num_slices"] - ) - config.extraction.save_features = preprocessing_dict["prepare_dl"] - config.preprocessing.use_uncropped_image = preprocessing_dict[ - "use_uncropped_image" - ] - return CapsDatasetSlice( - config, - slice_index=cnn_index, - label_presence=label_presence, - preprocessing_dict=preprocessing_dict, - ) - else: - raise NotImplementedError( - f"Mode {preprocessing_dict['mode']} is not implemented." - ) + def train(self): + """Put the dataset on training mode (data augmentation is performed).""" + self.eval_mode = False + return self diff --git a/clinicadl/dataset/caps_dataset_config.py b/clinicadl/dataset/caps_dataset_config.py deleted file mode 100644 index 0eac3ffd3..000000000 --- a/clinicadl/dataset/caps_dataset_config.py +++ /dev/null @@ -1,127 +0,0 @@ -from pathlib import Path -from typing import Optional, Tuple, Union - -from pydantic import BaseModel, ConfigDict - -from clinicadl.dataset.config import extraction -from clinicadl.dataset.config.preprocessing import ( - CustomPreprocessingConfig, - DTIPreprocessingConfig, - FlairPreprocessingConfig, - PETPreprocessingConfig, - PreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.data_config import DataConfig -from clinicadl.dataset.dataloader_config import DataLoaderConfig -from clinicadl.dataset.utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, -) -from clinicadl.transforms.config import TransformsConfig -from clinicadl.utils.enum import ExtractionMethod, Preprocessing -from clinicadl.utils.iotools.clinica_utils import FileType - - -def get_extraction(extract_method: ExtractionMethod): - if extract_method == ExtractionMethod.ROI: - return extraction.ExtractionROIConfig - elif extract_method == ExtractionMethod.SLICE: - return extraction.ExtractionSliceConfig - elif extract_method == ExtractionMethod.IMAGE: - return extraction.ExtractionImageConfig - elif extract_method == ExtractionMethod.PATCH: - return extraction.ExtractionPatchConfig - else: - raise ValueError(f"Preprocessing {extract_method.value} is not implemented.") - - -def get_preprocessing(preprocessing_type: Preprocessing): - if preprocessing_type == Preprocessing.T1_LINEAR: - return T1PreprocessingConfig - elif preprocessing_type == Preprocessing.PET_LINEAR: - return PETPreprocessingConfig - elif preprocessing_type == Preprocessing.FLAIR_LINEAR: - return FlairPreprocessingConfig - elif preprocessing_type == Preprocessing.CUSTOM: - return CustomPreprocessingConfig - elif preprocessing_type == Preprocessing.DWI_DTI: - return DTIPreprocessingConfig - else: - raise ValueError( - f"Preprocessing {preprocessing_type.value} is not implemented." - ) - - -class CapsDatasetConfig(BaseModel): - """Config class for CapsDataset object. - - caps_directory, preprocessing_json, extract_method, preprocessing - are arguments that must be passed by the user. - - transforms isn't optional because there is always at least one transform (NanRemoval) - """ - - data: DataConfig - dataloader: DataLoaderConfig - extraction: extraction.ExtractionConfig - preprocessing: PreprocessingConfig - transforms: TransformsConfig - - # pydantic config - model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) - - @classmethod - def from_preprocessing_and_extraction_method( - cls, - preprocessing_type: Union[str, Preprocessing], - extraction: Union[str, ExtractionMethod], - **kwargs, - ): - return cls( - data=DataConfig(**kwargs), - dataloader=DataLoaderConfig(**kwargs), - preprocessing=get_preprocessing(Preprocessing(preprocessing_type))( - **kwargs - ), - extraction=get_extraction(ExtractionMethod(extraction))(**kwargs), - transforms=TransformsConfig(**kwargs), - ) - - def compute_folder_and_file_type( - self, from_bids: Optional[Path] = None - ) -> Tuple[str, FileType]: - preprocessing = self.preprocessing.preprocessing - if from_bids is not None: - if isinstance(self.preprocessing, CustomPreprocessingConfig): - mod_subfolder = Preprocessing.CUSTOM.value - file_type = FileType( - pattern=f"*{self.preprocessing.custom_suffix}", - description="Custom suffix", - ) - else: - mod_subfolder = preprocessing - file_type = bids_nii(self.preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if isinstance(self.preprocessing, T1PreprocessingConfig) or isinstance( - self.preprocessing, FlairPreprocessingConfig - ): - file_type = linear_nii(self.preprocessing) - elif isinstance(self.preprocessing, PETPreprocessingConfig): - file_type = pet_linear_nii(self.preprocessing) - elif isinstance(self.preprocessing, DTIPreprocessingConfig): - file_type = dwi_dti(self.preprocessing) - elif isinstance(self.preprocessing, CustomPreprocessingConfig): - file_type = FileType( - pattern=f"*{self.preprocessing.custom_suffix}", - description="Custom suffix", - ) - return mod_subfolder, file_type diff --git a/clinicadl/dataset/caps_dataset_utils.py b/clinicadl/dataset/caps_dataset_utils.py deleted file mode 100644 index b54ba373d..000000000 --- a/clinicadl/dataset/caps_dataset_utils.py +++ /dev/null @@ -1,193 +0,0 @@ -import json -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig -from clinicadl.dataset.config.preprocessing import ( - CustomPreprocessingConfig, - DTIPreprocessingConfig, - FlairPreprocessingConfig, - PETPreprocessingConfig, - T1PreprocessingConfig, -) -from clinicadl.dataset.utils import ( - bids_nii, - dwi_dti, - linear_nii, - pet_linear_nii, -) -from clinicadl.utils.enum import Preprocessing -from clinicadl.utils.exceptions import ClinicaDLArgumentError -from clinicadl.utils.iotools.clinica_utils import FileType - - -def compute_folder_and_file_type( - config: CapsDatasetConfig, from_bids: Optional[Path] = None -) -> Tuple[str, FileType]: - preprocessing = config.preprocessing.preprocessing - if from_bids is not None: - if isinstance(config.preprocessing, CustomPreprocessingConfig): - mod_subfolder = Preprocessing.CUSTOM.value - file_type = FileType( - pattern=f"*{config.preprocessing.custom_suffix}", - description="Custom suffix", - ) - else: - mod_subfolder = preprocessing - file_type = bids_nii(config.preprocessing) - - elif preprocessing not in Preprocessing: - raise NotImplementedError( - f"Extraction of preprocessing {preprocessing} is not implemented from CAPS directory." - ) - else: - mod_subfolder = preprocessing.value.replace("-", "_") - if isinstance(config.preprocessing, T1PreprocessingConfig) or isinstance( - config.preprocessing, FlairPreprocessingConfig - ): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - file_type = pet_linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, DTIPreprocessingConfig): - file_type = dwi_dti(config.preprocessing) - elif isinstance(config.preprocessing, CustomPreprocessingConfig): - file_type = FileType( - pattern=f"*{config.preprocessing.custom_suffix}", - description="Custom suffix", - ) - return mod_subfolder, file_type - - -def find_file_type(config: CapsDatasetConfig) -> FileType: - if isinstance(config.preprocessing, T1PreprocessingConfig): - file_type = linear_nii(config.preprocessing) - elif isinstance(config.preprocessing, PETPreprocessingConfig): - if ( - config.preprocessing.tracer is None - or config.preprocessing.suvr_reference_region is None - ): - raise ClinicaDLArgumentError( - "`tracer` and `suvr_reference_region` must be defined " - "when using `pet-linear` preprocessing." - ) - file_type = pet_linear_nii(config.preprocessing) - else: - raise NotImplementedError( - f"Generation of synthetic data is not implemented for preprocessing {config.preprocessing.preprocessing.value}" - ) - - return file_type - - -def read_json(json_path: Path) -> Dict[str, Any]: - """ - Ensures retro-compatibility between the different versions of ClinicaDL. - - Parameters - ---------- - json_path: Path - path to the JSON file summing the parameters of a MAPS. - - Returns - ------- - A dictionary of training parameters. - """ - from clinicadl.utils.iotools.utils import path_decoder - - with json_path.open(mode="r") as f: - parameters = json.load(f, object_hook=path_decoder) - # Types of retro-compatibility - # Change arg name: ex network --> model - # Change arg value: ex for preprocessing: mni --> t1-extensive - # New arg with default hard-coded value --> discarded_slice --> 20 - retro_change_name = { - "model": "architecture", - "multi": "multi_network", - "minmaxnormalization": "normalize", - "num_workers": "n_proc", - "mode": "extract_method", - } - - retro_add = { - "optimizer": "Adam", - "loss": None, - } - - for old_name, new_name in retro_change_name.items(): - if old_name in parameters: - parameters[new_name] = parameters[old_name] - del parameters[old_name] - - for name, value in retro_add.items(): - if name not in parameters: - parameters[name] = value - - if "extract_method" in parameters: - parameters["mode"] = parameters["extract_method"] - # Value changes - if "use_cpu" in parameters: - parameters["gpu"] = not parameters["use_cpu"] - del parameters["use_cpu"] - if "nondeterministic" in parameters: - parameters["deterministic"] = not parameters["nondeterministic"] - del parameters["nondeterministic"] - - # Build preprocessing_dict - if "preprocessing_dict" not in parameters: - parameters["preprocessing_dict"] = {"mode": parameters["mode"]} - preprocessing_options = [ - "preprocessing", - "use_uncropped_image", - "prepare_dl", - "custom_suffix", - "tracer", - "suvr_reference_region", - "patch_size", - "stride_size", - "slice_direction", - "slice_mode", - "discarded_slices", - "roi_list", - "uncropped_roi", - "roi_custom_suffix", - "roi_custom_template", - "roi_custom_mask_pattern", - ] - for preprocessing_var in preprocessing_options: - if preprocessing_var in parameters: - parameters["preprocessing_dict"][preprocessing_var] = parameters[ - preprocessing_var - ] - del parameters[preprocessing_var] - - # Add missing parameters in previous version of extract - if "use_uncropped_image" not in parameters["preprocessing_dict"]: - parameters["preprocessing_dict"]["use_uncropped_image"] = False - - if ( - "prepare_dl" not in parameters["preprocessing_dict"] - and parameters["mode"] != "image" - ): - parameters["preprocessing_dict"]["prepare_dl"] = False - - if ( - parameters["mode"] == "slice" - and "slice_mode" not in parameters["preprocessing_dict"] - ): - parameters["preprocessing_dict"]["slice_mode"] = "rgb" - - if "preprocessing" not in parameters: - parameters["preprocessing"] = parameters["preprocessing_dict"]["preprocessing"] - - from clinicadl.dataset.caps_dataset_config import CapsDatasetConfig - - config = CapsDatasetConfig.from_preprocessing_and_extraction_method( - extraction=parameters["mode"], - preprocessing_type=parameters["preprocessing"], - **parameters, - ) - if "file_type" not in parameters["preprocessing_dict"]: - _, file_type = compute_folder_and_file_type(config) - parameters["preprocessing_dict"]["file_type"] = file_type.model_dump() - - return parameters diff --git a/clinicadl/dataset/caps_reader.py b/clinicadl/dataset/caps_reader.py index 14199616e..36f8020b1 100644 --- a/clinicadl/dataset/caps_reader.py +++ b/clinicadl/dataset/caps_reader.py @@ -1,7 +1,18 @@ +import json +from abc import abstractmethod +from enum import Enum +from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Optional, Union + +import nibabel as nib +import pandas as pd +import torch +from joblib import Parallel, delayed +from torch import save as save_tensor from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.dataset.concat import ConcatDataset from clinicadl.dataset.config.extraction import ( ExtractionConfig, ExtractionImageConfig, @@ -9,54 +20,369 @@ ExtractionROIConfig, ExtractionSliceConfig, ) -from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.config.preprocessing import ( + CustomPreprocessingConfig, + DTIPreprocessingConfig, + FlairPreprocessingConfig, + PETPreprocessingConfig, + PreprocessingConfig, + T1PreprocessingConfig, + T2PreprocessingConfig, +) +from clinicadl.dataset.config.utils import ( + get_preprocessing, + get_preprocessing_and_mode_from_json, +) from clinicadl.experiment_manager.experiment_manager import ExperimentManager -from clinicadl.transforms.config import TransformsConfig +from clinicadl.transforms.transforms import Transforms +from clinicadl.utils.enum import ( + DTIMeasure, + DTISpace, + Preprocessing, + SliceDirection, + SliceMode, + SubFolder, + Suffix, + SUVRReferenceRegions, + Tracer, +) +from clinicadl.utils.exceptions import ClinicaDLArgumentError, ClinicaDLTSVError +from clinicadl.utils.iotools.clinica_utils import ( + check_caps_folder, + clinicadl_file_reader, + container_from_filename, + create_subs_sess_list, + determine_caps_or_bids, + get_subject_session_list, +) +from clinicadl.utils.iotools.utils import path_encoder +logger = getLogger("clinicadl.caps_reader") -class CapsReader: - def __init__(self, caps_directory: Path, manager: ExperimentManager): + +class Reader: + def __init__(self, input_dir: Path, bids: bool) -> None: + self.input_directory = input_dir + self.bids = bids + + def get_preprocessing( + self, preprocessing: Union[str, Preprocessing] + ) -> PreprocessingConfig: """TO COMPLETE""" - pass + + preprocessing_ = Preprocessing(preprocessing) + print(preprocessing_) + subjects, sessions = get_subject_session_list( + input_dir=self.input_directory, is_bids_dir=self.bids + ) + if ( + self.input_directory + / "subjects" + / subjects[0] + / sessions[0] + / (preprocessing_.value).replace("-", "_") + ).is_dir(): + preprocessing_config = get_preprocessing(preprocessing_)() + preprocessing_config.from_bids = self.bids + pattern = preprocessing_config.file_type.pattern + + def get_value(enum, pattern: str): + for value in enum: + if value.value in pattern: + return value + raise ValueError( + f"We can't find a value matching {[e.value for e in enum]} in the pattern {pattern}" + ) + + if isinstance(preprocessing_config, PETPreprocessingConfig): + preprocessing_config.tracer = get_value(Tracer, pattern) + preprocessing_config.suvr_reference_region = get_value( + SUVRReferenceRegions, pattern + ) + + elif isinstance(preprocessing_config, DTIPreprocessingConfig): + preprocessing_config.dti_measure = get_value(DTIMeasure, pattern) + preprocessing_config.dti_space = get_value(DTISpace, pattern) + + elif isinstance(preprocessing_config, CustomPreprocessingConfig): + # TODO: add something to find the custom pattern + pass + else: + raise FileNotFoundError( + f"The preprocessing folder {preprocessing} does not exist." + ) + return preprocessing_config + + def get_preprocessing_and_extraction_from_json(self, preprocessing_json: Path): + if not preprocessing_json.is_file(): + raise FileNotFoundError( + f"The provided preprocessing JSON file {preprocessing_json} does not exist." + ) + + preprocessing, extraction = get_preprocessing_and_mode_from_json( + preprocessing_json + ) + return preprocessing, extraction + + +class CapsReader(Reader): + def __init__( + self, + caps_directory: Path, + # manager: Optional[ExperimentManager], + from_bids: Optional[Path] = None, + ): + """TO COMPLETE""" + + # self.manager = manager + self._get_input_directory(caps_directory, from_bids) + + def create_caps_json(self): + caps_json = self.input_directory / "caps.json" + if caps_json.is_file: + with open(caps_json, "a") as f: + caps_data = json.load(f) + return caps_data + + else: + with open(caps_json, "w") as f: + f.write("tests") + caps_data = json.load(f) + return caps_data + + def _get_input_directory( + self, caps_directory: Path, from_bids: Optional[Path] = None + ): + # Get subject and session list + if from_bids is not None: + try: + self.input_directory = Path(from_bids) + except ClinicaDLArgumentError: + logger.warning("Your BIDS directory doesn't exist.") + logger.debug(f"BIDS directory: {self.input_directory}.") + self.bids = True + else: + self.input_directory = caps_directory + check_caps_folder(self.input_directory) + logger.debug(f"CAPS directory: {self.input_directory}.") + self.bids = False + + def prepare_extraction( + self, preprocessing: PreprocessingConfig, data_tsv: Optional[Path] = None + ): + subjects, sessions = get_subject_session_list( + self.input_directory, data_tsv, self.bids, False, None + ) + logger.debug(f"List of subjects: \n{subjects}.") + logger.debug(f"List of sessions: \n{sessions}.") + + file_type = preprocessing.get_filetype() + + input_files = clinicadl_file_reader( + subjects, sessions, self.input_directory, file_type.model_dump() + )[0] + logger.debug(f"Selected image file name list: {input_files}.") + + return input_files + + def prepare_data( + self, + preprocessing: PreprocessingConfig, + data_tsv: Optional[Path] = None, + n_proc: int = 2, + use_uncropped_images: bool = False, + ): + """TO COMPLETE""" + + input_files = self.prepare_extraction(preprocessing, data_tsv=data_tsv) + + def prepare_image(file: Path): + output_file_dir = ( + self.input_directory + / container_from_filename(file) + / "deeplearning_prepare_data" + / SubFolder.IMAGE.value + / preprocessing.compute_folder(self.bids) + ) + + output_file_dir.mkdir(parents=True, exist_ok=True) + output_file = output_file_dir / file.name.replace( + Suffix.NII_GZ.value, Suffix.PT.value + ) + + logger.debug(f"Processing of {file}.") + image_array = nib.loadsave.load(file).get_fdata(dtype="float32") # type: ignore + + # get some important infos about the image + info_df = pd.DataFrame(columns=["mean", "std", "max", "min"]) + info_df.loc[0] = [ + image_array.mean(), + image_array.std(), + image_array.max(), + image_array.min(), + ] + info_df.to_csv("image_info.tsv", sep="\t") + + # extract and save the image tensor + image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() + save_tensor(image_tensor.clone(), output_file) + logger.debug(f"Output tensor saved at {output_file}") + + Parallel(n_jobs=n_proc)(delayed(prepare_image)(file) for file in input_files) + + def write_output_imgs( + self, + output_mode: list, + file: Path, + subfolder: SubFolder, + preprocessing: PreprocessingConfig, + ): + # Write the extracted tensor on a .pt file + container = container_from_filename(file) + mod_subfolder = preprocessing.compute_folder(self.bids) + + for filename, tensor in output_mode: + output_file_dir = ( + self.input_directory + / container + / "deeplearning_prepare_data" + / subfolder.value + / mod_subfolder + ) + output_file_dir.mkdir(parents=True, exist_ok=True) + output_file = output_file_dir / filename + save_tensor(tensor, output_file) + logger.debug(f"Output tensor saved at {output_file}") + + def write_preprocessing( + self, preprocessing: PreprocessingConfig, extraction: ExtractionConfig + ) -> Path: + extract_dir = self.input_directory / "tensor_extraction" + extract_dir.mkdir(parents=True, exist_ok=True) + + json_path = extract_dir / extraction.extract_json + + if json_path.is_file(): + raise FileExistsError( + f"JSON file at {json_path} already exists. " + f"Please choose another name for your preprocessing file." + ) + + preprocessing_dict = preprocessing.model_dump() + preprocessing_dict.update(extraction.model_dump()) + + with json_path.open(mode="w") as json_file: + json.dump(preprocessing_dict, json_file, default=path_encoder) + return json_path + + def get_dataset_from_json( + self, json_path: Path, sub_ses_tsv: Optional[Path] = None + ): + preprocessing, transforms = self.get_preprocessing_and_extraction_from_json( + json_path + ) # we need to add the transforms infos in the caps.json + + return self.get_dataset( + preprocessing=preprocessing, transforms=transforms, sub_ses_tsv=sub_ses_tsv + ) def get_dataset( self, - extraction: ExtractionConfig, preprocessing: PreprocessingConfig, - sub_ses_tsv: Path, - transforms: TransformsConfig, + sub_ses_tsv: Optional[Path] = None, + transforms: Optional[Transforms] = None, ) -> CapsDataset: - return CapsDataset(extraction, preprocessing, sub_ses_tsv, transforms) - - def get_preprocessing(self, preprocessing: str) -> PreprocessingConfig: """TO COMPLETE""" - return PreprocessingConfig() + if sub_ses_tsv is None: + sub_ses_tsv = create_subs_sess_list( + self.input_directory, output_dir=self.input_directory + ) + elif not sub_ses_tsv.is_file(): + raise FileNotFoundError( + f"The provided sub_ses_tsv file {sub_ses_tsv} does not exist." + ) - def extract_slice( - self, preprocessing: PreprocessingConfig, arg_slice: Optional[int] = None - ) -> ExtractionSliceConfig: - """TO COMPLETE""" + data_df = pd.read_csv( + sub_ses_tsv, sep="\t" + ) # create function to check if we have the part and sess columns and to read the csv - return ExtractionSliceConfig() + if transforms is None: + logger.info( + "No transforms was provided. We will use the default transforms. Check the documentation for more information" + ) + transforms = Transforms(extraction=ExtractionImageConfig()) - def extract_patch( - self, preprocessing: PreprocessingConfig, arg_patch: Optional[int] = None - ) -> ExtractionPatchConfig: - """TO COMPLETE""" + # if isinstance(transforms.extraction, ExtractionImageConfig): + # dataset_type = CapsDatasetImage - return ExtractionPatchConfig() + # elif isinstance(transforms.extraction, ExtractionSliceConfig): + # dataset_type = CapsDatasetSlice - def extract_roi( - self, preprocessing: PreprocessingConfig, arg_roi: Optional[int] = None - ) -> ExtractionROIConfig: - """TO COMPLETE""" + # elif isinstance(transforms.extraction, ExtractionPatchConfig): + # dataset_type = CapsDatasetPatch + + # elif isinstance(transforms.extraction, ExtractionROIConfig): + # dataset_type = CapsDatasetRoi + + # else: + # raise NotImplementedError( + # f"Mode {transforms.extraction.extract_method.value} is not implemented." + # ) - return ExtractionROIConfig() + return CapsDataset( + caps_directory=self.input_directory, + preprocessing=preprocessing, + data_df=data_df, + transforms=transforms, + ) - def extract_image( - self, preprocessing: PreprocessingConfig, arg_image: Optional[int] = None - ) -> ExtractionImageConfig: + +class CapsMultiReader(Reader): + def __init__( + self, + caps_directory: Path, + # manager: Optional[ExperimentManager], + from_bids: Optional[Path] = None, + ): """TO COMPLETE""" + self.caps_list = [] + + if not caps_directory.is_file(): + raise FileNotFoundError( + f"The provided caps directory {caps_directory} does not exist. Careful: It must be a tsv file in multi-cohort." + ) + + caps_df = pd.read_csv(caps_directory, sep="\t") + + if not set(("cohort", "path")).issubset(caps_df.columns.values): + raise ClinicaDLTSVError( + "Columns of the TSV file used for CAPS location must include cohort and path" + ) + + for idx in range(len(caps_df)): + # cohort_name = caps_df.at[idx, "cohort"] + cohort_path = Path(caps_df.at[idx, "path"]) + + caps_reader = CapsReader(caps_directory=cohort_path) + self.caps_list.append(caps_reader) + + self.input_directory = caps_reader.input_directory + self.bids = caps_reader.bids + + def get_dataset( + self, + preprocessing: PreprocessingConfig, + sub_ses_tsv: Optional[Path] = None, + transforms: Optional[Transforms] = None, + ) -> ConcatDataset: + dataset_list = [] + for caps_reader in self.caps_list: + dataset = caps_reader.get_dataset( + preprocessing=preprocessing, + sub_ses_tsv=sub_ses_tsv, + transforms=transforms, + ) + dataset_list.append(dataset) - return ExtractionImageConfig() + return ConcatDataset(dataset_list) diff --git a/clinicadl/dataset/concat.py b/clinicadl/dataset/concat.py index f0b420dfe..25199ecfc 100644 --- a/clinicadl/dataset/concat.py +++ b/clinicadl/dataset/concat.py @@ -1,6 +1,132 @@ +# coding: utf8 +# TODO: create a folder for generate/ prepare_data/ data to deal with capsDataset objects ? +import abc +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from pydantic import BaseModel +from torch.utils.data import Dataset + from clinicadl.dataset.caps_dataset import CapsDataset +from clinicadl.dataset.config.extraction import ( + ExtractionConfig, + ExtractionImageConfig, + ExtractionPatchConfig, + ExtractionROIConfig, + ExtractionSliceConfig, +) +from clinicadl.dataset.config.preprocessing import PreprocessingConfig +from clinicadl.dataset.config.utils import ( + get_preprocessing_and_mode_from_json, +) +from clinicadl.dataset.utils import CapsDatasetOutput +from clinicadl.transforms.config import TransformsConfig +from clinicadl.utils.enum import ( + ExtractionMethod, + Pattern, + Preprocessing, + SliceDirection, + SliceMode, + Template, +) +from clinicadl.utils.exceptions import ( + ClinicaDLCAPSError, + ClinicaDLConcatError, + ClinicaDLTSVError, +) +from clinicadl.utils.iotools.clinica_utils import check_caps_folder +from clinicadl.utils.iotools.utils import path_decoder, read_json + +logger = getLogger("clinicadl") class ConcatDataset(CapsDataset): - def __init__(self, list_: list[CapsDataset]): - """TO COMPLETE""" + def __init__(self, datasets: List[CapsDataset]): + self._datasets = datasets + self._len = sum(len(dataset) for dataset in datasets) + self._indexes = [] + + # Calculate distribution of indexes in all datasets + cumulative_index = 0 + for idx, dataset in enumerate(datasets): + next_cumulative_index = cumulative_index + len(dataset) + self._indexes.append((cumulative_index, next_cumulative_index, idx)) + cumulative_index = next_cumulative_index + + logger.debug(f"Datasets summary length: {self._len}") + logger.debug(f"Datasets indexes: {self._indexes}") + + self.caps_dict = self.compute_caps_dict() + self.check_configs() + + self.eval_mode = False + + def __getitem__(self, index: int) -> Optional[CapsDatasetOutput]: + for start, stop, dataset_index in self._indexes: + if start <= index < stop: + dataset = self._datasets[dataset_index] + return dataset[index - start] + + def __len__(self) -> int: + return self._len + + def check_configs(self): + extraction = self._datasets[len(self._datasets) - 1].extraction + preprocessing = self._datasets[len(self._datasets) - 1].preprocessing + transforms = self._datasets[len(self._datasets) - 1].transforms + size = self._datasets[len(self._datasets) - 1].size + elem_per_image = self._datasets[len(self._datasets) - 1].elem_per_image + + for idx in range(len(self._datasets) - 1): + if self._datasets[idx].extraction != extraction: + raise ClinicaDLConcatError( + f"Different extraction modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].extraction}, " + f"Dataset {len(self._datasets)}: {extraction}" + ) + + if self._datasets[idx].preprocessing != preprocessing: + raise ClinicaDLConcatError( + f"Different preprocessing modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].preprocessing}, " + f"Dataset {len(self._datasets)}: {preprocessing}" + ) + + if self._datasets[idx].transforms != transforms: + raise ClinicaDLConcatError( + f"Different transforms modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].transforms}, " + f"Dataset {len(self._datasets)}: {transforms}" + ) + if self._datasets[idx].size != size: + raise ClinicaDLConcatError( + f"Different size modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].size}, " + f"Dataset {len(self._datasets)}: {size}" + ) + if self._datasets[idx].elem_per_image != elem_per_image: + raise ClinicaDLConcatError( + f"Different elem_per_image modes found in datasets. " + f"Dataset {idx+1}: {self._datasets[idx].elem_per_image}, " + f"Dataset {len(self._datasets)}: {elem_per_image}" + ) + + self.extraction = extraction + self.preprocessing = preprocessing + self.transforms = transforms + self.size = size + self.elem_per_image = elem_per_image + + def compute_caps_dict(self) -> Dict[str, Path]: + caps_dict = dict() + for idx in range(len(self._datasets)): + cohort = idx + caps_path = self._datasets[idx].caps_dict["caps_directory"] + check_caps_folder(caps_path) + caps_dict[cohort] = caps_path + + return caps_dict diff --git a/clinicadl/dataset/config/data.py b/clinicadl/dataset/config/data.py new file mode 100644 index 000000000..7ec3e3481 --- /dev/null +++ b/clinicadl/dataset/config/data.py @@ -0,0 +1,74 @@ +from logging import getLogger +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import pandas as pd +from pydantic import BaseModel, ConfigDict, field_validator + +from clinicadl.dataset.data_utils import load_data_test +from clinicadl.utils.exceptions import ( + ClinicaDLArgumentError, + ClinicaDLTSVError, +) + +logger = getLogger("clinicadl.data_config") + + +class DataConfig(BaseModel): # TODO : put in data module + """Config class to specify the data. + + caps_directory and preprocessing_json are arguments + that must be passed by the user. + """ + + caps_directory: Optional[Path] = None + baseline: bool = False + mask_path: Optional[Path] = None + data_tsv: Optional[Path] = None + n_subjects: int = 300 + # pydantic config + model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True) + + @field_validator("diagnoses", mode="before") + def validator_diagnoses(cls, v): + """Transforms a list to a tuple.""" + if isinstance(v, list): + return tuple(v) + return v # TODO : check if columns are in tsv + + def create_groupe_df(self): + group_df = None + if self.data_tsv is not None and self.data_tsv.is_file(): + group_df = load_data_test( + self.data_tsv, + multi_cohort=self.multi_cohort, + ) + return group_df + + def is_given_label_code(self, _label: str, _label_code: Union[str, Dict[str, int]]): + return ( + self.label is not None + and self.label != "" + and self.label != _label + and _label_code == "default" + ) + + def check_label(self, _label: str): + if not self.label: + self.label = _label + + @field_validator("data_tsv", mode="before") + @classmethod + def check_data_tsv(cls, v) -> Path: + if v is not None: + if not isinstance(v, Path): + v = Path(v) + if not v.is_file(): + raise ClinicaDLTSVError( + "The participants_list you gave is not a file. Please give an existing file." + ) + if v.stat().st_size == 0: + raise ClinicaDLTSVError( + "The participants_list you gave is empty. Please give a non-empty file." + ) + return v diff --git a/clinicadl/dataset/config/extraction.py b/clinicadl/dataset/config/extraction.py index f3619590f..96626c33f 100644 --- a/clinicadl/dataset/config/extraction.py +++ b/clinicadl/dataset/config/extraction.py @@ -1,15 +1,25 @@ +from abc import ABC, abstractmethod from logging import getLogger +from pathlib import Path from time import time -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union +import nibabel as nib +import numpy as np +import torch from pydantic import BaseModel, ConfigDict, field_validator from pydantic.types import NonNegativeInt from clinicadl.utils.enum import ( ExtractionMethod, + Pattern, + Preprocessing, SliceDirection, SliceMode, + Suffix, + Template, ) +from clinicadl.utils.exceptions import ClinicaDLArgumentError from clinicadl.utils.iotools.clinica_utils import FileType logger = getLogger("clinicadl.preprocessing_config") @@ -21,47 +31,590 @@ class ExtractionConfig(BaseModel): """ extract_method: ExtractionMethod - file_type: Optional[FileType] = None + extract_json: str = f"extract_{int(time())}.json" + use_uncropped_image: bool = True save_features: bool = False - extract_json: Optional[str] = None # pydantic config model_config = ConfigDict(validate_assignment=True) @field_validator("extract_json", mode="before") def compute_extract_json(cls, v: str): - if v is None: - return f"extract_{int(time())}.json" + if isinstance(v, Path): + v = str(v) elif not v.endswith(".json"): - return f"{v}.json" - else: - return v + v = f"{v}.json" + return v + + def extract_image(self, input_img: Path) -> torch.Tensor: + image_array = nib.loadsave.load(input_img).get_fdata(dtype="float32") # type: ignore + image_tensor = torch.from_numpy(image_array).unsqueeze(0).float() + return image_tensor + + @abstractmethod + def extract_tensor( + self, + image_tensor: torch.Tensor, + index: int, + object_tensors: Optional[torch.Tensor] = None, + ): + pass + + @abstractmethod + def extract_path(self, image_path, index): + pass + + @abstractmethod + def extract(self, nii_path: Path): + pass + + @abstractmethod + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + pass class ExtractionImageConfig(ExtractionConfig): extract_method: ExtractionMethod = ExtractionMethod.IMAGE + def extract(self, nii_path: Path) -> list[Tuple[Path, torch.Tensor]]: + """Extract the images + This function convert nifti image to tensor (.pt) version of the image. + Tensor version is saved at the same location than input_img. + Args: + input_img: path to the NifTi input image. + Returns: + filename (str): single tensor file saved on the disk. Same location than input file. + """ + + image_tensor = self.extract_image(nii_path) + + # make sure the tensor type is torch.float32 + output_file = ( + Path(nii_path.name.replace(Suffix.NII_GZ.value, Suffix.PT.value)), + image_tensor.clone(), + ) + + return [output_file] + + def extract_tensor( + self, + image_tensor: torch.Tensor, + index: int, + object_tensors: Optional[torch.Tensor] = None, + ): + return image_tensor + + def extract_path(self, image_path, index): + return image_path + + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + return 1 + class ExtractionPatchConfig(ExtractionConfig): patch_size: int = 50 stride_size: int = 50 extract_method: ExtractionMethod = ExtractionMethod.PATCH + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + if elem_index is not None: + return 1 + + patches_tensor = self.create_patches(image) + num_patches = patches_tensor.shape[0] + return num_patches + + def extract( + self, + nii_path: Path, + ) -> List[Tuple[Path, torch.Tensor]]: + """Extracts the patches + This function extracts patches form the preprocessed nifti image. Patch size + if provided as input and also the stride size. If stride size is smaller + than the patch size an overlap exist between consecutive patches. If stride + size is equal to path size there is no overlap. Otherwise, unprocessed + zones can exits. + Args: + nii_path: path to the NifTi input image. + self.patch_size: size of a single patch. + self.stride_size: size of the stride leading to next patch. + Returns: + list of tuples containing the path to the extracted patch + and the tensor of the corresponding patch. + """ + + image_tensor = self.extract_image(nii_path) + patches_tensor = self.create_patches(image_tensor) + + patch_list = [] + for patch_index in range(patches_tensor.shape[0]): + patch_tensor = self.extract_tensor( + image_tensor, patch_index, patches_tensor + ) + patch_path = self.extract_path(nii_path, patch_index) + + patch_list.append((patch_path, patch_tensor)) + + return patch_list + + def create_patches(self, image_tensor: torch.Tensor) -> torch.Tensor: + patches_tensor = ( + image_tensor.unfold(1, self.patch_size, self.stride_size) + .unfold(2, self.patch_size, self.stride_size) + .unfold(3, self.patch_size, self.stride_size) + .contiguous() + ) + # the dimension of patches_tensor is [1, patch_num1, patch_num2, patch_num3, self.patch_size1, self.patch_size2, self.patch_size3] + return patches_tensor.view( + -1, self.patch_size, self.patch_size, self.patch_size + ) + + def extract_tensor( + self, + image_tensor: torch.Tensor, + patch_index: int, + patches_tensor: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Extracts a single patch from image_tensor""" + + patches_tensor = self.create_patches(image_tensor) + + return patches_tensor[patch_index, ...].unsqueeze_(0).clone() + + def extract_path(self, img_path: Path, patch_index: int) -> Path: + input_img_filename = img_path.name + txt_idx = input_img_filename.rfind("_") + it_filename_prefix = input_img_filename[0:txt_idx] + it_filename_suffix = input_img_filename[txt_idx:] + it_filename_suffix = it_filename_suffix.replace( + Suffix.NII_GZ.value, Suffix.PT.value + ) + + return Path( + f"{it_filename_prefix}_patchsize-{self.patch_size}_stride-{self.stride_size}_patch-{patch_index}{it_filename_suffix}" + ) + class ExtractionSliceConfig(ExtractionConfig): slice_direction: SliceDirection = SliceDirection.SAGITTAL slice_mode: SliceMode = SliceMode.RGB - num_slices: Optional[NonNegativeInt] = None - discarded_slices: Tuple[NonNegativeInt, NonNegativeInt] = (0, 0) + # num_slices: Optional[NonNegativeInt] = None # not sure it is needed + discarded_slices: Tuple = (0,) extract_method: ExtractionMethod = ExtractionMethod.SLICE + @field_validator("slice_direction", mode="before") + def check_slice_direction(cls, v: str): + if isinstance(v, int): + return SliceDirection(str(v)) + + @field_validator("discarded_slices", mode="before") + def compute_discarded_slice(cls, v: Union[int, Tuple]) -> Tuple[int, int]: + if isinstance(v, int): + begin_discard, end_discard = v, v + elif len(v) == 1: + begin_discard, end_discard = ( + v[0], + v[0], + ) + elif len(v) == 2: + begin_discard, end_discard = ( + v[0], + v[1], + ) + else: + raise IndexError( + f"Maximum two number of discarded slices can be defined. " + f"You gave discarded slices = {v}." + ) + return (begin_discard, end_discard) + + # DONE in extraction + + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + if elem_index is not None: + return 1 + + # if self.num_slices is not None: + # return self.num_slices + + return ( + image.size(int(self.slice_direction) + 1) + - self.discarded_slices[0] + - self.discarded_slices[1] + ) + + def compute_discarded_slices(self) -> Tuple[int, int]: + if isinstance(self.discarded_slices, int): + begin_discard, end_discard = self.discarded_slices, self.discarded_slices + elif len(self.discarded_slices) == 1: + begin_discard, end_discard = ( + self.discarded_slices[0], + self.discarded_slices[0], + ) + elif len(self.discarded_slices) == 2: + begin_discard, end_discard = ( + self.discarded_slices[0], + self.discarded_slices[1], + ) + else: + raise IndexError( + f"Maximum two number of discarded slices can be defined. " + f"You gave discarded slices = {self.discarded_slices}." + ) + return begin_discard, end_discard + + def extract( + self, + nii_path: Path, + ) -> List[Tuple[str, torch.Tensor]]: + """Extracts the slices from three directions + This function extracts slices form the preprocessed nifti image. + + The direction of extraction can be defined either on sagittal direction (0), + coronal direction (1) or axial direction (other). + + The output slices can be stored following two modes: + single (1 channel) or rgb (3 channels, all the same). + + Args: + nii_path: path to the NifTi input image. + slice_direction: along which axis slices are extracted. + slice_mode: 'single' or 'rgb'. + discarded_slices: Number of slices to discard at the beginning and the end of the image. + Will be a tuple of two integers if the number of slices to discard at the beginning + and at the end differ. + Returns: + list of tuples containing the path to the extracted slice + and the tensor of the corresponding slice. + """ + + image_tensor = self.extract_image(nii_path) + + begin_discard, end_discard = self.compute_discarded_slices() + index_list = range( + begin_discard, + image_tensor.shape[int(self.slice_direction.value) + 1] - end_discard, + ) + + slice_list = [] + for slice_index in index_list: + slice_tensor = self.extract_tensor(image_tensor, slice_index) + slice_path = self.extract_path(nii_path, slice_index) + + slice_list.append((slice_path, slice_tensor)) + + return slice_list + + def extract_tensor( + self, + image_tensor: torch.Tensor, + slice_index: int, + ) -> torch.Tensor: + # Allow to select the slice `slice_index` in dimension `slice_direction` + slice_index = slice_index + self.discarded_slices[0] + + idx_tuple = tuple( + [slice(None)] * (int(self.slice_direction.value) + 1) + + [slice_index] + + [slice(None)] * (2 - int(self.slice_direction.value)) + ) + slice_tensor = image_tensor[idx_tuple] # shape is 1 * W * L + + if self.slice_mode == "rgb": + slice_tensor = torch.cat( + (slice_tensor, slice_tensor, slice_tensor) + ) # shape is 3 * W * L + + return slice_tensor.clone() + + def extract_path( + self, + img_path: Path, + slice_index: int, + ) -> str: + slice_dict = {0: "sag", 1: "cor", 2: "axi"} + input_img_filename = img_path.name + txt_idx = input_img_filename.rfind("_") + it_filename_prefix = input_img_filename[0:txt_idx] + it_filename_suffix = input_img_filename[txt_idx:] + it_filename_suffix = it_filename_suffix.replace( + Suffix.NII_GZ.value, Suffix.PT.value + ) + return ( + f"{it_filename_prefix}_axis-{slice_dict[int(self.slice_direction.value)]}" + f"_channel-{self.slice_mode.value}_slice-{slice_index}{it_filename_suffix}" + ) + class ExtractionROIConfig(ExtractionConfig): roi_list: List[str] = [] - roi_uncrop_output: bool = False + roi_crop_input: bool = True + roi_crop_output: bool = True + roi_template: str = "" + roi_mask_pattern: str = "" + roi_mask_location: Path + roi_custom_template: str = "" - roi_custom_pattern: str = "" - roi_custom_suffix: str = "" roi_custom_mask_pattern: str = "" - roi_background_value: int = 0 extract_method: ExtractionMethod = ExtractionMethod.ROI + + @field_validator("roi_list", mode="before") + def check_roi_list(self, v): + if v is None: + raise NotImplementedError( + "Default regions are not available anymore in ClinicaDL. " + "Please define appropriate masks and give a roi_list." + ) + + def num_elem_per_image(self, image: torch.Tensor, elem_index: Optional[int] = None): + if elem_index is not None: + return 1 + else: + return len(self.roi_list) + + def check_with_preprocessing(self, preprocessing: Preprocessing): + if preprocessing == Preprocessing.CUSTOM: + if not self.roi_template: + raise ClinicaDLArgumentError( + "A custom template must be defined when the modality is set to custom." + ) + self.roi_template = self.roi_custom_template + self.roi_mask_pattern = self.roi_custom_mask_pattern + else: + if preprocessing == Preprocessing.T1_LINEAR: + self.roi_template = Template.T1_LINEAR + self.roi_mask_pattern = Pattern.T1_LINEAR + elif preprocessing == Preprocessing.PET_LINEAR: + self.roi_template = Template.PET_LINEAR + self.roi_mask_pattern = Pattern.PET_LINEAR + elif preprocessing == Preprocessing.FLAIR_LINEAR: + self.roi_template = Template.FLAIR_LINEAR + self.roi_mask_pattern = Pattern.FLAIR_LINEAR + + def check_mask_list( + self, + masks_location: Path, + ) -> None: + if len(self.roi_list) == 0: + raise ClinicaDLArgumentError("A list of regions of interest must be given.") + + for roi in self.roi_list: + roi_path, desc = self.find_mask_path(masks_location, roi) + if roi_path is None: + raise FileNotFoundError( + f"The ROI '{roi}' does not correspond to a mask in the CAPS directory. {desc}" + ) + roi_mask = nib.loadsave.load(roi_path).get_fdata() # type: ignore + mask_values = set(np.unique(roi_mask)) + if mask_values != {0, 1}: + raise ValueError( + "The ROI masks used should be binary (composed of 0 and 1 only)." + ) + + def find_mask_path( + self, + masks_location: Path, + roi: str, + ) -> Tuple[Union[None, Path], str]: + """ + Finds masks corresponding to the pattern asked and containing the adequate self.roi_crop_input description + + Parameters + ---------- + masks_location: Path + Directory containing the masks. + roi: str + Name of the region. + mask_pattern: str + Pattern which should be found in the filename of the mask. + self.roi_crop_input: bool + If True the original image should contain the substring 'desc-Crop'. + + Returns + ------- + path of the mask or None if nothing was found. + a human-friendly description of the pattern looked for. + """ + + # Check that pattern begins and ends with _ to avoid mixing keys + if self.roi_mask_pattern is None: + mask_pattern = "" + + candidates_pattern = f"*{mask_pattern}*_roi-{roi}_mask.nii*" + + desc = f"The mask should follow the pattern {candidates_pattern}. " + candidates = [e for e in masks_location.glob(candidates_pattern)] + if self.roi_crop_input is None: + # pass + candidates2 = candidates + elif self.roi_crop_input: + candidates2 = [mask for mask in candidates if "_desc-Crop_" in mask.name] + desc += "and contain '_desc-Crop_' string." + else: + candidates2 = [ + mask for mask in candidates if "_desc-Crop_" not in mask.name + ] + desc += "and not contain '_desc-Crop_' string." + + if len(candidates2) == 0: + return None, desc + else: + return min(candidates2), desc + + def compute_output_pattern(self, mask_path: Path): + """ + Computes the output pattern of the region cropped (without the source file prefix) + Parameters + ---------- + mask_path: Path + Path to the masks + self.roi_crop_output: bool + If True the output is cropped, and the descriptor CropRoi must exist + + Returns + ------- + the output pattern + """ + + mask_filename = mask_path.name + template_id = mask_filename.split("_")[0].split("-")[1] + mask_descriptors = mask_filename.split("_")[1:-2:] + roi_id = mask_filename.split("_")[-2].split("-")[1] + if "desc-Crop" not in mask_descriptors and not self.roi_crop_output: + mask_descriptors = ["desc-CropRoi"] + mask_descriptors + elif "desc-Crop" in mask_descriptors: + mask_descriptors = [ + descriptor + for descriptor in mask_descriptors + if descriptor != "desc-Crop" + ] + if self.roi_crop_output: + mask_descriptors = ["desc-CropRoi"] + mask_descriptors + else: + mask_descriptors = ["desc-CropImage"] + mask_descriptors + + mask_pattern = "_".join(mask_descriptors) + + if mask_pattern == "": + output_pattern = f"space-{template_id}_roi-{roi_id}" + else: + output_pattern = f"space-{template_id}_{mask_pattern}_roi-{roi_id}" + + return output_pattern + + def extract( + self, + nii_path: Path, + ) -> List[Tuple[str, torch.Tensor]]: + """Extracts regions of interest defined by masks + This function extracts regions of interest from preprocessed nifti images. + The regions are defined using binary masks that must be located in the CAPS + at `masks/tpl-