From 12fcad09cadbc50318330ab008bc99d92d7377b2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Fri, 24 Oct 2025 12:32:43 +0200 Subject: [PATCH 01/99] update: add pydantic to deps --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b104d91c6..5ebe1c368 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ "looseversion==1.3.0; python_version>='3.12'", "junifer_data==1.3.0", "structlog>=25.0.0,<26.0.0", + "pydantic>=2.11.4", ] dynamic = ["version"] From 9029f8a3d6ef1df918dd204c0aa6ebb01a02607a Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 27 Oct 2025 16:19:26 +0100 Subject: [PATCH 02/99] update: introduce DataType enum --- junifer/datagrabber/__init__.pyi | 3 ++- junifer/datagrabber/base.py | 19 ++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/junifer/datagrabber/__init__.pyi b/junifer/datagrabber/__init__.pyi index 51fcd6f49..371e32dc5 100644 --- a/junifer/datagrabber/__init__.pyi +++ b/junifer/datagrabber/__init__.pyi @@ -15,10 +15,11 @@ __all__ = [ "OptionalTypeSchema", "PatternValidationMixin", "register_data_type", + "DataType", ] # These 4 need to be in this order, otherwise it is a circular import -from .base import BaseDataGrabber +from .base import BaseDataGrabber, DataType from .datalad_base import DataladDataGrabber from .pattern import PatternDataGrabber from .pattern_datalad import PatternDataladDataGrabber diff --git a/junifer/datagrabber/base.py b/junifer/datagrabber/base.py index f2707ec80..2b5c23eb9 100644 --- a/junifer/datagrabber/base.py +++ b/junifer/datagrabber/base.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterator +from enum import Enum from pathlib import Path from typing import Union @@ -15,9 +16,25 @@ from ..utils import logger, raise_error -__all__ = ["BaseDataGrabber"] +__all__ = ["BaseDataGrabber", "DataType"] +class DataType(str, Enum): + """Accepted data type.""" + + T1w = "T1w" + T2w = "T2w" + BOLD = "BOLD" + Warp = "Warp" + VBM_GM = "VBM_GM" + VBM_WM = "VBM_WM" + VBM_CSF = "VBM_CSF" + FALFF = "fALFF" + GCOR = "GCOR" + LCOR = "LCOR" + DWI = "DWI" + FreeSurfer = "FreeSurfer" + class BaseDataGrabber(ABC, UpdateMetaMixin): """Abstract base class for DataGrabber. From 8b997e44b82f9e04e75099979a680070b4eeb386 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 27 Oct 2025 16:42:13 +0100 Subject: [PATCH 03/99] update: introduce StorageType enum --- junifer/storage/__init__.pyi | 3 ++- junifer/storage/base.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/junifer/storage/__init__.pyi b/junifer/storage/__init__.pyi index b4a687e7b..eb533c099 100644 --- a/junifer/storage/__init__.pyi +++ b/junifer/storage/__init__.pyi @@ -3,9 +3,10 @@ __all__ = [ "PandasBaseFeatureStorage", "SQLiteFeatureStorage", "HDF5FeatureStorage", + "StorageType", ] -from .base import BaseFeatureStorage from .pandas_base import PandasBaseFeatureStorage from .sqlite import SQLiteFeatureStorage +from .base import BaseFeatureStorage, StorageType from .hdf5 import HDF5FeatureStorage diff --git a/junifer/storage/base.py b/junifer/storage/base.py index 13f4448b7..aac7f1fff 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence +from enum import Enum from pathlib import Path from typing import Any, ClassVar, Optional, Union @@ -16,7 +17,17 @@ from .utils import process_meta -__all__ = ["BaseFeatureStorage"] +__all__ = ["BaseFeatureStorage", "StorageType"] + + +class StorageType(str, Enum): + """Accepted storage type.""" + + Vector = "vector" + Matrix = "matrix" + Timeseries = "timeseries" + Timeseries2D = "timeseries_2d" + ScalarTable = "scalar_table" class BaseFeatureStorage(ABC): From 6a836d94c5d11ba2f043834d69780f83696d003f Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 13:25:33 +0100 Subject: [PATCH 04/99] update: adapt DataType and StorageType for markers --- junifer/markers/brainprint.py | 12 +++--- junifer/markers/complexity/complexity_base.py | 7 +++- junifer/markers/ets_rss.py | 6 ++- .../functional_connectivity_base.py | 6 ++- junifer/markers/parcel_aggregation.py | 38 ++++++++++--------- junifer/markers/sphere_aggregation.py | 38 ++++++++++--------- .../markers/temporal_snr/temporal_snr_base.py | 7 +++- 7 files changed, 65 insertions(+), 49 deletions(-) diff --git a/junifer/markers/brainprint.py b/junifer/markers/brainprint.py index 815c3f820..877fd6aff 100644 --- a/junifer/markers/brainprint.py +++ b/junifer/markers/brainprint.py @@ -16,12 +16,14 @@ import numpy.typing as npt from ..api.decorators import register_marker +from ..datagrabber import DataType from ..external.BrainPrint.brainprint.brainprint import ( compute_asymmetry, compute_brainprint, ) from ..external.BrainPrint.brainprint.surfaces import surf_to_vtk from ..pipeline import WorkDirManager +from ..storage import StorageType from ..typing import Dependencies, ExternalDependencies, MarkerInOutMappings from ..utils import logger, run_ext_cmd from .base import BaseMarker @@ -81,11 +83,11 @@ class BrainPrint(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"lapy", "numpy"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "FreeSurfer": { - "eigenvalues": "scalar_table", - "areas": "vector", - "volumes": "vector", - "distances": "vector", + DataType.FreeSurfer: { + "eigenvalues": StorageType.ScalarTable, + "areas": StorageType.Vector, + "volumes": StorageType.Vector, + "distances": StorageType.Vector, } } diff --git a/junifer/markers/complexity/complexity_base.py b/junifer/markers/complexity/complexity_base.py index de5864d36..da30bcfc6 100644 --- a/junifer/markers/complexity/complexity_base.py +++ b/junifer/markers/complexity/complexity_base.py @@ -1,6 +1,7 @@ """Provide base class for complexity.""" # Authors: Amir Omidvarnia +# Synchon Mandal # License: AGPL from abc import abstractmethod @@ -12,6 +13,8 @@ Union, ) +from ...datagrabber import DataType +from ...storage import StorageType from ...typing import Dependencies, MarkerInOutMappings from ...utils import raise_error from ..base import BaseMarker @@ -52,8 +55,8 @@ class ComplexityBase(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn", "neurokit2"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "BOLD": { - "complexity": "vector", + DataType.BOLD: { + "complexity": StorageType.Vector, }, } diff --git a/junifer/markers/ets_rss.py b/junifer/markers/ets_rss.py index 6031d3b57..aa816b6c7 100644 --- a/junifer/markers/ets_rss.py +++ b/junifer/markers/ets_rss.py @@ -11,6 +11,8 @@ import numpy as np from ..api.decorators import register_marker +from ..datagrabber import DataType +from ..storage import StorageType from ..typing import Dependencies, MarkerInOutMappings from ..utils import logger from .base import BaseMarker @@ -49,8 +51,8 @@ class RSSETSMarker(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "BOLD": { - "rss_ets": "timeseries", + DataType.BOLD: { + "rss_ets": StorageType.Timeseries, }, } diff --git a/junifer/markers/functional_connectivity/functional_connectivity_base.py b/junifer/markers/functional_connectivity/functional_connectivity_base.py index e35bf23d3..1a317ad42 100644 --- a/junifer/markers/functional_connectivity/functional_connectivity_base.py +++ b/junifer/markers/functional_connectivity/functional_connectivity_base.py @@ -8,7 +8,9 @@ from sklearn.covariance import EmpiricalCovariance, LedoitWolf +from ...datagrabber import DataType from ...external.nilearn import JuniferConnectivityMeasure +from ...storage import StorageType from ...typing import Dependencies, MarkerInOutMappings from ...utils import raise_error from ..base import BaseMarker @@ -54,8 +56,8 @@ class FunctionalConnectivityBase(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn", "scikit-learn"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "BOLD": { - "functional_connectivity": "matrix", + DataType.BOLD: { + "functional_connectivity": StorageType.Matrix, }, } diff --git a/junifer/markers/parcel_aggregation.py b/junifer/markers/parcel_aggregation.py index ca9cd3911..92fbf7efa 100644 --- a/junifer/markers/parcel_aggregation.py +++ b/junifer/markers/parcel_aggregation.py @@ -12,7 +12,9 @@ from ..api.decorators import register_marker from ..data import get_data +from ..datagrabber import DataType from ..stats import get_aggfunc_by_name +from ..storage import StorageType from ..typing import Dependencies, MarkerInOutMappings from ..utils import logger, raise_error, warn_with_log from .base import BaseMarker @@ -65,32 +67,32 @@ class ParcelAggregation(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn", "numpy"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "T1w": { - "aggregation": "vector", + DataType.T1w: { + "aggregation": StorageType.Vector, }, - "T2w": { - "aggregation": "vector", + DataType.T2w: { + "aggregation": StorageType.Vector, }, - "BOLD": { - "aggregation": "timeseries", + DataType.BOLD: { + "aggregation": StorageType.Timeseries, }, - "VBM_GM": { - "aggregation": "vector", + DataType.VBM_GM: { + "aggregation": StorageType.Vector, }, - "VBM_WM": { - "aggregation": "vector", + DataType.VBM_WM: { + "aggregation": StorageType.Vector, }, - "VBM_CSF": { - "aggregation": "vector", + DataType.VBM_CSF: { + "aggregation": StorageType.Vector, }, - "fALFF": { - "aggregation": "vector", + DataType.FALFF: { + "aggregation": StorageType.Vector, }, - "GCOR": { - "aggregation": "vector", + DataType.GCOR: { + "aggregation": StorageType.Vector, }, - "LCOR": { - "aggregation": "vector", + DataType.LCOR: { + "aggregation": StorageType.Vector, }, } diff --git a/junifer/markers/sphere_aggregation.py b/junifer/markers/sphere_aggregation.py index 04e75db1e..b225fab2e 100644 --- a/junifer/markers/sphere_aggregation.py +++ b/junifer/markers/sphere_aggregation.py @@ -8,8 +8,10 @@ from ..api.decorators import register_marker from ..data import get_data +from ..datagrabber import DataType from ..external.nilearn import JuniferNiftiSpheresMasker from ..stats import get_aggfunc_by_name +from ..storage import StorageType from ..typing import Dependencies, MarkerInOutMappings from ..utils import logger, raise_error, warn_with_log from .base import BaseMarker @@ -70,32 +72,32 @@ class SphereAggregation(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn", "numpy"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "T1w": { - "aggregation": "vector", + DataType.T1w: { + "aggregation": StorageType.Vector, }, - "T2w": { - "aggregation": "vector", + DataType.T2w: { + "aggregation": StorageType.Vector, }, - "BOLD": { - "aggregation": "timeseries", + DataType.BOLD: { + "aggregation": StorageType.Timeseries, }, - "VBM_GM": { - "aggregation": "vector", + DataType.VBM_GM: { + "aggregation": StorageType.Vector, }, - "VBM_WM": { - "aggregation": "vector", + DataType.VBM_WM: { + "aggregation": StorageType.Vector, }, - "VBM_CSF": { - "aggregation": "vector", + DataType.VBM_CSF: { + "aggregation": StorageType.Vector, }, - "fALFF": { - "aggregation": "vector", + DataType.FALFF: { + "aggregation": StorageType.Vector, }, - "GCOR": { - "aggregation": "vector", + DataType.GCOR: { + "aggregation": StorageType.Vector, }, - "LCOR": { - "aggregation": "vector", + DataType.LCOR: { + "aggregation": StorageType.Vector, }, } diff --git a/junifer/markers/temporal_snr/temporal_snr_base.py b/junifer/markers/temporal_snr/temporal_snr_base.py index c0f946c12..b4f3c3aa9 100644 --- a/junifer/markers/temporal_snr/temporal_snr_base.py +++ b/junifer/markers/temporal_snr/temporal_snr_base.py @@ -1,6 +1,7 @@ """Provide abstract base class for temporal signal-to-noise ratio (tSNR).""" # Authors: Leonard Sasse +# Synchon Mandal # License: AGPL from abc import abstractmethod @@ -8,6 +9,8 @@ from nilearn import image as nimg +from ...datagrabber import DataType +from ...storage import StorageType from ...typing import Dependencies, MarkerInOutMappings from ...utils import raise_error from ..base import BaseMarker @@ -40,8 +43,8 @@ class TemporalSNRBase(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "BOLD": { - "tsnr": "vector", + DataType.BOLD: { + "tsnr": StorageType.Vector, }, } From 2de4aebe3d7fffca65a16638f17fb2cf75c11d0c Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 13:38:35 +0100 Subject: [PATCH 05/99] update: introduce ExtDep enum --- junifer/pipeline/__init__.pyi | 2 ++ junifer/pipeline/utils.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/junifer/pipeline/__init__.pyi b/junifer/pipeline/__init__.pyi index f2c45a38d..aa973e18f 100644 --- a/junifer/pipeline/__init__.pyi +++ b/junifer/pipeline/__init__.pyi @@ -3,6 +3,7 @@ __all__ = [ "AssetLoaderDispatcher", "BaseDataDumpAsset", "DataObjectDumper", + "ExtDep", "PipelineComponentRegistry", "PipelineStepMixin", "UpdateMetaMixin", @@ -19,5 +20,6 @@ from ._data_object_dumper import ( from .pipeline_component_registry import PipelineComponentRegistry from .pipeline_step_mixin import PipelineStepMixin from .update_meta_mixin import UpdateMetaMixin +from .utils import ExtDep from .workdir_manager import WorkDirManager from .marker_collection import MarkerCollection diff --git a/junifer/pipeline/utils.py b/junifer/pipeline/utils.py index 66dd433d7..c943ae53a 100644 --- a/junifer/pipeline/utils.py +++ b/junifer/pipeline/utils.py @@ -5,12 +5,22 @@ # License: AGPL import subprocess +from enum import Enum from typing import Any, Optional from junifer.utils.logging import raise_error, warn_with_log -__all__ = ["check_ext_dependencies"] +__all__ = ["ExtDep", "check_ext_dependencies"] + + +class ExtDep(str, Enum): + """Accepted external dependencies.""" + + AFNI = "afni" + FSL = "fsl" + ANTs = "ants" + FreeSurfer = "freesurfer" def check_ext_dependencies( From c523739f16b907b3b9301c2c33ccf7e94fa59439 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 14:34:42 +0100 Subject: [PATCH 06/99] update: adapt ExtDep for preprocessors --- junifer/preprocess/smoothing/_afni_smoothing.py | 4 ++-- junifer/preprocess/smoothing/_fsl_smoothing.py | 4 ++-- junifer/preprocess/warping/_ants_warper.py | 4 ++-- junifer/preprocess/warping/_fsl_warper.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/junifer/preprocess/smoothing/_afni_smoothing.py b/junifer/preprocess/smoothing/_afni_smoothing.py index 53c14292f..e7c847b3f 100644 --- a/junifer/preprocess/smoothing/_afni_smoothing.py +++ b/junifer/preprocess/smoothing/_afni_smoothing.py @@ -11,7 +11,7 @@ import nibabel as nib -from ...pipeline import WorkDirManager +from ...pipeline import ExtDep, WorkDirManager from ...typing import Dependencies, ExternalDependencies from ...utils import logger, run_ext_cmd @@ -28,7 +28,7 @@ class AFNISmoothing: _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "afni", + "name": ExtDep.AFNI, "commands": ["3dBlurToFWHM"], }, ] diff --git a/junifer/preprocess/smoothing/_fsl_smoothing.py b/junifer/preprocess/smoothing/_fsl_smoothing.py index 88d21b3a1..a50e33cb8 100644 --- a/junifer/preprocess/smoothing/_fsl_smoothing.py +++ b/junifer/preprocess/smoothing/_fsl_smoothing.py @@ -10,7 +10,7 @@ import nibabel as nib -from ...pipeline import WorkDirManager +from ...pipeline import ExtDep, WorkDirManager from ...typing import Dependencies, ExternalDependencies from ...utils import logger, run_ext_cmd @@ -27,7 +27,7 @@ class FSLSmoothing: _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "fsl", + "name": ExtDep.FSL, "commands": ["susan"], }, ] diff --git a/junifer/preprocess/warping/_ants_warper.py b/junifer/preprocess/warping/_ants_warper.py index 7a94432b1..3b083c6d8 100644 --- a/junifer/preprocess/warping/_ants_warper.py +++ b/junifer/preprocess/warping/_ants_warper.py @@ -12,7 +12,7 @@ import numpy as np from ...data import get_template, get_xfm -from ...pipeline import WorkDirManager +from ...pipeline import ExtDep, WorkDirManager from ...typing import Dependencies, ExternalDependencies from ...utils import logger, raise_error, run_ext_cmd @@ -30,7 +30,7 @@ class ANTsWarper: _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "ants", + "name": ExtDep.ANTs, "commands": ["ResampleImage", "antsApplyTransforms"], }, ] diff --git a/junifer/preprocess/warping/_fsl_warper.py b/junifer/preprocess/warping/_fsl_warper.py index 51178b946..f7483b263 100644 --- a/junifer/preprocess/warping/_fsl_warper.py +++ b/junifer/preprocess/warping/_fsl_warper.py @@ -11,7 +11,7 @@ import nibabel as nib import numpy as np -from ...pipeline import WorkDirManager +from ...pipeline import ExtDep, WorkDirManager from ...typing import Dependencies, ExternalDependencies from ...utils import logger, raise_error, run_ext_cmd @@ -29,7 +29,7 @@ class FSLWarper: _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "fsl", + "name": ExtDep.FSL, "commands": ["flirt", "applywarp"], }, ] From f416a3848eae5dcadb2b468ff68933564dc2c0bb Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 14:43:12 +0100 Subject: [PATCH 07/99] update: adapt ExtDep for markers --- junifer/markers/brainprint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/junifer/markers/brainprint.py b/junifer/markers/brainprint.py index 877fd6aff..76a0f5c6a 100644 --- a/junifer/markers/brainprint.py +++ b/junifer/markers/brainprint.py @@ -22,7 +22,7 @@ compute_brainprint, ) from ..external.BrainPrint.brainprint.surfaces import surf_to_vtk -from ..pipeline import WorkDirManager +from ..pipeline import ExtDep, WorkDirManager from ..storage import StorageType from ..typing import Dependencies, ExternalDependencies, MarkerInOutMappings from ..utils import logger, run_ext_cmd @@ -70,7 +70,7 @@ class BrainPrint(BaseMarker): _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "freesurfer", + "name": ExtDep.FreeSurfer, "commands": [ "mri_binarize", "mri_pretess", From c8c669f2c354eee5e15449247a3b9e69175ce4a1 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 17:06:32 +0100 Subject: [PATCH 08/99] update: improve typing annotation for MarkerInOutMappings --- junifer/typing/_typing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/junifer/typing/_typing.py b/junifer/typing/_typing.py index df5fd646a..b4fede74b 100644 --- a/junifer/typing/_typing.py +++ b/junifer/typing/_typing.py @@ -12,11 +12,11 @@ if TYPE_CHECKING: from ..data import BasePipelineDataRegistry - from ..datagrabber import BaseDataGrabber + from ..datagrabber import BaseDataGrabber, DataType from ..datareader import DefaultDataReader from ..markers import BaseMarker from ..preprocess import BasePreprocessor - from ..storage import BaseFeatureStorage + from ..storage import BaseFeatureStorage, StorageType __all__ = [ @@ -62,9 +62,9 @@ Sequence[PipelineComponent], ], ] +MarkerInOutMappings = dict["DataType", dict[str, "StorageType"]] ] ExternalDependencies = Sequence[MutableMapping[str, Union[str, Sequence[str]]]] -MarkerInOutMappings = MutableMapping[str, MutableMapping[str, str]] DataGrabberPatterns = dict[str, Union[dict[str, str], list[dict[str, str]]]] ConfigVal = Union[bool, int, float, str] Element = Union[str, tuple[str, ...]] From 5fcf65f50827d6a799f1d431f8be2329651f6192 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 17:09:38 +0100 Subject: [PATCH 09/99] update: adapt DataType and StorageType for remaining markers --- junifer/markers/falff/falff_base.py | 8 ++-- ...ossparcellation_functional_connectivity.py | 6 ++- junifer/markers/maps_aggregation.py | 38 ++++++++++--------- junifer/markers/reho/reho_base.py | 6 ++- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/junifer/markers/falff/falff_base.py b/junifer/markers/falff/falff_base.py index 0960c0385..d743deea4 100644 --- a/junifer/markers/falff/falff_base.py +++ b/junifer/markers/falff/falff_base.py @@ -14,6 +14,8 @@ Optional, ) +from ...datagrabber import DataType +from ...storage import StorageType from ...typing import ConditionalDependencies, MarkerInOutMappings from ...utils.logging import logger, raise_error from ..base import BaseMarker @@ -80,9 +82,9 @@ class ALFFBase(BaseMarker): ] _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "BOLD": { - "alff": "vector", - "falff": "vector", + DataType.BOLD: { + "alff": StorageType.Vector, + "falff": StorageType.Vector, }, } diff --git a/junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py b/junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py index 97479cee7..4f7515953 100644 --- a/junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +++ b/junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py @@ -9,6 +9,8 @@ import pandas as pd from ...api.decorators import register_marker +from ...datagrabber import DataType +from ...storage import StorageType from ...typing import Dependencies, MarkerInOutMappings from ...utils import logger, raise_error from ..base import BaseMarker @@ -53,8 +55,8 @@ class CrossParcellationFC(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "BOLD": { - "functional_connectivity": "matrix", + DataType.BOLD: { + "functional_connectivity": StorageType.Matrix, }, } diff --git a/junifer/markers/maps_aggregation.py b/junifer/markers/maps_aggregation.py index 2d231f67a..540c6428b 100644 --- a/junifer/markers/maps_aggregation.py +++ b/junifer/markers/maps_aggregation.py @@ -9,7 +9,9 @@ from ..api.decorators import register_marker from ..data import get_data +from ..datagrabber import DataType from ..stats import get_aggfunc_by_name +from ..storage import StorageType from ..typing import Dependencies, MarkerInOutMappings from ..utils import logger, raise_error, warn_with_log from .base import BaseMarker @@ -56,32 +58,32 @@ class MapsAggregation(BaseMarker): _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn", "numpy"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "T1w": { - "aggregation": "vector", + DataType.T1w: { + "aggregation": StorageType.Vector, }, - "T2w": { - "aggregation": "vector", + DataType.T2w: { + "aggregation": StorageType.Vector, }, - "BOLD": { - "aggregation": "timeseries", + DataType.BOLD: { + "aggregation": StorageType.Timeseries, }, - "VBM_GM": { - "aggregation": "vector", + DataType.VBM_GM: { + "aggregation": StorageType.Vector, }, - "VBM_WM": { - "aggregation": "vector", + DataType.VBM_WM: { + "aggregation": StorageType.Vector, }, - "VBM_CSF": { - "aggregation": "vector", + DataType.VBM_CSF: { + "aggregation": StorageType.Vector, }, - "fALFF": { - "aggregation": "vector", + DataType.FALFF: { + "aggregation": StorageType.Vector, }, - "GCOR": { - "aggregation": "vector", + DataType.GCOR: { + "aggregation": StorageType.Vector, }, - "LCOR": { - "aggregation": "vector", + DataType.LCOR: { + "aggregation": StorageType.Vector, }, } diff --git a/junifer/markers/reho/reho_base.py b/junifer/markers/reho/reho_base.py index cb0e363ed..0e0f0ebab 100644 --- a/junifer/markers/reho/reho_base.py +++ b/junifer/markers/reho/reho_base.py @@ -11,6 +11,8 @@ Optional, ) +from ...datagrabber import DataType +from ...storage import StorageType from ...typing import ConditionalDependencies, MarkerInOutMappings from ...utils import logger, raise_error from ..base import BaseMarker @@ -58,8 +60,8 @@ class ReHoBase(BaseMarker): ] _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "BOLD": { - "reho": "vector", + DataType.BOLD: { + "reho": StorageType.Vector, }, } From 0e8d1a614e375ac0d9d19ba0e8e72f67ec628790 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 17:50:38 +0100 Subject: [PATCH 10/99] update: introduce ConfoundsFormat enum --- junifer/datagrabber/__init__.pyi | 3 ++- junifer/datagrabber/pattern.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/junifer/datagrabber/__init__.pyi b/junifer/datagrabber/__init__.pyi index 371e32dc5..9ac3524d6 100644 --- a/junifer/datagrabber/__init__.pyi +++ b/junifer/datagrabber/__init__.pyi @@ -16,12 +16,13 @@ __all__ = [ "PatternValidationMixin", "register_data_type", "DataType", + "ConfoundsFormat", ] # These 4 need to be in this order, otherwise it is a circular import from .base import BaseDataGrabber, DataType from .datalad_base import DataladDataGrabber -from .pattern import PatternDataGrabber +from .pattern import PatternDataGrabber, ConfoundsFormat from .pattern_datalad import PatternDataladDataGrabber from .aomic import DataladAOMICID1000, DataladAOMICPIOP1, DataladAOMICPIOP2 diff --git a/junifer/datagrabber/pattern.py b/junifer/datagrabber/pattern.py index 7b8173cbf..86041b68c 100644 --- a/junifer/datagrabber/pattern.py +++ b/junifer/datagrabber/pattern.py @@ -7,6 +7,7 @@ import re from copy import deepcopy +from enum import Enum from pathlib import Path from typing import Optional, Union @@ -19,11 +20,16 @@ from .pattern_validation_mixin import PatternValidationMixin -__all__ = ["PatternDataGrabber"] +__all__ = ["ConfoundsFormat", "PatternDataGrabber"] # Accepted formats for confounds specification _CONFOUNDS_FORMATS = ("fmriprep", "adhoc") +class ConfoundsFormat(str, Enum): + """Accepted confounds format.""" + + FMRIPrep = "fmriprep" + AdHoc = "adhoc" @register_datagrabber From 0a2e76430c68a4d098d6e9e23adc178495f71177 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 18:00:07 +0100 Subject: [PATCH 11/99] update: introduce MatrixKind enum --- junifer/storage/__init__.pyi | 3 ++- junifer/storage/base.py | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/junifer/storage/__init__.pyi b/junifer/storage/__init__.pyi index eb533c099..06ae79c91 100644 --- a/junifer/storage/__init__.pyi +++ b/junifer/storage/__init__.pyi @@ -4,9 +4,10 @@ __all__ = [ "SQLiteFeatureStorage", "HDF5FeatureStorage", "StorageType", + "MatrixKind", ] from .pandas_base import PandasBaseFeatureStorage from .sqlite import SQLiteFeatureStorage -from .base import BaseFeatureStorage, StorageType +from .base import BaseFeatureStorage, MatrixKind, StorageType from .hdf5 import HDF5FeatureStorage diff --git a/junifer/storage/base.py b/junifer/storage/base.py index aac7f1fff..51eff7b94 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -17,7 +17,7 @@ from .utils import process_meta -__all__ = ["BaseFeatureStorage", "StorageType"] +__all__ = ["BaseFeatureStorage", "MatrixKind", "StorageType"] class StorageType(str, Enum): @@ -30,6 +30,14 @@ class StorageType(str, Enum): ScalarTable = "scalar_table" +class MatrixKind(str, Enum): + """Accepted matrix kind value.""" + + UpperTriangle = "triu" + LowerTriangle = "tril" + Full = "full" + + class BaseFeatureStorage(ABC): """Abstract base class for feature storage. @@ -269,14 +277,8 @@ def store_matrix( The column labels (default None). row_names : list-like of str, optional The row labels (default None). - matrix_kind : str, optional - The kind of matrix: - - * ``triu`` : store upper triangular only - * ``tril`` : store lower triangular - * ``full`` : full matrix - - (default "full"). + matrix_kind : MatrixKind, optional + The matrix kind (default MatrixKind.Full). diagonal : bool, optional Whether to store the diagonal. If ``matrix_kind = full``, setting this to False will raise an error (default True). From 532df417684cd53c638e99d187c1580c29ea2ce2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 28 Oct 2025 18:20:22 +0100 Subject: [PATCH 12/99] update: adapt MatrixKind --- .../functional_connectivity_base.py | 8 +++--- junifer/storage/base.py | 6 ++--- junifer/storage/hdf5.py | 18 +++++-------- junifer/storage/sqlite.py | 23 ++++++++-------- junifer/storage/utils.py | 27 ++++++++----------- 5 files changed, 36 insertions(+), 46 deletions(-) diff --git a/junifer/markers/functional_connectivity/functional_connectivity_base.py b/junifer/markers/functional_connectivity/functional_connectivity_base.py index 1a317ad42..6ab818401 100644 --- a/junifer/markers/functional_connectivity/functional_connectivity_base.py +++ b/junifer/markers/functional_connectivity/functional_connectivity_base.py @@ -10,7 +10,7 @@ from ...datagrabber import DataType from ...external.nilearn import JuniferConnectivityMeasure -from ...storage import StorageType +from ...storage import MatrixKind, StorageType from ...typing import Dependencies, MarkerInOutMappings from ...utils import raise_error from ..base import BaseMarker @@ -123,7 +123,7 @@ def compute( - ``data`` : functional connectivity matrix as ``numpy.ndarray`` - ``row_names`` : ROI labels as list of str - ``col_names`` : ROI labels as list of str - - ``matrix_kind`` : the kind of matrix (tril, triu or full) + - ``matrix_kind`` : :obj:`.junifer.storage.MatrixKind` """ # Perform necessary aggregation @@ -154,7 +154,9 @@ def compute( "col_names": labels, # xi correlation coefficient is not symmetric "matrix_kind": ( - "full" if self.conn_method == "xi correlation" else "tril" + MatrixKind.Full + if self.conn_method == "xi correlation" + else MatrixKind.LowerTriangle ), }, } diff --git a/junifer/storage/base.py b/junifer/storage/base.py index 51eff7b94..4a687612c 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -260,7 +260,7 @@ def store_matrix( data: np.ndarray, col_names: Optional[Sequence[str]] = None, row_names: Optional[Sequence[str]] = None, - matrix_kind: str = "full", + matrix_kind: MatrixKind = MatrixKind.Full, diagonal: bool = True, ) -> None: """Store matrix. @@ -280,8 +280,8 @@ def store_matrix( matrix_kind : MatrixKind, optional The matrix kind (default MatrixKind.Full). diagonal : bool, optional - Whether to store the diagonal. If ``matrix_kind = full``, setting - this to False will raise an error (default True). + Whether to store the diagonal. If ``matrix_kind=MatrixKind.Full``, + setting this to False will raise an error (default True). """ raise_error( diff --git a/junifer/storage/hdf5.py b/junifer/storage/hdf5.py index 13ade4af5..936beea97 100644 --- a/junifer/storage/hdf5.py +++ b/junifer/storage/hdf5.py @@ -22,7 +22,7 @@ write_hdf5, ) from ..utils import logger, raise_error -from .base import BaseFeatureStorage +from .base import BaseFeatureStorage, MatrixKind from .utils import ( element_to_prefix, matrix_to_vector, @@ -818,7 +818,7 @@ def store_matrix( data: np.ndarray, col_names: Optional[Sequence[str]] = None, row_names: Optional[Sequence[str]] = None, - matrix_kind: str = "full", + matrix_kind: MatrixKind = MatrixKind.Full, diagonal: bool = True, row_header_col_name: str = "ROI", ) -> None: @@ -839,16 +839,10 @@ def store_matrix( The column labels (default None). row_names : list-like of str, optional The row labels (default None). - matrix_kind : str, optional - The kind of matrix: - - * ``triu`` : store upper triangular only - * ``tril`` : store lower triangular - * ``full`` : full matrix - - (default "full"). + matrix_kind : MatrixKind, optional + The matrix kind (default MatrixKind.Full). diagonal : bool, optional - Whether to store the diagonal. If ``matrix_kind`` is "full", + Whether to store the diagonal. If ``matrix_kind=MatrixKind.Full``, setting this to False will raise an error (default True). row_header_col_name : str, optional The column name for the row header column (default "ROI"). @@ -879,7 +873,7 @@ def store_matrix( ) # Store self._store_data( - kind="matrix", + kind=StorageType.Matrix, meta_md5=meta_md5, element=[element], # convert to list data=data[:, :, np.newaxis], # convert to 3D diff --git a/junifer/storage/sqlite.py b/junifer/storage/sqlite.py index 63a8ced91..e73dd565a 100644 --- a/junifer/storage/sqlite.py +++ b/junifer/storage/sqlite.py @@ -19,7 +19,12 @@ from ..api.decorators import register_storage from ..utils import logger, raise_error, warn_with_log from .pandas_base import PandasBaseFeatureStorage -from .utils import element_to_prefix, matrix_to_vector, store_matrix_checks +from .utils import ( + MatrixKind, + element_to_prefix, + matrix_to_vector, + store_matrix_checks, +) if TYPE_CHECKING: @@ -427,7 +432,7 @@ def store_matrix( data: np.ndarray, col_names: Optional[Sequence[str]] = None, row_names: Optional[Sequence[str]] = None, - matrix_kind: str = "full", + matrix_kind: MatrixKind = MatrixKind.Full, diagonal: bool = True, ) -> None: """Store matrix. @@ -446,17 +451,11 @@ def store_matrix( The column labels (default None). row_names : list-like of str, optional The row labels (optional None). - matrix_kind : str, optional - The kind of matrix: - - * ``triu`` : store upper triangular only - * ``tril`` : store lower triangular - * ``full`` : full matrix - - (default "full"). + matrix_kind : MatrixKind, optional + The matrix kind (default MatrixKind.Full). diagonal : bool, optional - Whether to store the diagonal. If ``matrix_kind = full``, setting - this to False will raise an error (default True). + Whether to store the diagonal. If ``matrix_kind=MatrixKind.Full``, + setting this to False will raise an error (default True). """ # Row data validation diff --git a/junifer/storage/utils.py b/junifer/storage/utils.py index 0ab8be595..84431d33a 100644 --- a/junifer/storage/utils.py +++ b/junifer/storage/utils.py @@ -8,12 +8,17 @@ import json from collections.abc import Sequence from importlib.metadata import PackageNotFoundError, version +from typing import TYPE_CHECKING import numpy as np from ..utils.logging import logger, raise_error +if TYPE_CHECKING: + from .base import MatrixKind + + __all__ = [ "element_to_prefix", "get_dependency_version", @@ -174,15 +179,10 @@ def store_matrix_checks( Parameters ---------- - matrix_kind : {"triu", "tril", "full"} - The kind of matrix: - - * ``triu`` : store upper triangular only - * ``tril`` : store lower triangular - * ``full`` : full matrix - + matrix_kind : MatrixKind, optional + The matrix kind. diagonal : bool - Whether to store the diagonal. If ``matrix_kind`` is "full", + Whether to store the diagonal. If ``matrix_kind=MatrixKind.Full``, setting this to False will raise an error. data_shape : tuple of int and int The shape of the matrix data to store. @@ -280,7 +280,7 @@ def matrix_to_vector( data: np.ndarray, col_names: Sequence[str], row_names: Sequence[str], - matrix_kind: str, + matrix_kind: "MatrixKind", diagonal: bool, ) -> tuple[np.ndarray, list[str]]: """Convert matrix to vector based on parameters. @@ -293,13 +293,8 @@ def matrix_to_vector( The column labels. row_names : list-like of str The row labels. - matrix_kind : str - The kind of matrix: - - * ``triu`` : store upper triangular only - * ``tril`` : store lower triangular - * ``full`` : full matrix - + matrix_kind : MatrixKind + The matrix kind. diagonal : bool Whether to store the diagonal. From bc8c1327562c5e74efc5b547e4442167ada84014 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 29 Oct 2025 11:40:48 +0100 Subject: [PATCH 13/99] update: introduce Upsert enum --- junifer/storage/__init__.pyi | 3 ++- junifer/storage/sqlite.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/junifer/storage/__init__.pyi b/junifer/storage/__init__.pyi index 06ae79c91..b68a6db81 100644 --- a/junifer/storage/__init__.pyi +++ b/junifer/storage/__init__.pyi @@ -5,9 +5,10 @@ __all__ = [ "HDF5FeatureStorage", "StorageType", "MatrixKind", + "Upsert", ] from .pandas_base import PandasBaseFeatureStorage -from .sqlite import SQLiteFeatureStorage from .base import BaseFeatureStorage, MatrixKind, StorageType from .hdf5 import HDF5FeatureStorage +from .sqlite import SQLiteFeatureStorage, Upsert diff --git a/junifer/storage/sqlite.py b/junifer/storage/sqlite.py index e73dd565a..6c8e2d17b 100644 --- a/junifer/storage/sqlite.py +++ b/junifer/storage/sqlite.py @@ -7,6 +7,7 @@ import json from collections.abc import Sequence from pathlib import Path +from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np @@ -31,7 +32,14 @@ from sqlalchemy.engine import Engine -__all__ = ["SQLiteFeatureStorage"] +__all__ = ["SQLiteFeatureStorage", "Upsert"] + + +class Upsert(str, Enum): + """Accepted upsert value.""" + + Update = "update" + Ignore = "ignore" @register_storage From 715599e884bd3e882df51e66d5ef76c0324469cd Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 29 Oct 2025 11:42:50 +0100 Subject: [PATCH 14/99] update: adapt Upsert enum --- junifer/storage/sqlite.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/junifer/storage/sqlite.py b/junifer/storage/sqlite.py index 6c8e2d17b..edf1204d4 100644 --- a/junifer/storage/sqlite.py +++ b/junifer/storage/sqlite.py @@ -57,9 +57,10 @@ class SQLiteFeatureStorage(PandasBaseFeatureStorage): ``uri`` and store all the elements in the same file. This behaviour is only suitable for non-parallel executions. SQLite does not support concurrency (default True). - upsert : {"ignore", "update"}, optional - Upsert mode. If "ignore" is used, the existing elements are ignored. - If "update", the existing elements are updated (default "update"). + upsert : Upsert, optional + Upsert mode. If ``Upsert.Ignore`` is used, the existing elements are + ignored. If ``Upsert.Update``, the existing elements are updated + (default Upsert.Update). See Also -------- @@ -72,7 +73,7 @@ def __init__( self, uri: Union[str, Path], single_output: bool = True, - upsert: str = "update", + upsert: Upsert = Upsert.Update, ) -> None: # Check and set upsert argument value if upsert not in ["update", "ignore"]: From 5fdc283a5c79b2754bf600f849bbb3f73f5f251d Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 29 Oct 2025 11:43:33 +0100 Subject: [PATCH 15/99] chore: reorder storage imports --- junifer/storage/__init__.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/junifer/storage/__init__.pyi b/junifer/storage/__init__.pyi index b68a6db81..20e8f0524 100644 --- a/junifer/storage/__init__.pyi +++ b/junifer/storage/__init__.pyi @@ -1,14 +1,14 @@ __all__ = [ "BaseFeatureStorage", + "MatrixKind", + "StorageType", + "HDF5FeatureStorage", "PandasBaseFeatureStorage", "SQLiteFeatureStorage", - "HDF5FeatureStorage", - "StorageType", - "MatrixKind", "Upsert", ] -from .pandas_base import PandasBaseFeatureStorage from .base import BaseFeatureStorage, MatrixKind, StorageType from .hdf5 import HDF5FeatureStorage +from .pandas_base import PandasBaseFeatureStorage from .sqlite import SQLiteFeatureStorage, Upsert From ceaa06f77d8f4a41f9eef24b48e5fd1e400683c2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 10 Nov 2025 18:31:00 +0100 Subject: [PATCH 16/99] update: adapt necessary enums in storage --- junifer/storage/base.py | 10 +++++----- junifer/storage/hdf5.py | 34 +++++++++++++++++----------------- junifer/storage/pandas_base.py | 10 +++++----- junifer/storage/sqlite.py | 12 ++++++------ junifer/storage/utils.py | 6 +++--- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/junifer/storage/base.py b/junifer/storage/base.py index 4a687612c..64398e143 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -58,7 +58,7 @@ class BaseFeatureStorage(ABC): """ - _STORAGE_TYPES: ClassVar[Sequence[str]] + _STORAGE_TYPES: ClassVar[Sequence[StorageType]] def __init__( self, @@ -208,12 +208,12 @@ def store_metadata(self, meta_md5: str, element: dict, meta: dict) -> None: klass=NotImplementedError, ) # pragma: no cover - def store(self, kind: str, **kwargs) -> None: + def store(self, kind: StorageType, **kwargs) -> None: """Store extracted features data. Parameters ---------- - kind : {"matrix", "timeseries", "vector", "scalar_table"} + kind : :enum:`.StorageType` The storage kind. **kwargs The keyword arguments. @@ -277,8 +277,8 @@ def store_matrix( The column labels (default None). row_names : list-like of str, optional The row labels (default None). - matrix_kind : MatrixKind, optional - The matrix kind (default MatrixKind.Full). + matrix_kind : :enum:`.MatrixKind`, optional + The matrix kind (default ``MatrixKind.Full``). diagonal : bool, optional Whether to store the diagonal. If ``matrix_kind=MatrixKind.Full``, setting this to False will raise an error (default True). diff --git a/junifer/storage/hdf5.py b/junifer/storage/hdf5.py index 936beea97..fa44468f1 100644 --- a/junifer/storage/hdf5.py +++ b/junifer/storage/hdf5.py @@ -22,7 +22,7 @@ write_hdf5, ) from ..utils import logger, raise_error -from .base import BaseFeatureStorage, MatrixKind +from .base import BaseFeatureStorage, MatrixKind, StorageType from .utils import ( element_to_prefix, matrix_to_vector, @@ -37,7 +37,7 @@ def _create_chunk( chunk_data: list[np.ndarray], - kind: str, + kind: StorageType, element_count: int, chunk_size: int, i_chunk: int, @@ -48,7 +48,7 @@ def _create_chunk( ---------- chunk_data : list of numpy.ndarray The data to be chunked. - kind : str + kind : :enum:`.StorageType` The kind of data to be chunked. element_count : int The total number of elements. @@ -135,12 +135,12 @@ class HDF5FeatureStorage(BaseFeatureStorage): """ - _STORAGE_TYPES: ClassVar[Sequence[str]] = [ - "vector", - "timeseries", - "matrix", - "scalar_table", - "timeseries_2d", + _STORAGE_TYPES: ClassVar[Sequence[StorageType]] = [ + StorageType.Vector, + StorageType.Timeseries, + StorageType.Matrix, + StorageType.ScalarTable, + StorageType.Timeseries2D, ] def __init__( @@ -670,7 +670,7 @@ def store_metadata( def _store_data( self, - kind: str, + kind: StorageType, meta_md5: str, element: list[dict[str, str]], data: np.ndarray, @@ -685,7 +685,7 @@ def _store_data( Parameters ---------- - kind : {"matrix", "vector", "timeseries", "scalar_table"} + kind : :enum:`.StorageType` The storage kind. meta_md5 : str The metadata MD5 hash. @@ -839,8 +839,8 @@ def store_matrix( The column labels (default None). row_names : list-like of str, optional The row labels (default None). - matrix_kind : MatrixKind, optional - The matrix kind (default MatrixKind.Full). + matrix_kind : :enum:`.MatrixKind`, optional + The matrix kind (default ``MatrixKind.Full``). diagonal : bool, optional Whether to store the diagonal. If ``matrix_kind=MatrixKind.Full``, setting this to False will raise an error (default True). @@ -920,7 +920,7 @@ def store_vector( processed_data = data.ravel() self._store_data( - kind="vector", + kind=StorageType.Vector, meta_md5=meta_md5, element=[element], # convert to list data=processed_data[:, np.newaxis], # convert to 2D @@ -949,7 +949,7 @@ def store_timeseries( """ self._store_data( - kind="timeseries", + kind=StorageType.Timeseries, meta_md5=meta_md5, element=[element], # convert to list data=[data], # convert to list @@ -987,7 +987,7 @@ def store_timeseries_2d( col_names_len=len(col_names) if col_names is not None else 0, ) self._store_data( - kind="timeseries_2d", + kind=StorageType.Timeseries2D, meta_md5=meta_md5, element=[element], # convert to list data=[data], # convert to list @@ -1023,7 +1023,7 @@ def store_scalar_table( """ self._store_data( - kind="scalar_table", + kind=StorageType.ScalarTable, meta_md5=meta_md5, element=[element], # convert to list data=[data], # convert to list diff --git a/junifer/storage/pandas_base.py b/junifer/storage/pandas_base.py index 3a37952a7..69d2bad5f 100644 --- a/junifer/storage/pandas_base.py +++ b/junifer/storage/pandas_base.py @@ -13,7 +13,7 @@ import pandas as pd from ..utils import raise_error -from .base import BaseFeatureStorage +from .base import BaseFeatureStorage, StorageType __all__ = ["PandasBaseFeatureStorage"] @@ -38,10 +38,10 @@ class PandasBaseFeatureStorage(BaseFeatureStorage): """ - _STORAGE_TYPES: ClassVar[Sequence[str]] = [ - "vector", - "timeseries", - "matrix", + _STORAGE_TYPES: ClassVar[Sequence[StorageType]] = [ + StorageType.Vector, + StorageType.Timeseries, + StorageType.Matrix, ] def __init__( diff --git a/junifer/storage/sqlite.py b/junifer/storage/sqlite.py index edf1204d4..6c2b93084 100644 --- a/junifer/storage/sqlite.py +++ b/junifer/storage/sqlite.py @@ -19,9 +19,9 @@ from ..api.decorators import register_storage from ..utils import logger, raise_error, warn_with_log +from .base import MatrixKind from .pandas_base import PandasBaseFeatureStorage from .utils import ( - MatrixKind, element_to_prefix, matrix_to_vector, store_matrix_checks, @@ -57,10 +57,10 @@ class SQLiteFeatureStorage(PandasBaseFeatureStorage): ``uri`` and store all the elements in the same file. This behaviour is only suitable for non-parallel executions. SQLite does not support concurrency (default True). - upsert : Upsert, optional + upsert : :enum:`.Upsert`, optional Upsert mode. If ``Upsert.Ignore`` is used, the existing elements are ignored. If ``Upsert.Update``, the existing elements are updated - (default Upsert.Update). + (default ``Upsert.Update``). See Also -------- @@ -460,8 +460,8 @@ def store_matrix( The column labels (default None). row_names : list-like of str, optional The row labels (optional None). - matrix_kind : MatrixKind, optional - The matrix kind (default MatrixKind.Full). + matrix_kind : :enum:`.MatrixKind`, optional + The matrix kind (default ``MatrixKind.Full``). diagonal : bool, optional Whether to store the diagonal. If ``matrix_kind=MatrixKind.Full``, setting this to False will raise an error (default True). @@ -534,7 +534,7 @@ def collect(self) -> None: ) logger.info(f"Collecting data from {self.uri.parent}/*{self.uri.name}") # Create new instance - out_storage = SQLiteFeatureStorage(uri=self.uri, upsert="ignore") + out_storage = SQLiteFeatureStorage(uri=self.uri, upsert=Upsert.Ignore) # Glob files files = self.uri.parent.glob(f"*{self.uri.name}") for elem in tqdm(files, desc="file"): diff --git a/junifer/storage/utils.py b/junifer/storage/utils.py index 84431d33a..e8a093345 100644 --- a/junifer/storage/utils.py +++ b/junifer/storage/utils.py @@ -169,7 +169,7 @@ def element_to_prefix(element: dict) -> str: def store_matrix_checks( - matrix_kind: str, + matrix_kind: "MatrixKind", diagonal: bool, data_shape: tuple[int, int], row_names_len: int, @@ -179,7 +179,7 @@ def store_matrix_checks( Parameters ---------- - matrix_kind : MatrixKind, optional + matrix_kind : :enum:`.MatrixKind`, optional The matrix kind. diagonal : bool Whether to store the diagonal. If ``matrix_kind=MatrixKind.Full``, @@ -293,7 +293,7 @@ def matrix_to_vector( The column labels. row_names : list-like of str The row labels. - matrix_kind : MatrixKind + matrix_kind : :enum:`.MatrixKind` The matrix kind. diagonal : bool Whether to store the diagonal. From ee6d1c49faa2edfb430682b6a9e1f399622b8af4 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 10 Nov 2025 18:35:15 +0100 Subject: [PATCH 17/99] refactor: make storage interface and impls pydantic models --- junifer/storage/base.py | 21 ++++++++---------- junifer/storage/hdf5.py | 40 +++++++--------------------------- junifer/storage/pandas_base.py | 8 ------- junifer/storage/sqlite.py | 21 +----------------- junifer/storage/utils.py | 12 ++++------ 5 files changed, 22 insertions(+), 80 deletions(-) diff --git a/junifer/storage/base.py b/junifer/storage/base.py index 64398e143..e1336034c 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd +from pydantic import BaseModel, ConfigDict from ..utils import logger, raise_error from .utils import process_meta @@ -38,7 +39,7 @@ class MatrixKind(str, Enum): Full = "full" -class BaseFeatureStorage(ABC): +class BaseFeatureStorage(BaseModel, ABC): """Abstract base class for feature storage. For every storage, one needs to provide a concrete @@ -46,7 +47,7 @@ class BaseFeatureStorage(ABC): Parameters ---------- - uri : str or pathlib.Path + uri : pathlib.Path The path to the storage. single_output : bool, optional Whether to have single output (default True). @@ -60,21 +61,18 @@ class BaseFeatureStorage(ABC): _STORAGE_TYPES: ClassVar[Sequence[StorageType]] - def __init__( - self, - uri: Union[str, Path], - single_output: bool = True, - ) -> None: + model_config = ConfigDict(frozen=True, use_enum_values=True) + + uri: Path + single_output: bool = True + + def model_post_init(self, context: Any): # noqa: D102 # Check for missing storage types attribute if not hasattr(self, "_STORAGE_TYPES"): raise_error( msg="Missing `_STORAGE_TYPES` for the storage", klass=AttributeError, ) - # Convert str to Path - if not isinstance(uri, Path): - uri = Path(uri) - self.uri = uri # Create parent directories if not present if not self.uri.parent.exists(): logger.info( @@ -82,7 +80,6 @@ def __init__( "does not exist, creating now" ) self.uri.parent.mkdir(parents=True, exist_ok=True) - self.single_output = single_output def get_valid_inputs(self) -> list[str]: """Get valid storage types for input. diff --git a/junifer/storage/hdf5.py b/junifer/storage/hdf5.py index fa44468f1..f0dcd4657 100644 --- a/junifer/storage/hdf5.py +++ b/junifer/storage/hdf5.py @@ -7,10 +7,11 @@ from collections import defaultdict from collections.abc import Sequence from pathlib import Path -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union import numpy as np import pandas as pd +from pydantic import PositiveInt from tqdm import tqdm from ..api.decorators import register_storage @@ -62,12 +63,6 @@ def _create_chunk( ChunkedArray or ChunkedList The chunked array or list. - Raises - ------ - ValueError - If `kind` is not one of ['vector', 'matrix', 'timeseries', - 'scalar_table']. - """ if kind in ["vector", "matrix"]: features_data = np.concatenate(chunk_data, axis=-1) @@ -93,12 +88,6 @@ def _create_chunk( size=element_count, offset=i_chunk * chunk_size, ) - else: - raise_error( - f"Invalid kind: {kind}. " - "Must be one of ['vector', 'matrix', 'timeseries', " - "'timeseries_2d', 'scalar_table']." - ) return out @@ -108,7 +97,7 @@ class HDF5FeatureStorage(BaseFeatureStorage): Parameters ---------- - uri : str or pathlib.Path + uri : pathlib.Path The path to the file to be used. single_output : bool, optional If False, will create one HDF5 file per element. The name @@ -124,7 +113,7 @@ class HDF5FeatureStorage(BaseFeatureStorage): force_float32 : bool, optional Whether to force casting of numpy.ndarray values to float32 if float64 values are found (default True). - chunk_size : int, optional + chunk_size : positive int, optional The chunk size to use when collecting data from element files in :meth:`.collect`. If the file count is smaller than the value, the minimum is used (default 100). @@ -143,23 +132,10 @@ class HDF5FeatureStorage(BaseFeatureStorage): StorageType.Timeseries2D, ] - def __init__( - self, - uri: Union[str, Path], - single_output: bool = True, - overwrite: Union[bool, str] = "update", - compression: int = 7, - force_float32: bool = True, - chunk_size: int = 100, - ) -> None: - self.overwrite = overwrite - self.compression = compression - self.force_float32 = force_float32 - self.chunk_size = chunk_size - super().__init__( - uri=uri, - single_output=single_output, - ) + overwrite: Union[bool, str] = "update" + compression: Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] = 7 + force_float32: bool = True + chunk_size: PositiveInt = 100 def _fetch_correct_uri_for_io(self, element: Optional[dict]) -> str: """Return proper URI for I/O based on ``element``. diff --git a/junifer/storage/pandas_base.py b/junifer/storage/pandas_base.py index 69d2bad5f..9af075c46 100644 --- a/junifer/storage/pandas_base.py +++ b/junifer/storage/pandas_base.py @@ -6,7 +6,6 @@ import json from collections.abc import Sequence -from pathlib import Path from typing import ClassVar, Optional, Union import numpy as np @@ -44,13 +43,6 @@ class PandasBaseFeatureStorage(BaseFeatureStorage): StorageType.Matrix, ] - def __init__( - self, - uri: Union[str, Path], - single_output: bool = True, - ) -> None: - super().__init__(uri=uri, single_output=single_output) - def _meta_row(self, meta: dict, meta_md5: str) -> pd.DataFrame: """Convert the metadata to a pandas DataFrame. diff --git a/junifer/storage/sqlite.py b/junifer/storage/sqlite.py index 6c2b93084..1ccbacc73 100644 --- a/junifer/storage/sqlite.py +++ b/junifer/storage/sqlite.py @@ -6,7 +6,6 @@ import json from collections.abc import Sequence -from pathlib import Path from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union @@ -69,25 +68,7 @@ class SQLiteFeatureStorage(PandasBaseFeatureStorage): """ - def __init__( - self, - uri: Union[str, Path], - single_output: bool = True, - upsert: Upsert = Upsert.Update, - ) -> None: - # Check and set upsert argument value - if upsert not in ["update", "ignore"]: - raise_error( - msg=( - "Invalid choice for `upsert`. " - "Must be either 'update' or 'ignore'." - ) - ) - self.upsert = upsert - super().__init__( - uri=uri, - single_output=single_output, - ) + upsert: Upsert = Upsert.Update def get_engine(self, element: Optional[dict] = None) -> "Engine": """Get engine. diff --git a/junifer/storage/utils.py b/junifer/storage/utils.py index e8a093345..c069c5233 100644 --- a/junifer/storage/utils.py +++ b/junifer/storage/utils.py @@ -194,16 +194,12 @@ def store_matrix_checks( Raises ------ ValueError - If the matrix kind is invalid - If the diagonal is False and the matrix kind is "full" - If the matrix kind is "triu" or "tril" and the matrix is not square - If the number of row names does not match the number of rows - If the number of column names does not match the number of columns + If ``diagonal=False`` and ``matrix_kind="full"`` or + if ``matrix_kind`` is "triu" or "tril" and the matrix is not square or + if the number of row names does not match the number of rows + If the number of column names does not match the number of columns. """ - # Matrix kind validation - if matrix_kind not in ("triu", "tril", "full"): - raise_error(msg=f"Invalid kind {matrix_kind}", klass=ValueError) # Diagonal validation if diagonal is False and matrix_kind not in ["triu", "tril"]: raise_error( From 8c8cea48bc1f68651d40413893bf86f0940731e1 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 10 Nov 2025 18:39:07 +0100 Subject: [PATCH 18/99] chore: improve docstrings in storage --- junifer/storage/hdf5.py | 11 +---------- junifer/storage/sqlite.py | 2 -- junifer/storage/utils.py | 8 ++++---- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/junifer/storage/hdf5.py b/junifer/storage/hdf5.py index f0dcd4657..dd2314d8a 100644 --- a/junifer/storage/hdf5.py +++ b/junifer/storage/hdf5.py @@ -252,7 +252,7 @@ def _read_data( ---------- md5 : str The MD5 used as the HDF5 group name. - element : dict, optional + element : dict or None, optional The element as dictionary (default None). Returns @@ -823,15 +823,6 @@ def store_matrix( row_header_col_name : str, optional The column name for the row header column (default "ROI"). - Raises - ------ - ValueError - If invalid ``matrix_kind`` is provided, ``diagonal = False`` - for ``matrix_kind = "full"``, non-square data is provided - for ``matrix_kind = {"triu", "tril"}``, length of ``row_names`` - do not match data row count, or length of ``col_names`` do not - match data column count. - """ # Row data validation if row_names is None: diff --git a/junifer/storage/sqlite.py b/junifer/storage/sqlite.py index 1ccbacc73..291f596d0 100644 --- a/junifer/storage/sqlite.py +++ b/junifer/storage/sqlite.py @@ -435,8 +435,6 @@ def store_matrix( The element as a dictionary. data : numpy.ndarray The matrix data to store. - meta : dict - The metadata as a dictionary. col_names : list-like of str, optional The column labels (default None). row_names : list-like of str, optional diff --git a/junifer/storage/utils.py b/junifer/storage/utils.py index c069c5233..0ea795f61 100644 --- a/junifer/storage/utils.py +++ b/junifer/storage/utils.py @@ -111,7 +111,7 @@ def process_meta(meta: dict) -> tuple[str, dict, dict]: Raises ------ ValueError - If ``meta`` is None or if it does not contain the key "element". + If ``meta=None`` or if it does not contain the key "element". """ if meta is None: @@ -246,9 +246,9 @@ def store_timeseries_2d_checks( Raises ------ ValueError - If the data is not a 3D array (timepoints, rows, columns) - If the number of row names does not match the number of rows - If the number of column names does not match the number of columns + If the data is not a 3D array (timepoints, rows, columns) or + if the number of row names does not match the number of rows or + if the number of column names does not match the number of columns. """ # Data validation From 88deabd27eba3835049e2a079c64f95a65def14e Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 11 Nov 2025 13:00:55 +0100 Subject: [PATCH 19/99] update: introduce Confounds and Strategy types --- junifer/preprocess/__init__.pyi | 4 +- junifer/preprocess/confounds/__init__.pyi | 12 ++++- .../confounds/fmriprep_confound_remover.py | 49 ++++++++++++------- 3 files changed, 44 insertions(+), 21 deletions(-) diff --git a/junifer/preprocess/__init__.pyi b/junifer/preprocess/__init__.pyi index 29649c62d..d1ec22028 100644 --- a/junifer/preprocess/__init__.pyi +++ b/junifer/preprocess/__init__.pyi @@ -1,6 +1,8 @@ __all__ = [ "BasePreprocessor", "fMRIPrepConfoundRemover", + "Confounds", + "Strategy", "SpaceWarper", "Smoothing", "TemporalSlicer", @@ -8,8 +10,8 @@ __all__ = [ ] from .base import BasePreprocessor -from .confounds import fMRIPrepConfoundRemover from .warping import SpaceWarper from .smoothing import Smoothing +from .confounds import fMRIPrepConfoundRemover, Confounds, Strategy from ._temporal_slicer import TemporalSlicer from ._temporal_filter import TemporalFilter diff --git a/junifer/preprocess/confounds/__init__.pyi b/junifer/preprocess/confounds/__init__.pyi index 55d9d4ee2..0d948a5cb 100644 --- a/junifer/preprocess/confounds/__init__.pyi +++ b/junifer/preprocess/confounds/__init__.pyi @@ -1,3 +1,11 @@ -__all__ = ["fMRIPrepConfoundRemover"] +__all__ = [ + "fMRIPrepConfoundRemover", + "Confounds", + "Strategy", +] -from .fmriprep_confound_remover import fMRIPrepConfoundRemover +from .fmriprep_confound_remover import ( + fMRIPrepConfoundRemover, + Confounds, + Strategy, +) diff --git a/junifer/preprocess/confounds/fmriprep_confound_remover.py b/junifer/preprocess/confounds/fmriprep_confound_remover.py index 753c21a38..9645ef5eb 100644 --- a/junifer/preprocess/confounds/fmriprep_confound_remover.py +++ b/junifer/preprocess/confounds/fmriprep_confound_remover.py @@ -6,10 +6,12 @@ # License: AGPL from collections.abc import Sequence +from enum import Enum from typing import ( Any, ClassVar, Optional, + TypedDict, Union, ) @@ -29,7 +31,7 @@ from ..base import BasePreprocessor -__all__ = ["fMRIPrepConfoundRemover"] +__all__ = ["Confounds", "Strategy", "fMRIPrepConfoundRemover"] FMRIPREP_BASICS = { @@ -100,6 +102,31 @@ FMRIPREP_VALID_NAMES.append("framewise_displacement") +class Confounds(str, Enum): + """Accepted confounds. + + * ``Basic`` : only the confounding time series + * ``Power2`` : signal + quadratic term + * ``Derivatives`` : signal + derivatives + * ``Full`` : signal + deriv. + quadratic terms + power2 + + """ + + Basic = "basic" + Power2 = "power2" + Derivatives = "derivatives" + Full = "full" + + +class Strategy(TypedDict, total=False): + """Accepted confound removal strategy.""" + + motion: Confounds + wm_csf: Confounds + global_signal: Confounds + scrubbing: bool + + @register_preprocessor class fMRIPrepConfoundRemover(BasePreprocessor): """Class for confound removal using fMRIPrep confounds format. @@ -111,27 +138,13 @@ class fMRIPrepConfoundRemover(BasePreprocessor): Parameters ---------- - strategy : dict, optional + strategy : :class:`.Strategy` or None, optional The strategy to use for each component. If None, will use the *full* strategy for all components except ``"scrubbing"`` which will be set to False (default None). The keys of the dictionary should correspond to names of noise - components to include: - - * ``motion`` - * ``wm_csf`` - * ``global_signal`` - * ``scrubbing`` - - The values of dictionary should correspond to types of confounds - extracted from each signal: - - * ``basic`` : only the confounding time series - * ``power2`` : signal + quadratic term - * ``derivatives`` : signal + derivatives - * ``full`` : signal + deriv. + quadratic terms + power2 deriv. - - except ``scrubbing`` which needs to be bool. + components (Strategy) to include and the values should correspond to + types of confounds (Confounds) extracted from each signal. spike : float, optional If None, no spike regressor is added. If spike is a float, it will add a spike regressor for every point at which framewise displacement From ce7ef3a6083d657a942270b8d10762ddb1495f79 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 11 Nov 2025 13:04:30 +0100 Subject: [PATCH 20/99] update: introduce and adapt SmoothingImpl enum --- junifer/preprocess/__init__.pyi | 3 +- junifer/preprocess/smoothing/__init__.pyi | 4 +-- junifer/preprocess/smoothing/smoothing.py | 39 +++++++++++++---------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/junifer/preprocess/__init__.pyi b/junifer/preprocess/__init__.pyi index d1ec22028..dfddc43d4 100644 --- a/junifer/preprocess/__init__.pyi +++ b/junifer/preprocess/__init__.pyi @@ -5,13 +5,14 @@ __all__ = [ "Strategy", "SpaceWarper", "Smoothing", + "SmoothingImpl", "TemporalSlicer", "TemporalFilter", ] from .base import BasePreprocessor from .warping import SpaceWarper -from .smoothing import Smoothing from .confounds import fMRIPrepConfoundRemover, Confounds, Strategy +from .smoothing import Smoothing, SmoothingImpl from ._temporal_slicer import TemporalSlicer from ._temporal_filter import TemporalFilter diff --git a/junifer/preprocess/smoothing/__init__.pyi b/junifer/preprocess/smoothing/__init__.pyi index 3c8764812..2a41ba69f 100644 --- a/junifer/preprocess/smoothing/__init__.pyi +++ b/junifer/preprocess/smoothing/__init__.pyi @@ -1,3 +1,3 @@ -__all__ = ["Smoothing"] +__all__ = ["Smoothing", "SmoothingImpl"] -from .smoothing import Smoothing +from .smoothing import Smoothing, SmoothingImpl diff --git a/junifer/preprocess/smoothing/smoothing.py b/junifer/preprocess/smoothing/smoothing.py index 99160b5c2..4efd5032f 100644 --- a/junifer/preprocess/smoothing/smoothing.py +++ b/junifer/preprocess/smoothing/smoothing.py @@ -5,6 +5,7 @@ from collections.abc import Sequence from typing import Any, ClassVar, Optional, Union +from enum import Enum from ...api.decorators import register_preprocessor from ...typing import ConditionalDependencies @@ -15,7 +16,21 @@ from ._nilearn_smoothing import NilearnSmoothing -__all__ = ["Smoothing"] +__all__ = ["Smoothing", "SmoothingImpl"] + + +class SmoothingImpl(str, Enum): + """Accepted smoothing implementations. + + * ``nilearn`` : :func:`nilearn.image.smooth_img` + * ``afni`` : AFNI's ``3dBlurToFWHM`` + * ``fsl`` : FSL SUSAN's ``susan`` + + """ + + nilearn = "nilearn" + afni = "afni" + fsl = "fsl" @register_preprocessor @@ -24,18 +39,10 @@ class Smoothing(BasePreprocessor): Parameters ---------- - using : {"nilearn", "afni", "fsl"} - Implementation to use for smoothing: - - * "nilearn" : Use :func:`nilearn.image.smooth_img` - * "afni" : Use AFNI's ``3dBlurToFWHM`` - * "fsl" : Use FSL SUSAN's ``susan`` - - on : {"T1w", "T2w", "BOLD"} or list of the options - The data type to apply smoothing to. + using : :enum:`.SmoothingImpl` smoothing_params : dict, optional Extra parameters for smoothing as a dictionary (default None). - If ``using="nilearn"``, then the valid keys are: + If ``using=SmoothingImpl.nilearn``, then the valid keys are: * ``fmhw`` : scalar, ``numpy.ndarray``, tuple or list of scalar, \ "fast" or None @@ -52,13 +59,13 @@ class Smoothing(BasePreprocessor): - If None, no filtering is performed (useful when just removal of non-finite values is needed). - else if ``using="afni"``, then the valid keys are: + else if ``using=SmoothingImpl.afni``, then the valid keys are: * ``fwhm`` : int or float Smooth until the value. AFNI estimates the smoothing and then applies smoothing to reach ``fwhm``. - else if ``using="fsl"``, then the valid keys are: + else if ``using=SmoothingImpl.fsl``, then the valid keys are: * ``brightness_threshold`` : float Threshold to discriminate between noise and the underlying image. @@ -71,16 +78,16 @@ class Smoothing(BasePreprocessor): _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { - "using": "nilearn", "depends_on": NilearnSmoothing, + "using": SmoothingImpl.nilearn, }, { - "using": "afni", "depends_on": AFNISmoothing, + "using": SmoothingImpl.afni, }, { - "using": "fsl", "depends_on": FSLSmoothing, + "using": SmoothingImpl.fsl, }, ] _VALID_DATA_TYPES: ClassVar[Sequence[str]] = ["T1w", "T2w", "BOLD"] From dc07c41b39da8f5f5554f6e4e1475671614a770a Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 11 Nov 2025 13:10:17 +0100 Subject: [PATCH 21/99] update: introduce and adapt SpaceWarpingImpl enum --- junifer/preprocess/__init__.pyi | 3 ++- junifer/preprocess/warping/__init__.pyi | 4 +-- junifer/preprocess/warping/space_warper.py | 31 ++++++++++++++-------- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/junifer/preprocess/__init__.pyi b/junifer/preprocess/__init__.pyi index dfddc43d4..821364f4a 100644 --- a/junifer/preprocess/__init__.pyi +++ b/junifer/preprocess/__init__.pyi @@ -4,6 +4,7 @@ __all__ = [ "Confounds", "Strategy", "SpaceWarper", + "SpaceWarpingImpl", "Smoothing", "SmoothingImpl", "TemporalSlicer", @@ -11,8 +12,8 @@ __all__ = [ ] from .base import BasePreprocessor -from .warping import SpaceWarper from .confounds import fMRIPrepConfoundRemover, Confounds, Strategy +from .warping import SpaceWarper, SpaceWarpingImpl from .smoothing import Smoothing, SmoothingImpl from ._temporal_slicer import TemporalSlicer from ._temporal_filter import TemporalFilter diff --git a/junifer/preprocess/warping/__init__.pyi b/junifer/preprocess/warping/__init__.pyi index b9b2db117..ef9484480 100644 --- a/junifer/preprocess/warping/__init__.pyi +++ b/junifer/preprocess/warping/__init__.pyi @@ -1,3 +1,3 @@ -__all__ = ["SpaceWarper"] +__all__ = ["SpaceWarper", "SpaceWarpingImpl"] -from .space_warper import SpaceWarper +from .space_warper import SpaceWarper, SpaceWarpingImpl diff --git a/junifer/preprocess/warping/space_warper.py b/junifer/preprocess/warping/space_warper.py index fd7ebc848..b8f9f6f2a 100644 --- a/junifer/preprocess/warping/space_warper.py +++ b/junifer/preprocess/warping/space_warper.py @@ -5,6 +5,7 @@ from collections.abc import Sequence from typing import Any, ClassVar, Optional, Union +from enum import Enum from templateflow import api as tflow @@ -16,7 +17,21 @@ from ._fsl_warper import FSLWarper -__all__ = ["SpaceWarper"] +__all__ = ["SpaceWarper", "SpaceWarpingImpl"] + + +class SpaceWarpingImpl(str, Enum): + """Accepted space warping implementations. + + * ``fsl`` : FSL's ``applywarp`` + * ``ants`` : ANTs' ``antsApplyTransforms`` + * ``auto`` : Auto-select tool when ``reference="T1w"`` + + """ + + fsl = "fsl" + ants = "ants" + auto = "auto" @register_preprocessor @@ -25,13 +40,7 @@ class SpaceWarper(BasePreprocessor): Parameters ---------- - using : {"fsl", "ants", "auto"} - Implementation to use for warping: - - * "fsl" : Use FSL's ``applywarp`` - * "ants" : Use ANTs' ``antsApplyTransforms`` - * "auto" : Auto-select tool when ``reference="T1w"`` - + using : :enum:`.SpaceWarpingImpl` reference : str The data type to use as reference for warping, can be either a data type like ``"T1w"`` or a template space like ``"MNI152NLin2009cAsym"``. @@ -51,15 +60,15 @@ class SpaceWarper(BasePreprocessor): _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { - "using": "fsl", "depends_on": FSLWarper, + "using": SpaceWarpingImpl.fsl, }, { - "using": "ants", "depends_on": ANTsWarper, + "using": SpaceWarpingImpl.ants, }, { - "using": "auto", + "using": SpaceWarpingImpl.auto, "depends_on": [FSLWarper, ANTsWarper], }, ] From 78a0856c4401ca3efa63d8b33f56bfea49d599ef Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 11 Nov 2025 13:16:40 +0100 Subject: [PATCH 22/99] update: adapt necessary enums in preprocess --- junifer/preprocess/_temporal_filter.py | 3 +- junifer/preprocess/_temporal_slicer.py | 3 +- junifer/preprocess/base.py | 17 ++++++----- .../confounds/fmriprep_confound_remover.py | 3 +- junifer/preprocess/smoothing/smoothing.py | 9 +++++- junifer/preprocess/warping/space_warper.py | 28 ++++++++++--------- 6 files changed, 39 insertions(+), 24 deletions(-) diff --git a/junifer/preprocess/_temporal_filter.py b/junifer/preprocess/_temporal_filter.py index 17270c682..ba097cc39 100644 --- a/junifer/preprocess/_temporal_filter.py +++ b/junifer/preprocess/_temporal_filter.py @@ -19,6 +19,7 @@ from ..api.decorators import register_preprocessor from ..data import get_data +from ..datagrabber import DataType from ..pipeline import WorkDirManager from ..typing import Dependencies from ..utils import logger @@ -57,7 +58,6 @@ class TemporalFilter(BasePreprocessor): """ _DEPENDENCIES: ClassVar[Dependencies] = {"numpy", "nilearn"} - _VALID_DATA_TYPES: ClassVar[Sequence[str]] = ["BOLD"] def __init__( self, @@ -77,6 +77,7 @@ def __init__( self.masks = masks super().__init__() + _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [DataType.BOLD] def _validate_data( self, diff --git a/junifer/preprocess/_temporal_slicer.py b/junifer/preprocess/_temporal_slicer.py index 90dfcc922..f6cd11944 100644 --- a/junifer/preprocess/_temporal_slicer.py +++ b/junifer/preprocess/_temporal_slicer.py @@ -10,6 +10,7 @@ import nilearn.image as nimg from ..api.decorators import register_preprocessor +from ..datagrabber import DataType from ..pipeline import WorkDirManager from ..typing import Dependencies from ..utils import logger, raise_error @@ -46,7 +47,7 @@ class TemporalSlicer(BasePreprocessor): """ _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn"} - _VALID_DATA_TYPES: ClassVar[Sequence[str]] = ["BOLD"] + _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [DataType.BOLD] def __init__( self, diff --git a/junifer/preprocess/base.py b/junifer/preprocess/base.py index 4cc51e5fb..a45a12e9a 100644 --- a/junifer/preprocess/base.py +++ b/junifer/preprocess/base.py @@ -8,6 +8,7 @@ from collections.abc import Sequence from typing import Any, ClassVar, Optional, Union +from ..datagrabber import DataType from ..pipeline import PipelineStepMixin, UpdateMetaMixin from ..utils import logger, raise_error @@ -23,12 +24,14 @@ class BasePreprocessor(ABC, PipelineStepMixin, UpdateMetaMixin): Parameters ---------- - on : str or list of str or None, optional - The data type(s) to apply the preprocessor on. If None, - will work on all available data types (default None). - required_data_types : str or list of str, optional - The data types needed for computation. If None, - will be equal to ``on`` (default None). + on : list of :enum:`.DataType` or None, optional + The data type(s) to apply the preprocessor on. + If None, will work on all available data types. + Check :enum:`.DataType` for valid values (default None). + required_data_types : list of :enum:`.DataType` or None, optional + The data type(s) needed for computation. + If None, will be equal to ``on``. + Check :enum:`.DataType` for valid values (default None). Raises ------ @@ -39,7 +42,7 @@ class BasePreprocessor(ABC, PipelineStepMixin, UpdateMetaMixin): """ - _VALID_DATA_TYPES: ClassVar[Sequence[str]] + _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] def __init__( self, diff --git a/junifer/preprocess/confounds/fmriprep_confound_remover.py b/junifer/preprocess/confounds/fmriprep_confound_remover.py index 9645ef5eb..3fb19e26a 100644 --- a/junifer/preprocess/confounds/fmriprep_confound_remover.py +++ b/junifer/preprocess/confounds/fmriprep_confound_remover.py @@ -25,6 +25,7 @@ from ...api.decorators import register_preprocessor from ...data import get_data +from ...datagrabber import DataType from ...pipeline import WorkDirManager from ...typing import Dependencies from ...utils import logger, raise_error @@ -189,7 +190,6 @@ class fMRIPrepConfoundRemover(BasePreprocessor): """ _DEPENDENCIES: ClassVar[Dependencies] = {"numpy", "nilearn"} - _VALID_DATA_TYPES: ClassVar[Sequence[str]] = ["BOLD"] def __init__( self, @@ -211,6 +211,7 @@ def __init__( "motion": "full", "wm_csf": "full", "global_signal": "full", + _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [DataType.BOLD] "scrubbing": False, } self.strategy = strategy diff --git a/junifer/preprocess/smoothing/smoothing.py b/junifer/preprocess/smoothing/smoothing.py index 4efd5032f..93f55c051 100644 --- a/junifer/preprocess/smoothing/smoothing.py +++ b/junifer/preprocess/smoothing/smoothing.py @@ -8,6 +8,7 @@ from enum import Enum from ...api.decorators import register_preprocessor +from ...datagrabber import DataType from ...typing import ConditionalDependencies from ...utils import logger, raise_error from ..base import BasePreprocessor @@ -40,6 +41,8 @@ class Smoothing(BasePreprocessor): Parameters ---------- using : :enum:`.SmoothingImpl` + on : list of {``DataType.T1w``, ``DataType.T2w``, ``DataType.BOLD``} + The data type(s) to apply smoothing to. smoothing_params : dict, optional Extra parameters for smoothing as a dictionary (default None). If ``using=SmoothingImpl.nilearn``, then the valid keys are: @@ -90,7 +93,11 @@ class Smoothing(BasePreprocessor): "using": SmoothingImpl.fsl, }, ] - _VALID_DATA_TYPES: ClassVar[Sequence[str]] = ["T1w", "T2w", "BOLD"] + _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + ] def __init__( self, diff --git a/junifer/preprocess/warping/space_warper.py b/junifer/preprocess/warping/space_warper.py index b8f9f6f2a..6ff416e98 100644 --- a/junifer/preprocess/warping/space_warper.py +++ b/junifer/preprocess/warping/space_warper.py @@ -10,6 +10,7 @@ from templateflow import api as tflow from ...api.decorators import register_preprocessor +from ...datagrabber import DataType from ...typing import ConditionalDependencies from ...utils import logger, raise_error from ..base import BasePreprocessor @@ -46,15 +47,16 @@ class SpaceWarper(BasePreprocessor): type like ``"T1w"`` or a template space like ``"MNI152NLin2009cAsym"``. Use ``"T1w"`` for native space warping and named templates for template space warping. - on : {"T1w", "T2w", "BOLD", "VBM_GM", "VBM_WM", "VBM_CSF", "fALFF", \ - "GCOR", "LCOR"} or list of the options - The data type to warp. Raises ------ ValueError If ``using`` is invalid or if ``reference`` is invalid. + on : list of {``DataType.T1w``, ``DataType.T2w``, ``DataType.BOLD``, \ + ``DataType.VBM_GM``, ``DataType.VBM_WM``, ``DataType.VBM_CSF``, \ + ``DataType.FALFF``, ``DataType.GCOR``, ``DataType.LCOR``} + The data type(s) to warp. """ @@ -72,16 +74,16 @@ class SpaceWarper(BasePreprocessor): "depends_on": [FSLWarper, ANTsWarper], }, ] - _VALID_DATA_TYPES: ClassVar[Sequence[str]] = [ - "T1w", - "T2w", - "BOLD", - "VBM_GM", - "VBM_WM", - "VBM_CSF", - "fALFF", - "GCOR", - "LCOR", + _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.VBM_CSF, + DataType.FALFF, + DataType.GCOR, + DataType.LCOR, ] def __init__( From a833c675dbf3028eba8278aa8c4538195a025f90 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 11 Nov 2025 13:21:20 +0100 Subject: [PATCH 23/99] refactor: make preprocess interface and impls pydantic models --- junifer/preprocess/_temporal_filter.py | 28 ++--- junifer/preprocess/_temporal_slicer.py | 30 ++--- junifer/preprocess/base.py | 105 +++++++++++------- .../confounds/fmriprep_confound_remover.py | 97 ++++------------ junifer/preprocess/smoothing/smoothing.py | 33 ++---- junifer/preprocess/warping/space_warper.py | 63 +++++------ 6 files changed, 136 insertions(+), 220 deletions(-) diff --git a/junifer/preprocess/_temporal_filter.py b/junifer/preprocess/_temporal_filter.py index ba097cc39..41c66ed59 100644 --- a/junifer/preprocess/_temporal_filter.py +++ b/junifer/preprocess/_temporal_filter.py @@ -50,7 +50,7 @@ class TemporalFilter(BasePreprocessor): t_r : float, optional Repetition time, in second (sampling period). If None, it will use t_r from nifti header (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). @@ -58,27 +58,15 @@ class TemporalFilter(BasePreprocessor): """ _DEPENDENCIES: ClassVar[Dependencies] = {"numpy", "nilearn"} - - def __init__( - self, - detrend: bool = True, - standardize: bool = True, - low_pass: Optional[float] = None, - high_pass: Optional[float] = None, - t_r: Optional[float] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - ) -> None: - """Initialize the class.""" - self.detrend = detrend - self.standardize = standardize - self.low_pass = low_pass - self.high_pass = high_pass - self.t_r = t_r - self.masks = masks - - super().__init__() _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [DataType.BOLD] + detrend: bool = True + standardize: bool = True + low_pass: Optional[float] = None + high_pass: Optional[float] = None + t_r: Optional[float] = None + masks: Optional[list[Union[dict, str]]] = None + def _validate_data( self, input: dict[str, Any], diff --git a/junifer/preprocess/_temporal_slicer.py b/junifer/preprocess/_temporal_slicer.py index f6cd11944..70cd84c85 100644 --- a/junifer/preprocess/_temporal_slicer.py +++ b/junifer/preprocess/_temporal_slicer.py @@ -4,10 +4,11 @@ # License: AGPL from collections.abc import Sequence -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar, Literal, Optional, Union import nibabel as nib import nilearn.image as nimg +from pydantic import PositiveFloat from ..api.decorators import register_preprocessor from ..datagrabber import DataType @@ -26,7 +27,7 @@ class TemporalSlicer(BasePreprocessor): Parameters ---------- - start : float + start : ``zero`` or positive float Starting time point, in second. stop : float or None Ending time point, in second. If None, stops at the last time point. @@ -39,32 +40,15 @@ class TemporalSlicer(BasePreprocessor): Repetition time, in second (sampling period). If None, it will use t_r from nifti header (default None). - Raises - ------ - ValueError - If ``start`` is negative. - """ _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn"} _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [DataType.BOLD] - def __init__( - self, - start: float, - stop: Optional[float], - duration: Optional[float] = None, - t_r: Optional[float] = None, - ) -> None: - """Initialize the class.""" - if start < 0: - raise_error("`start` cannot be negative") - else: - self.start = start - self.stop = stop - self.duration = duration - self.t_r = t_r - super().__init__() + start: Union[Literal[0], PositiveFloat] + stop: Optional[float] + duration: Optional[float] = None + t_r: Optional[float] = None def preprocess( self, diff --git a/junifer/preprocess/base.py b/junifer/preprocess/base.py index a45a12e9a..59fcde36e 100644 --- a/junifer/preprocess/base.py +++ b/junifer/preprocess/base.py @@ -6,7 +6,9 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Optional + +from pydantic import BaseModel, ConfigDict from ..datagrabber import DataType from ..pipeline import PipelineStepMixin, UpdateMetaMixin @@ -16,7 +18,7 @@ __all__ = ["BasePreprocessor"] -class BasePreprocessor(ABC, PipelineStepMixin, UpdateMetaMixin): +class BasePreprocessor(BaseModel, ABC, PipelineStepMixin, UpdateMetaMixin): """Abstract base class for preprocessor. For every preprocessor, one needs to provide a concrete @@ -33,6 +35,10 @@ class BasePreprocessor(ABC, PipelineStepMixin, UpdateMetaMixin): If None, will be equal to ``on``. Check :enum:`.DataType` for valid values (default None). + Attributes + ---------- + valid_inputs + Raises ------ AttributeError @@ -44,38 +50,62 @@ class BasePreprocessor(ABC, PipelineStepMixin, UpdateMetaMixin): _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] - def __init__( - self, - on: Optional[Union[list[str], str]] = None, - required_data_types: Optional[Union[list[str], str]] = None, - ) -> None: - """Initialize the class.""" + model_config = ConfigDict(use_enum_values=True) + + on: Optional[list[DataType]] = None + required_data_types: Optional[list[DataType]] = None + + def model_post_init(self, context: Any): # noqa: D102 # Check for missing data types attributes if not hasattr(self, "_VALID_DATA_TYPES"): raise_error( msg="Missing `_VALID_DATA_TYPES` for the preprocessor", klass=AttributeError, ) + # Run extra validation for preprocessors and fail early if needed + self.validate_preprocessor_params() # Use all data types if not provided - if on is None: - on = self.get_valid_inputs() - # Convert data types to list - if not isinstance(on, list): - on = [on] - # Check if required inputs are found - if any(x not in self.get_valid_inputs() for x in on): - name = self.__class__.__name__ - wrong_on = [x for x in on if x not in self.get_valid_inputs()] - raise_error(f"{name} cannot be computed on {wrong_on}") - self._on = on + if self.on is None: + self.on = self.valid_inputs + else: + # Convert to correct data type + self.on = [DataType(t) for t in self.on] + # Check if required input data types are provided + if any(x not in self.valid_inputs for x in self.on): + wrong_on = [ + x.value for x in self.on if x not in self.valid_inputs + ] + raise_error( + f"{self.__class__.__name__} cannot be computed on " + f"{wrong_on}" + ) # Set required data types for validation - if required_data_types is None: - self._required_data_types = on + if self.required_data_types is None: + self.required_data_types = self.on else: - # Convert data types to list - if not isinstance(required_data_types, list): - required_data_types = [required_data_types] - self._required_data_types = required_data_types + # Convert to correct data type + self.required_data_types = [ + DataType(t) for t in self.required_data_types + ] + + @property + def valid_inputs(self) -> list[DataType]: + """Valid data types to operate on. + + Returns + ------- + list of DataType + The list of data types that can be used as input for this marker. + + """ + return [DataType(x) for x in self._VALID_DATA_TYPES] + + def validate_preprocessor_params(self) -> None: + """Run extra logical validation for preprocessor params. + + Subclasses can override to provide validation. + """ + pass def validate_input(self, input: list[str]) -> list[str]: """Validate input. @@ -98,25 +128,14 @@ def validate_input(self, input: list[str]) -> list[str]: If the input does not have the required data. """ - if any(x not in input for x in self._required_data_types): + if any(x not in input for x in self.required_data_types): raise_error( - "Input does not have the required data." - f"\t Input: {input}" - f"\t Required (all of): {self._required_data_types}" + "Input does not have the required data.\n" + f"\t Input: {input}\n" + f"\t Required (all of): " + f"{[t.value for t in self.required_data_types]}" ) - return [x for x in self._on if x in input] - - def get_valid_inputs(self) -> list[str]: - """Get valid data types for input. - - Returns - ------- - list of str - The list of data types that can be used as input for this - preprocessor. - - """ - return list(self._VALID_DATA_TYPES) + return [x for x in self.on if x in input] @abstractmethod def preprocess( @@ -167,7 +186,7 @@ def _fit_transform( # Copy input to not modify the original out = input.copy() # For each data type, run preprocessing - for type_ in self._on: + for type_ in self.on: # Check if data type is available if type_ in input.keys(): logger.info(f"Preprocessing {type_}") diff --git a/junifer/preprocess/confounds/fmriprep_confound_remover.py b/junifer/preprocess/confounds/fmriprep_confound_remover.py index 3fb19e26a..05da29e93 100644 --- a/junifer/preprocess/confounds/fmriprep_confound_remover.py +++ b/junifer/preprocess/confounds/fmriprep_confound_remover.py @@ -182,7 +182,7 @@ class fMRIPrepConfoundRemover(BasePreprocessor): t_r : float, optional Repetition time, in second (sampling period). If None, it will use t_r from nifti header (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). @@ -190,84 +190,29 @@ class fMRIPrepConfoundRemover(BasePreprocessor): """ _DEPENDENCIES: ClassVar[Dependencies] = {"numpy", "nilearn"} - - def __init__( - self, - strategy: Optional[dict[str, Union[str, bool]]] = None, - spike: Optional[float] = None, - scrub: Optional[int] = None, - fd_threshold: Optional[float] = None, - std_dvars_threshold: Optional[float] = None, - detrend: bool = True, - standardize: bool = True, - low_pass: Optional[float] = None, - high_pass: Optional[float] = None, - t_r: Optional[float] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - ) -> None: - """Initialize the class.""" - if strategy is None: - strategy = { - "motion": "full", - "wm_csf": "full", - "global_signal": "full", _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [DataType.BOLD] + + strategy: Optional[Strategy] = None + spike: Optional[float] = None + scrub: Optional[int] = None + fd_threshold: Optional[float] = None + std_dvars_threshold: Optional[float] = None + detrend: bool = True + standardize: bool = True + low_pass: Optional[float] = None + high_pass: Optional[float] = None + t_r: Optional[float] = None + masks: Optional[list[Union[dict, str]]] = None + + def validate_preprocessor_params(self) -> None: + """Run extra logical validation for preprocessor.""" + if self.strategy is None: + self.strategy = { + "motion": Confounds.Full, + "wm_csf": Confounds.Full, + "global_signal": Confounds.Full, "scrubbing": False, } - self.strategy = strategy - self.spike = spike - self.scrub = scrub - self.fd_threshold = fd_threshold - self.std_dvars_threshold = std_dvars_threshold - self.detrend = detrend - self.standardize = standardize - self.low_pass = low_pass - self.high_pass = high_pass - self.t_r = t_r - self.masks = masks - - self._valid_components = [ - "motion", - "wm_csf", - "global_signal", - "scrubbing", - ] - self._valid_confounds = ["basic", "power2", "derivatives", "full"] - - if any(not isinstance(k, str) for k in strategy.keys()): - raise_error("Strategy keys must be strings", ValueError) - - if any( - not isinstance(v, str) - for k, v in strategy.items() - if k != "scrubbing" - ): - raise_error("Strategy values must be strings", ValueError) - - if any(x not in self._valid_components for x in strategy.keys()): - raise_error( - msg=f"Invalid component names {list(strategy.keys())}. " - f"Valid components are {self._valid_components}.\n" - f"If any of them is a valid parameter in " - "nilearn.interfaces.fmriprep.load_confounds we may " - "include it in the future", - klass=ValueError, - ) - - if any( - v not in self._valid_confounds - for k, v in strategy.items() - if k != "scrubbing" - ): - raise_error( - msg=f"Invalid confound types {list(strategy.values())}. " - f"Valid confound types are {self._valid_confounds}.\n" - f"If any of them is a valid parameter in " - "nilearn.interfaces.fmriprep.load_confounds we may " - "include it in the future", - klass=ValueError, - ) - super().__init__() def _map_adhoc_to_fmriprep(self, input: dict[str, Any]) -> None: """Map the adhoc format to the fmpriprep format spec. diff --git a/junifer/preprocess/smoothing/smoothing.py b/junifer/preprocess/smoothing/smoothing.py index 93f55c051..f09307afd 100644 --- a/junifer/preprocess/smoothing/smoothing.py +++ b/junifer/preprocess/smoothing/smoothing.py @@ -4,13 +4,13 @@ # License: AGPL from collections.abc import Sequence -from typing import Any, ClassVar, Optional, Union from enum import Enum +from typing import Any, ClassVar, Literal, Optional from ...api.decorators import register_preprocessor from ...datagrabber import DataType from ...typing import ConditionalDependencies -from ...utils import logger, raise_error +from ...utils import logger from ..base import BasePreprocessor from ._afni_smoothing import AFNISmoothing from ._fsl_smoothing import FSLSmoothing @@ -81,16 +81,16 @@ class Smoothing(BasePreprocessor): _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { - "depends_on": NilearnSmoothing, "using": SmoothingImpl.nilearn, + "depends_on": [NilearnSmoothing], }, { - "depends_on": AFNISmoothing, "using": SmoothingImpl.afni, + "depends_on": [AFNISmoothing], }, { - "depends_on": FSLSmoothing, "using": SmoothingImpl.fsl, + "depends_on": [FSLSmoothing], }, ] _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [ @@ -99,24 +99,15 @@ class Smoothing(BasePreprocessor): DataType.BOLD, ] - def __init__( - self, - using: str, - on: Union[list[str], str], - smoothing_params: Optional[dict] = None, - ) -> None: - """Initialize the class.""" - # Validate `using` parameter - valid_using = [dep["using"] for dep in self._CONDITIONAL_DEPENDENCIES] - if using not in valid_using: - raise_error( - f"Invalid value for `using`, should be one of: {valid_using}" - ) - self.using = using + using: SmoothingImpl + on: list[Literal[DataType.T1w, DataType.T2w, DataType.BOLD]] + smoothing_params: Optional[dict] = None + + def validate_preprocessor_params(self) -> None: + """Run extra logical validation for preprocessor.""" self.smoothing_params = ( - smoothing_params if smoothing_params is not None else {} + self.smoothing_params if self.smoothing_params is not None else {} ) - super().__init__(on=on) def preprocess( self, diff --git a/junifer/preprocess/warping/space_warper.py b/junifer/preprocess/warping/space_warper.py index 6ff416e98..436e41b2a 100644 --- a/junifer/preprocess/warping/space_warper.py +++ b/junifer/preprocess/warping/space_warper.py @@ -4,8 +4,8 @@ # License: AGPL from collections.abc import Sequence -from typing import Any, ClassVar, Optional, Union from enum import Enum +from typing import Any, ClassVar, Literal, Optional from templateflow import api as tflow @@ -47,12 +47,6 @@ class SpaceWarper(BasePreprocessor): type like ``"T1w"`` or a template space like ``"MNI152NLin2009cAsym"``. Use ``"T1w"`` for native space warping and named templates for template space warping. - - Raises - ------ - ValueError - If ``using`` is invalid or - if ``reference`` is invalid. on : list of {``DataType.T1w``, ``DataType.T2w``, ``DataType.BOLD``, \ ``DataType.VBM_GM``, ``DataType.VBM_WM``, ``DataType.VBM_CSF``, \ ``DataType.FALFF``, ``DataType.GCOR``, ``DataType.LCOR``} @@ -62,12 +56,12 @@ class SpaceWarper(BasePreprocessor): _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { - "depends_on": FSLWarper, "using": SpaceWarpingImpl.fsl, + "depends_on": [FSLWarper], }, { - "depends_on": ANTsWarper, "using": SpaceWarpingImpl.ants, + "depends_on": [ANTsWarper], }, { "using": SpaceWarpingImpl.auto, @@ -86,35 +80,30 @@ class SpaceWarper(BasePreprocessor): DataType.LCOR, ] - def __init__( - self, using: str, reference: str, on: Union[list[str], str] - ) -> None: - """Initialize the class.""" - # Validate `using` parameter - valid_using = [dep["using"] for dep in self._CONDITIONAL_DEPENDENCIES] - if using not in valid_using: - raise_error( - f"Invalid value for `using`, should be one of: {valid_using}" - ) - self.using = using - self.reference = reference - # Set required data types based on reference and - # initialize superclass - if self.reference == "T1w": # pragma: no cover - required_data_types = [self.reference, "Warp"] - # Listify on - if not isinstance(on, list): - on = [on] - # Extend required data types - required_data_types.extend(on) + using: SpaceWarpingImpl + reference: str + on: list[ + Literal[ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.VBM_CSF, + DataType.FALFF, + DataType.GCOR, + DataType.LCOR, + ] + ] - super().__init__( - on=on, - required_data_types=required_data_types, - ) - elif self.reference in tflow.templates(): - super().__init__(on=on) - else: + def validate_preprocessor_params(self) -> None: + """Run extra logical validation for preprocessor.""" + # Set required data types based on reference + if self.reference == "T1w": # pragma: no cover + # Update required data types + self.required_data_types = [DataType.T1w, DataType.Warp] + self.required_data_types.extend(self.on) + elif self.reference not in tflow.templates(): raise_error(f"Unknown reference: {self.reference}") def preprocess( # noqa: C901 From a37c82c1391fe63725376f1f22602070d65b6e04 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 11 Nov 2025 13:21:47 +0100 Subject: [PATCH 24/99] chore: improve docstring for BasePreprocessor --- junifer/preprocess/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/preprocess/base.py b/junifer/preprocess/base.py index 59fcde36e..020a69825 100644 --- a/junifer/preprocess/base.py +++ b/junifer/preprocess/base.py @@ -42,7 +42,7 @@ class BasePreprocessor(BaseModel, ABC, PipelineStepMixin, UpdateMetaMixin): Raises ------ AttributeError - If the preprocessor does not have `_VALID_DATA_TYPES` attribute. + If the preprocessor does not have ``_VALID_DATA_TYPES`` attribute. ValueError If required input data type(s) is(are) not found. From 59c587fb950d120105b193020b80a46daf75ce1f Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 11 Nov 2025 16:09:43 +0100 Subject: [PATCH 25/99] chore: add package guard for fMRIPrepConfoundRemover --- junifer/preprocess/confounds/fmriprep_confound_remover.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/junifer/preprocess/confounds/fmriprep_confound_remover.py b/junifer/preprocess/confounds/fmriprep_confound_remover.py index 05da29e93..8ae721dfc 100644 --- a/junifer/preprocess/confounds/fmriprep_confound_remover.py +++ b/junifer/preprocess/confounds/fmriprep_confound_remover.py @@ -5,6 +5,14 @@ # Synchon Mandal # License: AGPL +import sys + + +if sys.version_info < (3, 12): # pragma: no cover + from typing_extensions import TypedDict +else: + from typing import TypedDict + from collections.abc import Sequence from enum import Enum from typing import ( From b734caa965f033417b97d6f7a9000c6a3692c9c4 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 12 Nov 2025 14:29:38 +0100 Subject: [PATCH 26/99] update: introduce and adapt ReHoImpl --- junifer/markers/__init__.pyi | 3 ++- junifer/markers/reho/__init__.pyi | 8 +++++++- junifer/markers/reho/reho_base.py | 27 ++++++++++++++++++--------- junifer/markers/reho/reho_maps.py | 11 +++-------- junifer/markers/reho/reho_parcels.py | 11 +++-------- junifer/markers/reho/reho_spheres.py | 11 +++-------- 6 files changed, 36 insertions(+), 35 deletions(-) diff --git a/junifer/markers/__init__.pyi b/junifer/markers/__init__.pyi index 34c35c7d0..f715a92fd 100644 --- a/junifer/markers/__init__.pyi +++ b/junifer/markers/__init__.pyi @@ -11,6 +11,7 @@ __all__ = [ "EdgeCentricFCMaps", "EdgeCentricFCParcels", "EdgeCentricFCSpheres", + "ReHoImpl", "ReHoMaps", "ReHoParcels", "ReHoSpheres", @@ -37,8 +38,8 @@ from .functional_connectivity import ( EdgeCentricFCParcels, EdgeCentricFCSpheres, ) -from .reho import ReHoMaps, ReHoParcels, ReHoSpheres from .falff import ALFFMaps, ALFFParcels, ALFFSpheres +from .reho import ReHoImpl, ReHoMaps, ReHoParcels, ReHoSpheres from .temporal_snr import ( TemporalSNRMaps, TemporalSNRParcels, diff --git a/junifer/markers/reho/__init__.pyi b/junifer/markers/reho/__init__.pyi index 82d880733..8ec3b25b7 100644 --- a/junifer/markers/reho/__init__.pyi +++ b/junifer/markers/reho/__init__.pyi @@ -1,5 +1,11 @@ -__all__ = ["ReHoMaps", "ReHoParcels", "ReHoSpheres"] +__all__ = [ + "ReHoImpl", + "ReHoMaps", + "ReHoParcels", + "ReHoSpheres", +] +from .reho_base import ReHoImpl from .reho_maps import ReHoMaps from .reho_parcels import ReHoParcels from .reho_spheres import ReHoSpheres diff --git a/junifer/markers/reho/reho_base.py b/junifer/markers/reho/reho_base.py index 0e0f0ebab..d5c9f4ec5 100644 --- a/junifer/markers/reho/reho_base.py +++ b/junifer/markers/reho/reho_base.py @@ -3,6 +3,7 @@ # Authors: Synchon Mandal # License: AGPL +from enum import Enum from pathlib import Path from typing import ( TYPE_CHECKING, @@ -23,7 +24,20 @@ if TYPE_CHECKING: from nibabel.nifti1 import Nifti1Image -__all__ = ["ReHoBase"] + +__all__ = ["ReHoBase", "ReHoImpl"] + + +class ReHoImpl(str, Enum): + """Accepted ReHo implementations. + + * ``junifer`` : ``junifer``'s ReHo + * ``afni`` : AFNI's ``3dReHo`` + + """ + + junifer = "junifer" + afni = "afni" class ReHoBase(BaseMarker): @@ -31,12 +45,7 @@ class ReHoBase(BaseMarker): Parameters ---------- - using : {"junifer", "afni"} - Implementation to use for computing ReHo: - - * "junifer" : Use ``junifer``'s own ReHo implementation - * "afni" : Use AFNI's ``3dReHo`` - + using : :enum:`.ReHoImpl` name : str, optional The name of the marker. If None, it will use the class name (default None). @@ -50,12 +59,12 @@ class ReHoBase(BaseMarker): _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { - "using": "afni", "depends_on": AFNIReHo, + "using": ReHoImpl.afni, }, { - "using": "junifer", "depends_on": JuniferReHo, + "using": ReHoImpl.junifer, }, ] diff --git a/junifer/markers/reho/reho_maps.py b/junifer/markers/reho/reho_maps.py index 50ed38f7e..a9822b6dd 100644 --- a/junifer/markers/reho/reho_maps.py +++ b/junifer/markers/reho/reho_maps.py @@ -25,15 +25,10 @@ class ReHoMaps(ReHoBase): maps : str The name of the map(s) to use. See :func:`.list_data` for options. - using : {"junifer", "afni"} - Implementation to use for computing ReHo: - - * "junifer" : Use ``junifer``'s own ReHo implementation - * "afni" : Use AFNI's ``3dReHo`` - reho_params : dict, optional + using : :enum:`.ReHoImpl` Extra parameters for computing ReHo map as a dictionary (default None). - If ``using="afni"``, then the valid keys are: + If ``using=ReHoImpl.afni``, then the valid keys are: * ``nneigh`` : {7, 19, 27}, optional (default 27) Number of voxels in the neighbourhood, inclusive. Can be: @@ -63,7 +58,7 @@ class ReHoMaps(ReHoBase): The number of voxels for +/- z-axis of cuboidal volumes (default None). - else if ``using="junifer"``, then the valid keys are: + else if ``using=ReHoImpl.junifer``, then the valid keys are: * ``nneigh`` : {7, 19, 27, 125}, optional (default 27) Number of voxels in the neighbourhood, inclusive. Can be: diff --git a/junifer/markers/reho/reho_parcels.py b/junifer/markers/reho/reho_parcels.py index f78a84251..2cce611e2 100644 --- a/junifer/markers/reho/reho_parcels.py +++ b/junifer/markers/reho/reho_parcels.py @@ -25,15 +25,10 @@ class ReHoParcels(ReHoBase): parcellation : str or list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. - using : {"junifer", "afni"} - Implementation to use for computing ReHo: - - * "junifer" : Use ``junifer``'s own ReHo implementation - * "afni" : Use AFNI's ``3dReHo`` - reho_params : dict, optional + using : :enum:`.ReHoImpl` Extra parameters for computing ReHo map as a dictionary (default None). - If ``using="afni"``, then the valid keys are: + If ``using=ReHoImpl.afni``, then the valid keys are: * ``nneigh`` : {7, 19, 27}, optional (default 27) Number of voxels in the neighbourhood, inclusive. Can be: @@ -63,7 +58,7 @@ class ReHoParcels(ReHoBase): The number of voxels for +/- z-axis of cuboidal volumes (default None). - else if ``using="junifer"``, then the valid keys are: + else if ``using=ReHoImpl.junifer``, then the valid keys are: * ``nneigh`` : {7, 19, 27, 125}, optional (default 27) Number of voxels in the neighbourhood, inclusive. Can be: diff --git a/junifer/markers/reho/reho_spheres.py b/junifer/markers/reho/reho_spheres.py index 0d6afb38d..649293887 100644 --- a/junifer/markers/reho/reho_spheres.py +++ b/junifer/markers/reho/reho_spheres.py @@ -25,16 +25,11 @@ class ReHoSpheres(ReHoBase): coords : str The name of the coordinates list to use. See :func:`.list_data` for options. - using : {"junifer", "afni"} - Implementation to use for computing ReHo: - - * "junifer" : Use ``junifer``'s own ReHo implementation - * "afni" : Use AFNI's ``3dReHo`` - radius : float, optional The radius of the sphere in millimeters. If None, the signal will be extracted from a single voxel. See :class:`nilearn.maskers.NiftiSpheresMasker` for more information + using : :enum:`.ReHoImpl` (default None). allow_overlap : bool, optional Whether to allow overlapping spheres. If False, an error is raised if @@ -44,7 +39,7 @@ class ReHoSpheres(ReHoBase): if available (default None). reho_params : dict, optional Extra parameters for computing ReHo map as a dictionary (default None). - If ``using="afni"``, then the valid keys are: + If ``using=ReHoImpl.afni``, then the valid keys are: * ``nneigh`` : {7, 19, 27}, optional (default 27) Number of voxels in the neighbourhood, inclusive. Can be: @@ -74,7 +69,7 @@ class ReHoSpheres(ReHoBase): The number of voxels for +/- z-axis of cuboidal volumes (default None). - else if ``using="junifer"``, then the valid keys are: + else if ``using=ReHoImpl.junifer``, then the valid keys are: * ``nneigh`` : {7, 19, 27, 125}, optional (default 27) Number of voxels in the neighbourhood, inclusive. Can be: From 1adbecd974abb08e6deab7f89d33c111581be320 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 12 Nov 2025 14:40:40 +0100 Subject: [PATCH 27/99] update: introduce and adapt ALFFImpl --- junifer/markers/__init__.pyi | 3 ++- junifer/markers/falff/__init__.pyi | 8 +++++++- junifer/markers/falff/falff_base.py | 26 +++++++++++++++++--------- junifer/markers/falff/falff_maps.py | 7 +------ junifer/markers/falff/falff_parcels.py | 7 +------ junifer/markers/falff/falff_spheres.py | 7 +------ 6 files changed, 29 insertions(+), 29 deletions(-) diff --git a/junifer/markers/__init__.pyi b/junifer/markers/__init__.pyi index f715a92fd..edcd0686b 100644 --- a/junifer/markers/__init__.pyi +++ b/junifer/markers/__init__.pyi @@ -15,6 +15,7 @@ __all__ = [ "ReHoMaps", "ReHoParcels", "ReHoSpheres", + "ALFFImpl", "ALFFMaps", "ALFFParcels", "ALFFSpheres", @@ -38,8 +39,8 @@ from .functional_connectivity import ( EdgeCentricFCParcels, EdgeCentricFCSpheres, ) -from .falff import ALFFMaps, ALFFParcels, ALFFSpheres from .reho import ReHoImpl, ReHoMaps, ReHoParcels, ReHoSpheres +from .falff import ALFFImpl, ALFFMaps, ALFFParcels, ALFFSpheres from .temporal_snr import ( TemporalSNRMaps, TemporalSNRParcels, diff --git a/junifer/markers/falff/__init__.pyi b/junifer/markers/falff/__init__.pyi index 1b003e374..ba54274f6 100644 --- a/junifer/markers/falff/__init__.pyi +++ b/junifer/markers/falff/__init__.pyi @@ -1,5 +1,11 @@ -__all__ = ["ALFFMaps", "ALFFParcels", "ALFFSpheres"] +__all__ = [ + "ALFFImpl", + "ALFFMaps", + "ALFFParcels", + "ALFFSpheres", +] +from .falff_base import ALFFImpl from .falff_maps import ALFFMaps from .falff_parcels import ALFFParcels from .falff_spheres import ALFFSpheres diff --git a/junifer/markers/falff/falff_base.py b/junifer/markers/falff/falff_base.py index d743deea4..66fcb36e6 100644 --- a/junifer/markers/falff/falff_base.py +++ b/junifer/markers/falff/falff_base.py @@ -6,6 +6,7 @@ # Synchon Mandal # License: AGPL +from enum import Enum from pathlib import Path from typing import ( TYPE_CHECKING, @@ -27,7 +28,19 @@ from nibabel.nifti1 import Nifti1Image -__all__ = ["ALFFBase"] +__all__ = ["ALFFBase", "ALFFImpl"] + + +class ALFFImpl(str, Enum): + """Accepted ALFF implementations. + + * ``junifer`` : ``junifer``'s ALFF + * ``afni`` : AFNI's ``3dRSFC`` + + """ + + junifer = "junifer" + afni = "afni" class ALFFBase(BaseMarker): @@ -39,12 +52,7 @@ class ALFFBase(BaseMarker): Highpass cutoff frequency. lowpass : positive float Lowpass cutoff frequency. - using : {"junifer", "afni"} - Implementation to use for computing ALFF: - - * "junifer" : Use ``junifer``'s own ALFF implementation - * "afni" : Use AFNI's ``3dRSFC`` - + using : :enum:`.ALFFImpl` tr : positive float, optional The Repetition Time of the BOLD data. If None, will extract the TR from NIfTI header (default None). @@ -72,12 +80,12 @@ class ALFFBase(BaseMarker): _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { - "using": "afni", "depends_on": AFNIALFF, + "using": ALFFImpl.afni, }, { - "using": "junifer", "depends_on": JuniferALFF, + "using": ALFFImpl.junifer, }, ] diff --git a/junifer/markers/falff/falff_maps.py b/junifer/markers/falff/falff_maps.py index 1d6434044..f41e1520d 100644 --- a/junifer/markers/falff/falff_maps.py +++ b/junifer/markers/falff/falff_maps.py @@ -23,12 +23,7 @@ class ALFFMaps(ALFFBase): maps : str The name of the map(s) to use. See :func:`.list_data` for options. - using : {"junifer", "afni"} - Implementation to use for computing ALFF: - - * "junifer" : Use ``junifer``'s own ALFF implementation - * "afni" : Use AFNI's ``3dRSFC`` - + using : :enum:`.ALFFImpl` highpass : positive float, optional The highpass cutoff frequency for the bandpass filter. If 0, it will not apply a highpass filter (default 0.01). diff --git a/junifer/markers/falff/falff_parcels.py b/junifer/markers/falff/falff_parcels.py index 7fbaf896b..ed808c8d7 100644 --- a/junifer/markers/falff/falff_parcels.py +++ b/junifer/markers/falff/falff_parcels.py @@ -26,12 +26,7 @@ class ALFFParcels(ALFFBase): parcellation : str or list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. - using : {"junifer", "afni"} - Implementation to use for computing ALFF: - - * "junifer" : Use ``junifer``'s own ALFF implementation - * "afni" : Use AFNI's ``3dRSFC`` - + using : :enum:`.ALFFImpl` highpass : positive float, optional The highpass cutoff frequency for the bandpass filter. If 0, it will not apply a highpass filter (default 0.01). diff --git a/junifer/markers/falff/falff_spheres.py b/junifer/markers/falff/falff_spheres.py index faee35b57..8ec9c1f98 100644 --- a/junifer/markers/falff/falff_spheres.py +++ b/junifer/markers/falff/falff_spheres.py @@ -26,16 +26,11 @@ class ALFFSpheres(ALFFBase): coords : str The name of the coordinates list to use. See :func:`.list_data` for options. - using : {"junifer", "afni"} - Implementation to use for computing ALFF: - - * "junifer" : Use ``junifer``'s own ALFF implementation - * "afni" : Use AFNI's ``3dRSFC`` - radius : float, optional The radius of the sphere in mm. If None, the signal will be extracted from a single voxel. See :class:`nilearn.maskers.NiftiSpheresMasker` for more information (default None). + using : :enum:`.ALFFImpl` allow_overlap : bool, optional Whether to allow overlapping spheres. If False, an error is raised if the spheres overlap (default is False). From fa7f1f701e69baaf6931e5232f3c4ea320005ee5 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 12 Nov 2025 14:50:41 +0100 Subject: [PATCH 28/99] update: adapt ExtDep for remaining markers --- junifer/markers/falff/_afni_falff.py | 4 ++-- junifer/markers/reho/_afni_reho.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/junifer/markers/falff/_afni_falff.py b/junifer/markers/falff/_afni_falff.py index 3c9fb52cb..5eb217017 100644 --- a/junifer/markers/falff/_afni_falff.py +++ b/junifer/markers/falff/_afni_falff.py @@ -13,7 +13,7 @@ import nibabel as nib -from ...pipeline import WorkDirManager +from ...pipeline import ExtDep, WorkDirManager from ...typing import ExternalDependencies from ...utils import logger, run_ext_cmd from ...utils.singleton import Singleton @@ -36,7 +36,7 @@ class AFNIALFF(metaclass=Singleton): _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "afni", + "name": ExtDep.AFNI, "commands": ["3dRSFC", "3dAFNItoNIFTI"], }, ] diff --git a/junifer/markers/reho/_afni_reho.py b/junifer/markers/reho/_afni_reho.py index 47f3a2341..298958851 100644 --- a/junifer/markers/reho/_afni_reho.py +++ b/junifer/markers/reho/_afni_reho.py @@ -13,7 +13,7 @@ import nibabel as nib -from ...pipeline import WorkDirManager +from ...pipeline import ExtDep, WorkDirManager from ...typing import ExternalDependencies from ...utils import logger, run_ext_cmd from ...utils.singleton import Singleton @@ -36,7 +36,7 @@ class AFNIReHo(metaclass=Singleton): _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "afni", + "name": ExtDep.AFNI, "commands": ["3dReHo", "3dAFNItoNIFTI"], }, ] From 06073d75ef7327117699101b9ae5d17a98106c83 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 12 Nov 2025 19:14:22 +0100 Subject: [PATCH 29/99] update: adapt necessary enums in markers --- junifer/markers/base.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/junifer/markers/base.py b/junifer/markers/base.py index 16636fc6a..be56d0e44 100644 --- a/junifer/markers/base.py +++ b/junifer/markers/base.py @@ -8,7 +8,9 @@ from copy import deepcopy from typing import Any, ClassVar, Optional, Union +from ..datagrabber import DataType from ..pipeline import PipelineStepMixin, UpdateMetaMixin +from ..storage import StorageType from ..typing import MarkerInOutMappings, StorageLike from ..utils import logger, raise_error @@ -24,9 +26,10 @@ class BaseMarker(ABC, PipelineStepMixin, UpdateMetaMixin): Parameters ---------- - on : str or list of str or None, optional - The data type to apply the marker on. If None, - will work on all available data types (default None). + on : list of :enum:`.DataType` or None, optional + The data type(s) to apply the marker on. + If None, will work on all available data types. + Check :enum:`.DataType` for valid values (default None). name : str, optional The name of the marker. If None, will use the class name as the name of the marker (default None). @@ -107,19 +110,21 @@ def get_valid_inputs(self) -> list[str]: """ return list(self._MARKER_INOUT_MAPPINGS.keys()) - def storage_type(self, input_type: str, output_feature: str) -> str: - """Get storage type for a feature. + def storage_type( + self, input_type: DataType, output_feature: str + ) -> StorageType: + """Get :enum:`.StorageType` for a feature. Parameters ---------- - input_type : str + input_type : :enum:`.DataType` The data type input to the marker. output_feature : str The feature output of the marker. Returns ------- - str + :enum:`.StorageType` The storage type output of the marker. """ @@ -155,7 +160,7 @@ def compute(self, input: dict, extra_input: Optional[dict] = None) -> dict: def store( self, - type_: str, + data_type: DataType, feature: str, out: dict[str, Any], storage: StorageLike, @@ -164,7 +169,7 @@ def store( Parameters ---------- - type_ : str + data_type : :enum:`.DataType` The data type to store. feature : str The feature to store. From 2586ce3d2516503c3bda27afa54580787d65bf31 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 12 Nov 2025 19:24:21 +0100 Subject: [PATCH 30/99] refactor: make marker interface and impls pydantic models --- junifer/markers/base.py | 139 ++++++++++-------- junifer/markers/brainprint.py | 41 ++---- junifer/markers/complexity/complexity_base.py | 49 +++--- junifer/markers/complexity/hurst_exponent.py | 58 +++----- .../complexity/multiscale_entropy_auc.py | 59 ++++---- junifer/markers/complexity/perm_entropy.py | 59 +++----- junifer/markers/complexity/range_entropy.py | 58 +++----- .../markers/complexity/range_entropy_auc.py | 58 +++----- junifer/markers/complexity/sample_entropy.py | 59 +++----- .../complexity/weighted_perm_entropy.py | 59 +++----- junifer/markers/ets_rss.py | 42 +++--- junifer/markers/falff/falff_base.py | 76 +++++----- junifer/markers/falff/falff_maps.py | 46 ++---- junifer/markers/falff/falff_parcels.py | 64 +++----- junifer/markers/falff/falff_spheres.py | 83 ++++------- ...ossparcellation_functional_connectivity.py | 51 +++---- .../edge_functional_connectivity_maps.py | 36 ++--- .../edge_functional_connectivity_parcels.py | 51 +++---- .../edge_functional_connectivity_spheres.py | 67 +++------ .../functional_connectivity_base.py | 56 ++++--- .../functional_connectivity_maps.py | 36 ++--- .../functional_connectivity_parcels.py | 51 +++---- .../functional_connectivity_spheres.py | 68 +++------ junifer/markers/maps_aggregation.py | 70 +++++---- junifer/markers/parcel_aggregation.py | 95 ++++++------ junifer/markers/reho/reho_base.py | 87 ++++++++--- junifer/markers/reho/reho_maps.py | 30 ++-- junifer/markers/reho/reho_parcels.py | 47 +++--- junifer/markers/reho/reho_spheres.py | 67 ++++----- junifer/markers/sphere_aggregation.py | 107 +++++++------- .../markers/temporal_snr/temporal_snr_base.py | 36 ++--- .../markers/temporal_snr/temporal_snr_maps.py | 26 ++-- .../temporal_snr/temporal_snr_parcels.py | 44 +++--- .../temporal_snr/temporal_snr_spheres.py | 67 ++++----- 34 files changed, 887 insertions(+), 1155 deletions(-) diff --git a/junifer/markers/base.py b/junifer/markers/base.py index be56d0e44..45e2b2949 100644 --- a/junifer/markers/base.py +++ b/junifer/markers/base.py @@ -8,6 +8,8 @@ from copy import deepcopy from typing import Any, ClassVar, Optional, Union +from pydantic import BaseModel, ConfigDict + from ..datagrabber import DataType from ..pipeline import PipelineStepMixin, UpdateMetaMixin from ..storage import StorageType @@ -18,7 +20,7 @@ __all__ = ["BaseMarker"] -class BaseMarker(ABC, PipelineStepMixin, UpdateMetaMixin): +class BaseMarker(BaseModel, ABC, PipelineStepMixin, UpdateMetaMixin): """Abstract base class for marker. For every marker, one needs to provide a concrete @@ -30,9 +32,13 @@ class BaseMarker(ABC, PipelineStepMixin, UpdateMetaMixin): The data type(s) to apply the marker on. If None, will work on all available data types. Check :enum:`.DataType` for valid values (default None). - name : str, optional - The name of the marker. If None, will use the class name as the - name of the marker (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). + + Attributes + ---------- + valid_inputs Raises ------ @@ -45,30 +51,56 @@ class BaseMarker(ABC, PipelineStepMixin, UpdateMetaMixin): _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] - def __init__( - self, - on: Optional[Union[list[str], str]] = None, - name: Optional[str] = None, - ) -> None: + model_config = ConfigDict(use_enum_values=True) + + on: Optional[list[DataType]] = None + name: Optional[str] = None + + def model_post_init(self, context: Any): # noqa: D102 # Check for missing mapping attribute if not hasattr(self, "_MARKER_INOUT_MAPPINGS"): raise_error( msg=("Missing `_MARKER_INOUT_MAPPINGS` for the marker"), klass=AttributeError, ) + # Run extra validation for markers and fail early if needed + self.validate_marker_params() # Use all data types if not provided - if on is None: - on = self.get_valid_inputs() - # Convert data types to list - if not isinstance(on, list): - on = [on] + if self.on is None: + self.on = self.valid_inputs + else: + # Convert to correct data type + self.on = [DataType(t) for t in self.on] + # Check if required input data types are provided + if any(x not in self.valid_inputs for x in self.on): + wrong_on = [ + x.value for x in self.on if x not in self.valid_inputs + ] + raise_error( + f"{self.__class__.__name__} cannot be computed on " + f"{wrong_on}" + ) # Set default name if not provided - self.name = self.__class__.__name__ if name is None else name - # Check if required inputs are found - if any(x not in self.get_valid_inputs() for x in on): - wrong_on = [x for x in on if x not in self.get_valid_inputs()] - raise_error(f"{self.name} cannot be computed on {wrong_on}") - self._on = on + self.name = self.__class__.__name__ if self.name is None else self.name + + @property + def valid_inputs(self) -> list[DataType]: + """Valid data types to operate on. + + Returns + ------- + list of :enum:`.DataType` + The list of data types that can be used as input for this marker. + + """ + return [DataType(x) for x in self._MARKER_INOUT_MAPPINGS.keys()] + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker. + + Subclasses can override to provide validation. + """ + pass def validate_input(self, input: list[str]) -> list[str]: """Validate input. @@ -91,24 +123,13 @@ def validate_input(self, input: list[str]) -> list[str]: If the input does not have the required data. """ - if not any(x in input for x in self._on): + if not any(x in input for x in self.on): raise_error( - "Input does not have the required data." - f"\t Input: {input}" - f"\t Required (any of): {self._on}" + "Input does not have the required data.\n" + f"\t Input: {input}\n" + f"\t Required (any of): {[t.value for t in self.on]}" ) - return [x for x in self._on if x in input] - - def get_valid_inputs(self) -> list[str]: - """Get valid data types for input. - - Returns - ------- - list of str - The list of data types that can be used as input for this marker. - - """ - return list(self._MARKER_INOUT_MAPPINGS.keys()) + return [x.value for x in self.on if x in input] def storage_type( self, input_type: DataType, output_feature: str @@ -162,7 +183,7 @@ def store( self, data_type: DataType, feature: str, - out: dict[str, Any], + output: dict[str, Any], storage: StorageLike, ) -> None: """Store. @@ -173,15 +194,15 @@ def store( The data type to store. feature : str The feature to store. - out : dict + output : dict The computed result as a dictionary to store. storage : storage-like The storage class, for example, SQLiteFeatureStorage. """ - output_type_ = self.storage_type(type_, feature) - logger.debug(f"Storing {output_type_} in {storage}") - storage.store(kind=output_type_, **out) + s_type = self.storage_type(data_type, feature) + logger.debug(f"Storing {s_type} in {storage}") + storage.store(kind=s_type, **output) def _fit_transform( self, @@ -200,58 +221,56 @@ def _fit_transform( Returns ------- dict - The processed output as a dictionary. If `storage` is provided, + The processed output as a dictionary. If ``storage`` is provided, empty dictionary is returned. """ out = {} - for type_ in self._on: - if type_ in input.keys(): - logger.info(f"Computing {type_}") + for t in self.on: + if t in input.keys(): + logger.info(f"Computing {t}") # Get data dict for data type - t_input = input[type_] + t_input = input[t] # Pass the other data types as extra input, removing # the current type extra_input = input.copy() - extra_input.pop(type_) + extra_input.pop(t) logger.debug( f"Extra data type for feature extraction: " f"{extra_input.keys()}" ) # Copy metadata t_meta = t_input["meta"].copy() - t_meta["type"] = type_ + t_meta["type"] = t.value # Compute marker t_out = self.compute(input=t_input, extra_input=extra_input) # Initialize empty dictionary if no storage object is provided if storage is None: - out[type_] = {} + out[t] = {} # Store individual features - for feature_name, feature_data in t_out.items(): + for f_name, f_data in t_out.items(): # Make deep copy of the feature data for manipulation - feature_data_copy = deepcopy(feature_data) + f_data_copy = deepcopy(f_data) # Make deep copy of metadata and add to feature data - feature_data_copy["meta"] = deepcopy(t_meta) + f_data_copy["meta"] = deepcopy(t_meta) # Update metadata for the feature, # feature data is not manipulated, only meta - self.update_meta(feature_data_copy, "marker") + self.update_meta(f_data_copy, "marker") # Update marker feature's metadata name - feature_data_copy["meta"]["marker"]["name"] += ( - f"_{feature_name}" - ) + f_data_copy["meta"]["marker"]["name"] += f"_{f_name}" if storage is not None: logger.info(f"Storing in {storage}") self.store( - type_=type_, - feature=feature_name, - out=feature_data_copy, + data_type=t, + feature=f_name, + output=f_data_copy, storage=storage, ) else: logger.info( "No storage specified, returning dictionary" ) - out[type_][feature_name] = feature_data_copy + out[t][f_name] = f_data_copy return out diff --git a/junifer/markers/brainprint.py b/junifer/markers/brainprint.py index 76a0f5c6a..2af6f753a 100644 --- a/junifer/markers/brainprint.py +++ b/junifer/markers/brainprint.py @@ -14,6 +14,7 @@ import numpy as np import numpy.typing as npt +from pydantic import PositiveInt from ..api.decorators import register_marker from ..datagrabber import DataType @@ -62,9 +63,9 @@ class BrainPrint(BaseMarker): execution speed. Requires the ``scikit-sparse`` library. If it cannot be found, an error will be thrown. If False, will use slower LU decomposition (default False). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ @@ -91,27 +92,17 @@ class BrainPrint(BaseMarker): } } - def __init__( - self, - num: int = 50, - skip_cortex=False, - keep_eigenvectors: bool = False, - norm: str = "none", - reweight: bool = False, - asymmetry: bool = False, - asymmetry_distance: str = "euc", - use_cholmod: bool = False, - name: Optional[str] = None, - ) -> None: - self.num = num - self.skip_cortex = skip_cortex - self.keep_eigenvectors = keep_eigenvectors - self.norm = norm - self.reweight = reweight - self.asymmetry = asymmetry - self.asymmetry_distance = asymmetry_distance - self.use_cholmod = use_cholmod - super().__init__(name=name, on="FreeSurfer") + num: PositiveInt = 50 + skip_cortex: bool = False + keep_eigenvectors: bool = False + norm: str = "none" + reweight: bool = False + asymmetry: bool = False + asymmetry_distance: str = "euc" + use_cholmod: bool = False + + _tempdir = Path() + _element_tempdir = Path() def _create_aseg_surface( self, @@ -355,7 +346,7 @@ def compute( - ``col_names`` : surface labels as list of str - ``row_names`` : eigenvalue count labels as list of str - ``row_header_col_name`` : "eigenvalue" - () + * ``areas`` : dictionary with the following keys: - ``data`` : areas as ``np.ndarray`` diff --git a/junifer/markers/complexity/complexity_base.py b/junifer/markers/complexity/complexity_base.py index da30bcfc6..90a08ffe9 100644 --- a/junifer/markers/complexity/complexity_base.py +++ b/junifer/markers/complexity/complexity_base.py @@ -9,6 +9,7 @@ TYPE_CHECKING, Any, ClassVar, + Literal, Optional, Union, ) @@ -29,26 +30,27 @@ class ComplexityBase(BaseMarker): - """Base class for complexity computation. + """Abstract base class for complexity computation. Parameters ---------- - parcellation : str or list of str - The name(s) of the parcellation(s). Check valid options by calling - :func:`junifer.data.parcellations.list_parcellations`. + parcellation : list of str + The name(s) of the parcellation(s) to use. + See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, it will use the class name - (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ @@ -60,19 +62,12 @@ class ComplexityBase(BaseMarker): }, } - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.parcellation = parcellation - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.masks = masks - super().__init__(on="BOLD", name=name) + parcellation: list[str] + agg_method: str = "mean" + agg_method_params: Optional[dict] = None + masks: Optional[list[Union[dict, str]]] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + @abstractmethod def compute_complexity( @@ -120,7 +115,7 @@ def compute( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input=input, extra_input=extra_input) # Compute complexity measure return { diff --git a/junifer/markers/complexity/hurst_exponent.py b/junifer/markers/complexity/hurst_exponent.py index 98311374e..4869dc12d 100644 --- a/junifer/markers/complexity/hurst_exponent.py +++ b/junifer/markers/complexity/hurst_exponent.py @@ -2,14 +2,16 @@ # Authors: Amir Omidvarnia # Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Literal, Optional import neurokit2 as nk import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger, warn_with_log from .complexity_base import ComplexityBase @@ -23,26 +25,27 @@ class HurstExponent(ComplexityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - params : dict, optional - Parameters to pass to the Hurst exponent calculation function. For more - information, check out ``junifer.markers.utils._hurst_exponent``. - If None, value is set to {"method": "dfa"} (default None). - name : str, optional - The name of the marker. If None, it will use the class name - (default None). + params : dict or None, optional + The parameters to pass to the Hurst exponent calculation function. + See ``junifer.markers.utils._hurst_exponent`` for more information. + If None, value is set to ``{"method": "dfa"}`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Warnings -------- @@ -53,26 +56,13 @@ class HurstExponent(ComplexityBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - params: Optional[dict] = None, - name: Optional[str] = None, - ) -> None: - super().__init__( - parcellation=parcellation, - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) - if params is None: + params: Optional[dict] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + if self.params is None: self.params = {"method": "dfa"} - else: - self.params = params def compute_complexity( self, diff --git a/junifer/markers/complexity/multiscale_entropy_auc.py b/junifer/markers/complexity/multiscale_entropy_auc.py index bd156ad35..2a8b20ef3 100644 --- a/junifer/markers/complexity/multiscale_entropy_auc.py +++ b/junifer/markers/complexity/multiscale_entropy_auc.py @@ -2,14 +2,16 @@ # Authors: Amir Omidvarnia # Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Literal, Optional import neurokit2 as nk import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger, warn_with_log from .complexity_base import ComplexityBase @@ -23,27 +25,29 @@ class MultiscaleEntropyAUC(ComplexityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - params : dict, optional - Parameters to pass to the AUC of multiscale entropy calculation - function. For more information, check out - ``junifer.markers.utils._multiscale_entropy_auc``. If None, value - is set to {"m": 2, "tol": 0.5, "scale": 10} (default None). - name : str, optional - The name of the marker. If None, it will use the class name + params : dict or None, optional + The parameters to pass to the AUC of multiscale entropy calculation + function. See + ``junifer.markers.utils._multiscale_entropy_auc`` for more information. + If None, value is set to ``{"m": 2, "tol": 0.5, "scale": 10}`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Warnings -------- @@ -54,26 +58,13 @@ class MultiscaleEntropyAUC(ComplexityBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - params: Optional[dict] = None, - name: Optional[str] = None, - ) -> None: - super().__init__( - parcellation=parcellation, - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) - if params is None: + params: Optional[dict] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + if self.params is None: self.params = {"m": 2, "tol": 0.5, "scale": 10} - else: - self.params = params def compute_complexity( self, diff --git a/junifer/markers/complexity/perm_entropy.py b/junifer/markers/complexity/perm_entropy.py index e6e984d0a..fcd51b5ce 100644 --- a/junifer/markers/complexity/perm_entropy.py +++ b/junifer/markers/complexity/perm_entropy.py @@ -2,14 +2,16 @@ # Authors: Amir Omidvarnia # Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Literal, Optional import neurokit2 as nk import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger, warn_with_log from .complexity_base import ComplexityBase @@ -23,27 +25,27 @@ class PermEntropy(ComplexityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - params : dict, optional - Parameters to pass to the permutation entropy calculation function. - For more information, check out - ``junifer.markers.utils._perm_entropy``. If None, value is set to - {"m": 2, "delay": 1} (default None). - name : str, optional - The name of the marker. If None, it will use the class name - (default None). + params : dict or None, optional + The parameters to pass to the permutation entropy calculation function. + See ``junifer.markers.utils._perm_entropy`` for more information. + If None, value is set to ``{"m": 2, "delay": 1}`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Warnings -------- @@ -54,26 +56,13 @@ class PermEntropy(ComplexityBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - params: Optional[dict] = None, - name: Optional[str] = None, - ) -> None: - super().__init__( - parcellation=parcellation, - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) - if params is None: + params: Optional[dict] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + if self.params is None: self.params = {"m": 4, "delay": 1} - else: - self.params = params def compute_complexity( self, diff --git a/junifer/markers/complexity/range_entropy.py b/junifer/markers/complexity/range_entropy.py index e67154543..a938516ff 100644 --- a/junifer/markers/complexity/range_entropy.py +++ b/junifer/markers/complexity/range_entropy.py @@ -2,14 +2,16 @@ # Authors: Amir Omidvarnia # Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Literal, Optional import neurokit2 as nk import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger, warn_with_log from .complexity_base import ComplexityBase @@ -23,27 +25,28 @@ class RangeEntropy(ComplexityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - params : dict, optional - Parameters to pass to the range entropy calculation function. For more - information, check out ``junifer.markers.utils._range_entropy``. - If None, value is set to {"m": 2, "tol": 0.5, "delay": 1} - (default None). - name : str, optional - The name of the marker. If None, it will use the class name + params : dict or None, optional + The parameters to pass to the range entropy calculation function. + See ``junifer.markers.utils._range_entropy`` for more information. + If None, value is set to ``{"m": 2, "tol": 0.5, "delay": 1}`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Warnings -------- @@ -54,26 +57,13 @@ class RangeEntropy(ComplexityBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - params: Optional[dict] = None, - name: Optional[str] = None, - ) -> None: - super().__init__( - parcellation=parcellation, - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) - if params is None: + params: Optional[dict] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + if self.params is None: self.params = {"m": 2, "tol": 0.5, "delay": 1} - else: - self.params = params def compute_complexity( self, diff --git a/junifer/markers/complexity/range_entropy_auc.py b/junifer/markers/complexity/range_entropy_auc.py index bbd84e8ad..8ae29be79 100644 --- a/junifer/markers/complexity/range_entropy_auc.py +++ b/junifer/markers/complexity/range_entropy_auc.py @@ -2,14 +2,16 @@ # Authors: Amir Omidvarnia # Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Literal, Optional import neurokit2 as nk import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger, warn_with_log from .complexity_base import ComplexityBase @@ -23,27 +25,28 @@ class RangeEntropyAUC(ComplexityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - params : dict, optional - Parameters to pass to the range entropy calculation function. For more - information, check out ``junifer.markers.utils._range_entropy``. - If None, value is set to {"m": 2, "delay": 1, "n_r": 10} - (default None). - name : str, optional - The name of the marker. If None, it will use the class name + params : dict or None, optional + The parameters to pass to the range entropy calculation function. + See ``junifer.markers.utils._range_entropy`` for more information. + If None, value is set to ``{"m": 2, "delay": 1, "n_r": 10}`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Warnings -------- @@ -54,26 +57,13 @@ class RangeEntropyAUC(ComplexityBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - params: Optional[dict] = None, - name: Optional[str] = None, - ) -> None: - super().__init__( - parcellation=parcellation, - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) - if params is None: + params: Optional[dict] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + if self.params is None: self.params = {"m": 2, "delay": 1, "n_r": 10} - else: - self.params = params def compute_complexity( self, diff --git a/junifer/markers/complexity/sample_entropy.py b/junifer/markers/complexity/sample_entropy.py index a8199b114..b187e25d1 100644 --- a/junifer/markers/complexity/sample_entropy.py +++ b/junifer/markers/complexity/sample_entropy.py @@ -2,14 +2,16 @@ # Authors: Amir Omidvarnia # Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Literal, Optional import neurokit2 as nk import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger, warn_with_log from .complexity_base import ComplexityBase @@ -23,28 +25,28 @@ class SampleEntropy(ComplexityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - params : dict, optional - Parameters to pass to the sample entropy calculation function. - For more information, check out - ``junifer.markers.utils._sample_entropy``. - If None, value is set to - {"m": 2, "delay": 1, "tol": 0.5} (default None). - name : str, optional - The name of the marker. If None, it will use the class name + params : dict or None, optional + The parameters to pass to the sample entropy calculation function. + See ``junifer.markers.utils._sample_entropy`` for more information. + If None, value is set to ``{"m": 2, "delay": 1, "tol": 0.5}`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Warnings -------- @@ -55,26 +57,13 @@ class SampleEntropy(ComplexityBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - params: Optional[dict] = None, - name: Optional[str] = None, - ) -> None: - super().__init__( - parcellation=parcellation, - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) - if params is None: + params: Optional[dict] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + if self.params is None: self.params = {"m": 4, "delay": 1, "tol": 0.5} - else: - self.params = params def compute_complexity( self, diff --git a/junifer/markers/complexity/weighted_perm_entropy.py b/junifer/markers/complexity/weighted_perm_entropy.py index bc709ea54..4ed3aedf0 100644 --- a/junifer/markers/complexity/weighted_perm_entropy.py +++ b/junifer/markers/complexity/weighted_perm_entropy.py @@ -2,14 +2,16 @@ # Authors: Amir Omidvarnia # Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Literal, Optional import neurokit2 as nk import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger, warn_with_log from .complexity_base import ComplexityBase @@ -23,28 +25,28 @@ class WeightedPermEntropy(ComplexityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`junifer.stats.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - params : dict, optional - Parameters to pass to the weighted permutation entropy calculation - function. - For more information, check out - ``junifer.markers.utils._weighted_perm_entropy``. If None, value - is set to {"m": 2, "delay": 1} (default None). - name : str, optional - The name of the marker. If None, it will use the class name + params : dict or None, optional + The parameters to pass to the weighted permutation entropy calculation + function. See ``junifer.markers.utils._weighted_perm_entropy`` for more + information. If None, value is set to ``{"m": 2, "delay": 1}`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Warnings -------- @@ -55,26 +57,13 @@ class WeightedPermEntropy(ComplexityBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - params: Optional[dict] = None, - name: Optional[str] = None, - ) -> None: - super().__init__( - parcellation=parcellation, - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) - if params is None: + params: Optional[dict] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + if self.params is None: self.params = {"m": 4, "delay": 1} - else: - self.params = params def compute_complexity( self, diff --git a/junifer/markers/ets_rss.py b/junifer/markers/ets_rss.py index aa816b6c7..d084554fb 100644 --- a/junifer/markers/ets_rss.py +++ b/junifer/markers/ets_rss.py @@ -6,7 +6,7 @@ # Synchon Mandal # License: AGPL -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union import numpy as np @@ -29,22 +29,23 @@ class RSSETSMarker(BaseMarker): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ @@ -56,19 +57,11 @@ class RSSETSMarker(BaseMarker): }, } - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.parcellation = parcellation - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.masks = masks - super().__init__(name=name) + parcellation: list[str] + agg_method: str = "mean" + agg_method_params: Optional[dict] = None + masks: Optional[list[Union[dict, str]]] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def compute( self, @@ -112,6 +105,7 @@ def compute( parcellation=self.parcellation, method=self.agg_method, method_params=self.agg_method_params, + on=[DataType.BOLD], masks=self.masks, ).compute(input=input, extra_input=extra_input) # Compute edgewise timeseries diff --git a/junifer/markers/falff/falff_base.py b/junifer/markers/falff/falff_base.py index 66fcb36e6..0e6a0c8d7 100644 --- a/junifer/markers/falff/falff_base.py +++ b/junifer/markers/falff/falff_base.py @@ -12,13 +12,17 @@ TYPE_CHECKING, Any, ClassVar, + Literal, Optional, + Union, ) +from pydantic import PositiveFloat + from ...datagrabber import DataType from ...storage import StorageType from ...typing import ConditionalDependencies, MarkerInOutMappings -from ...utils.logging import logger, raise_error +from ...utils.logging import logger from ..base import BaseMarker from ._afni_falff import AFNIALFF from ._junifer_falff import JuniferALFF @@ -48,17 +52,28 @@ class ALFFBase(BaseMarker): Parameters ---------- - highpass : positive float - Highpass cutoff frequency. - lowpass : positive float - Lowpass cutoff frequency. using : :enum:`.ALFFImpl` + highpass : positive float, optional + Highpass cutoff frequency (default 0.01). + lowpass : positive float, optional + Lowpass cutoff frequency (default 0.1). tr : positive float, optional - The Repetition Time of the BOLD data. If None, will extract - the TR from NIfTI header (default None). - name : str, optional - The name of the marker. If None, it will use the class name - (default None). + The repetition time of the BOLD data. + If None, will extract the TR from NIfTI header (default None). + agg_method : str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional + The specification of the masks to apply to regions before extracting + signals. Check :ref:`Using Masks ` for more details. + If None, will not apply any mask (default None). + name : str or None, optional + The name of the marker. + If None, it will use the class name (default None). Notes ----- @@ -68,14 +83,6 @@ class ALFFBase(BaseMarker): reported that some preprocessed data might not have the correct ``tr`` in the NIfTI header. - Raises - ------ - ValueError - If ``highpass`` is not positive or zero or - if ``lowpass`` is not positive or - if ``highpass`` is higher than ``lowpass`` or - if ``using`` is invalid. - """ _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ @@ -96,31 +103,14 @@ class ALFFBase(BaseMarker): }, } - def __init__( - self, - highpass: float, - lowpass: float, - using: str, - tr: Optional[float] = None, - name: Optional[str] = None, - ) -> None: - if highpass < 0: - raise_error("Highpass must be positive or 0") - if lowpass <= 0: - raise_error("Lowpass must be positive") - if highpass >= lowpass: - raise_error("Highpass must be lower than lowpass") - self.highpass = highpass - self.lowpass = lowpass - # Validate `using` parameter - valid_using = [dep["using"] for dep in self._CONDITIONAL_DEPENDENCIES] - if using not in valid_using: - raise_error( - f"Invalid value for `using`, should be one of: {valid_using}" - ) - self.using = using - self.tr = tr - super().__init__(on="BOLD", name=name) + using: ALFFImpl + highpass: PositiveFloat = 0.01 + lowpass: PositiveFloat = 0.1 + tr: Optional[PositiveFloat] = None + agg_method: str = "mean" + agg_method_params: Optional[dict] = None + masks: Optional[list[Union[dict, str]]] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def _compute( self, diff --git a/junifer/markers/falff/falff_maps.py b/junifer/markers/falff/falff_maps.py index f41e1520d..6bd03fd98 100644 --- a/junifer/markers/falff/falff_maps.py +++ b/junifer/markers/falff/falff_maps.py @@ -3,9 +3,10 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger from ..maps_aggregation import MapsAggregation from .falff_base import ALFFBase @@ -25,20 +26,19 @@ class ALFFMaps(ALFFBase): See :func:`.list_data` for options. using : :enum:`.ALFFImpl` highpass : positive float, optional - The highpass cutoff frequency for the bandpass filter. If 0, - it will not apply a highpass filter (default 0.01). + Highpass cutoff frequency (default 0.01). lowpass : positive float, optional - The lowpass cutoff frequency for the bandpass filter (default 0.1). + Lowpass cutoff frequency (default 0.1). tr : positive float, optional - The Repetition Time of the BOLD data. If None, will extract - the TR from NIfTI header (default None). - masks : str, dict or list of dict or str, optional + The repetition time of the BOLD data. + If None, will extract the TR from NIfTI header (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Notes ----- @@ -54,26 +54,8 @@ class ALFFMaps(ALFFBase): """ - def __init__( - self, - maps: str, - using: str, - highpass: float = 0.01, - lowpass: float = 0.1, - tr: Optional[float] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - # Superclass init first to validate `using` parameter - super().__init__( - highpass=highpass, - lowpass=lowpass, - using=using, - tr=tr, - name=name, - ) - self.maps = maps - self.masks = masks + maps: str + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def compute( self, @@ -127,7 +109,7 @@ def compute( **MapsAggregation( maps=self.maps, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute( input=aggregation_alff_input, extra_input=extra_input, @@ -137,7 +119,7 @@ def compute( **MapsAggregation( maps=self.maps, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute( input=aggregation_falff_input, extra_input=extra_input, diff --git a/junifer/markers/falff/falff_parcels.py b/junifer/markers/falff/falff_parcels.py index ed808c8d7..a8c1e8d10 100644 --- a/junifer/markers/falff/falff_parcels.py +++ b/junifer/markers/falff/falff_parcels.py @@ -6,9 +6,10 @@ # Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger from ..parcel_aggregation import ParcelAggregation from .falff_base import ALFFBase @@ -23,31 +24,30 @@ class ALFFParcels(ALFFBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. using : :enum:`.ALFFImpl` + agg_method : str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for valid options (default None). highpass : positive float, optional - The highpass cutoff frequency for the bandpass filter. If 0, - it will not apply a highpass filter (default 0.01). + Highpass cutoff frequency (default 0.01). lowpass : positive float, optional - The lowpass cutoff frequency for the bandpass filter (default 0.1). + Lowpass cutoff frequency (default 0.1). tr : positive float, optional - The Repetition Time of the BOLD data. If None, will extract - the TR from NIfTI header (default None). - agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The repetition time of the BOLD data. + If None, will extract the TR from NIfTI header (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Notes ----- @@ -63,30 +63,8 @@ class ALFFParcels(ALFFBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - using: str, - highpass: float = 0.01, - lowpass: float = 0.1, - tr: Optional[float] = None, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - # Superclass init first to validate `using` parameter - super().__init__( - highpass=highpass, - lowpass=lowpass, - using=using, - tr=tr, - name=name, - ) - self.parcellation = parcellation - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.masks = masks + parcellation: list[str] + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def compute( self, @@ -142,7 +120,7 @@ def compute( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute( input=aggregation_alff_input, extra_input=extra_input, @@ -154,7 +132,7 @@ def compute( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute( input=aggregation_falff_input, extra_input=extra_input, diff --git a/junifer/markers/falff/falff_spheres.py b/junifer/markers/falff/falff_spheres.py index 8ec9c1f98..b2386341d 100644 --- a/junifer/markers/falff/falff_spheres.py +++ b/junifer/markers/falff/falff_spheres.py @@ -6,9 +6,12 @@ # Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union + +from pydantic import PositiveFloat from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger from ..sphere_aggregation import SphereAggregation from .falff_base import ALFFBase @@ -26,35 +29,35 @@ class ALFFSpheres(ALFFBase): coords : str The name of the coordinates list to use. See :func:`.list_data` for options. - radius : float, optional - The radius of the sphere in mm. If None, the signal will be extracted - from a single voxel. See :class:`nilearn.maskers.NiftiSpheresMasker` - for more information (default None). using : :enum:`.ALFFImpl` + radius : ``zero`` or positive float or None, optional + The radius of the sphere in millimetres. + If None, the signal will be extracted from a single voxel. + See :class:`.JuniferNiftiSpheresMasker` for more information + (default None). allow_overlap : bool, optional - Whether to allow overlapping spheres. If False, an error is raised if - the spheres overlap (default is False). + Whether to allow overlapping spheres. + If False, an error is raised if the spheres overlap (default False). + agg_method : str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for valid options (default None). highpass : positive float, optional - The highpass cutoff frequency for the bandpass filter. If 0, - it will not apply a highpass filter (default 0.01). + Highpass cutoff frequency (default 0.01). lowpass : positive float, optional - The lowpass cutoff frequency for the bandpass filter (default 0.1). + Lowpass cutoff frequency (default 0.1). tr : positive float, optional - The Repetition Time of the BOLD data. If None, will extract - the TR from NIfTI header (default None). - agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`.get_aggfunc_by_name`. - masks : str, dict or list of dict or str, optional + The repetition time of the BOLD data. + If None, will extract the TR from NIfTI header (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Notes ----- @@ -70,34 +73,10 @@ class ALFFSpheres(ALFFBase): """ - def __init__( - self, - coords: str, - using: str, - radius: Optional[float] = None, - allow_overlap: bool = False, - highpass: float = 0.01, - lowpass: float = 0.1, - tr: Optional[float] = None, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - # Superclass init first to validate `using` parameter - super().__init__( - highpass=highpass, - lowpass=lowpass, - using=using, - tr=tr, - name=name, - ) - self.coords = coords - self.radius = radius - self.allow_overlap = allow_overlap - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.masks = masks + coords: str + radius: Optional[Union[Literal[0], PositiveFloat]] = None + allow_overlap: bool = False + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def compute( self, @@ -155,7 +134,7 @@ def compute( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute( input=aggregation_alff_input, extra_input=extra_input, @@ -169,7 +148,7 @@ def compute( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute( input=aggregation_falff_input, extra_input=extra_input, diff --git a/junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py b/junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py index 4f7515953..ead33fe5b 100644 --- a/junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +++ b/junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py @@ -4,7 +4,7 @@ # Kaustubh R. Patil # License: AGPL -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union import pandas as pd @@ -32,23 +32,23 @@ class CrossParcellationFC(BaseMarker): parcellation_two : str The name of the second parcellation. agg_method : str, optional - The method to perform aggregation using. + The aggregation function to use. See :func:`.get_aggfunc_by_name` for options (default "mean"). agg_method_params : dict, optional - Parameters to pass to the aggregation function. + The parameters to pass to the aggregation function. See :func:`.get_aggfunc_by_name` for options (default None). corr_method : str, optional Any method that can be passed to :meth:`pandas.DataFrame.corr` (default "pearson"). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use - ``BOLD_CrossParcellationFC`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ @@ -60,27 +60,20 @@ class CrossParcellationFC(BaseMarker): }, } - def __init__( - self, - parcellation_one: str, - parcellation_two: str, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - corr_method: str = "pearson", - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - if parcellation_one == parcellation_two: + parcellation_one: str + parcellation_two: str + agg_method: str = "mean" + agg_method_params: Optional[dict] = None + corr_method: str = "pearson" + masks: Optional[list[Union[dict, str]]] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + if self.parcellation_one == self.parcellation_two: raise_error( "The two parcellations must be different.", ) - self.parcellation_one = parcellation_one - self.parcellation_two = parcellation_two - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.corr_method = corr_method - self.masks = masks - super().__init__(on=["BOLD"], name=name) def compute( self, @@ -125,18 +118,18 @@ def compute( ) # Perform aggregation using two parcellations aggregation_parcellation_one = ParcelAggregation( - parcellation=self.parcellation_one, + parcellation=[self.parcellation_one], method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input, extra_input=extra_input) aggregation_parcellation_two = ParcelAggregation( - parcellation=self.parcellation_two, + parcellation=[self.parcellation_two], method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input, extra_input=extra_input) return { diff --git a/junifer/markers/functional_connectivity/edge_functional_connectivity_maps.py b/junifer/markers/functional_connectivity/edge_functional_connectivity_maps.py index 9f833a203..196b4267f 100644 --- a/junifer/markers/functional_connectivity/edge_functional_connectivity_maps.py +++ b/junifer/markers/functional_connectivity/edge_functional_connectivity_maps.py @@ -3,9 +3,10 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional from ...api.decorators import register_marker +from ...datagrabber import DataType from ..maps_aggregation import MapsAggregation from ..utils import _ets from .functional_connectivity_base import FunctionalConnectivityBase @@ -24,23 +25,23 @@ class EdgeCentricFCMaps(FunctionalConnectivityBase): The name of the map(s) to use. See :func:`.list_data` for options. conn_method : str, optional - The method to perform connectivity measure using. + The connectivity measure to use. See :class:`.JuniferConnectivityMeasure` for options (default "correlation"). - conn_method_params : dict, optional - Parameters to pass to :class:`.JuniferConnectivityMeasure`. + conn_method_params : dict or None, optional + The parameters to pass to :class:`.JuniferConnectivityMeasure`. If None, ``{"empirical": True}`` will be used, which would mean :class:`sklearn.covariance.EmpiricalCovariance` is used to compute covariance. If usage of :class:`sklearn.covariance.LedoitWolf` is desired, ``{"empirical": False}`` should be passed (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use - ``BOLD_EdgeCentricFCParcels`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). References ---------- @@ -50,21 +51,8 @@ class EdgeCentricFCMaps(FunctionalConnectivityBase): """ - def __init__( - self, - maps: str, - conn_method: str = "correlation", - conn_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.maps = maps - super().__init__( - conn_method=conn_method, - conn_method_params=conn_method_params, - masks=masks, - name=name, - ) + maps: str + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -99,7 +87,7 @@ def aggregate( aggregation = MapsAggregation( maps=self.maps, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input, extra_input=extra_input) # Compute edgewise timeseries ets, edge_names = _ets( diff --git a/junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py b/junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py index 1599ec5b0..e7e913de5 100644 --- a/junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +++ b/junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py @@ -4,9 +4,10 @@ # Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional from ...api.decorators import register_marker +from ...datagrabber import DataType from ..parcel_aggregation import ParcelAggregation from ..utils import _ets from .functional_connectivity_base import FunctionalConnectivityBase @@ -21,35 +22,34 @@ class EdgeCentricFCParcels(FunctionalConnectivityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. + The aggregation function to use. See :func:`.get_aggfunc_by_name` for options (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. - See :func:`.get_aggfunc_by_name` for options - (default None). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). conn_method : str, optional - The method to perform connectivity measure using. + The connectivity measure to use. See :class:`.JuniferConnectivityMeasure` for options (default "correlation"). - conn_method_params : dict, optional - Parameters to pass to :class:`.JuniferConnectivityMeasure`. + conn_method_params : dict or None, optional + The parameters to pass to :class:`.JuniferConnectivityMeasure`. If None, ``{"empirical": True}`` will be used, which would mean :class:`sklearn.covariance.EmpiricalCovariance` is used to compute covariance. If usage of :class:`sklearn.covariance.LedoitWolf` is desired, ``{"empirical": False}`` should be passed (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use - ``BOLD_EdgeCentricFCParcels`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). References ---------- @@ -59,25 +59,8 @@ class EdgeCentricFCParcels(FunctionalConnectivityBase): """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - conn_method: str = "correlation", - conn_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.parcellation = parcellation - super().__init__( - agg_method=agg_method, - agg_method_params=agg_method_params, - conn_method=conn_method, - conn_method_params=conn_method_params, - masks=masks, - name=name, - ) + parcellation: list[str] + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -114,7 +97,7 @@ def aggregate( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input, extra_input=extra_input) # Compute edgewise timeseries ets, edge_names = _ets( diff --git a/junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py b/junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py index 67595be7e..e0e99e6a3 100644 --- a/junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py +++ b/junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py @@ -4,11 +4,14 @@ # Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union + +from pydantic import PositiveFloat from ...api.decorators import register_marker +from ...datagrabber import DataType from ..sphere_aggregation import SphereAggregation -from ..utils import _ets, raise_error +from ..utils import _ets from .functional_connectivity_base import FunctionalConnectivityBase @@ -24,40 +27,39 @@ class EdgeCentricFCSpheres(FunctionalConnectivityBase): coords : str The name of the coordinates list to use. See :func:`.list_data` for options. - radius : positive float, optional - The radius of the sphere around each coordinates in millimetres. + radius : ``zero`` or positive float or None, optional + The radius of the sphere in millimetres. If None, the signal will be extracted from a single voxel. See :class:`.JuniferNiftiSpheresMasker` for more information (default None). allow_overlap : bool, optional - Whether to allow overlapping spheres. If False, an error is raised if - the spheres overlap (default False). + Whether to allow overlapping spheres. + If False, an error is raised if the spheres overlap (default False). agg_method : str, optional - The method to perform aggregation using. + The aggregation function to use. See :func:`.get_aggfunc_by_name` for options (default "mean"). agg_method_params : dict, optional - Parameters to pass to the aggregation function. - See :func:`.get_aggfunc_by_name` for options - (default None). + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). conn_method : str, optional - The method to perform connectivity measure using. + The connectivity measure to use. See :class:`.JuniferConnectivityMeasure` for options (default "correlation"). - conn_method_params : dict, optional - Parameters to pass to :class:`.JuniferConnectivityMeasure`. + conn_method_params : dict or None, optional + The parameters to pass to :class:`.JuniferConnectivityMeasure`. If None, ``{"empirical": True}`` will be used, which would mean :class:`sklearn.covariance.EmpiricalCovariance` is used to compute covariance. If usage of :class:`sklearn.covariance.LedoitWolf` is desired, ``{"empirical": False}`` should be passed (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use - ``BOLD_EdgeCentricFCSpheres`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). References ---------- @@ -67,31 +69,10 @@ class EdgeCentricFCSpheres(FunctionalConnectivityBase): """ - def __init__( - self, - coords: str, - radius: Optional[float] = None, - allow_overlap: bool = False, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - conn_method: str = "correlation", - conn_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.coords = coords - self.radius = radius - self.allow_overlap = allow_overlap - if radius is None or radius <= 0: - raise_error(f"radius should be > 0: provided {radius}") - super().__init__( - agg_method=agg_method, - agg_method_params=agg_method_params, - conn_method=conn_method, - conn_method_params=conn_method_params, - masks=masks, - name=name, - ) + coords: str + radius: Optional[Union[Literal[0], PositiveFloat]] = None + allow_overlap: bool = False + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -130,7 +111,7 @@ def aggregate( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input, extra_input=extra_input) # Compute edgewise timeseries ets, edge_names = _ets( diff --git a/junifer/markers/functional_connectivity/functional_connectivity_base.py b/junifer/markers/functional_connectivity/functional_connectivity_base.py index 6ab818401..4bcc4b4b3 100644 --- a/junifer/markers/functional_connectivity/functional_connectivity_base.py +++ b/junifer/markers/functional_connectivity/functional_connectivity_base.py @@ -4,7 +4,7 @@ # License: AGPL from abc import abstractmethod -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union from sklearn.covariance import EmpiricalCovariance, LedoitWolf @@ -25,31 +25,30 @@ class FunctionalConnectivityBase(BaseMarker): Parameters ---------- agg_method : str, optional - The method to perform aggregation using. - Check valid options in :func:`.get_aggfunc_by_name` + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. - Check valid options in :func:`.get_aggfunc_by_name` - (default None). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). conn_method : str, optional - The method to perform connectivity measure using. - Check valid options in :class:`.JuniferConnectivityMeasure` + The connectivity measure to use. + See :class:`.JuniferConnectivityMeasure` for more information (default "correlation"). - conn_method_params : dict, optional - Parameters to pass to :class:`.JuniferConnectivityMeasure`. + conn_method_params : dict or None, optional + The parameters to pass to :class:`.JuniferConnectivityMeasure`. If None, ``{"empirical": True}`` will be used, which would mean :class:`sklearn.covariance.EmpiricalCovariance` is used to compute covariance. If usage of :class:`sklearn.covariance.LedoitWolf` is desired, ``{"empirical": False}`` should be passed (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use ``BOLD_`` - (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ @@ -61,25 +60,20 @@ class FunctionalConnectivityBase(BaseMarker): }, } - def __init__( - self, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - conn_method: str = "correlation", - conn_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.conn_method = conn_method - self.conn_method_params = conn_method_params or {} + agg_method: str = "mean" + agg_method_params: Optional[dict] = None + conn_method: str = "correlation" + conn_method_params: Optional[dict] = None + masks: Optional[list[Union[dict, str]]] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + self.conn_method_params = self.conn_method_params or {} # Reverse of nilearn behavior self.conn_method_params["empirical"] = self.conn_method_params.get( "empirical", True ) - self.masks = masks - super().__init__(on="BOLD", name=name) @abstractmethod def aggregate( @@ -123,7 +117,7 @@ def compute( - ``data`` : functional connectivity matrix as ``numpy.ndarray`` - ``row_names`` : ROI labels as list of str - ``col_names`` : ROI labels as list of str - - ``matrix_kind`` : :obj:`.junifer.storage.MatrixKind` + - ``matrix_kind`` : :enum:`.MatrixKind` """ # Perform necessary aggregation diff --git a/junifer/markers/functional_connectivity/functional_connectivity_maps.py b/junifer/markers/functional_connectivity/functional_connectivity_maps.py index 90889869b..c9ebfcdb0 100644 --- a/junifer/markers/functional_connectivity/functional_connectivity_maps.py +++ b/junifer/markers/functional_connectivity/functional_connectivity_maps.py @@ -3,9 +3,10 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional from ...api.decorators import register_marker +from ...datagrabber import DataType from ..maps_aggregation import MapsAggregation from .functional_connectivity_base import FunctionalConnectivityBase @@ -23,41 +24,28 @@ class FunctionalConnectivityMaps(FunctionalConnectivityBase): The name of the map(s) to use. See :func:`.list_data` for options. conn_method : str, optional - The method to perform connectivity measure using. + The connectivity measure to use. See :class:`.JuniferConnectivityMeasure` for options (default "correlation"). - conn_method_params : dict, optional - Parameters to pass to :class:`.JuniferConnectivityMeasure`. + conn_method_params : dict or None, optional + The parameters to pass to :class:`.JuniferConnectivityMeasure`. If None, ``{"empirical": True}`` will be used, which would mean :class:`sklearn.covariance.EmpiricalCovariance` is used to compute covariance. If usage of :class:`sklearn.covariance.LedoitWolf` is desired, ``{"empirical": False}`` should be passed (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use - ``BOLD_FunctionalConnectivityMaps`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ - def __init__( - self, - maps: str, - conn_method: str = "correlation", - conn_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.maps = maps - super().__init__( - conn_method=conn_method, - conn_method_params=conn_method_params, - masks=masks, - name=name, - ) + maps: str + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -91,5 +79,5 @@ def aggregate( return MapsAggregation( maps=self.maps, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input=input, extra_input=extra_input) diff --git a/junifer/markers/functional_connectivity/functional_connectivity_parcels.py b/junifer/markers/functional_connectivity/functional_connectivity_parcels.py index db79293ba..d2de7849c 100644 --- a/junifer/markers/functional_connectivity/functional_connectivity_parcels.py +++ b/junifer/markers/functional_connectivity/functional_connectivity_parcels.py @@ -5,9 +5,10 @@ # Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional from ...api.decorators import register_marker +from ...datagrabber import DataType from ..parcel_aggregation import ParcelAggregation from .functional_connectivity_base import FunctionalConnectivityBase @@ -21,57 +22,39 @@ class FunctionalConnectivityParcels(FunctionalConnectivityBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. + The aggregation function to use. See :func:`.get_aggfunc_by_name` for options (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. - See :func:`.get_aggfunc_by_name` for options - (default None). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). conn_method : str, optional - The method to perform connectivity measure using. + The connectivity measure to use. See :class:`.JuniferConnectivityMeasure` for options (default "correlation"). - conn_method_params : dict, optional - Parameters to pass to :class:`.JuniferConnectivityMeasure`. + conn_method_params : dict or None, optional + The parameters to pass to :class:`.JuniferConnectivityMeasure`. If None, ``{"empirical": True}`` will be used, which would mean :class:`sklearn.covariance.EmpiricalCovariance` is used to compute covariance. If usage of :class:`sklearn.covariance.LedoitWolf` is desired, ``{"empirical": False}`` should be passed (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use - ``BOLD_FunctionalConnectivityParcels`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - conn_method: str = "correlation", - conn_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.parcellation = parcellation - super().__init__( - agg_method=agg_method, - agg_method_params=agg_method_params, - conn_method=conn_method, - conn_method_params=conn_method_params, - masks=masks, - name=name, - ) + parcellation: list[str] + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -107,5 +90,5 @@ def aggregate( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input=input, extra_input=extra_input) diff --git a/junifer/markers/functional_connectivity/functional_connectivity_spheres.py b/junifer/markers/functional_connectivity/functional_connectivity_spheres.py index 6f976c0b3..ae38e821d 100644 --- a/junifer/markers/functional_connectivity/functional_connectivity_spheres.py +++ b/junifer/markers/functional_connectivity/functional_connectivity_spheres.py @@ -5,11 +5,13 @@ # Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union + +from pydantic import PositiveFloat from ...api.decorators import register_marker +from ...datagrabber import DataType from ..sphere_aggregation import SphereAggregation -from ..utils import raise_error from .functional_connectivity_base import FunctionalConnectivityBase @@ -25,68 +27,46 @@ class FunctionalConnectivitySpheres(FunctionalConnectivityBase): coords : str The name of the coordinates list to use. See :func:`.list_data` for options. - radius : positive float, optional - The radius of the sphere around each coordinates in millimetres. + radius : ``zero`` or positive float or None, optional + The radius of the sphere in millimetres. If None, the signal will be extracted from a single voxel. See :class:`.JuniferNiftiSpheresMasker` for more information (default None). allow_overlap : bool, optional - Whether to allow overlapping spheres. If False, an error is raised if - the spheres overlap (default False). + Whether to allow overlapping spheres. + If False, an error is raised if the spheres overlap (default False). agg_method : str, optional - The method to perform aggregation using. + The aggregation function to use. See :func:`.get_aggfunc_by_name` for options (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. - See :func:`.get_aggfunc_by_name` for options - (default None). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). conn_method : str, optional - The method to perform connectivity measure using. + The connectivity measure to use. See :class:`.JuniferConnectivityMeasure` for options (default "correlation"). - conn_method_params : dict, optional - Parameters to pass to :class:`.JuniferConnectivityMeasure`. + conn_method_params : dict or None, optional + The parameters to pass to :class:`.JuniferConnectivityMeasure`. If None, ``{"empirical": True}`` will be used, which would mean :class:`sklearn.covariance.EmpiricalCovariance` is used to compute covariance. If usage of :class:`sklearn.covariance.LedoitWolf` is desired, ``{"empirical": False}`` should be passed (default None). - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use - ``BOLD_FunctionalConnectivitySpheres`` (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ - def __init__( - self, - coords: str, - radius: Optional[float] = None, - allow_overlap: bool = False, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - conn_method: str = "correlation", - conn_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.coords = coords - self.radius = radius - self.allow_overlap = allow_overlap - if radius is None or radius <= 0: - raise_error(f"radius should be > 0: provided {radius}") - super().__init__( - agg_method=agg_method, - agg_method_params=agg_method_params, - conn_method=conn_method, - conn_method_params=conn_method_params, - masks=masks, - name=name, - ) + coords: str + radius: Optional[Union[Literal[0], PositiveFloat]] = None + allow_overlap: bool = False + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -124,5 +104,5 @@ def aggregate( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input=input, extra_input=extra_input) diff --git a/junifer/markers/maps_aggregation.py b/junifer/markers/maps_aggregation.py index 540c6428b..21fd39518 100644 --- a/junifer/markers/maps_aggregation.py +++ b/junifer/markers/maps_aggregation.py @@ -3,7 +3,7 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union from nilearn.maskers import NiftiMapsMasker @@ -29,23 +29,26 @@ class MapsAggregation(BaseMarker): maps : str The name of the map(s) to use. See :func:`.list_data` for options. - time_method : str, optional - The method to use to aggregate the time series over the time points, - after aggregation (only applicable to BOLD data). If None, + time_method : str or None, optional + The aggregation function to use for time series after applying + :term:`method` (only applicable to BOLD data). If None, it will not operate on the time dimension (default None). - time_method_params : dict, optional + time_method_params : dict or None, optional The parameters to pass to the time aggregation method (default None). - masks : str, dict or list of dict or str, optional + on : list of {``DataType.T1w``, ``DataType.T2w``, ``DataType.BOLD``, \ + ``DataType.VBM_GM``, ``DataType.VBM_WM``, ``DataType.VBM_CSF``, \ + ``DataType.FALFF``, ``DataType.GCOR``, ``DataType.LCOR``} or None, \ + optional + The data type(s) to apply the marker on. + If None, will work on all available data. + Check :enum:`.DataType` for valid values (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - on : {"T1w", "T2w", "BOLD", "VBM_GM", "VBM_WM", "VBM_CSF", "fALFF", \ - "GCOR", "LCOR"} or list of the options, optional - The data types to apply the marker to. If None, will work on all - available data (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Raises ------ @@ -87,32 +90,37 @@ class MapsAggregation(BaseMarker): }, } - def __init__( - self, - maps: str, - time_method: Optional[str] = None, - time_method_params: Optional[dict[str, Any]] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - on: Union[list[str], str, None] = None, - name: Optional[str] = None, - ) -> None: - self.maps = maps - self.masks = masks - super().__init__(on=on, name=name) - - # Verify after super init so self._on is set - if "BOLD" not in self._on and time_method is not None: + on: list[ + Literal[ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.VBM_CSF, + DataType.FALFF, + DataType.GCOR, + DataType.LCOR, + ] + ] + maps: str + time_method: Optional[str] = None + time_method_params: Optional[dict[str, Any]] = None + masks: Optional[list[Union[dict, str]]] = None + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + # self.on is set already + if "BOLD" not in self.on and self.time_method is not None: raise_error( "`time_method` can only be used with BOLD data. " "Please remove `time_method` parameter." ) - if time_method is None and time_method_params is not None: + if self.time_method is None and self.time_method_params is not None: raise_error( "`time_method_params` can only be used with `time_method`. " "Please remove `time_method_params` parameter." ) - self.time_method = time_method - self.time_method_params = time_method_params or {} def compute( self, input: dict[str, Any], extra_input: Optional[dict] = None diff --git a/junifer/markers/parcel_aggregation.py b/junifer/markers/parcel_aggregation.py index 92fbf7efa..a0cf9c56a 100644 --- a/junifer/markers/parcel_aggregation.py +++ b/junifer/markers/parcel_aggregation.py @@ -4,7 +4,7 @@ # Synchon Mandal # License: AGPL -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union import numpy as np from nilearn.image import math_img @@ -29,32 +29,36 @@ class ParcelAggregation(BaseMarker): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. - method : str - The method to perform aggregation using. Check valid options in - :func:`.get_aggfunc_by_name`. - method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`.get_aggfunc_by_name`. - time_method : str, optional - The method to use to aggregate the time series over the time points, - after applying :term:`method` (only applicable to BOLD data). If None, + method : str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + time_method : str or None, optional + The aggregation function to use for time series after applying + :term:`method` (only applicable to BOLD data). If None, it will not operate on the time dimension (default None). - time_method_params : dict, optional - The parameters to pass to the time aggregation method (default None). - masks : str, dict or list of dict or str, optional + time_method_params : dict or None, optional + The parameters to pass to the time aggregation function (default None). + on : list of {``DataType.T1w``, ``DataType.T2w``, ``DataType.BOLD``, \ + ``DataType.VBM_GM``, ``DataType.VBM_WM``, ``DataType.VBM_CSF``, \ + ``DataType.FALFF``, ``DataType.GCOR``, ``DataType.LCOR``} or None, \ + optional + The data type(s) to apply the marker on. + If None, will work on all available data. + Check :enum:`.DataType` for valid values (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - on : {"T1w", "T2w", "BOLD", "VBM_GM", "VBM_WM", "VBM_CSF", "fALFF", \ - "GCOR", "LCOR"} or list of the options, optional - The data types to apply the marker to. If None, will work on all - available data (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Raises ------ @@ -96,38 +100,39 @@ class ParcelAggregation(BaseMarker): }, } - def __init__( - self, - parcellation: Union[str, list[str]], - method: str, - method_params: Optional[dict[str, Any]] = None, - time_method: Optional[str] = None, - time_method_params: Optional[dict[str, Any]] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - on: Union[list[str], str, None] = None, - name: Optional[str] = None, - ) -> None: - if not isinstance(parcellation, list): - parcellation = [parcellation] - self.parcellation = parcellation - self.method = method - self.method_params = method_params or {} - self.masks = masks - super().__init__(on=on, name=name) - - # Verify after super init so self._on is set - if "BOLD" not in self._on and time_method is not None: + on: list[ + Literal[ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.VBM_CSF, + DataType.FALFF, + DataType.GCOR, + DataType.LCOR, + ] + ] + parcellation: list[str] + method: str = "mean" + method_params: Optional[dict[str, Any]] = None + time_method: Optional[str] = None + time_method_params: Optional[dict[str, Any]] = None + masks: Optional[list[Union[dict, str]]] = None + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + # self.on is set already + if "BOLD" not in self.on and self.time_method is not None: raise_error( "`time_method` can only be used with BOLD data. " "Please remove `time_method` parameter." ) - if time_method is None and time_method_params is not None: + if self.time_method is None and self.time_method_params is not None: raise_error( "`time_method_params` can only be used with `time_method`. " "Please remove `time_method_params` parameter." ) - self.time_method = time_method - self.time_method_params = time_method_params or {} def compute( self, input: dict[str, Any], extra_input: Optional[dict] = None diff --git a/junifer/markers/reho/reho_base.py b/junifer/markers/reho/reho_base.py index d5c9f4ec5..6b591f376 100644 --- a/junifer/markers/reho/reho_base.py +++ b/junifer/markers/reho/reho_base.py @@ -9,13 +9,15 @@ TYPE_CHECKING, Any, ClassVar, + Literal, Optional, + Union, ) from ...datagrabber import DataType from ...storage import StorageType from ...typing import ConditionalDependencies, MarkerInOutMappings -from ...utils import logger, raise_error +from ...utils import logger from ..base import BaseMarker from ._afni_reho import AFNIReHo from ._junifer_reho import JuniferReHo @@ -46,14 +48,62 @@ class ReHoBase(BaseMarker): Parameters ---------- using : :enum:`.ReHoImpl` - name : str, optional - The name of the marker. If None, it will use the class name - (default None). - - Raises - ------ - ValueError - If ``using`` is invalid. + reho_params : dict or None, optional + Extra parameters for computing ReHo map as a dictionary (default None). + If ``using=ReHoImpl.afni``, then the valid keys are: + + * ``nneigh`` : {7, 19, 27}, optional (default 27) + Number of voxels in the neighbourhood, inclusive. Can be: + + - 7 : for facewise neighbours only + - 19 : for face- and edge-wise neighbours + - 27 : for face-, edge-, and node-wise neighbors + + * ``neigh_rad`` : positive float, optional + The radius of a desired neighbourhood (default None). + * ``neigh_x`` : positive float, optional + The semi-radius for x-axis of ellipsoidal volumes (default None). + * ``neigh_y`` : positive float, optional + The semi-radius for y-axis of ellipsoidal volumes (default None). + * ``neigh_z`` : positive float, optional + The semi-radius for z-axis of ellipsoidal volumes (default None). + * ``box_rad`` : positive int, optional + The number of voxels outward in a given cardinal direction for a + cubic box centered on a given voxel (default None). + * ``box_x`` : positive int, optional + The number of voxels for +/- x-axis of cuboidal volumes + (default None). + * ``box_y`` : positive int, optional + The number of voxels for +/- y-axis of cuboidal volumes + (default None). + * ``box_z`` : positive int, optional + The number of voxels for +/- z-axis of cuboidal volumes + (default None). + + else if ``using=ReHoImpl.junifer``, then the valid keys are: + + * ``nneigh`` : {7, 19, 27, 125}, optional (default 27) + Number of voxels in the neighbourhood, inclusive. Can be: + + * 7 : for facewise neighbours only + * 19 : for face- and edge-wise neighbours + * 27 : for face-, edge-, and node-wise neighbors + * 125 : for 5x5 cuboidal volume + + agg_method : str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional + The specification of the masks to apply to regions before extracting + signals. Check :ref:`Using Masks ` for more details. + If None, will not apply any mask (default None). + name : str or None, optional + The name of the marker. + If None, it will use the class name (default None). """ @@ -74,19 +124,12 @@ class ReHoBase(BaseMarker): }, } - def __init__( - self, - using: str, - name: Optional[str] = None, - ) -> None: - # Validate `using` parameter - valid_using = [dep["using"] for dep in self._CONDITIONAL_DEPENDENCIES] - if using not in valid_using: - raise_error( - f"Invalid value for `using`, should be one of: {valid_using}" - ) - self.using = using - super().__init__(on="BOLD", name=name) + using: ReHoImpl + reho_params: Optional[dict] = None + agg_method: str = "mean" + agg_method_params: Optional[dict] = None + masks: Optional[list[Union[dict, str]]] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def _compute( self, diff --git a/junifer/markers/reho/reho_maps.py b/junifer/markers/reho/reho_maps.py index a9822b6dd..7ce9b6ffa 100644 --- a/junifer/markers/reho/reho_maps.py +++ b/junifer/markers/reho/reho_maps.py @@ -3,11 +3,12 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger from ..maps_aggregation import MapsAggregation from .reho_base import ReHoBase @@ -25,8 +26,8 @@ class ReHoMaps(ReHoBase): maps : str The name of the map(s) to use. See :func:`.list_data` for options. - reho_params : dict, optional using : :enum:`.ReHoImpl` + reho_params : dict or None, optional Extra parameters for computing ReHo map as a dictionary (default None). If ``using=ReHoImpl.afni``, then the valid keys are: @@ -68,29 +69,18 @@ class ReHoMaps(ReHoBase): * 27 : for face-, edge-, and node-wise neighbors * 125 : for 5x5 cuboidal volume - masks : str, dict or list of dict or str, optional + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, it will use the class name - (default None). + name : str or None, optional + The name of the marker. + If None, it will use the class name (default None). """ - def __init__( - self, - maps: str, - using: str, - reho_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - # Superclass init first to validate `using` parameter - super().__init__(using=using, name=name) - self.maps = maps - self.reho_params = reho_params - self.masks = masks + maps: str + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def compute( self, @@ -139,7 +129,7 @@ def compute( maps_aggregation = MapsAggregation( maps=self.maps, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute( input=aggregation_input, extra_input=extra_input, diff --git a/junifer/markers/reho/reho_parcels.py b/junifer/markers/reho/reho_parcels.py index 2cce611e2..d13b6eaad 100644 --- a/junifer/markers/reho/reho_parcels.py +++ b/junifer/markers/reho/reho_parcels.py @@ -3,11 +3,12 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional import numpy as np from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger from ..parcel_aggregation import ParcelAggregation from .reho_base import ReHoBase @@ -22,11 +23,11 @@ class ReHoParcels(ReHoBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. - reho_params : dict, optional using : :enum:`.ReHoImpl` + reho_params : dict or None, optional Extra parameters for computing ReHo map as a dictionary (default None). If ``using=ReHoImpl.afni``, then the valid keys are: @@ -69,38 +70,24 @@ class ReHoParcels(ReHoBase): * 125 : for 5x5 cuboidal volume agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, it will use the class name - (default None). + name : str or None, optional + The name of the marker. + If None, it will use the class name (default None). """ - def __init__( - self, - parcellation: Union[str, list[str]], - using: str, - reho_params: Optional[dict] = None, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - # Superclass init first to validate `using` parameter - super().__init__(using=using, name=name) - self.parcellation = parcellation - self.reho_params = reho_params - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.masks = masks + parcellation: list[str] + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def compute( self, @@ -151,7 +138,7 @@ def compute( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute( input=aggregation_input, extra_input=extra_input, diff --git a/junifer/markers/reho/reho_spheres.py b/junifer/markers/reho/reho_spheres.py index 649293887..83aa4e54e 100644 --- a/junifer/markers/reho/reho_spheres.py +++ b/junifer/markers/reho/reho_spheres.py @@ -3,11 +3,13 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union import numpy as np +from pydantic import PositiveFloat from ...api.decorators import register_marker +from ...datagrabber import DataType from ...utils import logger from ..sphere_aggregation import SphereAggregation from .reho_base import ReHoBase @@ -25,19 +27,16 @@ class ReHoSpheres(ReHoBase): coords : str The name of the coordinates list to use. See :func:`.list_data` for options. - radius : float, optional - The radius of the sphere in millimeters. If None, the signal will be - extracted from a single voxel. See - :class:`nilearn.maskers.NiftiSpheresMasker` for more information using : :enum:`.ReHoImpl` + radius : ``zero`` or positive float or None, optional + The radius of the sphere in millimetres. + If None, the signal will be extracted from a single voxel. + See :class:`.JuniferNiftiSpheresMasker` for more information (default None). allow_overlap : bool, optional - Whether to allow overlapping spheres. If False, an error is raised if - the spheres overlap (default is False). - use_afni : bool, optional - Whether to use AFNI for computing. If None, will use AFNI only - if available (default None). - reho_params : dict, optional + Whether to allow overlapping spheres. + If False, an error is raised if the spheres overlap (default False). + reho_params : dict or None, optional Extra parameters for computing ReHo map as a dictionary (default None). If ``using=ReHoImpl.afni``, then the valid keys are: @@ -80,42 +79,26 @@ class ReHoSpheres(ReHoBase): * 125 : for 5x5 cuboidal volume agg_method : str, optional - The aggregation method to use. - See :func:`.get_aggfunc_by_name` for more information - (default None). - agg_method_params : dict, optional - The parameters to pass to the aggregation method (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, it will use the class name - (default None). + name : str or None, optional + The name of the marker. + If None, it will use the class name (default None). """ - def __init__( - self, - coords: str, - using: str, - radius: Optional[float] = None, - allow_overlap: bool = False, - reho_params: Optional[dict] = None, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - # Superclass init first to validate `using` parameter - super().__init__(using=using, name=name) - self.coords = coords - self.radius = radius - self.allow_overlap = allow_overlap - self.reho_params = reho_params - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.masks = masks + coords: str + radius: Optional[Union[Literal[0], PositiveFloat]] = None + allow_overlap: bool = False + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def compute( self, @@ -168,7 +151,7 @@ def compute( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input=aggregation_input, extra_input=extra_input) return { diff --git a/junifer/markers/sphere_aggregation.py b/junifer/markers/sphere_aggregation.py index b225fab2e..485292126 100644 --- a/junifer/markers/sphere_aggregation.py +++ b/junifer/markers/sphere_aggregation.py @@ -4,7 +4,9 @@ # Synchon Mandal # License: AGPL -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union + +from pydantic import PositiveFloat from ..api.decorators import register_marker from ..data import get_data @@ -29,37 +31,41 @@ class SphereAggregation(BaseMarker): coords : str The name of the coordinates list to use. See :func:`.list_data` for options. - radius : float, optional - The radius of the sphere in millimeters. If None, the signal will be - extracted from a single voxel. See - :class:`nilearn.maskers.NiftiSpheresMasker` for more information + radius : ``zero`` or positive float or None, optional + The radius of the sphere in millimeters. + If None, the signal will be extracted from a single voxel. + See :class:`.JuniferNiftiSpheresMasker` for more information (default None). allow_overlap : bool, optional - Whether to allow overlapping spheres. If False, an error is raised if - the spheres overlap (default is False). + Whether to allow overlapping spheres. + If False, an error is raised if the spheres overlap (default False). method : str, optional - The aggregation method to use. - See :func:`.get_aggfunc_by_name` for more information + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options (default "mean"). - method_params : dict, optional - The parameters to pass to the aggregation method (default None). - time_method : str, optional - The method to use to aggregate the time series over the time points, - after applying :term:`method` (only applicable to BOLD data). If None, + method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + time_method : str or None, optional + The aggregation function to use for time series after applying + :term:`method` (only applicable to BOLD data). If None, it will not operate on the time dimension (default None). - time_method_params : dict, optional - The parameters to pass to the time aggregation method (default None). - masks : str, dict or list of dict or str, optional + time_method_params : dict or None, optional + The parameters to pass to the time aggregation function (default None). + on : list of {``DataType.T1w``, ``DataType.T2w``, ``DataType.BOLD``, \ + ``DataType.VBM_GM``, ``DataType.VBM_WM``, ``DataType.VBM_CSF``, \ + ``DataType.FALFF``, ``DataType.GCOR``, ``DataType.LCOR``} or None, + optional + The data type(s) to apply the marker on. + If None, will work on all available data. + Check :enum:`.DataType` for valid values (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - on : {"T1w", "T2w", "BOLD", "VBM_GM", "VBM_WM", "VBM_CSF", "fALFF", \ - "GCOR", "LCOR"} or list of the options, optional - The data types to apply the marker to. If None, will work on all - available data (default None). - name : str, optional - The name of the marker. By default, it will use KIND_SphereAggregation - where KIND is the kind of data it was applied to (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). Raises ------ @@ -101,40 +107,41 @@ class SphereAggregation(BaseMarker): }, } - def __init__( - self, - coords: str, - radius: Optional[float] = None, - allow_overlap: bool = False, - method: str = "mean", - method_params: Optional[dict[str, Any]] = None, - time_method: Optional[str] = None, - time_method_params: Optional[dict[str, Any]] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - on: Union[list[str], str, None] = None, - name: Optional[str] = None, - ) -> None: - self.coords = coords - self.radius = radius - self.allow_overlap = allow_overlap - self.method = method - self.method_params = method_params or {} - self.masks = masks - super().__init__(on=on, name=name) - - # Verify after super init so self._on is set - if "BOLD" not in self._on and time_method is not None: + on: list[ + Literal[ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.VBM_CSF, + DataType.FALFF, + DataType.GCOR, + DataType.LCOR, + ] + ] + coords: str + radius: Optional[Union[Literal[0], PositiveFloat]] = None + allow_overlap: bool = False + method: str = "mean" + method_params: Optional[dict[str, Any]] = None + time_method: Optional[str] = None + time_method_params: Optional[dict[str, Any]] = None + masks: Optional[list[Union[dict, str]]] = None + + def validate_marker_params(self) -> None: + """Run extra logical validation for marker.""" + # self.on is set already + if "BOLD" not in self.on and self.time_method is not None: raise_error( "`time_method` can only be used with BOLD data. " "Please remove `time_method` parameter." ) - if time_method is None and time_method_params is not None: + if self.time_method is None and self.time_method_params is not None: raise_error( "`time_method_params` can only be used with `time_method`. " "Please remove `time_method_params` parameter." ) - self.time_method = time_method - self.time_method_params = time_method_params or {} def compute( self, diff --git a/junifer/markers/temporal_snr/temporal_snr_base.py b/junifer/markers/temporal_snr/temporal_snr_base.py index b4f3c3aa9..2be806be0 100644 --- a/junifer/markers/temporal_snr/temporal_snr_base.py +++ b/junifer/markers/temporal_snr/temporal_snr_base.py @@ -5,7 +5,7 @@ # License: AGPL from abc import abstractmethod -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union from nilearn import image as nimg @@ -25,18 +25,19 @@ class TemporalSNRBase(BaseMarker): Parameters ---------- agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ @@ -48,17 +49,10 @@ class TemporalSNRBase(BaseMarker): }, } - def __init__( - self, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.agg_method = agg_method - self.agg_method_params = agg_method_params - self.masks = masks - super().__init__(on="BOLD", name=name) + agg_method: str = "mean" + agg_method_params: Optional[dict] = None + masks: Optional[list[Union[dict, str]]] = None + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 @abstractmethod def aggregate( diff --git a/junifer/markers/temporal_snr/temporal_snr_maps.py b/junifer/markers/temporal_snr/temporal_snr_maps.py index e5408c50e..08b932f6f 100644 --- a/junifer/markers/temporal_snr/temporal_snr_maps.py +++ b/junifer/markers/temporal_snr/temporal_snr_maps.py @@ -3,9 +3,10 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional from ...api.decorators import register_marker +from ...datagrabber import DataType from ..maps_aggregation import MapsAggregation from .temporal_snr_base import TemporalSNRBase @@ -22,27 +23,18 @@ class TemporalSNRMaps(TemporalSNRBase): maps : str The name of the map(s) to use. See :func:`.list_data` for options. - masks : str, dict or list of dict or str, optional + masks : list of dict or str, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ - def __init__( - self, - maps: str, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.maps = maps - super().__init__( - masks=masks, - name=name, - ) + maps: str + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -76,5 +68,5 @@ def aggregate( return MapsAggregation( maps=self.maps, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input=input, extra_input=extra_input) diff --git a/junifer/markers/temporal_snr/temporal_snr_parcels.py b/junifer/markers/temporal_snr/temporal_snr_parcels.py index ef7dac8bf..7c3b29d6e 100644 --- a/junifer/markers/temporal_snr/temporal_snr_parcels.py +++ b/junifer/markers/temporal_snr/temporal_snr_parcels.py @@ -1,11 +1,13 @@ """Provide class for temporal SNR using parcels.""" # Authors: Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional from ...api.decorators import register_marker +from ...datagrabber import DataType from ..parcel_aggregation import ParcelAggregation from .temporal_snr_base import TemporalSNRBase @@ -19,40 +21,28 @@ class TemporalSNRParcels(TemporalSNRBase): Parameters ---------- - parcellation : str or list of str + parcellation : list of str The name(s) of the parcellation(s) to use. See :func:`.list_data` for options. agg_method : str, optional - The method to perform aggregation using. Check valid options in - :func:`.get_aggfunc_by_name` (default "mean"). - agg_method_params : dict, optional - Parameters to pass to the aggregation function. Check valid options in - :func:`.get_aggfunc_by_name` (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. If None, will use the class name (default - None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ - def __init__( - self, - parcellation: Union[str, list[str]], - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.parcellation = parcellation - super().__init__( - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) + parcellation: list[str] + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -88,5 +78,5 @@ def aggregate( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input=input, extra_input=extra_input) diff --git a/junifer/markers/temporal_snr/temporal_snr_spheres.py b/junifer/markers/temporal_snr/temporal_snr_spheres.py index f24038c41..486ffb0bd 100644 --- a/junifer/markers/temporal_snr/temporal_snr_spheres.py +++ b/junifer/markers/temporal_snr/temporal_snr_spheres.py @@ -1,13 +1,16 @@ """Provide class for temporal SNR using spheres.""" # Authors: Leonard Sasse +# Synchon Mandal # License: AGPL -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union + +from pydantic import PositiveFloat from ...api.decorators import register_marker +from ...datagrabber import DataType from ..sphere_aggregation import SphereAggregation -from ..utils import raise_error from .temporal_snr_base import TemporalSNRBase @@ -23,51 +26,35 @@ class TemporalSNRSpheres(TemporalSNRBase): coords : str The name of the coordinates list to use. See :func:`.list_data` for options. - radius : float, optional - The radius of the sphere in mm. If None, the signal will be extracted - from a single voxel. See :class:`nilearn.maskers.NiftiSpheresMasker` - for more information (default None). + radius : ``zero`` or positive float or None, optional + The radius of the sphere in millimeters. + If None, the signal will be extracted from a single voxel. + See :class:`.JuniferNiftiSpheresMasker` for more information + (default None). allow_overlap : bool, optional - Whether to allow overlapping spheres. If False, an error is raised if - the spheres overlap (default is False). + Whether to allow overlapping spheres. + If False, an error is raised if the spheres overlap (default is False). agg_method : str, optional - The aggregation method to use. - See :func:`.get_aggfunc_by_name` for more information - (default None). - agg_method_params : dict, optional - The parameters to pass to the aggregation method (default None). - masks : str, dict or list of dict or str, optional + The aggregation function to use. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict or None, optional + The parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options (default None). + masks : list of dict or str, or None, optional The specification of the masks to apply to regions before extracting signals. Check :ref:`Using Masks ` for more details. If None, will not apply any mask (default None). - name : str, optional - The name of the marker. By default, it will use - KIND_FunctionalConnectivitySpheres where KIND is the kind of data it - was applied to (default None). + name : str or None, optional + The name of the marker. + If None, will use the class name (default None). """ - def __init__( - self, - coords: str, - radius: Optional[float] = None, - allow_overlap: bool = False, - agg_method: str = "mean", - agg_method_params: Optional[dict] = None, - masks: Union[str, dict, list[Union[dict, str]], None] = None, - name: Optional[str] = None, - ) -> None: - self.coords = coords - self.radius = radius - self.allow_overlap = allow_overlap - if radius is None or radius <= 0: - raise_error(f"radius should be > 0: provided {radius}") - super().__init__( - agg_method=agg_method, - agg_method_params=agg_method_params, - masks=masks, - name=name, - ) + coords: str + radius: Optional[Union[Literal[0], PositiveFloat]] = None + allow_overlap: bool = False + on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 def aggregate( self, input: dict[str, Any], extra_input: Optional[dict] = None @@ -105,5 +92,5 @@ def aggregate( method=self.agg_method, method_params=self.agg_method_params, masks=self.masks, - on="BOLD", + on=[DataType.BOLD], ).compute(input=input, extra_input=extra_input) From 3eae988e8e6b9b8340e8471289dca8b1c088bcb3 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Wed, 12 Nov 2025 19:25:25 +0100 Subject: [PATCH 31/99] refactor: make datareader pydantic model --- junifer/datareader/default.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/junifer/datareader/default.py b/junifer/datareader/default.py index 8859a0faf..608747514 100644 --- a/junifer/datareader/default.py +++ b/junifer/datareader/default.py @@ -9,6 +9,7 @@ import nibabel as nib import pandas as pd +from pydantic import BaseModel from ..api.decorators import register_datareader from ..pipeline import PipelineStepMixin, UpdateMetaMixin @@ -34,7 +35,7 @@ @register_datareader -class DefaultDataReader(PipelineStepMixin, UpdateMetaMixin): +class DefaultDataReader(BaseModel, PipelineStepMixin, UpdateMetaMixin): """Concrete implementation for common data reading.""" def validate_input(self, input: list[str]) -> list[str]: From d459d547038827b7f8a18c0ca0296b78e987e478 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 12:57:56 +0100 Subject: [PATCH 32/99] chore: improve types --- junifer/typing/_typing.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/junifer/typing/_typing.py b/junifer/typing/_typing.py index b4fede74b..a53b3181a 100644 --- a/junifer/typing/_typing.py +++ b/junifer/typing/_typing.py @@ -3,9 +3,10 @@ # Authors: Synchon Mandal # License: AGPL -from collections.abc import MutableMapping, Sequence +from collections.abc import Sequence from typing import ( TYPE_CHECKING, + TypedDict, Union, ) @@ -15,6 +16,7 @@ from ..datagrabber import BaseDataGrabber, DataType from ..datareader import DefaultDataReader from ..markers import BaseMarker + from ..pipeline import BaseDataDumpAsset, ExtDep from ..preprocess import BasePreprocessor from ..storage import BaseFeatureStorage, StorageType @@ -38,7 +40,17 @@ ] -DataDumpAssetLike = type["DataDumpAssetLike"] +class ExternalDependency(TypedDict): + name: "ExtDep" + commands: list[str] + + +class ConditionalDependency(TypedDict): + using: object + depends_on: list[object] + + +DataDumpAssetLike = type["BaseDataDumpAsset"] DataRegistryLike = type["BasePipelineDataRegistry"] DataGrabberLike = type["BaseDataGrabber"] PreprocessorLike = type["BasePreprocessor"] @@ -52,20 +64,17 @@ "StorageLike", ] Dependencies = set[str] -ConditionalDependencies = Sequence[ - MutableMapping[ - str, - Union[ - str, - PipelineComponent, - Sequence[str], - Sequence[PipelineComponent], - ], - ] +ConditionalDependencies = list[ConditionalDependency] +ExternalDependencies = list[ExternalDependency] MarkerInOutMappings = dict["DataType", dict[str, "StorageType"]] +DataGrabberPatterns = dict[ + str, + Union[ + dict[str, Union[str, dict[str, str], list[dict[str, str]]]], + list[dict[str, str]], + ], ] -ExternalDependencies = Sequence[MutableMapping[str, Union[str, Sequence[str]]]] -DataGrabberPatterns = dict[str, Union[dict[str, str], list[dict[str, str]]]] + ConfigVal = Union[bool, int, float, str] Element = Union[str, tuple[str, ...]] Elements = Sequence[Element] From a394074497c70ae498978f83bdc19041788a47ba Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 12:59:16 +0100 Subject: [PATCH 33/99] refactor: make pipeline.utils.check_ext_dependencies validate via pydantic --- junifer/pipeline/utils.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/junifer/pipeline/utils.py b/junifer/pipeline/utils.py index c943ae53a..b16fde01c 100644 --- a/junifer/pipeline/utils.py +++ b/junifer/pipeline/utils.py @@ -8,7 +8,9 @@ from enum import Enum from typing import Any, Optional -from junifer.utils.logging import raise_error, warn_with_log +from pydantic import validate_call + +from ..utils.logging import raise_error, warn_with_log __all__ = ["ExtDep", "check_ext_dependencies"] @@ -23,14 +25,15 @@ class ExtDep(str, Enum): FreeSurfer = "freesurfer" +@validate_call def check_ext_dependencies( - name: str, optional: bool = False, **kwargs: Any + name: ExtDep, optional: bool = False, **kwargs: Any ) -> bool: """Check if external dependency `name` is found if mandatory. Parameters ---------- - name : str + name : :enum:`.ExtDep` The name of the dependency. optional : bool, optional Whether the dependency is optional (default False). @@ -44,18 +47,10 @@ def check_ext_dependencies( Raises ------ - ValueError - If ``name`` is invalid. RuntimeError If ``name`` is mandatory and is not found. """ - valid_ext_dependencies = ("afni", "fsl", "ants", "freesurfer") - if name not in valid_ext_dependencies: - raise_error( - "Invalid value for `name`, should be one of: " - f"{valid_ext_dependencies}" - ) # Check for afni if name == "afni": found = _check_afni(**kwargs) From 7b675bca3837c8d1ea869fb96e03ed1a825251af Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:02:50 +0100 Subject: [PATCH 34/99] update: improve conditional deps declaration for markers --- junifer/markers/falff/falff_base.py | 4 ++-- junifer/markers/reho/reho_base.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/junifer/markers/falff/falff_base.py b/junifer/markers/falff/falff_base.py index 0e6a0c8d7..41ad0c4b6 100644 --- a/junifer/markers/falff/falff_base.py +++ b/junifer/markers/falff/falff_base.py @@ -87,12 +87,12 @@ class ALFFBase(BaseMarker): _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { - "depends_on": AFNIALFF, "using": ALFFImpl.afni, + "depends_on": [AFNIALFF], }, { - "depends_on": JuniferALFF, "using": ALFFImpl.junifer, + "depends_on": [JuniferALFF], }, ] diff --git a/junifer/markers/reho/reho_base.py b/junifer/markers/reho/reho_base.py index 6b591f376..76d696b88 100644 --- a/junifer/markers/reho/reho_base.py +++ b/junifer/markers/reho/reho_base.py @@ -109,12 +109,12 @@ class ReHoBase(BaseMarker): _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { - "depends_on": AFNIReHo, "using": ReHoImpl.afni, + "depends_on": [AFNIReHo], }, { - "depends_on": JuniferReHo, "using": ReHoImpl.junifer, + "depends_on": [JuniferReHo], }, ] From 75a1c91f5f5900d748e059b99ff938bd0531bb8e Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:08:05 +0100 Subject: [PATCH 35/99] update: introduce AOMICSpace and AOMICTask --- junifer/datagrabber/__init__.pyi | 10 +++++++++- junifer/datagrabber/aomic/__init__.pyi | 9 ++++++++- junifer/datagrabber/aomic/_types.py | 22 ++++++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 junifer/datagrabber/aomic/_types.py diff --git a/junifer/datagrabber/__init__.pyi b/junifer/datagrabber/__init__.pyi index 9ac3524d6..1380510ee 100644 --- a/junifer/datagrabber/__init__.pyi +++ b/junifer/datagrabber/__init__.pyi @@ -3,6 +3,8 @@ __all__ = [ "DataladDataGrabber", "PatternDataGrabber", "PatternDataladDataGrabber", + "AOMICSpace", + "AOMICTask", "DataladAOMICID1000", "DataladAOMICPIOP1", "DataladAOMICPIOP2", @@ -25,8 +27,14 @@ from .datalad_base import DataladDataGrabber from .pattern import PatternDataGrabber, ConfoundsFormat from .pattern_datalad import PatternDataladDataGrabber -from .aomic import DataladAOMICID1000, DataladAOMICPIOP1, DataladAOMICPIOP2 from .hcp1200 import HCP1200, DataladHCP1200 +from .aomic import ( + AOMICSpace, + AOMICTask, + DataladAOMICID1000, + DataladAOMICPIOP1, + DataladAOMICPIOP2, +) from .multiple import MultipleDataGrabber from .dmcc13_benchmark import DMCC13Benchmark diff --git a/junifer/datagrabber/aomic/__init__.pyi b/junifer/datagrabber/aomic/__init__.pyi index 011cb067e..aa279d1c1 100644 --- a/junifer/datagrabber/aomic/__init__.pyi +++ b/junifer/datagrabber/aomic/__init__.pyi @@ -1,5 +1,12 @@ -__all__ = ["DataladAOMICID1000", "DataladAOMICPIOP1", "DataladAOMICPIOP2"] +__all__ = [ + "AOMICSpace", + "AOMICTask", + "DataladAOMICID1000", + "DataladAOMICPIOP1", + "DataladAOMICPIOP2", +] +from ._types import AOMICSpace, AOMICTask from .id1000 import DataladAOMICID1000 from .piop1 import DataladAOMICPIOP1 from .piop2 import DataladAOMICPIOP2 diff --git a/junifer/datagrabber/aomic/_types.py b/junifer/datagrabber/aomic/_types.py new file mode 100644 index 000000000..39a955dbe --- /dev/null +++ b/junifer/datagrabber/aomic/_types.py @@ -0,0 +1,22 @@ +"""Provide common types for AOMIC DataGrabbers.""" + +from enum import Enum + + +class AOMICSpace(str, Enum): + """Accepted spaces for AOMIC.""" + + Native = "native" + MNI152NLin2009cAsym = "MNI152NLin2009cAsym" + + +class AOMICTask(str, Enum): + """Accepted tasks for AOMIC.""" + + RestingState = "restingstate" + Anticipation = "anticipation" + EmoMatching = "emomatching" + Faces = "faces" + Gstroop = "gstroop" + WorkingMemory = "workingmemory" + StopSignal = "stopsignal" From c3dc983cbdc49c860d52688c4a0622688e46ead2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:09:32 +0100 Subject: [PATCH 36/99] update: introduce HCP1200Task and HCP1200PhaseEncoding --- junifer/datagrabber/__init__.pyi | 9 ++++++++- junifer/datagrabber/hcp1200/__init__.pyi | 9 +++++++-- junifer/datagrabber/hcp1200/hcp1200.py | 24 +++++++++++++++++++++++- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/junifer/datagrabber/__init__.pyi b/junifer/datagrabber/__init__.pyi index 1380510ee..78f299838 100644 --- a/junifer/datagrabber/__init__.pyi +++ b/junifer/datagrabber/__init__.pyi @@ -9,6 +9,8 @@ __all__ = [ "DataladAOMICPIOP1", "DataladAOMICPIOP2", "HCP1200", + "HCP1200Task", + "HCP1200PhaseEncoding", "DataladHCP1200", "MultipleDataGrabber", "DMCC13Benchmark", @@ -27,7 +29,6 @@ from .datalad_base import DataladDataGrabber from .pattern import PatternDataGrabber, ConfoundsFormat from .pattern_datalad import PatternDataladDataGrabber -from .hcp1200 import HCP1200, DataladHCP1200 from .aomic import ( AOMICSpace, AOMICTask, @@ -35,6 +36,12 @@ from .aomic import ( DataladAOMICPIOP1, DataladAOMICPIOP2, ) +from .hcp1200 import ( + HCP1200, + HCP1200Task, + HCP1200PhaseEncoding, + DataladHCP1200, +) from .multiple import MultipleDataGrabber from .dmcc13_benchmark import DMCC13Benchmark diff --git a/junifer/datagrabber/hcp1200/__init__.pyi b/junifer/datagrabber/hcp1200/__init__.pyi index a29441de1..9426b4c48 100644 --- a/junifer/datagrabber/hcp1200/__init__.pyi +++ b/junifer/datagrabber/hcp1200/__init__.pyi @@ -1,4 +1,9 @@ -__all__ = ["HCP1200", "DataladHCP1200"] +__all__ = [ + "HCP1200", + "HCP1200Task", + "HCP1200PhaseEncoding", + "DataladHCP1200", +] -from .hcp1200 import HCP1200 +from .hcp1200 import HCP1200, HCP1200Task, HCP1200PhaseEncoding from .datalad_hcp1200 import DataladHCP1200 diff --git a/junifer/datagrabber/hcp1200/hcp1200.py b/junifer/datagrabber/hcp1200/hcp1200.py index 5187d4e37..049623d43 100644 --- a/junifer/datagrabber/hcp1200/hcp1200.py +++ b/junifer/datagrabber/hcp1200/hcp1200.py @@ -5,6 +5,7 @@ # Synchon Mandal # License: AGPL +from enum import Enum from itertools import product from pathlib import Path from typing import Union @@ -14,7 +15,28 @@ from ..pattern import PatternDataGrabber -__all__ = ["HCP1200"] +__all__ = ["HCP1200", "HCP1200PhaseEncoding", "HCP1200Task"] + + +class HCP1200Task(str, Enum): + """Accepted HCP1200 tasks.""" + + REST1 = "REST1" + REST2 = "REST2" + SOCIAL = "SOCIAL" + WM = "WM" + RELATIONAL = "RELATIONAL" + EMOTION = "EMOTION" + LANGUAGE = "LANGUAGE" + GAMBLING = "GAMBLING" + MOTOR = "MOTOR" + + +class HCP1200PhaseEncoding(str, Enum): + """Accepted HCP1200 phase encoding directions.""" + + LR = "LR" + RL = "RL" @register_datagrabber From 34738f9e2db877e61ff7eeb816b3833c7b2fe5cc Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:12:08 +0100 Subject: [PATCH 37/99] update: introduce DMCCSession, DMCCTask, DMCCPhaseEncoding and DMCCRun --- junifer/datagrabber/__init__.pyi | 12 +++++++- junifer/datagrabber/dmcc13_benchmark.py | 41 ++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/junifer/datagrabber/__init__.pyi b/junifer/datagrabber/__init__.pyi index 78f299838..6864a2d21 100644 --- a/junifer/datagrabber/__init__.pyi +++ b/junifer/datagrabber/__init__.pyi @@ -14,6 +14,10 @@ __all__ = [ "DataladHCP1200", "MultipleDataGrabber", "DMCC13Benchmark", + "DMCCSession", + "DMCCTask", + "DMCCPhaseEncoding", + "DMCCRun", "DataTypeManager", "DataTypeSchema", "OptionalTypeSchema", @@ -43,7 +47,13 @@ from .hcp1200 import ( DataladHCP1200, ) from .multiple import MultipleDataGrabber -from .dmcc13_benchmark import DMCC13Benchmark +from .dmcc13_benchmark import ( + DMCC13Benchmark, + DMCCSession, + DMCCTask, + DMCCPhaseEncoding, + DMCCRun, +) from .pattern_validation_mixin import ( DataTypeManager, diff --git a/junifer/datagrabber/dmcc13_benchmark.py b/junifer/datagrabber/dmcc13_benchmark.py index b69b04c14..345a265fd 100644 --- a/junifer/datagrabber/dmcc13_benchmark.py +++ b/junifer/datagrabber/dmcc13_benchmark.py @@ -3,6 +3,7 @@ # Authors: Synchon Mandal # License: AGPL +from enum import Enum from itertools import product from pathlib import Path from typing import Union @@ -12,7 +13,45 @@ from .pattern_datalad import PatternDataladDataGrabber -__all__ = ["DMCC13Benchmark"] +__all__ = [ + "DMCC13Benchmark", + "DMCCPhaseEncoding", + "DMCCRun", + "DMCCSession", + "DMCCTask", +] + + +class DMCCSession(str, Enum): + """Accepted DMCC sessions.""" + + Wave1Bas = "ses-wave1bas" + Wave1Pro = "ses-wave1pro" + Wave1Rea = "ses-wave1rea" + + +class DMCCTask(str, Enum): + """Accepted DMCC tasks.""" + + Rest = "Rest" + Axcpt = "Axcpt" + Cuedts = "Cuedts" + Stern = "Stern" + Stroop = "Stroop" + + +class DMCCPhaseEncoding(str, Enum): + """Accepted DMCC phase encoding directions.""" + + AP = "AP" + PA = "PA" + + +class DMCCRun(str, Enum): + """Accepted DMCC runs.""" + + One = "1" + Two = "2" @register_datagrabber From d84d987c5fbcfcc93c8edbdaa251fb73055fb28e Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:16:36 +0100 Subject: [PATCH 38/99] refactor: make datagrabber interface and impls pydantic models --- junifer/datagrabber/aomic/id1000.py | 379 ++++++++-------- junifer/datagrabber/aomic/piop1.py | 429 +++++++++--------- junifer/datagrabber/aomic/piop2.py | 422 ++++++++--------- junifer/datagrabber/base.py | 91 ++-- junifer/datagrabber/datalad_base.py | 159 ++++--- junifer/datagrabber/dmcc13_benchmark.py | 313 ++++++------- .../datagrabber/hcp1200/datalad_hcp1200.py | 71 ++- junifer/datagrabber/hcp1200/hcp1200.py | 200 ++++---- junifer/datagrabber/multiple.py | 54 ++- junifer/datagrabber/pattern.py | 200 ++------ junifer/datagrabber/pattern_datalad.py | 148 +----- 11 files changed, 1092 insertions(+), 1374 deletions(-) diff --git a/junifer/datagrabber/aomic/id1000.py b/junifer/datagrabber/aomic/id1000.py index 10b50f39e..2a8acc6b1 100644 --- a/junifer/datagrabber/aomic/id1000.py +++ b/junifer/datagrabber/aomic/id1000.py @@ -7,12 +7,16 @@ # Synchon Mandal # License: AGPL -from pathlib import Path -from typing import Union +from typing import Literal + +from pydantic import HttpUrl from ...api.decorators import register_datagrabber -from ...utils import raise_error +from ...typing import DataGrabberPatterns +from ..base import DataType +from ..pattern import ConfoundsFormat from ..pattern_datalad import PatternDataladDataGrabber +from ._types import AOMICSpace __all__ = ["DataladAOMICID1000"] @@ -24,212 +28,219 @@ class DataladAOMICID1000(PatternDataladDataGrabber): Parameters ---------- - datadir : str or Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). - types: {"BOLD", "T1w", "VBM_CSF", "VBM_GM", "VBM_WM", "DWI", \ - "FreeSurfer"} or list of the options, optional - AOMIC data types. If None, all available data types are selected. - (default None). - space : {"native", "MNI152NLin2009cAsym"}, optional - The space to use for the data (default "MNI152NLin2009cAsym"). - - Raises - ------ - ValueError - If invalid value is passed for: - * ``space`` + types : list of {``DataType.BOLD``, ``DataType.T1w``, \ + ``DataType.VBM_CSF``, ``DataType.VBM_GM``, ``DataType.VBM_WM``, \ + ``DataType.DWI``, ``DataType.FreeSurfer``, ``DataType.Warp``}, \ + optional + The data type(s) to grab. + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. + space : :enum:`.AOMICSpace`, optional + AOMIC space (default ``AOMICSpace.MNI152NLin2009cAsym``). """ - def __init__( - self, - datadir: Union[str, Path, None] = None, - types: Union[str, list[str], None] = None, - space: str = "MNI152NLin2009cAsym", - ) -> None: - valid_spaces = ["native", "MNI152NLin2009cAsym"] - if space not in ["native", "MNI152NLin2009cAsym"]: - raise_error( - f"Invalid space {space}. Must be one of {valid_spaces}" - ) - - # Descriptor for space in `anat` - sp_anat_desc = ( - "" if space == "native" else "space-MNI152NLin2009cAsym_" - ) - # Descriptor for space in `func` - sp_func_desc = ( - "space-T1w_" if space == "native" else "space-MNI152NLin2009cAsym_" - ) - # The patterns - patterns = { - "BOLD": { + uri: HttpUrl = HttpUrl("https://github.com/OpenNeuroDatasets/ds003097.git") + types: list[ + Literal[ + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.DWI, + DataType.FreeSurfer, + DataType.Warp, + ] + ] = [ # noqa: RUF012 + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.DWI, + DataType.FreeSurfer, + DataType.Warp, + ] + space: AOMICSpace = AOMICSpace.MNI152NLin2009cAsym + patterns: DataGrabberPatterns = { # noqa: RUF012 + "BOLD": { + "pattern": ( + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-moviewatching_" + "{sp_func_desc}" + "desc-preproc_bold.nii.gz" + ), + "mask": { "pattern": ( "derivatives/fmriprep/{subject}/func/" "{subject}_task-moviewatching_" - f"{sp_func_desc}" - "desc-preproc_bold.nii.gz" + "{sp_func_desc}" + "desc-brain_mask.nii.gz" ), - "space": space, - "mask": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-moviewatching_" - f"{sp_func_desc}" - "desc-brain_mask.nii.gz" - ), - "space": space, - }, - "confounds": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-moviewatching_" - "desc-confounds_regressors.tsv" - ), - "format": "fmriprep", - }, - "reference": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-moviewatching_" - f"{sp_func_desc}" - "boldref.nii.gz" - ), - }, }, - "T1w": { + "confounds": { "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "desc-preproc_T1w.nii.gz" + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-moviewatching_" + "desc-confounds_regressors.tsv" ), - "space": space, - "mask": { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "desc-brain_mask.nii.gz" - ), - "space": space, - }, + "format": "fmriprep", }, - "VBM_CSF": { + "reference": { "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "label-CSF_probseg.nii.gz" + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-moviewatching_" + "{sp_func_desc}" + "boldref.nii.gz" ), - "space": space, }, - "VBM_GM": { + }, + "T1w": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "desc-preproc_T1w.nii.gz" + ), + "mask": { "pattern": ( "derivatives/fmriprep/{subject}/anat/" "{subject}_" - f"{sp_anat_desc}" - "label-GM_probseg.nii.gz" + "{sp_anat_desc}" + "desc-brain_mask.nii.gz" ), - "space": space, }, - "VBM_WM": { + }, + "VBM_CSF": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-CSF_probseg.nii.gz" + ), + }, + "VBM_GM": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-GM_probseg.nii.gz" + ), + }, + "VBM_WM": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-WM_probseg.nii.gz" + ), + }, + "DWI": { + "pattern": ( + "derivatives/dwipreproc/{subject}/dwi/" + "{subject}_desc-preproc_dwi.nii.gz" + ), + }, + "FreeSurfer": { + "pattern": "derivatives/freesurfer/[!f]{subject}/mri/T1.mg[z]", + "aseg": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/mri/aseg.mg[z]" + ) + }, + "norm": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/mri/norm.mg[z]" + ) + }, + "lh_white": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/lh.whit[e]" + ) + }, + "rh_white": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/rh.whit[e]" + ) + }, + "lh_pial": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/lh.pia[l]" + ) + }, + "rh_pial": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/rh.pia[l]" + ) + }, + }, + "Warp": [ + { "pattern": ( "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "label-WM_probseg.nii.gz" + "{subject}_from-MNI152NLin2009cAsym_to-T1w_" + "mode-image_xfm.h5" ), - "space": space, + "src": "MNI152NLin2009cAsym", + "dst": "native", + "warper": "ants", }, - "DWI": { + { "pattern": ( - "derivatives/dwipreproc/{subject}/dwi/" - "{subject}_desc-preproc_dwi.nii.gz" + "derivatives/fmriprep/{subject}/anat/" + "{subject}_from-T1w_to-MNI152NLin2009cAsym_" + "mode-image_xfm.h5" ), + "src": "native", + "dst": "MNI152NLin2009cAsym", + "warper": "ants", }, - "FreeSurfer": { - "pattern": "derivatives/freesurfer/[!f]{subject}/mri/T1.mg[z]", - "aseg": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/mri/aseg.mg[z]" - ) - }, - "norm": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/mri/norm.mg[z]" - ) - }, - "lh_white": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/lh.whit[e]" - ) - }, - "rh_white": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/rh.whit[e]" - ) - }, - "lh_pial": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/lh.pia[l]" - ) - }, - "rh_pial": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/rh.pia[l]" - ) - }, - }, - "Warp": [ - { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_from-MNI152NLin2009cAsym_to-T1w_" - "mode-image_xfm.h5" - ), - "src": "MNI152NLin2009cAsym", - "dst": "native", - "warper": "ants", - }, - { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_from-T1w_to-MNI152NLin2009cAsym_" - "mode-image_xfm.h5" - ), - "src": "native", - "dst": "MNI152NLin2009cAsym", - "warper": "ants", - }, - ], - } - if space == "native": - patterns["BOLD"]["prewarp_space"] = "MNI152NLin2009cAsym" + ], + } + replacements: list[str] = ["subject"] # noqa: RUF012 + confounds_format: ConfoundsFormat = ConfoundsFormat.FMRIPrep - else: - patterns["BOLD"]["prewarp_space"] = "native" - - # Use native T1w assets - self.space = space - - # Set default types - if types is None: - types = list(patterns.keys()) - # Convert single type into list - else: - if not isinstance(types, list): - types = [types] - # The replacements - replacements = ["subject"] - uri = "https://github.com/OpenNeuroDatasets/ds003097.git" - super().__init__( - types=types, - datadir=datadir, - uri=uri, - patterns=patterns, - replacements=replacements, - confounds_format="fmriprep", + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" + # Descriptor for space in `anat` + sp_anat_desc = ( + "" if self.space == "native" else "space-MNI152NLin2009cAsym_" ) + # Descriptor for space in `func` + sp_func_desc = ( + "space-T1w_" + if self.space == "native" + else "space-MNI152NLin2009cAsym_" + ) + self.patterns["BOLD"]["pattern"] = self.patterns["BOLD"][ + "pattern" + ].replace("{sp_func_desc}", sp_func_desc) + self.patterns["BOLD"]["mask"]["pattern"] = self.patterns["BOLD"][ + "mask" + ]["pattern"].replace("{sp_func_desc}", sp_func_desc) + self.patterns["BOLD"]["reference"]["pattern"] = self.patterns["BOLD"][ + "reference" + ]["pattern"].replace("{sp_func_desc}", sp_func_desc) + self.patterns["T1w"]["pattern"] = self.patterns["T1w"][ + "pattern" + ].replace("{sp_anat_desc}", sp_anat_desc) + self.patterns["T1w"]["mask"]["pattern"] = self.patterns["T1w"]["mask"][ + "pattern" + ].replace("{sp_anat_desc}", sp_anat_desc) + for t in ["VBM_CSF", "VBM_GM", "VBM_WM"]: + self.patterns[t]["pattern"] = self.patterns[t]["pattern"].replace( + "{sp_anat_desc}", sp_anat_desc + ) + for t in ["BOLD", "T1w"]: + self.patterns[t]["space"] = self.space + self.patterns[t]["mask"]["space"] = self.space + for t in ["VBM_CSF", "VBM_GM", "VBM_WM"]: + self.patterns[t]["space"] = self.space + if self.space == "native": + self.patterns["BOLD"]["prewarp_space"] = "MNI152NLin2009cAsym" + else: + self.patterns["BOLD"]["prewarp_space"] = "native" + super().validate_datagrabber_params() diff --git a/junifer/datagrabber/aomic/piop1.py b/junifer/datagrabber/aomic/piop1.py index 2e1bf8f58..889c2e61e 100644 --- a/junifer/datagrabber/aomic/piop1.py +++ b/junifer/datagrabber/aomic/piop1.py @@ -8,12 +8,16 @@ # License: AGPL from itertools import product -from pathlib import Path -from typing import Union +from typing import Literal + +from pydantic import HttpUrl from ...api.decorators import register_datagrabber -from ...utils import raise_error +from ...typing import DataGrabberPatterns +from ..base import DataType +from ..pattern import ConfoundsFormat from ..pattern_datalad import PatternDataladDataGrabber +from ._types import AOMICSpace, AOMICTask __all__ = ["DataladAOMICPIOP1"] @@ -25,246 +29,247 @@ class DataladAOMICPIOP1(PatternDataladDataGrabber): Parameters ---------- - datadir : str or Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). - types: {"BOLD", "T1w", "VBM_CSF", "VBM_GM", "VBM_WM", "DWI", \ - "FreeSurfer"} or list of the options, optional - AOMIC data types. If None, all available data types are selected. - (default None). - tasks : {"restingstate", "anticipation", "emomatching", "faces", \ - "gstroop", "workingmemory"} or list of the options, optional - AOMIC PIOP1 task sessions. If None, all available task sessions are - selected (default None). - space : {"native", "MNI152NLin2009cAsym"}, optional - The space to use for the data (default "MNI152NLin2009cAsym"). - - Raises - ------ - ValueError - If invalid value is passed for: - * ``tasks`` - * ``space`` + types : list of {``DataType.BOLD``, ``DataType.T1w``, \ + ``DataType.VBM_CSF``, ``DataType.VBM_GM``, ``DataType.VBM_WM``, \ + ``DataType.DWI``, ``DataType.FreeSurfer``, ``DataType.Warp``}, \ + optional + The data type(s) to grab. + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. + tasks : list of {``AOMICTask.RestingState``, ``AOMICTask.Anticipation``, \ + ``AOMICTask.EmoMatching``, ``AOMICTask.Faces``, \ + ``AOMICTask.Gstroop``, ``AOMICTask.WorkingMemory``}, optional + AOMIC PIOP1 task sessions. + By default, all available task sessions are selected. + space : :enum:`.AOMICSpace`, optional + AOMIC space (default ``AOMICSpace.MNI152NLin2009cAsym``). """ - def __init__( - self, - datadir: Union[str, Path, None] = None, - types: Union[str, list[str], None] = None, - tasks: Union[str, list[str], None] = None, - space: str = "MNI152NLin2009cAsym", - ) -> None: - valid_spaces = ["native", "MNI152NLin2009cAsym"] - if space not in ["native", "MNI152NLin2009cAsym"]: - raise_error( - f"Invalid space {space}. Must be one of {valid_spaces}" - ) - # Declare all tasks - all_tasks = [ - "restingstate", - "anticipation", - "emomatching", - "faces", - "gstroop", - "workingmemory", + uri: HttpUrl = HttpUrl("https://github.com/OpenNeuroDatasets/ds002785") + types: list[ + Literal[ + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.DWI, + DataType.FreeSurfer, + DataType.Warp, ] - # Set default tasks - if tasks is None: - tasks = all_tasks - else: - # Convert single task into list - if isinstance(tasks, str): - tasks = [tasks] - # Verify valid tasks - for t in tasks: - if t not in all_tasks: - raise_error( - f"{t} is not a valid task in the AOMIC PIOP1 dataset!" - ) - self.tasks = tasks - # Descriptor for space in `anat` - sp_anat_desc = ( - "" if space == "native" else "space-MNI152NLin2009cAsym_" - ) - # Descriptor for space in `func` - sp_func_desc = ( - "space-T1w_" if space == "native" else "space-MNI152NLin2009cAsym_" - ) - # The patterns - patterns = { - "BOLD": { + ] = [ # noqa: RUF012 + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.DWI, + DataType.FreeSurfer, + DataType.Warp, + ] + tasks: list[ + Literal[ + AOMICTask.RestingState, + AOMICTask.Anticipation, + AOMICTask.EmoMatching, + AOMICTask.Faces, + AOMICTask.Gstroop, + AOMICTask.WorkingMemory, + ] + ] = [ # noqa: RUF012 + AOMICTask.RestingState, + AOMICTask.Anticipation, + AOMICTask.EmoMatching, + AOMICTask.Faces, + AOMICTask.Gstroop, + AOMICTask.WorkingMemory, + ] + space: AOMICSpace = AOMICSpace.MNI152NLin2009cAsym + patterns: DataGrabberPatterns = { # noqa: RUF012 + "BOLD": { + "pattern": ( + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-{task}_" + "{sp_func_desc}" + "desc-preproc_bold.nii.gz" + ), + "mask": { "pattern": ( "derivatives/fmriprep/{subject}/func/" "{subject}_task-{task}_" - f"{sp_func_desc}" - "desc-preproc_bold.nii.gz" + "{sp_func_desc}" + "desc-brain_mask.nii.gz" ), - "space": space, - "mask": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-{task}_" - f"{sp_func_desc}" - "desc-brain_mask.nii.gz" - ), - "space": space, - }, - "confounds": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-{task}_" - "desc-confounds_regressors.tsv" - ), - "format": "fmriprep", - }, - "reference": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-{task}_" - f"{sp_func_desc}" - "boldref.nii.gz" - ), - }, }, - "T1w": { + "confounds": { "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "desc-preproc_T1w.nii.gz" + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-{task}_" + "desc-confounds_regressors.tsv" ), - "space": space, - "mask": { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "desc-brain_mask.nii.gz" - ), - "space": space, - }, + "format": "fmriprep", }, - "VBM_CSF": { + "reference": { "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "label-CSF_probseg.nii.gz" + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-{task}_" + "{sp_func_desc}" + "boldref.nii.gz" ), - "space": space, }, - "VBM_GM": { + }, + "T1w": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "desc-preproc_T1w.nii.gz" + ), + "mask": { "pattern": ( "derivatives/fmriprep/{subject}/anat/" "{subject}_" - f"{sp_anat_desc}" - "label-GM_probseg.nii.gz" + "{sp_anat_desc}" + "desc-brain_mask.nii.gz" ), - "space": space, }, - "VBM_WM": { + }, + "VBM_CSF": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-CSF_probseg.nii.gz" + ), + }, + "VBM_GM": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-GM_probseg.nii.gz" + ), + }, + "VBM_WM": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-WM_probseg.nii.gz" + ), + }, + "DWI": { + "pattern": ( + "derivatives/dwipreproc/{subject}/dwi/" + "{subject}_desc-preproc_dwi.nii.gz" + ), + }, + "FreeSurfer": { + "pattern": "derivatives/freesurfer/[!f]{subject}/mri/T1.mg[z]", + "aseg": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/mri/aseg.mg[z]" + ) + }, + "norm": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/mri/norm.mg[z]" + ) + }, + "lh_white": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/lh.whit[e]" + ) + }, + "rh_white": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/rh.whit[e]" + ) + }, + "lh_pial": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/lh.pia[l]" + ) + }, + "rh_pial": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/rh.pia[l]" + ) + }, + }, + "Warp": [ + { "pattern": ( "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "label-WM_probseg.nii.gz" + "{subject}_from-MNI152NLin2009cAsym_to-T1w_" + "mode-image_xfm.h5" ), - "space": space, + "src": "MNI152NLin2009cAsym", + "dst": "native", + "warper": "ants", }, - "DWI": { + { "pattern": ( - "derivatives/dwipreproc/{subject}/dwi/" - "{subject}_desc-preproc_dwi.nii.gz" + "derivatives/fmriprep/{subject}/anat/" + "{subject}_from-T1w_to-MNI152NLin2009cAsym_" + "mode-image_xfm.h5" ), + "src": "native", + "dst": "MNI152NLin2009cAsym", + "warper": "ants", }, - "FreeSurfer": { - "pattern": "derivatives/freesurfer/[!f]{subject}/mri/T1.mg[z]", - "aseg": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/mri/aseg.mg[z]" - ) - }, - "norm": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/mri/norm.mg[z]" - ) - }, - "lh_white": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/lh.whit[e]" - ) - }, - "rh_white": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/rh.whit[e]" - ) - }, - "lh_pial": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/lh.pia[l]" - ) - }, - "rh_pial": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/rh.pia[l]" - ) - }, - }, - "Warp": [ - { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_from-MNI152NLin2009cAsym_to-T1w_" - "mode-image_xfm.h5" - ), - "src": "MNI152NLin2009cAsym", - "dst": "native", - "warper": "ants", - }, - { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_from-T1w_to-MNI152NLin2009cAsym_" - "mode-image_xfm.h5" - ), - "src": "native", - "dst": "MNI152NLin2009cAsym", - "warper": "ants", - }, - ], - } + ], + } + replacements: list[str] = ["subject", "task"] # noqa: RUF012 + confounds_format: ConfoundsFormat = ConfoundsFormat.FMRIPrep - if space == "native": - patterns["BOLD"]["prewarp_space"] = "MNI152NLin2009cAsym" - else: - patterns["BOLD"]["prewarp_space"] = "native" - - # Use native T1w assets - self.space = space - - # Set default types - if types is None: - types = list(patterns.keys()) - # Convert single type into list - else: - if not isinstance(types, list): - types = [types] - # The replacements - replacements = ["subject", "task"] - uri = "https://github.com/OpenNeuroDatasets/ds002785" - super().__init__( - types=types, - datadir=datadir, - uri=uri, - patterns=patterns, - replacements=replacements, - confounds_format="fmriprep", + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" + # Descriptor for space in `anat` + sp_anat_desc = ( + "" if self.space == "native" else "space-MNI152NLin2009cAsym_" ) + # Descriptor for space in `func` + sp_func_desc = ( + "space-T1w_" + if self.space == "native" + else "space-MNI152NLin2009cAsym_" + ) + self.patterns["BOLD"]["pattern"] = self.patterns["BOLD"][ + "pattern" + ].replace("{sp_func_desc}", sp_func_desc) + self.patterns["BOLD"]["mask"]["pattern"] = self.patterns["BOLD"][ + "mask" + ]["pattern"].replace("{sp_func_desc}", sp_func_desc) + self.patterns["BOLD"]["reference"]["pattern"] = self.patterns["BOLD"][ + "reference" + ]["pattern"].replace("{sp_func_desc}", sp_func_desc) + self.patterns["T1w"]["pattern"] = self.patterns["T1w"][ + "pattern" + ].replace("{sp_anat_desc}", sp_anat_desc) + self.patterns["T1w"]["mask"]["pattern"] = self.patterns["T1w"]["mask"][ + "pattern" + ].replace("{sp_anat_desc}", sp_anat_desc) + for t in ["VBM_CSF", "VBM_GM", "VBM_WM"]: + self.patterns[t]["pattern"] = self.patterns[t]["pattern"].replace( + "{sp_anat_desc}", sp_anat_desc + ) + for t in ["BOLD", "T1w"]: + self.patterns[t]["space"] = self.space + self.patterns[t]["mask"]["space"] = self.space + for t in ["VBM_CSF", "VBM_GM", "VBM_WM"]: + self.patterns[t]["space"] = self.space + if self.space == "native": + self.patterns["BOLD"]["prewarp_space"] = "MNI152NLin2009cAsym" + else: + self.patterns["BOLD"]["prewarp_space"] = "native" + super().validate_datagrabber_params() def get_item(self, subject: str, task: str) -> dict: - """Index one element in the dataset. + """Get the specified item from the dataset. Parameters ---------- diff --git a/junifer/datagrabber/aomic/piop2.py b/junifer/datagrabber/aomic/piop2.py index 1203efed2..41e5a86e0 100644 --- a/junifer/datagrabber/aomic/piop2.py +++ b/junifer/datagrabber/aomic/piop2.py @@ -8,12 +8,16 @@ # License: AGPL from itertools import product -from pathlib import Path -from typing import Union +from typing import Literal + +from pydantic import HttpUrl from ...api.decorators import register_datagrabber -from ...utils import raise_error +from ...typing import DataGrabberPatterns +from ..base import DataType +from ..pattern import ConfoundsFormat from ..pattern_datalad import PatternDataladDataGrabber +from ._types import AOMICSpace, AOMICTask __all__ = ["DataladAOMICPIOP2"] @@ -25,241 +29,239 @@ class DataladAOMICPIOP2(PatternDataladDataGrabber): Parameters ---------- - datadir : str or Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). - types: {"BOLD", "T1w", "VBM_CSF", "VBM_GM", "VBM_WM", "DWI", \ - "FreeSurfer"} or list of the options, optional - AOMIC data types. If None, all available data types are selected. - (default None). - tasks : {"restingstate", "stopsignal", "workingmemory", "emomatching"} or \ - list of the options, optional - AOMIC PIOP2 task sessions. If None, all available task sessions are - selected (default None). - space : {"native", "MNI152NLin2009cAsym"}, optional - The space to use for the data (default "MNI152NLin2009cAsym"). - - Raises - ------ - ValueError - If invalid value is passed for: - * ``tasks`` - * ``space`` + types : list of {``DataType.BOLD``, ``DataType.T1w``, \ + ``DataType.VBM_CSF``, ``DataType.VBM_GM``, ``DataType.VBM_WM``, \ + ``DataType.DWI``, ``DataType.FreeSurfer``, ``DataType.Warp``}, \ + optional + The data type(s) to grab. + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. + tasks : list of {``AOMICTask.RestingState``, ``AOMICTask.StopSignal``, \ + ``AOMICTask.WorkingMemory``, ``AOMICTask.EmoMatching``}, optional + AOMIC PIOP2 task sessions. + By default, all available task sessions are selected. + space : :enum:`.AOMICSpace`, optional + AOMIC space (default ``AOMICSpace.MNI152NLin2009cAsym``). """ - def __init__( - self, - datadir: Union[str, Path, None] = None, - types: Union[str, list[str], None] = None, - tasks: Union[str, list[str], None] = None, - space: str = "MNI152NLin2009cAsym", - ) -> None: - valid_spaces = ["native", "MNI152NLin2009cAsym"] - if space not in ["native", "MNI152NLin2009cAsym"]: - raise_error( - f"Invalid space {space}. Must be one of {valid_spaces}" - ) - # Declare all tasks - all_tasks = [ - "restingstate", - "stopsignal", - "workingmemory", - "emomatching", + uri: HttpUrl = HttpUrl("https://github.com/OpenNeuroDatasets/ds002790") + types: list[ + Literal[ + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.DWI, + DataType.FreeSurfer, + DataType.Warp, ] - # Set default tasks - if tasks is None: - tasks = all_tasks - else: - # Convert single task into list - if isinstance(tasks, str): - tasks = [tasks] - # Verify valid tasks - for t in tasks: - if t not in all_tasks: - raise_error( - f"{t} is not a valid task in the AOMIC PIOP2 dataset!" - ) - self.tasks = tasks - # Descriptor for space in `anat` - sp_anat_desc = ( - "" if space == "native" else "space-MNI152NLin2009cAsym_" - ) - # Descriptor for space in `func` - sp_func_desc = ( - "space-T1w_" if space == "native" else "space-MNI152NLin2009cAsym_" - ) - # The patterns - patterns = { - "BOLD": { + ] = [ # noqa: RUF012 + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.DWI, + DataType.FreeSurfer, + DataType.Warp, + ] + tasks: list[ + Literal[ + AOMICTask.RestingState, + AOMICTask.StopSignal, + AOMICTask.EmoMatching, + AOMICTask.WorkingMemory, + ] + ] = [ # noqa: RUF012 + AOMICTask.RestingState, + AOMICTask.StopSignal, + AOMICTask.EmoMatching, + AOMICTask.WorkingMemory, + ] + space: AOMICSpace = AOMICSpace.MNI152NLin2009cAsym + patterns: DataGrabberPatterns = { # noqa: RUF012 + "BOLD": { + "pattern": ( + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-{task}_" + "{sp_func_desc}" + "desc-preproc_bold.nii.gz" + ), + "mask": { "pattern": ( "derivatives/fmriprep/{subject}/func/" "{subject}_task-{task}_" - f"{sp_func_desc}" - "desc-preproc_bold.nii.gz" + "{sp_func_desc}" + "desc-brain_mask.nii.gz" ), - "space": space, - "mask": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-{task}_" - f"{sp_func_desc}" - "desc-brain_mask.nii.gz" - ), - "space": space, - }, - "confounds": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-{task}_" - "desc-confounds_regressors.tsv" - ), - "format": "fmriprep", - }, - "reference": { - "pattern": ( - "derivatives/fmriprep/{subject}/func/" - "{subject}_task-{task}_" - f"{sp_func_desc}" - "boldref.nii.gz" - ), - }, }, - "T1w": { + "confounds": { "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "desc-preproc_T1w.nii.gz" + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-{task}_" + "desc-confounds_regressors.tsv" ), - "space": space, - "mask": { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "desc-brain_mask.nii.gz" - ), - "space": space, - }, + "format": "fmriprep", }, - "VBM_CSF": { + "reference": { "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "label-CSF_probseg.nii.gz" + "derivatives/fmriprep/{subject}/func/" + "{subject}_task-{task}_" + "{sp_func_desc}" + "boldref.nii.gz" ), - "space": space, }, - "VBM_GM": { + }, + "T1w": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "desc-preproc_T1w.nii.gz" + ), + "mask": { "pattern": ( "derivatives/fmriprep/{subject}/anat/" "{subject}_" - f"{sp_anat_desc}" - "label-GM_probseg.nii.gz" + "{sp_anat_desc}" + "desc-brain_mask.nii.gz" ), - "space": space, }, - "VBM_WM": { + }, + "VBM_CSF": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-CSF_probseg.nii.gz" + ), + }, + "VBM_GM": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-GM_probseg.nii.gz" + ), + }, + "VBM_WM": { + "pattern": ( + "derivatives/fmriprep/{subject}/anat/" + "{subject}_" + "{sp_anat_desc}" + "label-WM_probseg.nii.gz" + ), + }, + "DWI": { + "pattern": ( + "derivatives/dwipreproc/{subject}/dwi/" + "{subject}_desc-preproc_dwi.nii.gz" + ), + }, + "FreeSurfer": { + "pattern": "derivatives/freesurfer/[!f]{subject}/mri/T1.mg[z]", + "aseg": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/mri/aseg.mg[z]" + ) + }, + "norm": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/mri/norm.mg[z]" + ) + }, + "lh_white": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/lh.whit[e]" + ) + }, + "rh_white": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/rh.whit[e]" + ) + }, + "lh_pial": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/lh.pia[l]" + ) + }, + "rh_pial": { + "pattern": ( + "derivatives/freesurfer/[!f]{subject}/surf/rh.pia[l]" + ) + }, + }, + "Warp": [ + { "pattern": ( "derivatives/fmriprep/{subject}/anat/" - "{subject}_" - f"{sp_anat_desc}" - "label-WM_probseg.nii.gz" + "{subject}_from-MNI152NLin2009cAsym_to-T1w_" + "mode-image_xfm.h5" ), - "space": space, + "src": "MNI152NLin2009cAsym", + "dst": "native", + "warper": "ants", }, - "DWI": { + { "pattern": ( - "derivatives/dwipreproc/{subject}/dwi/" - "{subject}_desc-preproc_dwi.nii.gz" + "derivatives/fmriprep/{subject}/anat/" + "{subject}_from-T1w_to-MNI152NLin2009cAsym_" + "mode-image_xfm.h5" ), + "src": "native", + "dst": "MNI152NLin2009cAsym", + "warper": "ants", }, - "FreeSurfer": { - "pattern": "derivatives/freesurfer/[!f]{subject}/mri/T1.mg[z]", - "aseg": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/mri/aseg.mg[z]" - ) - }, - "norm": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/mri/norm.mg[z]" - ) - }, - "lh_white": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/lh.whit[e]" - ) - }, - "rh_white": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/rh.whit[e]" - ) - }, - "lh_pial": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/lh.pia[l]" - ) - }, - "rh_pial": { - "pattern": ( - "derivatives/freesurfer/[!f]{subject}/surf/rh.pia[l]" - ) - }, - }, - "Warp": [ - { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_from-MNI152NLin2009cAsym_to-T1w_" - "mode-image_xfm.h5" - ), - "src": "MNI152NLin2009cAsym", - "dst": "native", - "warper": "ants", - }, - { - "pattern": ( - "derivatives/fmriprep/{subject}/anat/" - "{subject}_from-T1w_to-MNI152NLin2009cAsym_" - "mode-image_xfm.h5" - ), - "src": "native", - "dst": "MNI152NLin2009cAsym", - "warper": "ants", - }, - ], - } + ], + } + replacements: list[str] = ["subject", "task"] # noqa: RUF012 + confounds_format: ConfoundsFormat = ConfoundsFormat.FMRIPrep - if space == "native": - patterns["BOLD"]["prewarp_space"] = "MNI152NLin2009cAsym" - else: - patterns["BOLD"]["prewarp_space"] = "native" - - # Use native T1w assets - self.space = space - - # Set default types - if types is None: - types = list(patterns.keys()) - # Convert single type into list - else: - if not isinstance(types, list): - types = [types] - # The replacements - replacements = ["subject", "task"] - uri = "https://github.com/OpenNeuroDatasets/ds002790" - super().__init__( - types=types, - datadir=datadir, - uri=uri, - patterns=patterns, - replacements=replacements, - confounds_format="fmriprep", + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" + # Descriptor for space in `anat` + sp_anat_desc = ( + "" if self.space == "native" else "space-MNI152NLin2009cAsym_" ) + # Descriptor for space in `func` + sp_func_desc = ( + "space-T1w_" + if self.space == "native" + else "space-MNI152NLin2009cAsym_" + ) + self.patterns["BOLD"]["pattern"] = self.patterns["BOLD"][ + "pattern" + ].replace("{sp_func_desc}", sp_func_desc) + self.patterns["BOLD"]["mask"]["pattern"] = self.patterns["BOLD"][ + "mask" + ]["pattern"].replace("{sp_func_desc}", sp_func_desc) + self.patterns["BOLD"]["reference"]["pattern"] = self.patterns["BOLD"][ + "reference" + ]["pattern"].replace("{sp_func_desc}", sp_func_desc) + self.patterns["T1w"]["pattern"] = self.patterns["T1w"][ + "pattern" + ].replace("{sp_anat_desc}", sp_anat_desc) + self.patterns["T1w"]["mask"]["pattern"] = self.patterns["T1w"]["mask"][ + "pattern" + ].replace("{sp_anat_desc}", sp_anat_desc) + for t in ["VBM_CSF", "VBM_GM", "VBM_WM"]: + self.patterns[t]["pattern"] = self.patterns[t]["pattern"].replace( + "{sp_anat_desc}", sp_anat_desc + ) + for t in ["BOLD", "T1w"]: + self.patterns[t]["space"] = self.space + self.patterns[t]["mask"]["space"] = self.space + for t in ["VBM_CSF", "VBM_GM", "VBM_WM"]: + self.patterns[t]["space"] = self.space + if self.space == "native": + self.patterns["BOLD"]["prewarp_space"] = "MNI152NLin2009cAsym" + else: + self.patterns["BOLD"]["prewarp_space"] = "native" + super().validate_datagrabber_params() def get_elements(self) -> list: """Implement fetching list of elements in the dataset. @@ -278,7 +280,7 @@ def get_elements(self) -> list: return elems def get_item(self, subject: str, task: str) -> dict: - """Index one element in the dataset. + """Get the specified item from the dataset. Parameters ---------- diff --git a/junifer/datagrabber/base.py b/junifer/datagrabber/base.py index 2b5c23eb9..16a563afd 100644 --- a/junifer/datagrabber/base.py +++ b/junifer/datagrabber/base.py @@ -7,9 +7,11 @@ from abc import ABC, abstractmethod from collections.abc import Iterator -from enum import Enum from pathlib import Path -from typing import Union +from typing import Any + +from aenum import Enum +from pydantic import BaseModel, ConfigDict, Field from ..pipeline import UpdateMetaMixin from ..typing import Element, Elements @@ -35,46 +37,44 @@ class DataType(str, Enum): DWI = "DWI" FreeSurfer = "FreeSurfer" -class BaseDataGrabber(ABC, UpdateMetaMixin): - """Abstract base class for DataGrabber. - For every interface that is required, one needs to provide a concrete +class BaseDataGrabber(BaseModel, ABC, UpdateMetaMixin): + """Abstract base class for data fetcher. + + For every datagrabber, one needs to provide a concrete implementation of this abstract class. Parameters ---------- - types : list of str - The types of data to be grabbed. - datadir : str or pathlib.Path - The directory where the data is / will be stored. - - Raises - ------ - TypeError - If ``types`` is not a list or if the values are not string. + types : list of :enum:`.DataType` + The data type(s) to grab. + datadir : pathlib.Path + The path where the data is or will be stored. """ - def __init__(self, types: list[str], datadir: Union[str, Path]) -> None: - # Validate types - if not isinstance(types, list): - raise_error(msg="`types` must be a list", klass=TypeError) - if any(not isinstance(x, str) for x in types): - raise_error( - msg="`types` must be a list of strings", klass=TypeError - ) - self.types = types - - # Convert str to Path - if not isinstance(datadir, Path): - datadir = Path(datadir) - self._datadir = datadir + model_config = ConfigDict(use_enum_values=True) + + types: list[DataType] = Field(frozen=True) + datadir: Path + def model_post_init(self, context: Any): # noqa: D102 logger.debug("Initializing BaseDataGrabber") - logger.debug(f"\t_datadir = {datadir}") - logger.debug(f"\ttypes = {types}") + logger.debug(f"\tdatadir = {self.datadir}") + logger.debug(f"\ttypes = {self.types}") + # Run extra validation for datagrabbers and fail early if needed + self.validate_datagrabber_params() + # Convert to correct data type + # self.types = [DataType(t) for t in self.types] + + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber. + + Subclasses can override to provide validation. + """ + pass - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[Elements]: """Enable iterable support. Yields @@ -90,7 +90,7 @@ def __getitem__(self, element: Element) -> dict[str, dict]: Parameters ---------- - element : str or tuple of str + element : `Element` The element to be indexed. Returns @@ -100,10 +100,14 @@ def __getitem__(self, element: Element) -> dict[str, dict]: specified element. """ + # Convert element to tuple if not already and extract enum values if + # present + element = ( + (element,) + if not isinstance(element, tuple) + else tuple(i.value if isinstance(i, Enum) else i for i in element) + ) logger.info(f"Getting element {element}") - # Convert element to tuple if not already - if not isinstance(element, tuple): - element = (element,) # Zip through element keys and actual values to construct element # access dictionary named_element: dict = dict(zip(self.get_element_keys(), element)) @@ -136,29 +140,30 @@ def get_types(self) -> list[str]: Returns ------- list of str - The types of data to be grabbed. + The data type(s) to grab. """ - return self.types.copy() + return [x.value for x in self.types] @property - def datadir(self) -> Path: - """Get data directory path. + def fulldir(self) -> Path: + """Get complete data directory path. Returns ------- pathlib.Path - Path to the data directory. Can be overridden by subclasses. + Complete path to the data directory. + Can be overridden by subclasses. """ - return self._datadir + return self.datadir def filter(self, selection: Elements) -> Iterator: """Filter elements to be grabbed. Parameters ---------- - selection : list + selection : `Elements` The list of partial or complete element selectors to filter using. Yields @@ -173,7 +178,7 @@ def filter_func(element: Element) -> bool: Parameters ---------- - element : str or tuple of str + element : `Elements` The element to be filtered. Returns diff --git a/junifer/datagrabber/datalad_base.py b/junifer/datagrabber/datalad_base.py index f7eb996dd..8e9f2656c 100644 --- a/junifer/datagrabber/datalad_base.py +++ b/junifer/datagrabber/datalad_base.py @@ -9,12 +9,13 @@ import os import tempfile from pathlib import Path -from typing import Optional, Union +from typing import Any, NoReturn, Optional import datalad import datalad.api as dl from datalad.support.exceptions import IncompleteResultsError from datalad.support.gitrepo import GitRepo +from pydantic import Field, HttpUrl, field_validator from ..pipeline import WorkDirManager from ..typing import Element @@ -25,6 +26,23 @@ __all__ = ["DataladDataGrabber"] +def _create_datadir() -> Path: + """Create a temporary directory for datalad dataset.""" + datadir = WorkDirManager().get_tempdir(prefix="datalad") + logger.info( + "Created a temporary directory for datalad dataset at: " + f"{datadir.resolve()!s}" + ) + return datadir + + +def _remove_datadir(datadir: Path) -> None: + """Remove temporary directory if it exists.""" + if datadir.exists(): + logger.debug(f"Removing temporary directory at: {datadir.resolve()!s}") + WorkDirManager().delete_tempdir(datadir) + + class DataladDataGrabber(BaseDataGrabber): """Abstract base class for datalad-based data fetching. @@ -32,17 +50,15 @@ class DataladDataGrabber(BaseDataGrabber): Parameters ---------- - rootdir : str or pathlib.Path, optional + uri : pydantic.HttpUrl + URI of the datalad sibling. + rootdir : pathlib.Path, optional The path within the datalad dataset to the root directory - (default "."). - datadir : str or pathlib.Path or None, optional - That directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). - uri : str or None, optional - URI of the datalad sibling (default None). - **kwargs - Keyword arguments passed to superclass. + (default Path(".")). + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. Methods ------- @@ -69,22 +85,41 @@ class DataladDataGrabber(BaseDataGrabber): """ - def __init__( - self, - rootdir: Union[str, Path] = ".", - datadir: Union[str, Path, None] = None, - uri: Optional[str] = None, - **kwargs, - ): - if datadir is None: - logger.info("`datadir` is None, creating a temporary directory") - # Create temporary directory - tmpdir = WorkDirManager().get_tempdir(prefix="datalad") - self._tmpdir = tmpdir - datadir = tmpdir / "datadir" - datadir.mkdir(parents=True, exist_ok=False) - logger.info(f"`datadir` set to {datadir}") - cache_dir = tmpdir / ".datalad_cache" + uri: HttpUrl = Field(frozen=True) + rootdir: Path = Field(frozen=True, default=Path(".")) + datadir: Path = Field(default_factory=lambda: _create_datadir()) + _repodir: Path = Path(".") + # Flag to indicate if the dataset was cloned before and it might be + # dirty + datalad_dirty: bool = False + datalad_commit_id: Optional[str] = None + datalad_id: Optional[str] = None + _dataset: Optional[dl.Dataset] = None + _got_files: list[str] = [] # noqa: RUF012 + _was_cloned: bool = False + + @field_validator("datalad_dirty", mode="before") + @classmethod + def disable_tag(cls, value: Any) -> NoReturn: + """Disable setting datalad_dirty directly.""" + raise_error( + msg="datalad_dirty cannot be set directly", + klass=ValueError, + ) + + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" + logger.debug("Initializing DataladDataGrabber") + logger.debug(f"\turi = {self.uri}") + logger.debug(f"\trootdir = {self.rootdir}") + if self.datadir.stem.startswith("datalad"): + self._repodir = self.datadir / "dataset" + self._repodir.mkdir(parents=True, exist_ok=False) + logger.info( + "Datalad dataset installation path set to: " + f"{self._repodir.resolve()!s}" + ) + cache_dir = self.datadir / ".datalad_cache" sockets_dir = cache_dir / "sockets" locks_dir = cache_dir / "locks" sockets_dir.mkdir(parents=True, exist_ok=False) @@ -106,43 +141,27 @@ def __init__( "Datalad locks set to " f"{datalad.cfg.get('datalad.locations.locks')}" ) - atexit.register(self._rmtmpdir) - # TODO: uri can be converted to a positional argument - if uri is None: - raise_error("`uri` must be provided") - - super().__init__(datadir=datadir, **kwargs) - logger.debug("Initializing DataladDataGrabber") - logger.debug(f"\turi = {uri}") - logger.debug(f"\t_rootdir = {rootdir}") - self.uri = uri - self._rootdir = rootdir - # Flag to indicate if the dataset was cloned before and it might be - # dirty - self.datalad_dirty = False + atexit.register(_remove_datadir, self.datadir) + else: + self._repodir = self.datadir + super().validate_datagrabber_params() def __del__(self) -> None: """Destructor.""" - if hasattr(self, "_tmpdir"): - self._rmtmpdir() - - def _rmtmpdir(self) -> None: - """Remove temporary directory if it exists.""" - if self._tmpdir.exists(): - logger.debug("Removing temporary directory") - WorkDirManager().delete_tempdir(self._tmpdir) + if self.datadir.stem.startswith("datalad"): + _remove_datadir(self.datadir) @property - def datadir(self) -> Path: - """Get data directory path. + def fulldir(self) -> Path: + """Get complete data directory path. Returns ------- pathlib.Path - Path to the data directory. + Complete path to the data directory. """ - return super().datadir / self._rootdir + return self._repodir / self.rootdir def _get_dataset_id_remote(self) -> tuple[str, bool]: """Get the dataset ID from the remote. @@ -165,8 +184,10 @@ def _get_dataset_id_remote(self) -> tuple[str, bool]: with tempfile.TemporaryDirectory() as tmpdir: if not config.get("datagrabber.skipidcheck", False): logger.debug(f"Querying {self.uri} for dataset ID") - repo = GitRepo.clone( - self.uri, path=tmpdir, clone_options=["-n", "--depth=1"] + repo: GitRepo = GitRepo.clone( + str(self.uri), + path=tmpdir, + clone_options=["-n", "--depth=1"], ) repo.checkout(name=".datalad/config", options=["HEAD"]) remote_id = repo.config.get("datalad.dataset.id", None) @@ -179,6 +200,7 @@ def _get_dataset_id_remote(self) -> tuple[str, bool]: is_dirty = False else: logger.debug("Skipping dataset ID check") + # Should be already set to the dataset remote_id = self._dataset.id is_dirty = False logger.debug( @@ -252,7 +274,7 @@ def _dataset_get(self, out: dict) -> dict: return out def install(self) -> None: - """Install the datalad dataset into the ``datadir``. + """Installs the datalad dataset. Raises ------ @@ -262,12 +284,10 @@ def install(self) -> None: If there is a datalad-related problem while cloning dataset. """ - isinstalled = dl.Dataset(self._datadir).is_installed() - if isinstalled: + is_installed = dl.Dataset(self._repodir).is_installed() + if is_installed: logger.debug("Dataset already installed") - self._got_files = [] - self._dataset: dl.Dataset = dl.Dataset(self._datadir) - + self._dataset = dl.Dataset(self._repodir) # Check if dataset is already installed with a different ID remote_id, is_dirty = self._get_dataset_id_remote() if remote_id != self._dataset.id: @@ -275,7 +295,6 @@ def install(self) -> None: "Dataset already installed but with a different " f"ID: {self._dataset.id} (local) != {remote_id} (remote)" ) - # Conditional reporting on dataset dirtiness self.datalad_dirty = is_dirty if self.datalad_dirty: @@ -287,18 +306,18 @@ def install(self) -> None: logger.debug(f"Dataset (id: {self._dataset.id}) is clean") else: - logger.debug(f"Installing dataset {self.uri} to {self._datadir}") + logger.debug(f"Installing dataset {self.uri} to {self._repodir}") try: - self._dataset: dl.Dataset = dl.clone( # type: ignore - self.uri, self._datadir, result_renderer="disabled" + self._dataset = dl.clone( + self.uri, self._repodir, result_renderer="disabled" ) except IncompleteResultsError as e: raise_error(f"Failed to clone dataset: {e.failed}") logger.debug("Dataset installed") - self._was_cloned = not isinstalled - - self.datalad_commit_id = self._dataset.repo.get_hexsha( # type: ignore - self._dataset.repo.get_corresponding_branch() # type: ignore + self._was_cloned = not is_installed + # Dataset should be set already + self.datalad_commit_id = self._dataset.repo.get_hexsha( + self._dataset.repo.get_corresponding_branch() ) self.datalad_id = self._dataset.id @@ -321,7 +340,7 @@ def __getitem__(self, element: Element) -> dict: Parameters ---------- - element : str or tuple of str + element : `Element` The element to be indexed. If one string is provided, it is assumed to be a tuple with only one item. If a tuple is provided, each item in the tuple is the value for the replacement string diff --git a/junifer/datagrabber/dmcc13_benchmark.py b/junifer/datagrabber/dmcc13_benchmark.py index 345a265fd..6af9e6c84 100644 --- a/junifer/datagrabber/dmcc13_benchmark.py +++ b/junifer/datagrabber/dmcc13_benchmark.py @@ -5,11 +5,14 @@ from enum import Enum from itertools import product -from pathlib import Path -from typing import Union +from typing import Literal + +from pydantic import HttpUrl from ..api.decorators import register_datagrabber -from ..utils import raise_error +from ..typing import DataGrabberPatterns +from .base import DataType +from .pattern import ConfoundsFormat from .pattern_datalad import PatternDataladDataGrabber @@ -60,191 +63,145 @@ class DMCC13Benchmark(PatternDataladDataGrabber): Parameters ---------- - datadir : str or Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). - types: {"BOLD", "T1w", "VBM_CSF", "VBM_GM", "VBM_WM"} or \ - list of the options, optional - DMCC data types. If None, all available data types are selected. - (default None). - sessions: {"ses-wave1bas", "ses-wave1pro", "ses-wave1rea"} or \ - list of the options, optional - DMCC sessions. If None, all available sessions are selected - (default None). - tasks: {"Rest", "Axcpt", "Cuedts", "Stern", "Stroop"} or \ - list of the options, optional - DMCC task sessions. If None, all available task sessions are selected - (default None). - phase_encodings : {"AP", "PA"} or list of the options, optional - DMCC phase encoding directions. If None, all available phase encodings - are selected (default None). - runs : {"1", "2"} or list of the options, optional - DMCC runs. If None, all available runs are selected (default None). + types : list of {``DataType.BOLD``, ``DataType.T1w``, \ + ``DataType.VBM_CSF``, ``DataType.VBM_GM``, ``DataType.VBM_WM``, \ + ``DataType.Warp``}, optional + The data type(s) to grab. + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. + sessions : list of :enum:`.DMCCSession`, optional + DMCC sessions. + By default, all available sessions are selected. + tasks : list of :enum:`.DMCCTask`, optional + DMCC tasks. + By default, all available tasks are selected. + phase_encodings : list of :enum:`.DMCCPhaseEncoding`, optional + DMCC phase encoding directions. + By default, all available phase encodings are selected. + runs : list of :enum:`.DMCCRun`, optional + DMCC runs. + By default, all available runs are selected. native_t1w : bool, optional Whether to use T1w in native space (default False). - Raises - ------ - ValueError - If invalid value is passed for: - * ``sessions`` - * ``tasks`` - * ``phase_encodings`` - * ``runs`` - """ - def __init__( - self, - datadir: Union[str, Path, None] = None, - types: Union[str, list[str], None] = None, - sessions: Union[str, list[str], None] = None, - tasks: Union[str, list[str], None] = None, - phase_encodings: Union[str, list[str], None] = None, - runs: Union[str, list[str], None] = None, - native_t1w: bool = False, - ) -> None: - # Declare all sessions - all_sessions = [ - "ses-wave1bas", - "ses-wave1pro", - "ses-wave1rea", + uri: HttpUrl = HttpUrl("https://github.com/OpenNeuroDatasets/ds003452.git") + types: list[ + Literal[ + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.Warp, ] - # Set default sessions - if sessions is None: - sessions = all_sessions - else: - # Convert single session into list - if isinstance(sessions, str): - sessions = [sessions] - # Verify valid sessions - for s in sessions: - if s not in all_sessions: - raise_error( - f"{s} is not a valid session in the DMCC dataset" - ) - self.sessions = sessions - # Declare all tasks - all_tasks = [ - "Rest", - "Axcpt", - "Cuedts", - "Stern", - "Stroop", - ] - # Set default tasks - if tasks is None: - tasks = all_tasks - else: - # Convert single task into list - if isinstance(tasks, str): - tasks = [tasks] - # Verify valid tasks - for t in tasks: - if t not in all_tasks: - raise_error(f"{t} is not a valid task in the DMCC dataset") - self.tasks = tasks - # Declare all phase encodings - all_phase_encodings = ["AP", "PA"] - # Set default phase encodings - if phase_encodings is None: - phase_encodings = all_phase_encodings - else: - # Convert single phase encoding into list - if isinstance(phase_encodings, str): - phase_encodings = [phase_encodings] - # Verify valid phase encodings - for p in phase_encodings: - if p not in all_phase_encodings: - raise_error( - f"{p} is not a valid phase encoding in the DMCC " - "dataset" - ) - self.phase_encodings = phase_encodings - # Declare all runs - all_runs = ["1", "2"] - # Set default runs - if runs is None: - runs = all_runs - else: - # Convert single run into list - if isinstance(runs, str): - runs = [runs] - # Verify valid runs - for r in runs: - if r not in all_runs: - raise_error(f"{r} is not a valid run in the DMCC dataset") - self.runs = runs - # The patterns - patterns = { - "BOLD": { + ] = [ # noqa: RUF012 + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + ] + sessions: list[DMCCSession] = [ # noqa: RUF012 + DMCCSession.Wave1Bas, + DMCCSession.Wave1Pro, + DMCCSession.Wave1Rea, + ] + tasks: list[DMCCTask] = [ # noqa: RUF012 + DMCCTask.Rest, + DMCCTask.Axcpt, + DMCCTask.Cuedts, + DMCCTask.Stern, + DMCCTask.Stroop, + ] + phase_encodings: list[DMCCPhaseEncoding] = [ # noqa: RUF012 + DMCCPhaseEncoding.AP, + DMCCPhaseEncoding.PA, + ] + runs: list[DMCCRun] = [ # noqa: RUF012 + DMCCRun.One, + DMCCRun.Two, + ] + native_t1w: bool = False + patterns: DataGrabberPatterns = { # noqa: RUF012 + "BOLD": { + "pattern": ( + "derivatives/fmriprep-1.3.2/{subject}/{session}/" + "func/{subject}_{session}_task-{task}_acq-mb4" + "{phase_encoding}_run-{run}_" + "space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + "mask": { "pattern": ( "derivatives/fmriprep-1.3.2/{subject}/{session}/" "func/{subject}_{session}_task-{task}_acq-mb4" "{phase_encoding}_run-{run}_" - "space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz" - ), - "space": "MNI152NLin2009cAsym", - "mask": { - "pattern": ( - "derivatives/fmriprep-1.3.2/{subject}/{session}/" - "func/{subject}_{session}_task-{task}_acq-mb4" - "{phase_encoding}_run-{run}_" - "space-MNI152NLin2009cAsym_desc-brain_mask.nii.gz" - ), - "space": "MNI152NLin2009cAsym", - }, - "confounds": { - "pattern": ( - "derivatives/fmriprep-1.3.2/{subject}/{session}/" - "func/{subject}_{session}_task-{task}_acq-mb4" - "{phase_encoding}_run-{run}_desc-confounds_regressors.tsv" - ), - "format": "fmriprep", - }, - }, - "T1w": { - "pattern": ( - "derivatives/fmriprep-1.3.2/{subject}/anat/" - "{subject}_space-MNI152NLin2009cAsym_desc-preproc_T1w.nii.gz" + "space-MNI152NLin2009cAsym_desc-brain_mask.nii.gz" ), "space": "MNI152NLin2009cAsym", - "mask": { - "pattern": ( - "derivatives/fmriprep-1.3.2/{subject}/anat/" - "{subject}_space-MNI152NLin2009cAsym_desc-brain_mask.nii.gz" - ), - "space": "MNI152NLin2009cAsym", - }, }, - "VBM_CSF": { + "confounds": { "pattern": ( - "derivatives/fmriprep-1.3.2/{subject}/anat/" - "{subject}_space-MNI152NLin2009cAsym_label-CSF_probseg.nii.gz" - ), - "space": "MNI152NLin2009cAsym", - }, - "VBM_GM": { - "pattern": ( - "derivatives/fmriprep-1.3.2/{subject}/anat/" - "{subject}_space-MNI152NLin2009cAsym_label-GM_probseg.nii.gz" + "derivatives/fmriprep-1.3.2/{subject}/{session}/" + "func/{subject}_{session}_task-{task}_acq-mb4" + "{phase_encoding}_run-{run}_desc-confounds_regressors.tsv" ), - "space": "MNI152NLin2009cAsym", + "format": "fmriprep", }, - "VBM_WM": { + }, + "T1w": { + "pattern": ( + "derivatives/fmriprep-1.3.2/{subject}/anat/" + "{subject}_space-MNI152NLin2009cAsym_desc-preproc_T1w.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + "mask": { "pattern": ( "derivatives/fmriprep-1.3.2/{subject}/anat/" - "{subject}_space-MNI152NLin2009cAsym_label-WM_probseg.nii.gz" + "{subject}_space-MNI152NLin2009cAsym_desc-brain_mask.nii.gz" ), "space": "MNI152NLin2009cAsym", }, - } - # Use native T1w assets - self.native_t1w = False - if native_t1w: - self.native_t1w = True - patterns.update( + }, + "VBM_CSF": { + "pattern": ( + "derivatives/fmriprep-1.3.2/{subject}/anat/" + "{subject}_space-MNI152NLin2009cAsym_label-CSF_probseg.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + }, + "VBM_GM": { + "pattern": ( + "derivatives/fmriprep-1.3.2/{subject}/anat/" + "{subject}_space-MNI152NLin2009cAsym_label-GM_probseg.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + }, + "VBM_WM": { + "pattern": ( + "derivatives/fmriprep-1.3.2/{subject}/anat/" + "{subject}_space-MNI152NLin2009cAsym_label-WM_probseg.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + }, + } + replacements: list[str] = [ # noqa: RUF012 + "subject", + "session", + "task", + "phase_encoding", + "run", + ] + confounds_format: ConfoundsFormat = ConfoundsFormat.FMRIPrep + + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" + if self.native_t1w: + self.patterns.update( { "T1w": { "pattern": ( @@ -284,24 +241,8 @@ def __init__( ], } ) - # Set default types - if types is None: - types = list(patterns.keys()) - # Convert single type into list - else: - if not isinstance(types, list): - types = [types] - # The replacements - replacements = ["subject", "session", "task", "phase_encoding", "run"] - uri = "https://github.com/OpenNeuroDatasets/ds003452.git" - super().__init__( - types=types, - datadir=datadir, - uri=uri, - patterns=patterns, - replacements=replacements, - confounds_format="fmriprep", - ) + self.types.append(DataType.Warp) + super().validate_datagrabber_params() def get_item( self, @@ -311,7 +252,7 @@ def get_item( phase_encoding: str, run: str, ) -> dict: - """Index one element in the dataset. + """Get the specified item from the dataset. Parameters ---------- diff --git a/junifer/datagrabber/hcp1200/datalad_hcp1200.py b/junifer/datagrabber/hcp1200/datalad_hcp1200.py index 50df47f17..0f6189f0d 100644 --- a/junifer/datagrabber/hcp1200/datalad_hcp1200.py +++ b/junifer/datagrabber/hcp1200/datalad_hcp1200.py @@ -6,11 +6,13 @@ # License: AGPL from pathlib import Path -from typing import Union +from typing import Literal -from junifer.datagrabber.datalad_base import DataladDataGrabber +from pydantic import HttpUrl from ...api.decorators import register_datagrabber +from ..base import DataType +from ..datalad_base import DataladDataGrabber from .hcp1200 import HCP1200 @@ -23,50 +25,37 @@ class DataladHCP1200(DataladDataGrabber, HCP1200): Parameters ---------- - datadir : str or Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). - tasks : {"REST1", "REST2", "SOCIAL", "WM", "RELATIONAL", "EMOTION", \ - "LANGUAGE", "GAMBLING", "MOTOR"} or list of the options or None \ - , optional - HCP task sessions. If None, all available task sessions are selected - (default None). - phase_encodings : {"LR", "RL"} or list of the options or None, optional - HCP phase encoding directions. If None, both will be used - (default None). + types : list of {``DataType.BOLD``, ``DataType.T1w``, ``DataType.Warp``}, \ + optional + The data type(s) to grab. + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. + tasks : list of :enum:`.HCP1200Task`, optional + HCP task sessions. + By default, all available task sessions are selected. + phase_encodings : list of :enum:`.HCP1200PhaseEncoding`, optional + HCP phase encoding directions. + By default, all are used. ica_fix : bool, optional Whether to retrieve data that was processed with ICA+FIX. - Only "REST1" and "REST2" tasks are available with ICA+FIX (default - False). - - Raises - ------ - ValueError - If invalid value is passed for ``tasks`` or ``phase_encodings``. + Only ``HCP1200Task.REST1`` and ``HCP1200Task.REST2`` tasks + are available with ICA+FIX + (default False). """ - def __init__( - self, - datadir: Union[str, Path, None] = None, - tasks: Union[str, list[str], None] = None, - phase_encodings: Union[str, list[str], None] = None, - ica_fix: bool = False, - ) -> None: - uri = ( - "https://github.com/datalad-datasets/" - "human-connectome-project-openaccess.git" - ) - rootdir = "HCP1200" - super().__init__( - datadir=datadir, - tasks=tasks, - phase_encodings=phase_encodings, - uri=uri, - rootdir=rootdir, - ica_fix=ica_fix, - ) + uri: HttpUrl = HttpUrl( + "https://github.com/datalad-datasets/" + "human-connectome-project-openaccess.git" + ) + types: list[Literal[DataType.BOLD, DataType.T1w, DataType.Warp]] = [ # noqa: RUF012 + DataType.BOLD, + DataType.T1w, + DataType.Warp, + ] + rootdir: Path = Path("HCP1200") # Needed here as HCP1200's subjects are sub-datasets, so will not be # found when elements are checked. diff --git a/junifer/datagrabber/hcp1200/hcp1200.py b/junifer/datagrabber/hcp1200/hcp1200.py index 049623d43..722fa1f2a 100644 --- a/junifer/datagrabber/hcp1200/hcp1200.py +++ b/junifer/datagrabber/hcp1200/hcp1200.py @@ -7,11 +7,12 @@ from enum import Enum from itertools import product -from pathlib import Path -from typing import Union +from typing import Literal from ...api.decorators import register_datagrabber +from ...typing import DataGrabberPatterns from ...utils import raise_error +from ..base import DataType from ..pattern import PatternDataGrabber @@ -45,135 +46,96 @@ class HCP1200(PatternDataGrabber): Parameters ---------- - datadir : str or Path, optional - The directory where the data is / will be stored. - tasks : {"REST1", "REST2", "SOCIAL", "WM", "RELATIONAL", "EMOTION", \ - "LANGUAGE", "GAMBLING", "MOTOR"} or list of the options or None \ - , optional - HCP task sessions. If None, all available task sessions are selected - (default None). - phase_encodings : {"LR", "RL"} or list of the options or None, optional - HCP phase encoding directions. If None, both will be used - (default None). + types : list of {``DataType.BOLD``, ``DataType.T1w``, ``DataType.Warp``}, \ + optional + The data type(s) to grab. + datadir : pathlib.Path + The path where the data is stored. + tasks : list of :enum:`.HCP1200Task`, optional + HCP task sessions. + By default, all available task sessions are selected. + phase_encodings : list of :enum:`.HCP1200PhaseEncoding`, optional + HCP phase encoding directions. + By default, all are used. ica_fix : bool, optional Whether to retrieve data that was processed with ICA+FIX. - Only "REST1" and "REST2" tasks are available with ICA+FIX (default - False). - - Raises - ------ - ValueError - If invalid value is passed for ``tasks`` or ``phase_encodings``. + Only ``HCP1200Task.REST1`` and ``HCP1200Task.REST2`` tasks + are available with ICA+FIX + (default False). """ - def __init__( - self, - datadir: Union[str, Path], - tasks: Union[str, list[str], None] = None, - phase_encodings: Union[str, list[str], None] = None, - ica_fix: bool = False, - ) -> None: - # All tasks - all_tasks = [ - "REST1", - "REST2", - "SOCIAL", - "WM", - "RELATIONAL", - "EMOTION", - "LANGUAGE", - "GAMBLING", - "MOTOR", - ] - # Set default tasks - if tasks is None: - self.tasks: list[str] = all_tasks - # Convert single task into list - else: - if not isinstance(tasks, list): - tasks = [tasks] - # Check for invalid task(s) - for task in tasks: - if task not in all_tasks: - raise_error( - f"'{task}' is not a valid HCP-YA fMRI task input. " - f"Valid task values can be any or all of {all_tasks}." - ) - self.tasks: list[str] = tasks - - # All phase encodings - all_phase_encodings = ["LR", "RL"] - # Set phase encodings - if phase_encodings is None: - phase_encodings = all_phase_encodings - # Convert single phase encoding into list - if isinstance(phase_encodings, str): - phase_encodings = [phase_encodings] - # Check for invalid phase encoding(s) - for pe in phase_encodings: - if pe not in all_phase_encodings: - raise_error( - f"'{pe}' is not a valid HCP-YA phase encoding. " - "Valid phase encoding can be any or all of " - f"{all_phase_encodings}." - ) - self.phase_encodings = phase_encodings + types: list[Literal[DataType.BOLD, DataType.T1w, DataType.Warp]] = [ # noqa: RUF012 + DataType.BOLD, + DataType.T1w, + DataType.Warp, + ] + tasks: list[HCP1200Task] = [ # noqa: RUF012 + HCP1200Task.REST1, + HCP1200Task.REST2, + HCP1200Task.SOCIAL, + HCP1200Task.WM, + HCP1200Task.RELATIONAL, + HCP1200Task.EMOTION, + HCP1200Task.LANGUAGE, + HCP1200Task.GAMBLING, + HCP1200Task.MOTOR, + ] + phase_encodings: list[HCP1200PhaseEncoding] = [ # noqa: RUF012 + HCP1200PhaseEncoding.RL, + HCP1200PhaseEncoding.LR, + ] + ica_fix: bool = False + patterns: DataGrabberPatterns = { # noqa: RUF012 + "BOLD": { + "pattern": ( + "{subject}/MNINonLinear/Results/" + "{task}_{phase_encoding}/" + "{task}_{phase_encoding}" + "{suffix}.nii.gz" + ), + "space": "MNI152NLin6Asym", + }, + "T1w": { + "pattern": "{subject}/T1w/T1w_acpc_dc_restore.nii.gz", + "space": "native", + }, + "Warp": [ + { + "pattern": ( + "{subject}/MNINonLinear/xfms/standard2acpc_dc.nii.gz" + ), + "src": "MNI152NLin6Asym", + "dst": "native", + "warper": "fsl", + }, + { + "pattern": ( + "{subject}/MNINonLinear/xfms/acpc_dc2standard.nii.gz" + ), + "src": "native", + "dst": "MNI152NLin6Asym", + "warper": "fsl", + }, + ], + } + replacements: list[str] = ["subject", "task", "phase_encoding"] # noqa: RUF012 - if ica_fix: + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" + if self.ica_fix: if not all(task in ["REST1", "REST2"] for task in self.tasks): raise_error( "ICA+FIX is only available for 'REST1' and 'REST2' tasks." ) - suffix = "_hp2000_clean" if ica_fix else "" - - # The types of data - types = ["BOLD", "T1w", "Warp"] - # The patterns - patterns = { - "BOLD": { - "pattern": ( - "{subject}/MNINonLinear/Results/" - "{task}_{phase_encoding}/" - "{task}_{phase_encoding}" - f"{suffix}.nii.gz" - ), - "space": "MNI152NLin6Asym", - }, - "T1w": { - "pattern": "{subject}/T1w/T1w_acpc_dc_restore.nii.gz", - "space": "native", - }, - "Warp": [ - { - "pattern": ( - "{subject}/MNINonLinear/xfms/standard2acpc_dc.nii.gz" - ), - "src": "MNI152NLin6Asym", - "dst": "native", - "warper": "fsl", - }, - { - "pattern": ( - "{subject}/MNINonLinear/xfms/acpc_dc2standard.nii.gz" - ), - "src": "native", - "dst": "MNI152NLin6Asym", - "warper": "fsl", - }, - ], - } - # The replacements - replacements = ["subject", "task", "phase_encoding"] - super().__init__( - types=types, - datadir=datadir, - patterns=patterns, - replacements=replacements, - ) + suffix = "_hp2000_clean" if self.ica_fix else "" + self.patterns["BOLD"]["pattern"] = self.patterns["BOLD"][ + "pattern" + ].replace("{suffix}", suffix) + super().validate_datagrabber_params() def get_item(self, subject: str, task: str, phase_encoding: str) -> dict: - """Implement single element indexing in the database. + """Get the specified item from the dataset. Parameters ---------- diff --git a/junifer/datagrabber/multiple.py b/junifer/datagrabber/multiple.py index 5f9f3c9d8..101ff2651 100644 --- a/junifer/datagrabber/multiple.py +++ b/junifer/datagrabber/multiple.py @@ -5,12 +5,17 @@ # Synchon Mandal # License: AGPL +from pathlib import Path from typing import Union +from pydantic import ConfigDict + from ..api.decorators import register_datagrabber -from ..typing import DataGrabberLike +from ..typing import DataGrabberLike, Element from ..utils import deep_update, raise_error -from .base import BaseDataGrabber +from .base import BaseDataGrabber, DataType +from .pattern import PatternDataGrabber +from .pattern_datalad import PatternDataladDataGrabber __all__ = ["MultipleDataGrabber"] @@ -38,22 +43,35 @@ class MultipleDataGrabber(BaseDataGrabber): """ - def __init__(self, datagrabbers: list[DataGrabberLike], **kwargs) -> None: + model_config = ConfigDict(extra="allow") + + datagrabbers: list[ + Union[ + DataGrabberLike, + PatternDataGrabber, + PatternDataladDataGrabber, + ] + ] + types: list[DataType] = [] # noqa: RUF012 + datadir: Path = Path(".") + + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" # Check datagrabbers consistency # Check for same element keys - first_keys = datagrabbers[0].get_element_keys() - for dg in datagrabbers[1:]: + first_keys = self.datagrabbers[0].get_element_keys() + for dg in self.datagrabbers[1:]: if dg.get_element_keys() != first_keys: raise_error( msg="DataGrabbers have different element keys", klass=RuntimeError, ) # Check for no overlapping types (and nested data types) - types = [x for dg in datagrabbers for x in dg.get_types()] + types = [x for dg in self.datagrabbers for x in dg.get_types()] if len(types) != len(set(types)): - if all(hasattr(dg, "patterns") for dg in datagrabbers): - first_patterns = datagrabbers[0].patterns - for dg in datagrabbers[1:]: + if all(hasattr(dg, "patterns") for dg in self.datagrabbers): + first_patterns = self.datagrabbers[0].patterns + for dg in self.datagrabbers[1:]: for data_type in set(types): dtype_pattern = dg.patterns.get(data_type) if dtype_pattern is None: @@ -77,14 +95,13 @@ def __init__(self, datagrabbers: list[DataGrabberLike], **kwargs) -> None: msg="DataGrabbers have overlapping types", klass=RuntimeError, ) - self._datagrabbers = datagrabbers - def __getitem__(self, element: Union[str, tuple]) -> dict: + def __getitem__(self, element: Element) -> dict: """Implement indexing. Parameters ---------- - element : str or tuple + element : `Element` The element to be indexed. If one string is provided, it is assumed to be a tuple with only one item. If a tuple is provided, each item in the tuple is the value for the replacement string @@ -100,7 +117,7 @@ def __getitem__(self, element: Union[str, tuple]) -> dict: out = {} metas = [] - for dg in self._datagrabbers: + for dg in self.datagrabbers: t_out = dg[element] deep_update(out, t_out) # Now get the meta for this datagrabber @@ -121,16 +138,15 @@ def __getitem__(self, element: Union[str, tuple]) -> dict: def __enter__(self) -> "MultipleDataGrabber": """Implement context entry.""" - for dg in self._datagrabbers: + for dg in self.datagrabbers: dg.__enter__() return self def __exit__(self, exc_type, exc_value, exc_traceback) -> None: """Implement context exit.""" - for dg in self._datagrabbers: + for dg in self.datagrabbers: dg.__exit__(exc_type, exc_value, exc_traceback) - # TODO: return type should be List[List[str]], but base type is List[str] def get_types(self) -> list[str]: """Get types. @@ -140,7 +156,7 @@ def get_types(self) -> list[str]: The types of data to be grabbed. """ - types = [x for dg in self._datagrabbers for x in dg.get_types()] + types = [x for dg in self.datagrabbers for x in dg.get_types()] return types def get_element_keys(self) -> list[str]: @@ -155,7 +171,7 @@ def get_element_keys(self) -> list[str]: The element keys. """ - return self._datagrabbers[0].get_element_keys() + return self.datagrabbers[0].get_element_keys() def get_elements(self) -> list: """Get elements. @@ -169,7 +185,7 @@ def get_elements(self) -> list: related DataGrabbers. """ - all_elements = [dg.get_elements() for dg in self._datagrabbers] + all_elements = [dg.get_elements() for dg in self.datagrabbers] elements = set(all_elements[0]) for s in all_elements[1:]: elements.intersection_update(s) diff --git a/junifer/datagrabber/pattern.py b/junifer/datagrabber/pattern.py index 86041b68c..4489e8a3f 100644 --- a/junifer/datagrabber/pattern.py +++ b/junifer/datagrabber/pattern.py @@ -9,9 +9,10 @@ from copy import deepcopy from enum import Enum from pathlib import Path -from typing import Optional, Union +from typing import Optional import numpy as np +from pydantic import Field from ..api.decorators import register_datagrabber from ..typing import DataGrabberPatterns, Elements @@ -23,8 +24,6 @@ __all__ = ["ConfoundsFormat", "PatternDataGrabber"] -# Accepted formats for confounds specification -_CONFOUNDS_FORMATS = ("fmriprep", "adhoc") class ConfoundsFormat(str, Enum): """Accepted confounds format.""" @@ -40,125 +39,16 @@ class PatternDataGrabber(BaseDataGrabber, PatternValidationMixin): Parameters ---------- - types : list of str - The types of data to be grabbed. - patterns : dict - Data type patterns as a dictionary. It has the following schema: - - * ``"T1w"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": { - "mask": { - "mandatory": ["pattern", "space"], - "optional": [] - } - } - } - - * ``"T2w"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": { - "mask": { - "mandatory": ["pattern", "space"], - "optional": [] - } - } - } - - * ``"BOLD"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": { - "mask": { - "mandatory": ["pattern", "space"], - "optional": [] - } - "confounds": { - "mandatory": ["pattern", "format"], - "optional": [] - } - } - } - - * ``"Warp"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "src", "dst", "warper"], - "optional": [] - } - - * ``"VBM_GM"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": [] - } - - * ``"VBM_WM"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": [] - } - - Basically, for each data type, one needs to provide ``mandatory`` keys - and can choose to also provide ``optional`` keys. The value for each - key is a string. So, one needs to provide necessary data types as a - dictionary, for example: - - .. code-block:: none - - { - "BOLD": { - "pattern": "...", - "space": "...", - }, - "T1w": { - "pattern": "...", - "space": "...", - }, - } - - except ``Warp``, which needs to be a list of dictionaries as there can - be multiple spaces to warp (for example, with fMRIPrep): - - .. code-block:: none - - { - "Warp": [ - { - "pattern": "...", - "src": "...", - "dst": "...", - "warper": "...", - }, - ], - } - - taken from :class:`.HCP1200`. - replacements : str or list of str - Replacements in the ``pattern`` key of each data type. The value needs - to be a list of all possible replacements. - datadir : str or pathlib.Path - The directory where the data is / will be stored. - confounds_format : {"fmriprep", "adhoc"} or None, optional + types : list of :enum:`.DataType` + The data type(s) to grab. + datadir : pathlib.Path + The path where the data is stored. + patterns : ``DataGrabberPatterns`` + The datagrabber patterns. Check :class:`.DataTypeSchema` for the \ + schema. + replacements : list of str + All possible replacements in ``patterns..pattern``. + confounds_format : :enum:`.ConfoundsFormat` or None, optional The format of the confounds for the dataset (default None). partial_pattern_ok : bool, optional Whether to raise error if partial pattern for a data type is found. @@ -168,52 +58,30 @@ class PatternDataGrabber(BaseDataGrabber, PatternValidationMixin): powerful when used with :class:`.MultipleDataGrabber` (default True). - Raises - ------ - ValueError - If ``confounds_format`` is invalid. + Attributes + ---------- + skip_file_check """ - def __init__( - self, - types: list[str], - patterns: DataGrabberPatterns, - replacements: Union[list[str], str], - datadir: Union[str, Path], - confounds_format: Optional[str] = None, - partial_pattern_ok: bool = False, - ) -> None: - # Convert replacements to list if not already - if not isinstance(replacements, list): - replacements = [replacements] + patterns: DataGrabberPatterns = Field(frozen=True) + replacements: list[str] = Field(frozen=True) + confounds_format: Optional[ConfoundsFormat] = Field(None, frozen=True) + partial_pattern_ok: bool = Field(False, frozen=True) + + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" # Validate patterns self.validate_patterns( - types=types, - replacements=replacements, - patterns=patterns, - partial_pattern_ok=partial_pattern_ok, + types=self.types, + replacements=self.replacements, + patterns=self.patterns, + partial_pattern_ok=self.partial_pattern_ok, ) - self.replacements = replacements - self.patterns = patterns - self.partial_pattern_ok = partial_pattern_ok - - # Validate confounds format - if ( - confounds_format is not None - and confounds_format not in _CONFOUNDS_FORMATS - ): - raise_error( - "Invalid value for `confounds_format`, should be one of " - f"{_CONFOUNDS_FORMATS}." - ) - self.confounds_format = confounds_format - - super().__init__(types=types, datadir=datadir) logger.debug("Initializing PatternDataGrabber") - logger.debug(f"\tpatterns = {patterns}") - logger.debug(f"\treplacements = {replacements}") - logger.debug(f"\tconfounds_format = {confounds_format}") + logger.debug(f"\tpatterns = {self.patterns}") + logger.debug(f"\treplacements = {self.replacements}") + logger.debug(f"\tconfounds_format = {self.confounds_format}") @property def skip_file_check(self) -> bool: @@ -331,7 +199,7 @@ def _get_path_from_patterns( resolved_pattern = self._replace_patterns_glob(element, pattern) # Resolve path for wildcard if "*" in resolved_pattern: - t_matches = list(self.datadir.absolute().glob(resolved_pattern)) + t_matches = list(self.fulldir.absolute().glob(resolved_pattern)) # Multiple matches if len(t_matches) > 1: raise_error( @@ -347,7 +215,7 @@ def _get_path_from_patterns( ) path = t_matches[0] else: - path = self.datadir / resolved_pattern + path = self.fulldir / resolved_pattern if not self.skip_file_check: if not path.exists() and not path.is_symlink(): raise_error( @@ -374,7 +242,7 @@ def get_element_keys(self) -> list[str]: return self.replacements def get_item(self, **element: dict) -> dict[str, dict]: - """Implement single element indexing for the datagrabber. + """Get the specified item from the dataset. This method constructs a real path to the requested item's data, by replacing the ``patterns`` with actual values passed via ``**element``. @@ -508,8 +376,8 @@ def get_elements(self) -> Elements: glob_pattern, t_replacements, ) = self._replace_patterns_regex(pattern) - for fname in self.datadir.glob(glob_pattern): - suffix = fname.relative_to(self.datadir).as_posix() + for fname in self.fulldir.glob(glob_pattern): + suffix = fname.relative_to(self.fulldir).as_posix() m = re.match(re_pattern, suffix) if m is not None: # Find the groups of replacements present in the diff --git a/junifer/datagrabber/pattern_datalad.py b/junifer/datagrabber/pattern_datalad.py index 112c1955d..0aef737bb 100644 --- a/junifer/datagrabber/pattern_datalad.py +++ b/junifer/datagrabber/pattern_datalad.py @@ -5,6 +5,8 @@ # Synchon Mandal # License: AGPL +from pydantic import ConfigDict + from ..api.decorators import register_datagrabber from ..utils import logger from .datalad_base import DataladDataGrabber @@ -23,122 +25,24 @@ class PatternDataladDataGrabber(DataladDataGrabber, PatternDataGrabber): Parameters ---------- - types : list of str - The types of data to be grabbed. - patterns : dict - Data type patterns as a dictionary. It has the following schema: - - * ``"T1w"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": { - "mask": { - "mandatory": ["pattern", "space"], - "optional": [] - } - } - } - - * ``"T2w"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": { - "mask": { - "mandatory": ["pattern", "space"], - "optional": [] - } - } - } - - * ``"BOLD"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": { - "mask": { - "mandatory": ["pattern", "space"], - "optional": [] - } - "confounds": { - "mandatory": ["pattern", "format"], - "optional": [] - } - } - } - - * ``"Warp"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "src", "dst"], - "optional": [] - } - - * ``"VBM_GM"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": [] - } - - * ``"VBM_WM"`` : - - .. code-block:: none - - { - "mandatory": ["pattern", "space"], - "optional": [] - } - - Basically, for each data type, one needs to provide ``mandatory`` keys - and can choose to also provide ``optional`` keys. The value for each - key is a string. So, one needs to provide necessary data types as a - dictionary, for example: - - .. code-block:: none - - { - "BOLD": { - "pattern": "...", - "space": "...", - }, - "T1w": { - "pattern": "...", - "space": "...", - }, - "Warp": { - "pattern": "...", - "src": "...", - "dst": "...", - } - } - - taken from :class:`.HCP1200`. - replacements : str or list of str - Replacements in the ``pattern`` key of each data type. The value needs - to be a list of all possible replacements. - confounds_format : {"fmriprep", "adhoc"} or None, optional - The format of the confounds for the dataset (default None). - datadir : str or pathlib.Path or None, optional - That directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). - rootdir : str or pathlib.Path, optional + uri : pydantic.HttpUrl + URI of the datalad sibling. + types : list of :enum:`.DataType` + The data type(s) to grab. + patterns : ``DataGrabberPatterns`` + The datagrabber patterns. Check :class:`DataTypeSchema` for the schema. + replacements : list of str + All possible replacements in ``patterns..pattern``. + rootdir : pathlib.Path, optional The path within the datalad dataset to the root directory - (default "."). - uri : str or None, optional - URI of the datalad sibling (default None). + (default Path(".")). + confounds_format : :enum:`.ConfoundsFormat` or None, optional + The format of the confounds for the dataset (default None). + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. + See Also -------- @@ -149,15 +53,11 @@ class PatternDataladDataGrabber(DataladDataGrabber, PatternDataGrabber): """ - def __init__( - self, - **kwargs, - ) -> None: - # TODO(synchon): needs to be reworked, DataladDataGrabber needs to be - # a mixin to avoid multiple inheritance wherever possible. + model_config = ConfigDict(extra="allow") + def validate_datagrabber_params(self) -> None: + """Run extra logical validation for datagrabber.""" + super().validate_datagrabber_params() logger.debug("Initializing PatternDataladDataGrabber") - for key, val in kwargs.items(): + for key, val in self.__pydantic_extra__.items(): logger.debug(f"\t{key} = {val}") - - super().__init__(**kwargs) From 67a93a6ab7a2aa4dafc2c88669cea2b70f692060 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:18:56 +0100 Subject: [PATCH 39/99] refactor: rename PipelineStepMixin validate to validate_component to avoid pydantic conflict --- junifer/pipeline/marker_collection.py | 6 +++--- junifer/pipeline/pipeline_step_mixin.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/junifer/pipeline/marker_collection.py b/junifer/pipeline/marker_collection.py index 9e25b40e8..9cbf9e0dd 100644 --- a/junifer/pipeline/marker_collection.py +++ b/junifer/pipeline/marker_collection.py @@ -162,7 +162,7 @@ def validate(self, datagrabber: DataGrabberLike) -> None: logger.info(f"DataGrabber output type: {t_data}") logger.info("Validating Data Reader:") - t_data = self._datareader.validate(t_data) + t_data = self._datareader.validate_component(t_data) logger.info(f"Data Reader output type: {t_data}") if self._preprocessors is not None: @@ -175,7 +175,7 @@ def validate(self, datagrabber: DataGrabberLike) -> None: old_t_data = t_data.copy() logger.info(f"Preprocessor input type: {t_data}") # Validate preprocessor - new_t_data = preprocessor.validate(old_t_data) + new_t_data = preprocessor.validate_component(old_t_data) # Set new data types t_data = list(set(old_t_data) | set(new_t_data)) logger.info(f"Preprocessor output type: {t_data}") @@ -183,7 +183,7 @@ def validate(self, datagrabber: DataGrabberLike) -> None: for marker in self._markers: logger.info(f"Validating Marker: {marker.name}") # Validate marker - m_data = marker.validate(input=t_data) + m_data = marker.validate_component(input=t_data) logger.info(f"Marker output type: {m_data}") # Check storage for the marker if self._storage is not None: diff --git a/junifer/pipeline/pipeline_step_mixin.py b/junifer/pipeline/pipeline_step_mixin.py index 926e486e5..836caf541 100644 --- a/junifer/pipeline/pipeline_step_mixin.py +++ b/junifer/pipeline/pipeline_step_mixin.py @@ -75,8 +75,8 @@ def _fit_transform( klass=NotImplementedError, ) # pragma: no cover - def validate(self, input: list[str]) -> list[str]: - """Validate the the pipeline step. + def validate_component(self, input: list[str]) -> list[str]: + """Validate the pipeline component. Parameters ---------- @@ -228,5 +228,6 @@ def fit_transform( The processed output of the pipeline step. """ - self.validate(input=list(input.keys())) + # Needs to be validated if called directly via API + self.validate_component(input=list(input.keys())) return self._fit_transform(input=input, **kwargs) From dffa60ebc9a29569809cff20e1942cd11a030bed Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:21:09 +0100 Subject: [PATCH 40/99] chore: remove BaseFeatureStorage.get_valid_inputs --- junifer/storage/base.py | 12 ------------ junifer/storage/tests/test_hdf5.py | 12 ------------ 2 files changed, 24 deletions(-) diff --git a/junifer/storage/base.py b/junifer/storage/base.py index e1336034c..986ded3a4 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -81,18 +81,6 @@ def model_post_init(self, context: Any): # noqa: D102 ) self.uri.parent.mkdir(parents=True, exist_ok=True) - def get_valid_inputs(self) -> list[str]: - """Get valid storage types for input. - - Returns - ------- - list of str - The list of storage types that can be used as input for this - storage. - - """ - return list(self._STORAGE_TYPES) - def validate(self, input_: list[str]) -> None: """Validate the input to the pipeline step. diff --git a/junifer/storage/tests/test_hdf5.py b/junifer/storage/tests/test_hdf5.py index 9cb4c0d76..82032b189 100644 --- a/junifer/storage/tests/test_hdf5.py +++ b/junifer/storage/tests/test_hdf5.py @@ -21,18 +21,6 @@ ) -def test_get_valid_inputs() -> None: - """Test valid inputs.""" - storage = HDF5FeatureStorage(uri="/tmp") - assert set(storage.get_valid_inputs()) == { - "matrix", - "vector", - "timeseries", - "scalar_table", - "timeseries_2d", - } - - def test_single_output(tmp_path: Path) -> None: """Test single output setup. From d808b44eaf2200e805612d0347fe8ab1207b8507 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:22:06 +0100 Subject: [PATCH 41/99] update: rename BaseFeatureStroage.validate to validate_input to make intent clear --- junifer/pipeline/marker_collection.py | 2 +- junifer/storage/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/junifer/pipeline/marker_collection.py b/junifer/pipeline/marker_collection.py index 9cbf9e0dd..73bd83858 100644 --- a/junifer/pipeline/marker_collection.py +++ b/junifer/pipeline/marker_collection.py @@ -189,4 +189,4 @@ def validate(self, datagrabber: DataGrabberLike) -> None: if self._storage is not None: logger.info(f"Validating storage for {marker.name}") # Validate storage - self._storage.validate(input_=m_data) + self._storage.validate_input(input_=m_data) diff --git a/junifer/storage/base.py b/junifer/storage/base.py index 986ded3a4..b6d28888e 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -81,7 +81,7 @@ def model_post_init(self, context: Any): # noqa: D102 ) self.uri.parent.mkdir(parents=True, exist_ok=True) - def validate(self, input_: list[str]) -> None: + def validate_input(self, input_: list[str]) -> None: """Validate the input to the pipeline step. Parameters From e32e6354704648886df155c71950c23442819ac2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 13:25:44 +0100 Subject: [PATCH 42/99] chore: improve docstrings --- junifer/preprocess/base.py | 4 ++-- junifer/storage/base.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/junifer/preprocess/base.py b/junifer/preprocess/base.py index 020a69825..40ac2015c 100644 --- a/junifer/preprocess/base.py +++ b/junifer/preprocess/base.py @@ -94,14 +94,14 @@ def valid_inputs(self) -> list[DataType]: Returns ------- - list of DataType + list of :enum:`.DataType` The list of data types that can be used as input for this marker. """ return [DataType(x) for x in self._VALID_DATA_TYPES] def validate_preprocessor_params(self) -> None: - """Run extra logical validation for preprocessor params. + """Run extra logical validation for preprocessor. Subclasses can override to provide validation. """ diff --git a/junifer/storage/base.py b/junifer/storage/base.py index b6d28888e..6c55ec55b 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -55,7 +55,7 @@ class BaseFeatureStorage(BaseModel, ABC): Raises ------ AttributeError - If the storage does not have `_STORAGE_TYPES` attribute. + If the storage does not have ``_STORAGE_TYPES`` attribute. """ From 0d79a6dd4c432645bc9092f83fe272de39ef49b2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 17:14:35 +0100 Subject: [PATCH 43/99] update: reorder marker params validation in BaseMarker --- junifer/markers/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/junifer/markers/base.py b/junifer/markers/base.py index 45e2b2949..ae121315d 100644 --- a/junifer/markers/base.py +++ b/junifer/markers/base.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar, Optional from pydantic import BaseModel, ConfigDict @@ -63,8 +63,6 @@ def model_post_init(self, context: Any): # noqa: D102 msg=("Missing `_MARKER_INOUT_MAPPINGS` for the marker"), klass=AttributeError, ) - # Run extra validation for markers and fail early if needed - self.validate_marker_params() # Use all data types if not provided if self.on is None: self.on = self.valid_inputs @@ -80,6 +78,8 @@ def model_post_init(self, context: Any): # noqa: D102 f"{self.__class__.__name__} cannot be computed on " f"{wrong_on}" ) + # Run extra validation for markers and fail early if needed + self.validate_marker_params() # Set default name if not provided self.name = self.__class__.__name__ if self.name is None else self.name From daf72c32f787c088d5dd5c5704a26694109c1aa1 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 17:15:31 +0100 Subject: [PATCH 44/99] fix: make on param for *Aggregation markers None by default --- junifer/markers/maps_aggregation.py | 28 ++++++++++++++------------- junifer/markers/parcel_aggregation.py | 28 ++++++++++++++------------- junifer/markers/sphere_aggregation.py | 28 ++++++++++++++------------- 3 files changed, 45 insertions(+), 39 deletions(-) diff --git a/junifer/markers/maps_aggregation.py b/junifer/markers/maps_aggregation.py index 21fd39518..173b7a68c 100644 --- a/junifer/markers/maps_aggregation.py +++ b/junifer/markers/maps_aggregation.py @@ -90,23 +90,25 @@ class MapsAggregation(BaseMarker): }, } - on: list[ - Literal[ - DataType.T1w, - DataType.T2w, - DataType.BOLD, - DataType.VBM_GM, - DataType.VBM_WM, - DataType.VBM_CSF, - DataType.FALFF, - DataType.GCOR, - DataType.LCOR, - ] - ] maps: str time_method: Optional[str] = None time_method_params: Optional[dict[str, Any]] = None masks: Optional[list[Union[dict, str]]] = None + on: Optional[ + list[ + Literal[ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.VBM_CSF, + DataType.FALFF, + DataType.GCOR, + DataType.LCOR, + ] + ] + ] = None def validate_marker_params(self) -> None: """Run extra logical validation for marker.""" diff --git a/junifer/markers/parcel_aggregation.py b/junifer/markers/parcel_aggregation.py index a0cf9c56a..35966f40f 100644 --- a/junifer/markers/parcel_aggregation.py +++ b/junifer/markers/parcel_aggregation.py @@ -100,25 +100,27 @@ class ParcelAggregation(BaseMarker): }, } - on: list[ - Literal[ - DataType.T1w, - DataType.T2w, - DataType.BOLD, - DataType.VBM_GM, - DataType.VBM_WM, - DataType.VBM_CSF, - DataType.FALFF, - DataType.GCOR, - DataType.LCOR, - ] - ] parcellation: list[str] method: str = "mean" method_params: Optional[dict[str, Any]] = None time_method: Optional[str] = None time_method_params: Optional[dict[str, Any]] = None masks: Optional[list[Union[dict, str]]] = None + on: Optional[ + list[ + Literal[ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.VBM_CSF, + DataType.FALFF, + DataType.GCOR, + DataType.LCOR, + ] + ] + ] = None def validate_marker_params(self) -> None: """Run extra logical validation for marker.""" diff --git a/junifer/markers/sphere_aggregation.py b/junifer/markers/sphere_aggregation.py index 485292126..b4190d031 100644 --- a/junifer/markers/sphere_aggregation.py +++ b/junifer/markers/sphere_aggregation.py @@ -107,19 +107,6 @@ class SphereAggregation(BaseMarker): }, } - on: list[ - Literal[ - DataType.T1w, - DataType.T2w, - DataType.BOLD, - DataType.VBM_GM, - DataType.VBM_WM, - DataType.VBM_CSF, - DataType.FALFF, - DataType.GCOR, - DataType.LCOR, - ] - ] coords: str radius: Optional[Union[Literal[0], PositiveFloat]] = None allow_overlap: bool = False @@ -128,6 +115,21 @@ class SphereAggregation(BaseMarker): time_method: Optional[str] = None time_method_params: Optional[dict[str, Any]] = None masks: Optional[list[Union[dict, str]]] = None + on: Optional[ + list[ + Literal[ + DataType.T1w, + DataType.T2w, + DataType.BOLD, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.VBM_CSF, + DataType.FALFF, + DataType.GCOR, + DataType.LCOR, + ] + ] + ] = None def validate_marker_params(self) -> None: """Run extra logical validation for marker.""" From 213f3ca3d52fc6068750c191864aa448e0e6a965 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 17:17:17 +0100 Subject: [PATCH 45/99] update: make onthefly lazy during import --- junifer/onthefly/__init__.py | 5 ++--- junifer/onthefly/__init__.pyi | 4 ++++ junifer/onthefly/py.typed | 0 3 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 junifer/onthefly/__init__.pyi create mode 100644 junifer/onthefly/py.typed diff --git a/junifer/onthefly/__init__.py b/junifer/onthefly/__init__.py index c96f8568c..c8416fcad 100644 --- a/junifer/onthefly/__init__.py +++ b/junifer/onthefly/__init__.py @@ -3,8 +3,7 @@ # Authors: Synchon Mandal # License: AGPL -from .read_transform import read_transform -from . import _brainprint as brainprint +import lazy_loader as lazy -__all__ = ["read_transform", "brainprint"] +__getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__) diff --git a/junifer/onthefly/__init__.pyi b/junifer/onthefly/__init__.pyi new file mode 100644 index 000000000..f932de2ce --- /dev/null +++ b/junifer/onthefly/__init__.pyi @@ -0,0 +1,4 @@ +__all__ = ["read_transform", "brainprint"] + +from .read_transform import read_transform +from . import _brainprint as brainprint diff --git a/junifer/onthefly/py.typed b/junifer/onthefly/py.typed new file mode 100644 index 000000000..e69de29bb From 915d1a57773af3c245d9d83b503304c428b4778a Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Thu, 13 Nov 2025 17:19:15 +0100 Subject: [PATCH 46/99] update: fix tests for preprocess --- .../tests/test_fmriprep_confound_remover.py | 69 +++++++-------- .../smoothing/tests/test_smoothing.py | 26 +++--- .../preprocess/tests/test_preprocess_base.py | 8 +- .../preprocess/tests/test_temporal_filter.py | 10 +-- .../preprocess/tests/test_temporal_slicer.py | 8 -- .../warping/tests/test_space_warper.py | 84 ++++++++++--------- 6 files changed, 96 insertions(+), 109 deletions(-) diff --git a/junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py b/junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py index d8c8aa3da..bd7eb1253 100644 --- a/junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py +++ b/junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py @@ -14,7 +14,7 @@ from pandas.testing import assert_frame_equal from junifer.datareader import DefaultDataReader -from junifer.preprocess.confounds import fMRIPrepConfoundRemover +from junifer.preprocess import Confounds, fMRIPrepConfoundRemover from junifer.testing import get_testing_data from junifer.testing.datagrabbers import ( OasisVBMTestingDataGrabber, @@ -23,22 +23,6 @@ ) -def test_fMRIPrepConfoundRemover_init() -> None: - """Test fMRIPrepConfoundRemover init.""" - - with pytest.raises(ValueError, match=r"keys must be strings"): - fMRIPrepConfoundRemover(strategy={1: "full"}) # type: ignore - - with pytest.raises(ValueError, match=r"values must be strings"): - fMRIPrepConfoundRemover(strategy={"motion": 1}) # type: ignore - - with pytest.raises(ValueError, match=r"component names"): - fMRIPrepConfoundRemover(strategy={"wrong": "full"}) - - with pytest.raises(ValueError, match=r"confound types"): - fMRIPrepConfoundRemover(strategy={"motion": "wrong"}) - - @pytest.mark.parametrize( "input_", [ @@ -59,10 +43,10 @@ def test_fMRIPrepConfoundRemover_validate_input(input_: list[str]) -> None: confound_remover.validate_input(input_) -def test_fMRIPrepConfoundRemover_get_valid_inputs() -> None: - """Test fMRIPrepConfoundRemover get_valid_inputs.""" +def test_fMRIPrepConfoundRemover_valid_inputs() -> None: + """Test fMRIPrepConfoundRemover valid_inputs.""" confound_remover = fMRIPrepConfoundRemover() - assert confound_remover.get_valid_inputs() == ["BOLD"] + assert confound_remover.valid_inputs == ["BOLD"] def test_fMRIPrepConfoundRemover__map_adhoc_to_fmriprep() -> None: @@ -98,7 +82,9 @@ def test_fMRIPrepConfoundRemover__process_fmriprep_spec() -> None: """Test fMRIPrepConfoundRemover fmriprep spec processing.""" # Test one strategy, full, no spike - confound_remover = fMRIPrepConfoundRemover(strategy={"wm_csf": "full"}) + confound_remover = fMRIPrepConfoundRemover( + strategy={"wm_csf": Confounds.Full} + ) var_names = [ "csf", @@ -144,7 +130,7 @@ def test_fMRIPrepConfoundRemover__process_fmriprep_spec() -> None: # Same strategy, with spike, only basics are present confound_remover = fMRIPrepConfoundRemover( - strategy={"wm_csf": "full"}, spike=0.2 + strategy={"wm_csf": Confounds.Full}, spike=0.2 ) var_names = ["csf", "white_matter"] @@ -171,7 +157,10 @@ def test_fMRIPrepConfoundRemover__process_fmriprep_spec() -> None: # Two component strategy, mixed confounds, no spike confound_remover = fMRIPrepConfoundRemover( - strategy={"wm_csf": "power2", "global_signal": "derivatives"} + strategy={ + "wm_csf": Confounds.Power2, + "global_signal": Confounds.Derivatives, + } ) var_names = ["csf", "white_matter", "global_signal"] @@ -196,7 +185,7 @@ def test_fMRIPrepConfoundRemover__process_fmriprep_spec() -> None: # Test for wrong columns/strategy pairs confound_remover = fMRIPrepConfoundRemover( - strategy={"wm_csf": "full"}, spike=0.2 + strategy={"wm_csf": Confounds.Full}, spike=0.2 ) var_names = ["csf"] confounds_df = pd.DataFrame( @@ -219,7 +208,9 @@ def test_fMRIPrepConfoundRemover__process_fmriprep_spec() -> None: def test_fMRIPrepConfoundRemover__pick_confounds_adhoc() -> None: """Test fMRIPrepConfoundRemover pick confounds on adhoc confounds.""" - confound_remover = fMRIPrepConfoundRemover(strategy={"wm_csf": "full"}) + confound_remover = fMRIPrepConfoundRemover( + strategy={"wm_csf": Confounds.Full} + ) # Use non fmriprep variable names adhoc_names = [f"var{i}" for i in range(2)] adhoc_df = pd.DataFrame(np.random.randn(10, 2), columns=adhoc_names) @@ -252,7 +243,7 @@ def test_fMRIPrepConfoundRemover__pick_confounds_adhoc() -> None: def test_fMRIPRepConfoundRemover__pick_confounds_fmriprep() -> None: """Test fMRIPrepConfoundRemover pick confounds on fmriprep confounds.""" confound_remover = fMRIPrepConfoundRemover( - strategy={"wm_csf": "full"}, spike=0.2 + strategy={"wm_csf": Confounds.Full}, spike=0.2 ) fmriprep_all_vars = [ "csf", @@ -285,7 +276,9 @@ def test_fMRIPRepConfoundRemover__pick_confounds_fmriprep() -> None: def test_fMRIPRepConfoundRemover__pick_confounds_fmriprep_compute() -> None: """Test if fmriprep returns the same derivatives/power2 as we compute.""" - confound_remover = fMRIPrepConfoundRemover(strategy={"wm_csf": "full"}) + confound_remover = fMRIPrepConfoundRemover( + strategy={"wm_csf": Confounds.Full} + ) fmriprep_all_vars = [ "csf", "white_matter", @@ -341,7 +334,9 @@ def test_fMRIPrepConfoundRemover__get_scrub_regressors_errors( def test_fMRIPrepConfoundRemover__validate_data() -> None: """Test fMRIPrepConfoundRemover validate data.""" - confound_remover = fMRIPrepConfoundRemover(strategy={"wm_csf": "full"}) + confound_remover = fMRIPrepConfoundRemover( + strategy={"wm_csf": Confounds.Full} + ) # Check correct data type with OasisVBMTestingDataGrabber() as dg: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) @@ -530,7 +525,7 @@ def test_fMRIPrepConfoundRemover_fit_transform_masks() -> None: """Test fMRIPrepConfoundRemover with all confounds present.""" # All strategies full, no spike confound_remover = fMRIPrepConfoundRemover( - masks={"compute_brain_mask": {"threshold": 0.2}} + masks=[{"compute_brain_mask": {"threshold": 0.2}}] ) with PartlyCloudyTestingDataGrabber(reduce_confounds=False) as dg: @@ -573,13 +568,13 @@ def test_fMRIPrepConfoundRemover_fit_transform_masks() -> None: assert t_meta["low_pass"] is None assert t_meta["high_pass"] is None assert t_meta["t_r"] is None - assert isinstance(t_meta["masks"], dict) + assert isinstance(t_meta["masks"], list) assert t_meta["masks"] is not None assert len(t_meta["masks"]) == 1 - assert "compute_brain_mask" in t_meta["masks"] - assert len(t_meta["masks"]["compute_brain_mask"]) == 1 - assert "threshold" in t_meta["masks"]["compute_brain_mask"] - assert t_meta["masks"]["compute_brain_mask"]["threshold"] == 0.2 + assert "compute_brain_mask" in t_meta["masks"][0] + assert len(t_meta["masks"][0]["compute_brain_mask"]) == 1 + assert "threshold" in t_meta["masks"][0]["compute_brain_mask"] + assert t_meta["masks"][0]["compute_brain_mask"]["threshold"] == 0.2 assert "mask" in output["BOLD"] @@ -592,9 +587,9 @@ def test_fMRIPrepConfoundRemover_scrubbing() -> None: """Test fMRIPrepConfoundRemover with scrubbing.""" confound_remover = fMRIPrepConfoundRemover( strategy={ - "motion": "full", - "wm_csf": "full", - "global_signal": "full", + "motion": Confounds.Full, + "wm_csf": Confounds.Full, + "global_signal": Confounds.Full, "scrubbing": True, }, ) diff --git a/junifer/preprocess/smoothing/tests/test_smoothing.py b/junifer/preprocess/smoothing/tests/test_smoothing.py index cd50fce86..622a04a55 100644 --- a/junifer/preprocess/smoothing/tests/test_smoothing.py +++ b/junifer/preprocess/smoothing/tests/test_smoothing.py @@ -7,20 +7,20 @@ from junifer.datareader import DefaultDataReader from junifer.pipeline.utils import _check_afni, _check_fsl -from junifer.preprocess import Smoothing +from junifer.preprocess import Smoothing, SmoothingImpl from junifer.testing.datagrabbers import SPMAuditoryTestingDataGrabber @pytest.mark.parametrize( "data_type", - ["T1w", "BOLD"], + [["T1w"], ["BOLD"]], ) -def test_Smoothing_nilearn(data_type: str) -> None: +def test_Smoothing_nilearn(data_type: list[str]) -> None: """Test Smoothing using nilearn. Parameters ---------- - data_type : str + data_type : list of str The parametrized data type. """ @@ -29,7 +29,7 @@ def test_Smoothing_nilearn(data_type: str) -> None: element_data = DefaultDataReader().fit_transform(dg["sub001"]) # Preprocess data output = Smoothing( - using="nilearn", + using=SmoothingImpl.nilearn, on=data_type, smoothing_params={"fwhm": "fast"}, ).fit_transform(element_data) @@ -39,17 +39,17 @@ def test_Smoothing_nilearn(data_type: str) -> None: @pytest.mark.parametrize( "data_type", - ["T1w", "BOLD"], + [["T1w"], ["BOLD"]], ) @pytest.mark.skipif( _check_afni() is False, reason="requires AFNI to be in PATH" ) -def test_Smoothing_afni(data_type: str) -> None: +def test_Smoothing_afni(data_type: list[str]) -> None: """Test Smoothing using AFNI. Parameters ---------- - data_type : str + data_type : list of tr The parametrized data type. """ @@ -58,7 +58,7 @@ def test_Smoothing_afni(data_type: str) -> None: element_data = DefaultDataReader().fit_transform(dg["sub001"]) # Preprocess data output = Smoothing( - using="afni", + using=SmoothingImpl.afni, on=data_type, smoothing_params={"fwhm": 3}, ).fit_transform(element_data) @@ -68,15 +68,15 @@ def test_Smoothing_afni(data_type: str) -> None: @pytest.mark.parametrize( "data_type", - ["T1w", "BOLD"], + [["T1w"], ["BOLD"]], ) @pytest.mark.skipif(_check_fsl() is False, reason="requires FSL to be in PATH") -def test_Smoothing_fsl(data_type: str) -> None: +def test_Smoothing_fsl(data_type: list[str]) -> None: """Test Smoothing using FSL. Parameters ---------- - data_type : str + data_type : list of str The parametrized data type. """ @@ -85,7 +85,7 @@ def test_Smoothing_fsl(data_type: str) -> None: element_data = DefaultDataReader().fit_transform(dg["sub001"]) # Preprocess data output = Smoothing( - using="fsl", + using=SmoothingImpl.fsl, on=data_type, smoothing_params={"brightness_threshold": 10.0, "fwhm": 3.0}, ).fit_transform(element_data) diff --git a/junifer/preprocess/tests/test_preprocess_base.py b/junifer/preprocess/tests/test_preprocess_base.py index 24322ffc6..6849adf85 100644 --- a/junifer/preprocess/tests/test_preprocess_base.py +++ b/junifer/preprocess/tests/test_preprocess_base.py @@ -15,7 +15,7 @@ def test_base_preprocessor_abstractness() -> None: """Test BasePreprocessor is abstract base class.""" with pytest.raises(TypeError, match=r"abstract"): - BasePreprocessor(on=["BOLD"]) # type: ignore + BasePreprocessor(on=["BOLD"]) def test_base_preprocessor_subclassing() -> None: @@ -25,9 +25,7 @@ def test_base_preprocessor_subclassing() -> None: class MyBasePreprocessor(BasePreprocessor): _VALID_DATA_TYPES: ClassVar[Sequence[str]] = ["BOLD", "T1w"] - def __init__(self, on): - self.parameter = 1 - super().__init__(on=on) + parameter: int = 1 def preprocess(self, input, extra_input=None): input["data"] = f"modified_{input['data']}" @@ -37,7 +35,7 @@ def preprocess(self, input, extra_input=None): MyBasePreprocessor(on=["BOLD", "T2w"]) with pytest.raises(ValueError, match=r"cannot be computed on \['T2w'\]"): - MyBasePreprocessor(on="T2w") + MyBasePreprocessor(on=["T2w"]) # Create input for marker input_ = { diff --git a/junifer/preprocess/tests/test_temporal_filter.py b/junifer/preprocess/tests/test_temporal_filter.py index 7efb64a7d..631c23ec0 100644 --- a/junifer/preprocess/tests/test_temporal_filter.py +++ b/junifer/preprocess/tests/test_temporal_filter.py @@ -29,7 +29,7 @@ 0.1, None, None, - "compute_brain_mask", + ["compute_brain_mask"], ], [ True, @@ -37,7 +37,7 @@ None, 0.08, None, - "compute_background_mask", + ["compute_background_mask"], ], [ False, @@ -53,7 +53,7 @@ 0.1, 0.08, 2, - "compute_brain_mask", + ["compute_brain_mask"], ], ), ) @@ -63,7 +63,7 @@ def test_TemporalFilter( low_pass: Optional[float], high_pass: Optional[float], t_r: Optional[float], - masks: Optional[str], + masks: Optional[list[str]], ) -> None: """Test TemporalFilter. @@ -79,7 +79,7 @@ def test_TemporalFilter( The parametrized high pass value. t_r : float or None The parametrized repetition time. - masks : str or None + masks : list of str or None The parametrized mask. """ diff --git a/junifer/preprocess/tests/test_temporal_slicer.py b/junifer/preprocess/tests/test_temporal_slicer.py index ed2be4ef4..d5add4acb 100644 --- a/junifer/preprocess/tests/test_temporal_slicer.py +++ b/junifer/preprocess/tests/test_temporal_slicer.py @@ -75,14 +75,6 @@ pytest.raises(RuntimeError, match="`stop` should be None"), ], [10.0, None, 30.0, None, 30, nullcontext()], - [ - -1.0, - None, - None, - None, - 84, - pytest.raises(ValueError, match="`start` cannot be negative"), - ], [ 0.0, 500.0, diff --git a/junifer/preprocess/warping/tests/test_space_warper.py b/junifer/preprocess/warping/tests/test_space_warper.py index d301a091a..bcadd8374 100644 --- a/junifer/preprocess/warping/tests/test_space_warper.py +++ b/junifer/preprocess/warping/tests/test_space_warper.py @@ -8,10 +8,20 @@ import pytest from numpy.testing import assert_array_equal, assert_raises -from junifer.datagrabber import DataladHCP1200, DMCC13Benchmark +from junifer.datagrabber import ( + DataladHCP1200, + DataType, + DMCC13Benchmark, + DMCCPhaseEncoding, + DMCCRun, + DMCCSession, + DMCCTask, + HCP1200PhaseEncoding, + HCP1200Task, +) from junifer.datareader import DefaultDataReader from junifer.pipeline.utils import _check_ants, _check_fsl -from junifer.preprocess import SpaceWarper +from junifer.preprocess import SpaceWarper, SpaceWarpingImpl from junifer.testing.datagrabbers import PartlyCloudyTestingDataGrabber from junifer.typing import DataGrabberLike @@ -19,14 +29,12 @@ @pytest.mark.parametrize( "using, reference, error_type, error_msg", [ - ("jam", "T1w", ValueError, "`using`"), - ("ants", "juice", ValueError, "reference"), - ("ants", "MNI152NLin2009cAsym", RuntimeError, "remove"), - ("fsl", "MNI152NLin2009cAsym", RuntimeError, "ANTs"), + (SpaceWarpingImpl.ants, "MNI152NLin2009cAsym", RuntimeError, "remove"), + (SpaceWarpingImpl.fsl, "MNI152NLin2009cAsym", RuntimeError, "ANTs"), ], ) def test_SpaceWarper_errors( - using: str, + using: SpaceWarpingImpl, reference: str, error_type: type[Exception], error_msg: str, @@ -35,7 +43,7 @@ def test_SpaceWarper_errors( Parameters ---------- - using : str + using : SpaceWarpingImpl The parametrized implementation method. reference : str The parametrized reference to use. @@ -46,14 +54,12 @@ def test_SpaceWarper_errors( """ with PartlyCloudyTestingDataGrabber() as dg: - # Read data element_data = DefaultDataReader().fit_transform(dg["sub-01"]) - # Preprocess data with pytest.raises(error_type, match=error_msg): SpaceWarper( using=using, reference=reference, - on="BOLD", + on=[DataType.BOLD], ).preprocess( input=element_data["BOLD"], extra_input=element_data, @@ -65,24 +71,24 @@ def test_SpaceWarper_errors( [ [ DMCC13Benchmark( - types=["BOLD", "T1w", "Warp"], - sessions=["ses-wave1bas"], - tasks=["Rest"], - phase_encodings=["AP"], - runs=["1"], + types=[DataType.BOLD, DataType.T1w, DataType.Warp], + sessions=[DMCCSession.Wave1Bas], + tasks=[DMCCTask.Rest], + phase_encodings=[DMCCPhaseEncoding.AP], + runs=[DMCCRun.One], native_t1w=True, ), ("sub-f9057kp", "ses-wave1bas", "Rest", "AP", "1"), - "ants", + SpaceWarpingImpl.ants, ], [ DataladHCP1200( - tasks=["REST1"], - phase_encodings=["LR"], + tasks=[HCP1200Task.REST1], + phase_encodings=[HCP1200PhaseEncoding.LR], ica_fix=True, ), ("100206", "REST1", "LR"), - "fsl", + SpaceWarpingImpl.fsl, ], ], ) @@ -95,7 +101,9 @@ def test_SpaceWarper_errors( reason="only for juseless", ) def test_SpaceWarper_native( - datagrabber: DataGrabberLike, element: tuple[str, ...], using: str + datagrabber: DataGrabberLike, + element: tuple[str, ...], + using: SpaceWarpingImpl, ) -> None: """Test SpaceWarper for native space warping. @@ -105,23 +113,20 @@ def test_SpaceWarper_native( The parametrized DataGrabber objects. element : tuple of str The parametrized elements. - using : str + using : SpaceWarpingImpl The parametrized implementation method. """ with datagrabber as dg: - # Read data element_data = DefaultDataReader().fit_transform(dg[element]) - # Preprocess data output = SpaceWarper( using=using, reference="T1w", - on="BOLD", + on=[DataType.BOLD], ).preprocess( input=element_data["BOLD"], extra_input=element_data, ) - # Check assert isinstance(output, dict) @@ -130,11 +135,11 @@ def test_SpaceWarper_native( [ [ DMCC13Benchmark( - types=["T1w"], - sessions=["ses-wave1bas"], - tasks=["Rest"], - phase_encodings=["AP"], - runs=["1"], + types=[DataType.T1w], + sessions=[DMCCSession.Wave1Bas], + tasks=[DMCCTask.Rest], + phase_encodings=[DMCCPhaseEncoding.AP], + runs=[DMCCRun.One], native_t1w=False, ), ("sub-f9057kp", "ses-wave1bas", "Rest", "AP", "1"), @@ -142,11 +147,11 @@ def test_SpaceWarper_native( ], [ DMCC13Benchmark( - types=["T1w"], - sessions=["ses-wave1bas"], - tasks=["Rest"], - phase_encodings=["AP"], - runs=["1"], + types=[DataType.T1w], + sessions=[DMCCSession.Wave1Bas], + tasks=[DMCCTask.Rest], + phase_encodings=[DMCCPhaseEncoding.AP], + runs=[DMCCRun.One], native_t1w=False, ), ("sub-f9057kp", "ses-wave1bas", "Rest", "AP", "1"), @@ -175,19 +180,16 @@ def test_SpaceWarper_multi_mni( """ with datagrabber as dg: - # Read data element_data = DefaultDataReader().fit_transform(dg[element]) pre_xfm_data = element_data["T1w"]["data"].get_fdata().copy() - # Preprocess data output = SpaceWarper( - using="ants", + using=SpaceWarpingImpl.ants, reference=space, - on=["T1w"], + on=[DataType.T1w], ).preprocess( input=element_data["T1w"], extra_input=element_data, ) - # Checks assert isinstance(output, dict) assert output["space"] == space with assert_raises(AssertionError): From 26a08ce1d3150d30ae52b4bcfbcb283ee64eecdc Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Fri, 14 Nov 2025 11:08:43 +0100 Subject: [PATCH 47/99] update: fix tests for markers --- .../complexity/tests/test_hurst_exponent.py | 5 +- .../tests/test_multiscale_entropy_auc.py | 5 +- .../complexity/tests/test_perm_entropy.py | 5 +- .../complexity/tests/test_range_entropy.py | 5 +- .../tests/test_range_entropy_auc.py | 5 +- .../complexity/tests/test_sample_entropy.py | 5 +- .../tests/test_weighted_perm_entropy.py | 5 +- .../markers/falff/tests/test_falff_maps.py | 22 +-- .../markers/falff/tests/test_falff_parcels.py | 23 ++-- .../markers/falff/tests/test_falff_spheres.py | 21 +-- ...ossparcellation_functional_connectivity.py | 12 +- .../test_edge_functional_connectivity_maps.py | 4 +- ...st_edge_functional_connectivity_parcels.py | 9 +- ...st_edge_functional_connectivity_spheres.py | 7 +- .../test_functional_connectivity_maps.py | 4 +- .../test_functional_connectivity_parcels.py | 11 +- .../test_functional_connectivity_spheres.py | 15 +-- junifer/markers/reho/tests/test_reho_maps.py | 18 +-- .../markers/reho/tests/test_reho_parcels.py | 23 ++-- .../markers/reho/tests/test_reho_spheres.py | 19 +-- .../tests/test_temporal_snr_maps.py | 7 +- .../tests/test_temporal_snr_parcels.py | 10 +- .../tests/test_temporal_snr_spheres.py | 14 +- junifer/markers/tests/test_brainprint.py | 19 +-- junifer/markers/tests/test_ets_rss.py | 6 +- .../markers/tests/test_maps_aggregation.py | 70 +++++----- junifer/markers/tests/test_markers_base.py | 12 +- .../markers/tests/test_parcel_aggregation.py | 125 +++++++++--------- .../markers/tests/test_sphere_aggregation.py | 75 ++++++----- 29 files changed, 290 insertions(+), 271 deletions(-) diff --git a/junifer/markers/complexity/tests/test_hurst_exponent.py b/junifer/markers/complexity/tests/test_hurst_exponent.py index 4e6063517..a68e2a943 100644 --- a/junifer/markers/complexity/tests/test_hurst_exponent.py +++ b/junifer/markers/complexity/tests/test_hurst_exponent.py @@ -12,6 +12,7 @@ pytest.importorskip("neurokit2") +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.complexity import HurstExponent from junifer.pipeline.utils import _check_ants @@ -22,7 +23,7 @@ # Set parcellation -PARCELLATION = "Schaefer100x17" +PARCELLATION = ["Schaefer100x17"] @pytest.mark.skipif( @@ -46,7 +47,7 @@ def test_compute() -> None: def test_storage_type() -> None: """Test HurstExponent storage_type.""" assert "vector" == HurstExponent(parcellation=PARCELLATION).storage_type( - input_type="BOLD", output_feature="complexity" + input_type=DataType.BOLD, output_feature="complexity" ) diff --git a/junifer/markers/complexity/tests/test_multiscale_entropy_auc.py b/junifer/markers/complexity/tests/test_multiscale_entropy_auc.py index de65c5a82..1fffbab1c 100644 --- a/junifer/markers/complexity/tests/test_multiscale_entropy_auc.py +++ b/junifer/markers/complexity/tests/test_multiscale_entropy_auc.py @@ -11,6 +11,7 @@ pytest.importorskip("neurokit2") +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.complexity import MultiscaleEntropyAUC from junifer.pipeline.utils import _check_ants @@ -21,7 +22,7 @@ # Set parcellation -PARCELLATION = "Schaefer100x17" +PARCELLATION = ["Schaefer100x17"] @pytest.mark.skipif( @@ -46,7 +47,7 @@ def test_storage_type() -> None: """Test MultiscaleEntropyAUC storage_type.""" assert "vector" == MultiscaleEntropyAUC( parcellation=PARCELLATION - ).storage_type(input_type="BOLD", output_feature="complexity") + ).storage_type(input_type=DataType.BOLD, output_feature="complexity") @pytest.mark.skipif( diff --git a/junifer/markers/complexity/tests/test_perm_entropy.py b/junifer/markers/complexity/tests/test_perm_entropy.py index ac0db77b2..551e8ebba 100644 --- a/junifer/markers/complexity/tests/test_perm_entropy.py +++ b/junifer/markers/complexity/tests/test_perm_entropy.py @@ -11,6 +11,7 @@ pytest.importorskip("neurokit2") +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.complexity import PermEntropy from junifer.pipeline.utils import _check_ants @@ -21,7 +22,7 @@ # Set parcellation -PARCELLATION = "Schaefer100x17" +PARCELLATION = ["Schaefer100x17"] @pytest.mark.skipif( @@ -45,7 +46,7 @@ def test_compute() -> None: def test_storage_type() -> None: """Test PermEntropy storage_type.""" assert "vector" == PermEntropy(parcellation=PARCELLATION).storage_type( - input_type="BOLD", output_feature="complexity" + input_type=DataType.BOLD, output_feature="complexity" ) diff --git a/junifer/markers/complexity/tests/test_range_entropy.py b/junifer/markers/complexity/tests/test_range_entropy.py index ba477fffe..c226df729 100644 --- a/junifer/markers/complexity/tests/test_range_entropy.py +++ b/junifer/markers/complexity/tests/test_range_entropy.py @@ -12,6 +12,7 @@ pytest.importorskip("neurokit2") +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.complexity import RangeEntropy from junifer.pipeline.utils import _check_ants @@ -22,7 +23,7 @@ # Set parcellation -PARCELLATION = "Schaefer100x17" +PARCELLATION = ["Schaefer100x17"] @pytest.mark.skipif( @@ -46,7 +47,7 @@ def test_compute() -> None: def test_storage_type() -> None: """Test RangeEntropy storage_type.""" assert "vector" == RangeEntropy(parcellation=PARCELLATION).storage_type( - input_type="BOLD", output_feature="complexity" + input_type=DataType.BOLD, output_feature="complexity" ) diff --git a/junifer/markers/complexity/tests/test_range_entropy_auc.py b/junifer/markers/complexity/tests/test_range_entropy_auc.py index 4cfac4720..5465379b0 100644 --- a/junifer/markers/complexity/tests/test_range_entropy_auc.py +++ b/junifer/markers/complexity/tests/test_range_entropy_auc.py @@ -12,6 +12,7 @@ pytest.importorskip("neurokit2") +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.complexity import RangeEntropyAUC from junifer.pipeline.utils import _check_ants @@ -22,7 +23,7 @@ # Set parcellation -PARCELLATION = "Schaefer100x17" +PARCELLATION = ["Schaefer100x17"] @pytest.mark.skipif( @@ -46,7 +47,7 @@ def test_compute() -> None: def test_storage_type() -> None: """Test RangeEntropyAUC storage_type.""" assert "vector" == RangeEntropyAUC(parcellation=PARCELLATION).storage_type( - input_type="BOLD", output_feature="complexity" + input_type=DataType.BOLD, output_feature="complexity" ) diff --git a/junifer/markers/complexity/tests/test_sample_entropy.py b/junifer/markers/complexity/tests/test_sample_entropy.py index fc79cad8c..f54f30ad7 100644 --- a/junifer/markers/complexity/tests/test_sample_entropy.py +++ b/junifer/markers/complexity/tests/test_sample_entropy.py @@ -11,6 +11,7 @@ pytest.importorskip("neurokit2") +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.complexity import SampleEntropy from junifer.pipeline.utils import _check_ants @@ -21,7 +22,7 @@ # Set parcellation -PARCELLATION = "Schaefer100x17" +PARCELLATION = ["Schaefer100x17"] @pytest.mark.skipif( @@ -45,7 +46,7 @@ def test_compute() -> None: def test_storage_type() -> None: """Test SampleEntropy storage_type.""" assert "vector" == SampleEntropy(parcellation=PARCELLATION).storage_type( - input_type="BOLD", output_feature="complexity" + input_type=DataType.BOLD, output_feature="complexity" ) diff --git a/junifer/markers/complexity/tests/test_weighted_perm_entropy.py b/junifer/markers/complexity/tests/test_weighted_perm_entropy.py index 485df7f4e..56f56a3da 100644 --- a/junifer/markers/complexity/tests/test_weighted_perm_entropy.py +++ b/junifer/markers/complexity/tests/test_weighted_perm_entropy.py @@ -11,6 +11,7 @@ pytest.importorskip("neurokit2") +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.complexity import WeightedPermEntropy from junifer.pipeline.utils import _check_ants @@ -21,7 +22,7 @@ # Set parcellation -PARCELLATION = "Schaefer100x17" +PARCELLATION = ["Schaefer100x17"] @pytest.mark.skipif( @@ -46,7 +47,7 @@ def test_storage_type() -> None: """Test WeightedPermEntropy storage_type.""" assert "vector" == WeightedPermEntropy( parcellation=PARCELLATION - ).storage_type(input_type="BOLD", output_feature="complexity") + ).storage_type(input_type=DataType.BOLD, output_feature="complexity") @pytest.mark.skipif( diff --git a/junifer/markers/falff/tests/test_falff_maps.py b/junifer/markers/falff/tests/test_falff_maps.py index 12e2c6b9b..42be8d0c2 100644 --- a/junifer/markers/falff/tests/test_falff_maps.py +++ b/junifer/markers/falff/tests/test_falff_maps.py @@ -7,11 +7,11 @@ from pathlib import Path import pytest -import scipy as sp +import scipy.stats as sps -from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datagrabber import DataType, PatternDataladDataGrabber from junifer.datareader import DefaultDataReader -from junifer.markers import ALFFMaps +from junifer.markers import ALFFImpl, ALFFMaps from junifer.pipeline import WorkDirManager from junifer.pipeline.utils import _check_afni from junifer.storage import HDF5FeatureStorage @@ -38,8 +38,8 @@ def test_ALFFMaps_storage_type(feature: str) -> None: """ assert "vector" == ALFFMaps( maps=MAPS, - using="junifer", - ).storage_type(input_type="BOLD", output_feature=feature) + using=ALFFImpl.junifer, + ).storage_type(input_type=DataType.BOLD, output_feature=feature) def test_ALFFMaps( @@ -70,12 +70,12 @@ def test_ALFFMaps( # Initialize marker marker = ALFFMaps( maps=MAPS, - using="junifer", + using=ALFFImpl.junifer, ) # Check correct output for name in ["alff", "falff"]: assert "vector" == marker.storage_type( - input_type="BOLD", output_feature=name + input_type=DataType.BOLD, output_feature=name ) # Fit transform marker on data @@ -99,7 +99,7 @@ def test_ALFFMaps( # Reset log capture caplog.clear() # Initialize storage - storage = HDF5FeatureStorage(tmp_path / "falff_maps.hdf5") + storage = HDF5FeatureStorage(uri=tmp_path / "falff_maps.hdf5") # Fit transform marker on data with storage marker.fit_transform( input=element_data, @@ -141,7 +141,7 @@ def test_ALFFMaps_comparison( # Initialize marker junifer_marker = ALFFMaps( maps=MAPS, - using="junifer", + using=ALFFImpl.junifer, ) # Fit transform marker on data junifer_output = junifer_marker.fit_transform(element_data) @@ -151,7 +151,7 @@ def test_ALFFMaps_comparison( # Initialize marker afni_marker = ALFFMaps( maps=MAPS, - using="afni", + using=ALFFImpl.afni, ) # Fit transform marker on data afni_output = afni_marker.fit_transform(element_data) @@ -160,7 +160,7 @@ def test_ALFFMaps_comparison( for feature in afni_output_bold.keys(): # Check for Pearson correlation coefficient - r, _ = sp.stats.pearsonr( + r, _ = sps.pearsonr( junifer_output_bold[feature]["data"][0], afni_output_bold[feature]["data"][0], ) diff --git a/junifer/markers/falff/tests/test_falff_parcels.py b/junifer/markers/falff/tests/test_falff_parcels.py index c9a55b756..5897c8b67 100644 --- a/junifer/markers/falff/tests/test_falff_parcels.py +++ b/junifer/markers/falff/tests/test_falff_parcels.py @@ -8,17 +8,18 @@ from pathlib import Path import pytest -import scipy as sp +import scipy.stats as sps +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader -from junifer.markers.falff import ALFFParcels +from junifer.markers import ALFFImpl, ALFFParcels from junifer.pipeline import WorkDirManager from junifer.pipeline.utils import _check_afni from junifer.storage import SQLiteFeatureStorage from junifer.testing.datagrabbers import PartlyCloudyTestingDataGrabber -PARCELLATION = "TianxS1x3TxMNInonlinear2009cAsym" +PARCELLATION = ["TianxS1x3TxMNInonlinear2009cAsym"] @pytest.mark.parametrize( @@ -39,8 +40,8 @@ def test_ALFFParcels_storage_type(feature: str) -> None: """ assert "vector" == ALFFParcels( parcellation=PARCELLATION, - using="junifer", - ).storage_type(input_type="BOLD", output_feature=feature) + using=ALFFImpl.junifer, + ).storage_type(input_type=DataType.BOLD, output_feature=feature) def test_ALFFParcels(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: @@ -63,7 +64,7 @@ def test_ALFFParcels(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: # Initialize marker marker = ALFFParcels( parcellation=PARCELLATION, - using="junifer", + using=ALFFImpl.junifer, ) # Fit transform marker on data output = marker.fit_transform(element_data) @@ -86,7 +87,9 @@ def test_ALFFParcels(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: # Reset log capture caplog.clear() # Initialize storage - storage = SQLiteFeatureStorage(tmp_path / "falff_parcels.sqlite") + storage = SQLiteFeatureStorage( + uri=tmp_path / "falff_parcels.sqlite" + ) # Fit transform marker on data with storage marker.fit_transform( input=element_data, @@ -116,7 +119,7 @@ def test_ALFFParcels_comparison(tmp_path: Path) -> None: # Initialize marker junifer_marker = ALFFParcels( parcellation=PARCELLATION, - using="junifer", + using=ALFFImpl.junifer, ) # Fit transform marker on data junifer_output = junifer_marker.fit_transform(element_data) @@ -126,7 +129,7 @@ def test_ALFFParcels_comparison(tmp_path: Path) -> None: # Initialize marker afni_marker = ALFFParcels( parcellation=PARCELLATION, - using="afni", + using=ALFFImpl.afni, ) # Fit transform marker on data afni_output = afni_marker.fit_transform(element_data) @@ -135,7 +138,7 @@ def test_ALFFParcels_comparison(tmp_path: Path) -> None: for feature in afni_output_bold.keys(): # Check for Pearson correlation coefficient - r, _ = sp.stats.pearsonr( + r, _ = sps.pearsonr( junifer_output_bold[feature]["data"][0], afni_output_bold[feature]["data"][0], ) diff --git a/junifer/markers/falff/tests/test_falff_spheres.py b/junifer/markers/falff/tests/test_falff_spheres.py index 4a254448e..c6183c9ba 100644 --- a/junifer/markers/falff/tests/test_falff_spheres.py +++ b/junifer/markers/falff/tests/test_falff_spheres.py @@ -8,10 +8,11 @@ from pathlib import Path import pytest -import scipy as sp +import scipy.stats as sps +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader -from junifer.markers.falff import ALFFSpheres +from junifer.markers import ALFFImpl, ALFFSpheres from junifer.pipeline import WorkDirManager from junifer.pipeline.utils import _check_afni from junifer.storage import SQLiteFeatureStorage @@ -39,8 +40,8 @@ def test_ALFFSpheres_storage_type(feature: str) -> None: """ assert "vector" == ALFFSpheres( coords=COORDINATES, - using="junifer", - ).storage_type(input_type="BOLD", output_feature=feature) + using=ALFFImpl.junifer, + ).storage_type(input_type=DataType.BOLD, output_feature=feature) def test_ALFFSpheres(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: @@ -63,7 +64,7 @@ def test_ALFFSpheres(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: # Initialize marker marker = ALFFSpheres( coords=COORDINATES, - using="junifer", + using=ALFFImpl.junifer, radius=5.0, ) # Fit transform marker on data @@ -89,7 +90,9 @@ def test_ALFFSpheres(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: # Reset log capture caplog.clear() # Initialize storage - storage = SQLiteFeatureStorage(tmp_path / "falff_spheres.sqlite") + storage = SQLiteFeatureStorage( + uri=tmp_path / "falff_spheres.sqlite" + ) # Fit transform marker on data with storage marker.fit_transform( input=element_data, @@ -119,7 +122,7 @@ def test_ALFFSpheres_comparison(tmp_path: Path) -> None: # Initialize marker junifer_marker = ALFFSpheres( coords=COORDINATES, - using="junifer", + using=ALFFImpl.junifer, radius=5.0, ) # Fit transform marker on data @@ -130,7 +133,7 @@ def test_ALFFSpheres_comparison(tmp_path: Path) -> None: # Initialize marker afni_marker = ALFFSpheres( coords=COORDINATES, - using="afni", + using=ALFFImpl.afni, radius=5.0, ) # Fit transform marker on data @@ -140,7 +143,7 @@ def test_ALFFSpheres_comparison(tmp_path: Path) -> None: for feature in afni_output_bold.keys(): # Check for Pearson correlation coefficient - r, _ = sp.stats.pearsonr( + r, _ = sps.pearsonr( junifer_output_bold[feature]["data"][0], afni_output_bold[feature]["data"][0], ) diff --git a/junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py b/junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py index 19584a731..c86f3e671 100644 --- a/junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py +++ b/junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py @@ -9,11 +9,12 @@ import pytest +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader -from junifer.markers.functional_connectivity import CrossParcellationFC +from junifer.markers import CrossParcellationFC from junifer.pipeline import WorkDirManager from junifer.pipeline.utils import _check_ants -from junifer.storage import SQLiteFeatureStorage +from junifer.storage import SQLiteFeatureStorage, Upsert from junifer.testing.datagrabbers import SPMAuditoryTestingDataGrabber @@ -35,7 +36,9 @@ def test_storage_type() -> None: """Test CrossParcellationFC storage_type.""" assert "matrix" == CrossParcellationFC( parcellation_one=parcellation_one, parcellation_two=parcellation_two - ).storage_type(input_type="BOLD", output_feature="functional_connectivity") + ).storage_type( + input_type=DataType.BOLD, output_feature="functional_connectivity" + ) @pytest.mark.skipif( @@ -87,7 +90,8 @@ def test_store(tmp_path: Path) -> None: corr_method="spearman", ) storage = SQLiteFeatureStorage( - uri=tmp_path / "test_crossparcellation.sqlite", upsert="ignore" + uri=tmp_path / "test_crossparcellation.sqlite", + upsert=Upsert.Ignore, ) # Fit transform marker on data with storage crossparcellation.fit_transform(input=element_data, storage=storage) diff --git a/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_maps.py b/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_maps.py index 84b763925..ffa70f338 100644 --- a/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_maps.py +++ b/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_maps.py @@ -7,7 +7,7 @@ import pytest -from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datagrabber import DataType, PatternDataladDataGrabber from junifer.datareader import DefaultDataReader from junifer.markers import EdgeCentricFCMaps from junifer.storage import HDF5FeatureStorage @@ -48,7 +48,7 @@ def test_EdgeCentricFCMaps( ) # Check correct output assert "matrix" == marker.storage_type( - input_type="BOLD", output_feature="functional_connectivity" + input_type=DataType.BOLD, output_feature="functional_connectivity" ) # Fit-transform the data diff --git a/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py b/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py index e6083b960..779b736fe 100644 --- a/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +++ b/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py @@ -8,9 +8,10 @@ import pytest +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.functional_connectivity import EdgeCentricFCParcels -from junifer.storage import SQLiteFeatureStorage +from junifer.storage import SQLiteFeatureStorage, Upsert from junifer.testing.datagrabbers import PartlyCloudyTestingDataGrabber @@ -40,13 +41,13 @@ def test_EdgeCentricFCParcels( element_data = DefaultDataReader().fit_transform(dg["sub-01"]) # Setup marker marker = EdgeCentricFCParcels( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], conn_method="correlation", conn_method_params=conn_method_params, ) # Check correct output assert "matrix" == marker.storage_type( - input_type="BOLD", output_feature="functional_connectivity" + input_type=DataType.BOLD, output_feature="functional_connectivity" ) # Fit-transform the data @@ -64,7 +65,7 @@ def test_EdgeCentricFCParcels( # Store storage = SQLiteFeatureStorage( - uri=tmp_path / "test_edge_fc_parcels.sqlite", upsert="ignore" + uri=tmp_path / "test_edge_fc_parcels.sqlite", upsert=Upsert.Ignore ) marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() diff --git a/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py b/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py index db9f76b24..9cabc0b62 100644 --- a/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +++ b/junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py @@ -8,9 +8,10 @@ import pytest +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.functional_connectivity import EdgeCentricFCSpheres -from junifer.storage import SQLiteFeatureStorage +from junifer.storage import SQLiteFeatureStorage, Upsert from junifer.testing.datagrabbers import SPMAuditoryTestingDataGrabber @@ -47,7 +48,7 @@ def test_EdgeCentricFCSpheres( ) # Check correct output assert "matrix" == marker.storage_type( - input_type="BOLD", output_feature="functional_connectivity" + input_type=DataType.BOLD, output_feature="functional_connectivity" ) # Fit-transform the data @@ -66,7 +67,7 @@ def test_EdgeCentricFCSpheres( # Store storage = SQLiteFeatureStorage( - uri=tmp_path / "test_edge_fc_spheres.sqlite", upsert="ignore" + uri=tmp_path / "test_edge_fc_spheres.sqlite", upsert=Upsert.Ignore ) marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() diff --git a/junifer/markers/functional_connectivity/tests/test_functional_connectivity_maps.py b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_maps.py index 281ad3210..e9dfd2808 100644 --- a/junifer/markers/functional_connectivity/tests/test_functional_connectivity_maps.py +++ b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_maps.py @@ -13,7 +13,7 @@ from sklearn.covariance import EmpiricalCovariance, LedoitWolf from junifer.data import MapsRegistry -from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datagrabber import DataType, PatternDataladDataGrabber from junifer.datareader import DefaultDataReader from junifer.markers import FunctionalConnectivityMaps from junifer.storage import HDF5FeatureStorage @@ -61,7 +61,7 @@ def test_FunctionalConnectivityMaps( ) # Check correct output assert "matrix" == marker.storage_type( - input_type="BOLD", output_feature="functional_connectivity" + input_type=DataType.BOLD, output_feature="functional_connectivity" ) # Fit-transform the data diff --git a/junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py index 8f495b460..584582031 100644 --- a/junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +++ b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py @@ -15,11 +15,12 @@ from sklearn.covariance import EmpiricalCovariance, LedoitWolf from junifer.data import ParcellationRegistry +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader -from junifer.markers.functional_connectivity import ( +from junifer.markers import ( FunctionalConnectivityParcels, ) -from junifer.storage import SQLiteFeatureStorage +from junifer.storage import SQLiteFeatureStorage, Upsert from junifer.testing.datagrabbers import PartlyCloudyTestingDataGrabber @@ -56,13 +57,13 @@ def test_FunctionalConnectivityParcels( element_data = DefaultDataReader().fit_transform(dg["sub-01"]) # Setup marker marker = FunctionalConnectivityParcels( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], conn_method="correlation", conn_method_params=conn_method_params, ) # Check correct output assert "matrix" == marker.storage_type( - input_type="BOLD", output_feature="functional_connectivity" + input_type=DataType.BOLD, output_feature="functional_connectivity" ) # Fit-transform the data @@ -102,7 +103,7 @@ def test_FunctionalConnectivityParcels( # Store storage = SQLiteFeatureStorage( - uri=tmp_path / "test_fc_parcels.sqlite", upsert="ignore" + uri=tmp_path / "test_fc_parcels.sqlite", upsert=Upsert.Ignore ) marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() diff --git a/junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py index f80ea1f3e..1ba98ae95 100644 --- a/junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +++ b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py @@ -16,11 +16,12 @@ from sklearn.covariance import EmpiricalCovariance, LedoitWolf from junifer.data import CoordinatesRegistry +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.functional_connectivity import ( FunctionalConnectivitySpheres, ) -from junifer.storage import SQLiteFeatureStorage +from junifer.storage import SQLiteFeatureStorage, Upsert from junifer.testing.datagrabbers import SPMAuditoryTestingDataGrabber @@ -64,7 +65,7 @@ def test_FunctionalConnectivitySpheres( ) # Check correct output assert "matrix" == marker.storage_type( - input_type="BOLD", output_feature="functional_connectivity" + input_type=DataType.BOLD, output_feature="functional_connectivity" ) # Fit-transform the data @@ -103,7 +104,7 @@ def test_FunctionalConnectivitySpheres( # Store storage = SQLiteFeatureStorage( - uri=tmp_path / "test_fc_spheres.sqlite", upsert="ignore" + uri=tmp_path / "test_fc_spheres.sqlite", upsert=Upsert.Ignore ) marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() @@ -112,11 +113,3 @@ def test_FunctionalConnectivitySpheres( == "BOLD_FunctionalConnectivitySpheres_functional_connectivity" for x in features.values() ) - - -def test_FunctionalConnectivitySpheres_error() -> None: - """Test FunctionalConnectivitySpheres errors.""" - with pytest.raises(ValueError, match="radius should be > 0"): - FunctionalConnectivitySpheres( - coords="DMNBuckner", radius=-0.1, conn_method="correlation" - ) diff --git a/junifer/markers/reho/tests/test_reho_maps.py b/junifer/markers/reho/tests/test_reho_maps.py index 3dac74678..4cbb2ed5c 100644 --- a/junifer/markers/reho/tests/test_reho_maps.py +++ b/junifer/markers/reho/tests/test_reho_maps.py @@ -7,11 +7,11 @@ from pathlib import Path import pytest -import scipy as sp +import scipy.stats as sps -from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datagrabber import DataType, PatternDataladDataGrabber from junifer.datareader import DefaultDataReader -from junifer.markers import ReHoMaps +from junifer.markers import ReHoImpl, ReHoMaps from junifer.pipeline import WorkDirManager from junifer.pipeline.utils import _check_afni from junifer.storage import HDF5FeatureStorage @@ -45,11 +45,11 @@ def test_ReHoMaps( # Initialize marker marker = ReHoMaps( maps="Smith_rsn_10", - using="junifer", + using=ReHoImpl.junifer, ) # Check correct output assert "vector" == marker.storage_type( - input_type="BOLD", output_feature="reho" + input_type=DataType.BOLD, output_feature="reho" ) # Fit transform marker on data @@ -75,7 +75,7 @@ def test_ReHoMaps( # Reset log capture caplog.clear() # Initialize storage - storage = HDF5FeatureStorage(tmp_path / "reho_maps.hdf5") + storage = HDF5FeatureStorage(uri=tmp_path / "reho_maps.hdf5") # Fit transform marker on data with storage marker.fit_transform( input=element_data, @@ -114,21 +114,21 @@ def test_ReHoMaps_comparison( element_data = DefaultDataReader().fit_transform(element) # Initialize marker - junifer_marker = ReHoMaps(maps="Smith_rsn_10", using="junifer") + junifer_marker = ReHoMaps(maps="Smith_rsn_10", using=ReHoImpl.junifer) # Fit transform marker on data junifer_output = junifer_marker.fit_transform(element_data) # Get BOLD output junifer_output_bold = junifer_output["BOLD"]["reho"] # Initialize marker - afni_marker = ReHoMaps(maps="Smith_rsn_10", using="afni") + afni_marker = ReHoMaps(maps="Smith_rsn_10", using=ReHoImpl.afni) # Fit transform marker on data afni_output = afni_marker.fit_transform(element_data) # Get BOLD output afni_output_bold = afni_output["BOLD"]["reho"] # Check for Pearson correlation coefficient - r, _ = sp.stats.pearsonr( + r, _ = sps.pearsonr( junifer_output_bold["data"].flatten(), afni_output_bold["data"].flatten(), ) diff --git a/junifer/markers/reho/tests/test_reho_parcels.py b/junifer/markers/reho/tests/test_reho_parcels.py index db0bcf245..484b419ea 100644 --- a/junifer/markers/reho/tests/test_reho_parcels.py +++ b/junifer/markers/reho/tests/test_reho_parcels.py @@ -7,10 +7,11 @@ from pathlib import Path import pytest -import scipy as sp +import scipy.stats as sps +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader -from junifer.markers import ReHoParcels +from junifer.markers import ReHoImpl, ReHoParcels from junifer.pipeline import WorkDirManager from junifer.pipeline.utils import _check_afni, _check_ants from junifer.storage import SQLiteFeatureStorage @@ -39,12 +40,12 @@ def test_ReHoParcels(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: # Initialize marker marker = ReHoParcels( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", - using="junifer", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], + using=ReHoImpl.junifer, ) # Check correct output assert "vector" == marker.storage_type( - input_type="BOLD", output_feature="reho" + input_type=DataType.BOLD, output_feature="reho" ) # Fit transform marker on data @@ -70,7 +71,9 @@ def test_ReHoParcels(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: # Reset log capture caplog.clear() # Initialize storage - storage = SQLiteFeatureStorage(tmp_path / "reho_parcels.sqlite") + storage = SQLiteFeatureStorage( + uri=tmp_path / "reho_parcels.sqlite" + ) # Fit transform marker on data with storage marker.fit_transform( input=element_data, @@ -102,7 +105,7 @@ def test_ReHoParcels_comparison(tmp_path: Path) -> None: # Initialize marker junifer_marker = ReHoParcels( - parcellation="Schaefer100x7", using="junifer" + parcellation=["Schaefer100x7"], using=ReHoImpl.junifer ) # Fit transform marker on data junifer_output = junifer_marker.fit_transform(element_data) @@ -110,14 +113,16 @@ def test_ReHoParcels_comparison(tmp_path: Path) -> None: junifer_output_bold = junifer_output["BOLD"]["reho"] # Initialize marker - afni_marker = ReHoParcels(parcellation="Schaefer100x7", using="afni") + afni_marker = ReHoParcels( + parcellation=["Schaefer100x7"], using=ReHoImpl.afni + ) # Fit transform marker on data afni_output = afni_marker.fit_transform(element_data) # Get BOLD output afni_output_bold = afni_output["BOLD"]["reho"] # Check for Pearson correlation coefficient - r, _ = sp.stats.pearsonr( + r, _ = sps.pearsonr( junifer_output_bold["data"].flatten(), afni_output_bold["data"].flatten(), ) diff --git a/junifer/markers/reho/tests/test_reho_spheres.py b/junifer/markers/reho/tests/test_reho_spheres.py index aeee29c68..166641df5 100644 --- a/junifer/markers/reho/tests/test_reho_spheres.py +++ b/junifer/markers/reho/tests/test_reho_spheres.py @@ -7,10 +7,11 @@ from pathlib import Path import pytest -import scipy as sp +import scipy.stats as sps +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader -from junifer.markers import ReHoSpheres +from junifer.markers import ReHoImpl, ReHoSpheres from junifer.pipeline import WorkDirManager from junifer.pipeline.utils import _check_afni from junifer.storage import SQLiteFeatureStorage @@ -38,11 +39,11 @@ def test_ReHoSpheres(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: WorkDirManager().workdir = tmp_path # Initialize marker marker = ReHoSpheres( - coords=COORDINATES, using="junifer", radius=10.0 + coords=COORDINATES, using=ReHoImpl.junifer, radius=10.0 ) # Check correct output assert "vector" == marker.storage_type( - input_type="BOLD", output_feature="reho" + input_type=DataType.BOLD, output_feature="reho" ) # Fit transform marker on data @@ -68,7 +69,9 @@ def test_ReHoSpheres(caplog: pytest.LogCaptureFixture, tmp_path: Path) -> None: # Reset log capture caplog.clear() # Initialize storage - storage = SQLiteFeatureStorage(tmp_path / "reho_spheres.sqlite") + storage = SQLiteFeatureStorage( + uri=tmp_path / "reho_spheres.sqlite" + ) # Fit transform marker on data with storage marker.fit_transform( input=element_data, @@ -98,7 +101,7 @@ def test_ReHoSpheres_comparison(tmp_path: Path) -> None: # Initialize marker junifer_marker = ReHoSpheres( coords=COORDINATES, - using="junifer", + using=ReHoImpl.junifer, radius=10.0, ) # Fit transform marker on data @@ -109,7 +112,7 @@ def test_ReHoSpheres_comparison(tmp_path: Path) -> None: # Initialize marker afni_marker = ReHoSpheres( coords=COORDINATES, - using="afni", + using=ReHoImpl.afni, radius=10.0, ) # Fit transform marker on data @@ -118,7 +121,7 @@ def test_ReHoSpheres_comparison(tmp_path: Path) -> None: afni_output_bold = afni_output["BOLD"]["reho"] # Check for Pearson correlation coefficient - r, _ = sp.stats.pearsonr( + r, _ = sps.pearsonr( junifer_output_bold["data"].flatten(), afni_output_bold["data"].flatten(), ) diff --git a/junifer/markers/temporal_snr/tests/test_temporal_snr_maps.py b/junifer/markers/temporal_snr/tests/test_temporal_snr_maps.py index 320467421..a4612e102 100644 --- a/junifer/markers/temporal_snr/tests/test_temporal_snr_maps.py +++ b/junifer/markers/temporal_snr/tests/test_temporal_snr_maps.py @@ -5,7 +5,7 @@ from pathlib import Path -from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datagrabber import DataType, PatternDataladDataGrabber from junifer.datareader import DefaultDataReader from junifer.markers import TemporalSNRMaps from junifer.storage import HDF5FeatureStorage @@ -28,9 +28,8 @@ def test_TemporalSNRMaps_computation( marker = TemporalSNRMaps(maps="Smith_rsn_10") # Check correct output assert "vector" == marker.storage_type( - input_type="BOLD", output_feature="tsnr" + input_type=DataType.BOLD, output_feature="tsnr" ) - # Fit-transform the data tsnr_parcels = marker.fit_transform(element_data) tsnr_parcels_bold = tsnr_parcels["BOLD"]["tsnr"] @@ -59,7 +58,7 @@ def test_TemporalSNRMaps_storage( element_data = DefaultDataReader().fit_transform(element) marker = TemporalSNRMaps(maps="Smith_rsn_10") # Store - storage = HDF5FeatureStorage(tmp_path / "test_tsnr_maps.hdf5") + storage = HDF5FeatureStorage(uri=tmp_path / "test_tsnr_maps.hdf5") marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() assert any( diff --git a/junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py b/junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py index 8cb535b05..9d7dd6cd6 100644 --- a/junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py +++ b/junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py @@ -6,6 +6,7 @@ from pathlib import Path +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.temporal_snr import TemporalSNRParcels from junifer.storage import HDF5FeatureStorage @@ -17,13 +18,12 @@ def test_TemporalSNRParcels_computation() -> None: with PartlyCloudyTestingDataGrabber() as dg: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) marker = TemporalSNRParcels( - parcellation="TianxS1x3TxMNInonlinear2009cAsym" + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"] ) # Check correct output assert "vector" == marker.storage_type( - input_type="BOLD", output_feature="tsnr" + input_type=DataType.BOLD, output_feature="tsnr" ) - # Fit-transform the data tsnr_parcels = marker.fit_transform(element_data) tsnr_parcels_bold = tsnr_parcels["BOLD"]["tsnr"] @@ -46,10 +46,10 @@ def test_TemporalSNRParcels_storage(tmp_path: Path) -> None: with PartlyCloudyTestingDataGrabber() as dg: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) marker = TemporalSNRParcels( - parcellation="TianxS1x3TxMNInonlinear2009cAsym" + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"] ) # Store - storage = HDF5FeatureStorage(tmp_path / "test_tsnr_parcels.hdf5") + storage = HDF5FeatureStorage(uri=tmp_path / "test_tsnr_parcels.hdf5") marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() assert any( diff --git a/junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py b/junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py index 81e871c6c..b5a102991 100644 --- a/junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py +++ b/junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py @@ -6,8 +6,7 @@ from pathlib import Path -import pytest - +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.temporal_snr import TemporalSNRSpheres from junifer.storage import HDF5FeatureStorage @@ -21,9 +20,8 @@ def test_TemporalSNRSpheres_computation() -> None: marker = TemporalSNRSpheres(coords="DMNBuckner", radius=5.0) # Check correct output assert "vector" == marker.storage_type( - input_type="BOLD", output_feature="tsnr" + input_type=DataType.BOLD, output_feature="tsnr" ) - # Fit-transform the data tsnr_spheres = marker.fit_transform(element_data) tsnr_spheres_bold = tsnr_spheres["BOLD"]["tsnr"] @@ -47,16 +45,10 @@ def test_TemporalSNRSpheres_storage(tmp_path: Path) -> None: element_data = DefaultDataReader().fit_transform(dg["sub001"]) marker = TemporalSNRSpheres(coords="DMNBuckner", radius=5.0) # Store - storage = HDF5FeatureStorage(tmp_path / "test_tsnr_spheres.hdf5") + storage = HDF5FeatureStorage(uri=tmp_path / "test_tsnr_spheres.hdf5") marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() assert any( x["name"] == "BOLD_TemporalSNRSpheres_tsnr" for x in features.values() ) - - -def test_TemporalSNRSpheres_error() -> None: - """Test TemporalSNRSpheres errors.""" - with pytest.raises(ValueError, match="radius should be > 0"): - TemporalSNRSpheres(coords="DMNBuckner", radius=-0.1) diff --git a/junifer/markers/tests/test_brainprint.py b/junifer/markers/tests/test_brainprint.py index 65ef387a2..28ab40d19 100644 --- a/junifer/markers/tests/test_brainprint.py +++ b/junifer/markers/tests/test_brainprint.py @@ -7,34 +7,35 @@ import pytest -from junifer.datagrabber import DataladAOMICID1000 +from junifer.datagrabber import DataladAOMICID1000, DataType from junifer.datareader import DefaultDataReader from junifer.markers import BrainPrint from junifer.pipeline.utils import _check_freesurfer +from junifer.storage import StorageType @pytest.mark.parametrize( "feature, storage_type", [ - ("eigenvalues", "scalar_table"), - ("areas", "vector"), - ("volumes", "vector"), - ("distances", "vector"), + ("eigenvalues", StorageType.ScalarTable), + ("areas", StorageType.Vector), + ("volumes", StorageType.Vector), + ("distances", StorageType.Vector), ], ) -def test_storage_type(feature: str, storage_type: str) -> None: +def test_storage_type(feature: str, storage_type: StorageType) -> None: """Test BrainPrint storage_type. Parameters ---------- feature : str The parametrized feature name. - storage_type : str + storage_type : StorageType The parametrized storage type. """ assert storage_type == BrainPrint().storage_type( - input_type="FreeSurfer", output_feature=feature + input_type=DataType.FreeSurfer, output_feature=feature ) @@ -47,7 +48,7 @@ def test_storage_type(feature: str, storage_type: str) -> None: ) def test_compute() -> None: """Test BrainPrint compute().""" - with DataladAOMICID1000(types="FreeSurfer") as dg: + with DataladAOMICID1000(types=[DataType.FreeSurfer]) as dg: # Fetch element element = dg["sub-0001"] # Fetch element data diff --git a/junifer/markers/tests/test_ets_rss.py b/junifer/markers/tests/test_ets_rss.py index 5f478cc60..070a5a43b 100644 --- a/junifer/markers/tests/test_ets_rss.py +++ b/junifer/markers/tests/test_ets_rss.py @@ -18,7 +18,7 @@ # Set parcellation -PARCELLATION = "TianxS1x3TxMNInonlinear2009cAsym" +PARCELLATION = ["TianxS1x3TxMNInonlinear2009cAsym"] def test_compute() -> None: @@ -33,7 +33,7 @@ def test_compute() -> None: # Compare with nilearn # Load testing parcellation test_parcellation, _ = ParcellationRegistry().get( - parcellations=[PARCELLATION], + parcellations=PARCELLATION, target_data=element_data["BOLD"], ) # Extract timeseries @@ -65,7 +65,7 @@ def test_store(tmp_path: Path) -> None: # Get element data element_data = DefaultDataReader().fit_transform(dg["sub-01"]) # Create storage - storage = SQLiteFeatureStorage(tmp_path / "test_rss_ets.sqlite") + storage = SQLiteFeatureStorage(uri=tmp_path / "test_rss_ets.sqlite") # Compute the RSSETSMarker and store _ = RSSETSMarker(parcellation=PARCELLATION).fit_transform( input=element_data, storage=storage diff --git a/junifer/markers/tests/test_maps_aggregation.py b/junifer/markers/tests/test_maps_aggregation.py index 10babfa30..85937970c 100644 --- a/junifer/markers/tests/test_maps_aggregation.py +++ b/junifer/markers/tests/test_maps_aggregation.py @@ -12,68 +12,68 @@ from junifer.data import MapsRegistry, MaskRegistry from junifer.data.masks import compute_brain_mask -from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datagrabber import DataType, PatternDataladDataGrabber from junifer.datareader import DefaultDataReader from junifer.markers import MapsAggregation -from junifer.storage import HDF5FeatureStorage +from junifer.storage import HDF5FeatureStorage, StorageType @pytest.mark.parametrize( "input_type, storage_type", [ ( - "T1w", - "vector", + DataType.T1w, + StorageType.Vector, ), ( - "T2w", - "vector", + DataType.T2w, + StorageType.Vector, ), ( - "BOLD", - "timeseries", + DataType.BOLD, + StorageType.Timeseries, ), ( - "VBM_GM", - "vector", + DataType.VBM_GM, + StorageType.Vector, ), ( - "VBM_WM", - "vector", + DataType.VBM_WM, + StorageType.Vector, ), ( - "VBM_CSF", - "vector", + DataType.VBM_CSF, + StorageType.Vector, ), ( - "fALFF", - "vector", + DataType.FALFF, + StorageType.Vector, ), ( - "GCOR", - "vector", + DataType.GCOR, + StorageType.Vector, ), ( - "LCOR", - "vector", + DataType.LCOR, + StorageType.Vector, ), ], ) def test_MapsAggregation_input_output( - input_type: str, storage_type: str + input_type: DataType, storage_type: StorageType ) -> None: """Test MapsAggregation input and output types. Parameters ---------- - input_type : str + input_type : DataType The parametrized input type. - storage_type : str + storage_type : StorageType The parametrized storage type. """ assert storage_type == MapsAggregation( - maps="Smith_rsn_10", on=input_type + maps="Smith_rsn_10", on=[input_type] ).storage_type(input_type=input_type, output_feature="aggregation") @@ -184,7 +184,7 @@ def test_MapsAggregation_storage( ) marker = MapsAggregation( maps="Smith_rsn_10", - on="BOLD", + on=[DataType.BOLD], ) element_data["BOLD"]["data"] = element_data["BOLD"]["data"].slicer[ ..., 0:1 @@ -205,7 +205,7 @@ def test_MapsAggregation_storage( ) marker = MapsAggregation( maps="Smith_rsn_10", - on="BOLD", + on=[DataType.BOLD], ) marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() @@ -238,8 +238,8 @@ def test_MapsAggregation_3D_mask( ] marker = MapsAggregation( maps="Smith_rsn_10", - on="BOLD", - masks="compute_brain_mask", + on=[DataType.BOLD], + masks=["compute_brain_mask"], ) maps_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -306,8 +306,8 @@ def test_MapsAggregation_3D_mask_computed( # Use the MapsAggregation object marker = MapsAggregation( maps="Smith_rsn_10", - masks={"compute_brain_mask": {"threshold": 0.2}}, - on="BOLD", + masks=[{"compute_brain_mask": {"threshold": 0.2}}], + on=[DataType.BOLD], ) maps_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -342,7 +342,7 @@ def test_MapsAggregation_4D_agg_time( marker = MapsAggregation( maps="Smith_rsn_10", time_method="mean", - on="BOLD", + on=[DataType.BOLD], ) maps_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -375,7 +375,7 @@ def test_MapsAggregation_4D_agg_time( maps="Smith_rsn_10", time_method="select", time_method_params={"pick": [0]}, - on="BOLD", + on=[DataType.BOLD], ) maps_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -395,7 +395,7 @@ def test_MapsAggregation_errors() -> None: maps="Smith_rsn_10", time_method="select", time_method_params={"pick": [0]}, - on="VBM_GM", + on=[DataType.VBM_GM], ) with pytest.raises( @@ -404,7 +404,7 @@ def test_MapsAggregation_errors() -> None: MapsAggregation( maps="Smith_rsn_10", time_method_params={"pick": [0]}, - on="VBM_GM", + on=[DataType.VBM_GM], ) @@ -422,7 +422,7 @@ def test_MapsAggregation_warning( maps="Smith_rsn_10", time_method="select", time_method_params={"pick": [0]}, - on="BOLD", + on=[DataType.BOLD], ) element_data["BOLD"]["data"] = element_data["BOLD"]["data"].slicer[ ..., 0:1 diff --git a/junifer/markers/tests/test_markers_base.py b/junifer/markers/tests/test_markers_base.py index b513bc003..79f2b72fa 100644 --- a/junifer/markers/tests/test_markers_base.py +++ b/junifer/markers/tests/test_markers_base.py @@ -6,7 +6,9 @@ import pytest -from junifer.markers.base import BaseMarker +from junifer.datagrabber import DataType +from junifer.markers import BaseMarker +from junifer.storage import StorageType def test_base_marker_abstractness() -> None: @@ -21,14 +23,12 @@ def test_base_marker_subclassing() -> None: # Create concrete class class MyBaseMarker(BaseMarker): _MARKER_INOUT_MAPPINGS = { # noqa: RUF012 - "BOLD": { - "feat_1": "timeseries", + DataType.BOLD: { + "feat_1": StorageType.Timeseries, }, } - def __init__(self, on, name=None) -> None: - self.parameter = 1 - super().__init__(on, name) + parameter: int = 1 def compute(self, input, extra_input): return { diff --git a/junifer/markers/tests/test_parcel_aggregation.py b/junifer/markers/tests/test_parcel_aggregation.py index a0ff8c378..2e4d71dd3 100644 --- a/junifer/markers/tests/test_parcel_aggregation.py +++ b/junifer/markers/tests/test_parcel_aggregation.py @@ -17,9 +17,10 @@ from scipy.stats import trim_mean from junifer.data import MaskRegistry, ParcellationRegistry +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.parcel_aggregation import ParcelAggregation -from junifer.storage import SQLiteFeatureStorage +from junifer.storage import SQLiteFeatureStorage, StorageType, Upsert from junifer.testing.datagrabbers import PartlyCloudyTestingDataGrabber @@ -27,58 +28,58 @@ "input_type, storage_type", [ ( - "T1w", - "vector", + DataType.T1w, + StorageType.Vector, ), ( - "T2w", - "vector", + DataType.T2w, + StorageType.Vector, ), ( - "BOLD", - "timeseries", + DataType.BOLD, + StorageType.Timeseries, ), ( - "VBM_GM", - "vector", + DataType.VBM_GM, + StorageType.Vector, ), ( - "VBM_WM", - "vector", + DataType.VBM_WM, + StorageType.Vector, ), ( - "VBM_CSF", - "vector", + DataType.VBM_CSF, + StorageType.Vector, ), ( - "fALFF", - "vector", + DataType.FALFF, + StorageType.Vector, ), ( - "GCOR", - "vector", + DataType.GCOR, + StorageType.Vector, ), ( - "LCOR", - "vector", + DataType.LCOR, + StorageType.Vector, ), ], ) def test_ParcelAggregation_input_output( - input_type: str, storage_type: str + input_type: DataType, storage_type: StorageType ) -> None: """Test ParcelAggregation input and output types. Parameters ---------- - input_type : str + input_type : DataType The parametrized input type. - storage_type : str + storage_type : StorageType The parametrized storage type. """ assert storage_type == ParcelAggregation( - parcellation="Schaefer100x7", method="mean", on=input_type + parcellation=["Schaefer100x7"], method="mean", on=[input_type] ).storage_type(input_type=input_type, output_feature="aggregation") @@ -88,9 +89,9 @@ def test_ParcelAggregation_3D() -> None: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) # Create ParcelAggregation object marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", - on="BOLD", + on=[DataType.BOLD], ) element_data["BOLD"]["data"] = element_data["BOLD"]["data"].slicer[ ..., 0:1 @@ -155,9 +156,9 @@ def test_ParcelAggregation_3D() -> None: # Create ParcelAggregation object marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="std", - on="BOLD", + on=[DataType.BOLD], ) parcel_agg_std_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -179,10 +180,10 @@ def test_ParcelAggregation_3D() -> None: # Create ParcelAggregation object marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="trim_mean", method_params={"proportiontocut": 0.1}, - on="BOLD", + on=[DataType.BOLD], ) parcel_agg_trim_mean_bold_data = marker.fit_transform(element_data)[ "BOLD" @@ -198,7 +199,7 @@ def test_ParcelAggregation_4D(): element_data = DefaultDataReader().fit_transform(dg["sub-01"]) # Create ParcelAggregation object marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", method="mean" + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean" ) parcel_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -238,12 +239,13 @@ def test_ParcelAggregation_storage(tmp_path: Path) -> None: with PartlyCloudyTestingDataGrabber() as dg: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) storage = SQLiteFeatureStorage( - uri=tmp_path / "test_parcel_storage_3D.sqlite", upsert="ignore" + uri=tmp_path / "test_parcel_storage_3D.sqlite", + upsert=Upsert.Ignore, ) marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", - on="BOLD", + on=[DataType.BOLD], ) element_data["BOLD"]["data"] = element_data["BOLD"]["data"].slicer[ ..., 0:1 @@ -259,12 +261,13 @@ def test_ParcelAggregation_storage(tmp_path: Path) -> None: with PartlyCloudyTestingDataGrabber() as dg: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) storage = SQLiteFeatureStorage( - uri=tmp_path / "test_parcel_storage_4D.sqlite", upsert="ignore" + uri=tmp_path / "test_parcel_storage_4D.sqlite", + upsert=Upsert.Ignore, ) marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", - on="BOLD", + on=[DataType.BOLD], ) marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() @@ -280,11 +283,11 @@ def test_ParcelAggregation_3D_mask() -> None: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) # Create ParcelAggregation object marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", name="tian_mean", - on="BOLD", - masks="compute_brain_mask", + on=[DataType.BOLD], + masks=["compute_brain_mask"], ) element_data["BOLD"]["data"] = element_data["BOLD"]["data"].slicer[ ..., 0:1 @@ -358,11 +361,11 @@ def test_ParcelAggregation_3D_mask_computed() -> None: # Use the ParcelAggregation object marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", - masks={"compute_brain_mask": {"threshold": 0.2}}, + masks=[{"compute_brain_mask": {"threshold": 0.2}}], name="tian_mean", - on="BOLD", + on=[DataType.BOLD], ) parcel_agg_mean_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -441,10 +444,10 @@ def test_ParcelAggregation_3D_multiple_non_overlapping(tmp_path: Path) -> None: # Use the ParcelAggregation object on the original parcellation marker_original = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", name="tian_mean", - on="BOLD", + on=[DataType.BOLD], ) orig_mean = marker_original.fit_transform(element_data)["BOLD"][ "aggregation" @@ -462,7 +465,7 @@ def test_ParcelAggregation_3D_multiple_non_overlapping(tmp_path: Path) -> None: ], method="mean", name="tian_mean", - on="BOLD", + on=[DataType.BOLD], ) # No warnings should be raised @@ -546,10 +549,10 @@ def test_ParcelAggregation_3D_multiple_overlapping(tmp_path: Path) -> None: # Use the ParcelAggregation object on the original parcellation marker_original = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", name="tian_mean", - on="BOLD", + on=[DataType.BOLD], ) orig_mean = marker_original.fit_transform(element_data)["BOLD"][ "aggregation" @@ -567,7 +570,7 @@ def test_ParcelAggregation_3D_multiple_overlapping(tmp_path: Path) -> None: ], method="mean", name="tian_mean", - on="BOLD", + on=[DataType.BOLD], ) # Warning should be raised with pytest.warns(RuntimeWarning, match="overlapping voxels"): @@ -655,10 +658,10 @@ def test_ParcelAggregation_3D_multiple_duplicated_labels( # Use the ParcelAggregation object on the original parcellation marker_original = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", name="tian_mean", - on="BOLD", + on=[DataType.BOLD], ) orig_mean = marker_original.fit_transform(element_data)["BOLD"][ "aggregation" @@ -676,7 +679,7 @@ def test_ParcelAggregation_3D_multiple_duplicated_labels( ], method="mean", name="tian_mean", - on="BOLD", + on=[DataType.BOLD], ) # Warning should be raised @@ -709,10 +712,10 @@ def test_ParcelAggregation_4D_agg_time(): element_data = DefaultDataReader().fit_transform(dg["sub-01"]) # Create ParcelAggregation object marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", time_method="mean", - on="BOLD", + on=[DataType.BOLD], ) parcel_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -744,11 +747,11 @@ def test_ParcelAggregation_4D_agg_time(): # Test picking first time point nifti_labels_masked_bold_pick_0 = nifti_labels_masked_bold[:1, :] marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", time_method="select", time_method_params={"pick": [0]}, - on="BOLD", + on=[DataType.BOLD], ) parcel_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -767,21 +770,21 @@ def test_ParcelAggregation_errors() -> None: """Test errors for ParcelAggregation.""" with pytest.raises(ValueError, match="can only be used with BOLD data"): ParcelAggregation( - parcellation="Schaefer100x7", + parcellation=["Schaefer100x7"], method="mean", time_method="select", time_method_params={"pick": [0]}, - on="VBM_GM", + on=[DataType.VBM_GM], ) with pytest.raises( ValueError, match="can only be used with `time_method`" ): ParcelAggregation( - parcellation="Schaefer100x7", + parcellation=["Schaefer100x7"], method="mean", time_method_params={"pick": [0]}, - on="VBM_GM", + on=[DataType.VBM_GM], ) @@ -793,11 +796,11 @@ def test_ParcelAggregation_warning() -> None: RuntimeWarning, match="No time dimension to aggregate" ): marker = ParcelAggregation( - parcellation="TianxS1x3TxMNInonlinear2009cAsym", + parcellation=["TianxS1x3TxMNInonlinear2009cAsym"], method="mean", time_method="select", time_method_params={"pick": [0]}, - on="BOLD", + on=[DataType.BOLD], ) element_data["BOLD"]["data"] = element_data["BOLD"]["data"].slicer[ ..., 0:1 diff --git a/junifer/markers/tests/test_sphere_aggregation.py b/junifer/markers/tests/test_sphere_aggregation.py index 2f41efe79..f965c8245 100644 --- a/junifer/markers/tests/test_sphere_aggregation.py +++ b/junifer/markers/tests/test_sphere_aggregation.py @@ -11,9 +11,10 @@ from numpy.testing import assert_array_equal from junifer.data import CoordinatesRegistry, MaskRegistry +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers.sphere_aggregation import SphereAggregation -from junifer.storage import SQLiteFeatureStorage +from junifer.storage import SQLiteFeatureStorage, StorageType, Upsert from junifer.testing.datagrabbers import ( OasisVBMTestingDataGrabber, SPMAuditoryTestingDataGrabber, @@ -29,60 +30,60 @@ "input_type, storage_type", [ ( - "T1w", - "vector", + DataType.T1w, + StorageType.Vector, ), ( - "T2w", - "vector", + DataType.T2w, + StorageType.Vector, ), ( - "BOLD", - "timeseries", + DataType.BOLD, + StorageType.Timeseries, ), ( - "VBM_GM", - "vector", + DataType.VBM_GM, + StorageType.Vector, ), ( - "VBM_WM", - "vector", + DataType.VBM_WM, + StorageType.Vector, ), ( - "VBM_CSF", - "vector", + DataType.VBM_CSF, + StorageType.Vector, ), ( - "fALFF", - "vector", + DataType.FALFF, + StorageType.Vector, ), ( - "GCOR", - "vector", + DataType.GCOR, + StorageType.Vector, ), ( - "LCOR", - "vector", + DataType.LCOR, + StorageType.Vector, ), ], ) def test_SphereAggregation_input_output( - input_type: str, storage_type: str + input_type: DataType, storage_type: StorageType ) -> None: """Test SphereAggregation input and output types. Parameters ---------- - input_type : str + input_type : DataType The parametrized input type. - storage_type : str + storage_type : StorageType The parametrized storage type. """ assert storage_type == SphereAggregation( coords="DMNBuckner", method="mean", - on=input_type, + on=[input_type], ).storage_type(input_type=input_type, output_feature="aggregation") @@ -92,7 +93,7 @@ def test_SphereAggregation_3D() -> None: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) # Create SphereAggregation object marker = SphereAggregation( - coords=COORDS, method="mean", radius=RADIUS, on="VBM_GM" + coords=COORDS, method="mean", radius=RADIUS, on=[DataType.VBM_GM] ) sphere_agg_vbm_gm_data = marker.fit_transform(element_data)["VBM_GM"][ "aggregation" @@ -124,7 +125,7 @@ def test_SphereAggregation_4D() -> None: element_data = DefaultDataReader().fit_transform(dg["sub001"]) # Create SphereAggregation object marker = SphereAggregation( - coords=COORDS, method="mean", radius=RADIUS, on="BOLD" + coords=COORDS, method="mean", radius=RADIUS, on=[DataType.BOLD] ) sphere_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -163,10 +164,11 @@ def test_SphereAggregation_storage(tmp_path: Path) -> None: with OasisVBMTestingDataGrabber() as dg: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) storage = SQLiteFeatureStorage( - uri=tmp_path / "test_sphere_storage_3D.sqlite", upsert="ignore" + uri=tmp_path / "test_sphere_storage_3D.sqlite", + upsert=Upsert.Ignore, ) marker = SphereAggregation( - coords=COORDS, method="mean", radius=RADIUS, on="VBM_GM" + coords=COORDS, method="mean", radius=RADIUS, on=[DataType.VBM_GM] ) marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() @@ -179,10 +181,11 @@ def test_SphereAggregation_storage(tmp_path: Path) -> None: with SPMAuditoryTestingDataGrabber() as dg: element_data = DefaultDataReader().fit_transform(dg["sub001"]) storage = SQLiteFeatureStorage( - uri=tmp_path / "test_sphere_storage_4D.sqlite", upsert="ignore" + uri=tmp_path / "test_sphere_storage_4D.sqlite", + upsert=Upsert.Ignore, ) marker = SphereAggregation( - coords=COORDS, method="mean", radius=RADIUS, on="BOLD" + coords=COORDS, method="mean", radius=RADIUS, on=[DataType.BOLD] ) marker.fit_transform(input=element_data, storage=storage) features = storage.list_features() @@ -201,8 +204,8 @@ def test_SphereAggregation_3D_mask() -> None: coords=COORDS, method="mean", radius=RADIUS, - on="VBM_GM", - masks="compute_brain_mask", + on=[DataType.VBM_GM], + masks=["compute_brain_mask"], ) sphere_agg_vbm_gm_data = marker.fit_transform(element_data)["VBM_GM"][ "aggregation" @@ -245,7 +248,7 @@ def test_SphereAggregation_4D_agg_time() -> None: method="mean", radius=RADIUS, time_method="mean", - on="BOLD", + on=[DataType.BOLD], ) sphere_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -281,7 +284,7 @@ def test_SphereAggregation_4D_agg_time() -> None: radius=RADIUS, time_method="select", time_method_params={"pick": [0]}, - on="BOLD", + on=[DataType.BOLD], ) sphere_agg_bold_data = marker.fit_transform(element_data)["BOLD"][ "aggregation" @@ -305,7 +308,7 @@ def test_SphereAggregation_errors() -> None: radius=RADIUS, time_method="pick", time_method_params={"pick": [0]}, - on="VBM_GM", + on=[DataType.VBM_GM], ) with pytest.raises( @@ -316,7 +319,7 @@ def test_SphereAggregation_errors() -> None: method="mean", radius=RADIUS, time_method_params={"pick": [0]}, - on="VBM_GM", + on=[DataType.VBM_GM], ) @@ -333,7 +336,7 @@ def test_SphereAggregation_warning() -> None: radius=RADIUS, time_method="select", time_method_params={"pick": [0]}, - on="BOLD", + on=[DataType.BOLD], ) element_data["BOLD"]["data"] = element_data["BOLD"]["data"].slicer[ ..., 0:1 From ec9cd4bd57eacc6f742d5474ad4cd488673cf446 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Fri, 14 Nov 2025 11:16:12 +0100 Subject: [PATCH 48/99] update: fix tests for onthefly --- junifer/onthefly/tests/test_read_transform.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/junifer/onthefly/tests/test_read_transform.py b/junifer/onthefly/tests/test_read_transform.py index 0a4d9bfb2..2dedafdb1 100644 --- a/junifer/onthefly/tests/test_read_transform.py +++ b/junifer/onthefly/tests/test_read_transform.py @@ -10,7 +10,7 @@ import pytest from junifer.onthefly import read_transform -from junifer.storage.hdf5 import HDF5FeatureStorage +from junifer.storage import HDF5FeatureStorage, StorageType @pytest.fixture @@ -23,9 +23,9 @@ def vector_storage(tmp_path: Path) -> HDF5FeatureStorage: The path to the test directory. """ - storage = HDF5FeatureStorage(tmp_path / "vector_store.hdf5") + storage = HDF5FeatureStorage(uri=tmp_path / "vector_store.hdf5") storage.store( - kind="vector", + kind=StorageType.Vector, meta={ "element": {"subject": "test"}, "dependencies": [], @@ -48,9 +48,9 @@ def matrix_storage(tmp_path: Path) -> HDF5FeatureStorage: The path to the test directory. """ - storage = HDF5FeatureStorage(tmp_path / "matrix_store.hdf5") + storage = HDF5FeatureStorage(uri=tmp_path / "matrix_store.hdf5") storage.store( - kind="matrix", + kind=StorageType.Matrix, meta={ "element": {"subject": "test"}, "dependencies": [], @@ -74,13 +74,13 @@ def matrix_storage_with_nan(tmp_path: Path) -> HDF5FeatureStorage: The path to the test directory. """ - storage = HDF5FeatureStorage(tmp_path / "matrix_store_nan.hdf5") + storage = HDF5FeatureStorage(uri=tmp_path / "matrix_store_nan.hdf5") data = np.arange(36).reshape(3, 3, 4).astype(float) data[1, 1, 2] = np.nan data[1, 2, 2] = np.nan for i in range(4): storage.store( - kind="matrix", + kind=StorageType.Matrix, meta={ "element": {"subject": f"test{i + 1}"}, "dependencies": [], From 585ffb62211900577a9fc338c3c444d6b367e5bf Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 09:21:45 +0100 Subject: [PATCH 49/99] update: fix tests for storage --- junifer/storage/tests/test_storage_base.py | 31 +++++++++------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/junifer/storage/tests/test_storage_base.py b/junifer/storage/tests/test_storage_base.py index 2fbe1b3b0..f30bdc180 100644 --- a/junifer/storage/tests/test_storage_base.py +++ b/junifer/storage/tests/test_storage_base.py @@ -5,17 +5,18 @@ # License: AGPL from collections.abc import Sequence +from pathlib import Path from typing import ClassVar import pytest -from junifer.storage.base import BaseFeatureStorage +from junifer.storage import BaseFeatureStorage, StorageType def test_BaseFeatureStorage_abstractness() -> None: """Test BaseFeatureStorage is abstract base class.""" with pytest.raises(TypeError, match=r"abstract"): - BaseFeatureStorage(uri="/tmp") + BaseFeatureStorage(uri=Path(".")) def test_BaseFeatureStorage() -> None: @@ -25,19 +26,13 @@ def test_BaseFeatureStorage() -> None: class MyFeatureStorage(BaseFeatureStorage): """Implement concrete class.""" - _STORAGE_TYPES: ClassVar[Sequence[str]] = [ - "matrix", - "vector", - "timeseries", - "timeseries_2d", + _STORAGE_TYPES: ClassVar[Sequence[StorageType]] = [ + StorageType.Matrix, + StorageType.Vector, + StorageType.Timeseries, + StorageType.Timeseries2D, ] - def __init__(self, uri, single_output=True): - super().__init__( - uri=uri, - single_output=single_output, - ) - def list_features(self): super().list_features() @@ -53,17 +48,17 @@ def read_df(self, feature_name=None, feature_md5=None): feature_md5=feature_md5, ) - def store_metadata(self, meta_md5, meta, element): + def store_metadata(self, meta_md5, element, meta): super().store_metadata(meta_md5, meta, element) def collect(self): - return super().collect() + super().collect() # Check single_output is False - st = MyFeatureStorage(uri="/tmp", single_output=False) + st = MyFeatureStorage(uri=Path("."), single_output=False) assert st.single_output is False # Check single_output is True - st = MyFeatureStorage(uri="/tmp", single_output=True) + st = MyFeatureStorage(uri=Path("."), single_output=True) assert st.single_output is True # Check validate with valid argument @@ -109,4 +104,4 @@ def collect(self): with pytest.raises(ValueError): st.store(kind="lego", meta=meta) - assert str(st.uri) == "/tmp" + assert str(st.uri) == "." From 3cd5ebd51350723b3c139d03ebb36d25f450ce0d Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 09:23:07 +0100 Subject: [PATCH 50/99] chore: update tests for PipelineStepMixin --- junifer/pipeline/tests/test_pipeline_step_mixin.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/junifer/pipeline/tests/test_pipeline_step_mixin.py b/junifer/pipeline/tests/test_pipeline_step_mixin.py index d9b6c208b..1b3b63084 100644 --- a/junifer/pipeline/tests/test_pipeline_step_mixin.py +++ b/junifer/pipeline/tests/test_pipeline_step_mixin.py @@ -9,7 +9,7 @@ import pytest -from junifer.pipeline.pipeline_step_mixin import PipelineStepMixin +from junifer.pipeline import ExtDep, PipelineStepMixin from junifer.pipeline.utils import _check_afni from junifer.typing import ( ConditionalDependencies, @@ -64,7 +64,9 @@ def test_PipelineStepMixin_correct_ext_dependencies() -> None: class CorrectMixer(PipelineStepMixin): """Test class for validation.""" - _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [{"name": "afni"}] + _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ + {"name": ExtDep.AFNI} + ] def validate_input(self, input: list[str]) -> list[str]: return input @@ -86,7 +88,7 @@ class CorrectMixer(PipelineStepMixin): """Test class for validation.""" _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ - {"name": "afni", "commands": ["3dReHo"]} + {"name": ExtDep.AFNI, "commands": ["3dReHo"]} ] def validate_input(self, input: list[str]) -> list[str]: @@ -111,7 +113,7 @@ class CorrectMixer(PipelineStepMixin): """Test class for validation.""" _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ - {"name": "afni", "commands": ["3d"]} + {"name": ExtDep.AFNI, "commands": ["3d"]} ] def validate_input(self, input: list[str]) -> list[str]: @@ -208,7 +210,9 @@ def test_PipelineStepMixin_correct_conditional_ext_dependencies() -> None: """Test fit-transform with correct conditional external dependencies.""" class ExternalDependency: - _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [{"name": "afni"}] + _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ + {"name": ExtDep.AFNI} + ] class CorrectMixer(PipelineStepMixin): """Test class for validation.""" From 32965605890275041957135eb806893d563b4948 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:32:26 +0100 Subject: [PATCH 51/99] update: add aenum as dependency and use it for DataType --- junifer/datagrabber/base.py | 4 ++-- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/junifer/datagrabber/base.py b/junifer/datagrabber/base.py index 16a563afd..0669ecfcd 100644 --- a/junifer/datagrabber/base.py +++ b/junifer/datagrabber/base.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any -from aenum import Enum +from aenum import Enum as AEnum from pydantic import BaseModel, ConfigDict, Field from ..pipeline import UpdateMetaMixin @@ -21,7 +21,7 @@ __all__ = ["BaseDataGrabber", "DataType"] -class DataType(str, Enum): +class DataType(str, AEnum): """Accepted data type.""" T1w = "T1w" diff --git a/pyproject.toml b/pyproject.toml index 5ebe1c368..544ae4223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "junifer_data==1.3.0", "structlog>=25.0.0,<26.0.0", "pydantic>=2.11.4", + "aenum>=3.1.0,<3.2.0", ] dynamic = ["version"] From 98537c210f45700dce2e231f6dbc1ee4ac0c989e Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:33:04 +0100 Subject: [PATCH 52/99] update: improve type handling for BaseDataGrabber.get_types --- junifer/datagrabber/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/junifer/datagrabber/base.py b/junifer/datagrabber/base.py index 0669ecfcd..029a50358 100644 --- a/junifer/datagrabber/base.py +++ b/junifer/datagrabber/base.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterator +from enum import Enum from pathlib import Path from typing import Any @@ -143,7 +144,7 @@ def get_types(self) -> list[str]: The data type(s) to grab. """ - return [x.value for x in self.types] + return [x.value if isinstance(x, Enum) else x for x in self.types] @property def fulldir(self) -> Path: From e7ae2bb00071db66f1a575da4757924505fc10f2 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:38:18 +0100 Subject: [PATCH 53/99] update: fix tests for datagrabber --- .../datagrabber/aomic/tests/test_id1000.py | 107 ++++---- junifer/datagrabber/aomic/tests/test_piop1.py | 168 +++++-------- junifer/datagrabber/aomic/tests/test_piop2.py | 142 ++++------- .../datagrabber/hcp1200/tests/test_hcp1200.py | 231 ++++++------------ .../datagrabber/tests/test_datalad_base.py | 39 ++- .../tests/test_dmcc13_benchmark.py | 194 ++++++--------- junifer/datagrabber/tests/test_multiple.py | 41 ++-- junifer/datagrabber/tests/test_pattern.py | 23 +- .../datagrabber/tests/test_pattern_datalad.py | 174 ++++++------- 9 files changed, 406 insertions(+), 713 deletions(-) diff --git a/junifer/datagrabber/aomic/tests/test_id1000.py b/junifer/datagrabber/aomic/tests/test_id1000.py index 2e850fc81..0160688e5 100644 --- a/junifer/datagrabber/aomic/tests/test_id1000.py +++ b/junifer/datagrabber/aomic/tests/test_id1000.py @@ -7,34 +7,35 @@ # Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Optional import pytest +from pydantic import HttpUrl from junifer.datagrabber.aomic.id1000 import DataladAOMICID1000 -URI = "https://gin.g-node.org/juaml/datalad-example-aomic1000" +URI = HttpUrl("https://gin.g-node.org/juaml/datalad-example-aomic1000") @pytest.mark.parametrize( "type_, nested_types, space", [ - ("BOLD", ["confounds", "mask", "reference"], "MNI152NLin2009cAsym"), - ("BOLD", ["confounds", "mask", "reference"], "native"), - ("T1w", ["mask"], "MNI152NLin2009cAsym"), - ("T1w", ["mask"], "native"), - ("VBM_CSF", None, "MNI152NLin2009cAsym"), - ("VBM_CSF", None, "native"), - ("VBM_GM", None, "MNI152NLin2009cAsym"), - ("VBM_GM", None, "native"), - ("VBM_WM", None, "MNI152NLin2009cAsym"), - ("DWI", None, "MNI152NLin2009cAsym"), - ("FreeSurfer", None, "MNI152NLin2009cAsym"), + (["BOLD"], ["confounds", "mask", "reference"], "MNI152NLin2009cAsym"), + (["BOLD"], ["confounds", "mask", "reference"], "native"), + (["T1w"], ["mask"], "MNI152NLin2009cAsym"), + (["T1w"], ["mask"], "native"), + (["VBM_CSF"], None, "MNI152NLin2009cAsym"), + (["VBM_CSF"], None, "native"), + (["VBM_GM"], None, "MNI152NLin2009cAsym"), + (["VBM_GM"], None, "native"), + (["VBM_WM"], None, "MNI152NLin2009cAsym"), + (["DWI"], None, "MNI152NLin2009cAsym"), + (["FreeSurfer"], None, "MNI152NLin2009cAsym"), ], ) def test_DataladAOMICID1000( - type_: str, + type_: list[str], nested_types: Optional[list[str]], space: str, ) -> None: @@ -42,7 +43,7 @@ def test_DataladAOMICID1000( Parameters ---------- - type_ : str + type_ : list of str The parametrized type. nested_types : list of str or None The parametrized nested types. @@ -50,43 +51,38 @@ def test_DataladAOMICID1000( The parametrized space. """ - dg = DataladAOMICID1000(types=type_, space=space) - # Set URI to Gin - dg.uri = URI - + dg = DataladAOMICID1000(uri=URI, types=type_, space=space) with dg: - # Get all elements all_elements = dg.get_elements() - # Get test element test_element = all_elements[0] - # Get test element data out = dg[test_element] # Assert data type - assert type_ in out - assert out[type_]["path"].exists() - assert out[type_]["path"].is_file() - # Asserts data type metadata - assert "meta" in out[type_] - meta = out[type_]["meta"] - assert "element" in meta - assert "subject" in meta["element"] - assert test_element == meta["element"]["subject"] - # Assert nested data type if not None - if nested_types is not None: - for nested_type in nested_types: - assert out[type_][nested_type]["path"].exists() - assert out[type_][nested_type]["path"].is_file() + for t in type_: + assert t in out + assert out[t]["path"].exists() + assert out[t]["path"].is_file() + # Asserts data type metadata + assert "meta" in out[t] + meta = out[t]["meta"] + assert "element" in meta + assert "subject" in meta["element"] + assert test_element == meta["element"]["subject"] + # Assert nested data type if not None + if nested_types is not None: + for nested_type in nested_types: + assert out[t][nested_type]["path"].exists() + assert out[t][nested_type]["path"].is_file() @pytest.mark.parametrize( "types", [ - "BOLD", - "T1w", - "VBM_CSF", - "VBM_GM", - "VBM_WM", - "DWI", + ["BOLD"], + ["T1w"], + ["VBM_CSF"], + ["VBM_GM"], + ["VBM_WM"], + ["DWI"], ["BOLD", "VBM_CSF"], ["T1w", "VBM_CSF"], ["VBM_GM", "VBM_WM"], @@ -94,38 +90,21 @@ def test_DataladAOMICID1000( ], ) def test_DataladAOMICID1000_partial_data_access( - types: Union[str, list[str]], + types: list[str], ) -> None: """Test DataladAOMICID1000 DataGrabber partial data access. Parameters ---------- - types : str or list of str + types : list of str The parametrized types. """ - dg = DataladAOMICID1000(types=types) - # Set URI to Gin - dg.uri = URI - + dg = DataladAOMICID1000(uri=URI, types=types) with dg: - # Get all elements all_elements = dg.get_elements() - # Get test element test_element = all_elements[0] - # Get test element data out = dg[test_element] # Assert data type - if isinstance(types, list): - for type_ in types: - assert type_ in out - else: - assert types in out - - -def test_DataladAOMICID1000_incorrect_data_type() -> None: - """Test DataladAOMICID1000 DataGrabber incorrect data type.""" - with pytest.raises( - ValueError, match="`patterns` must contain all `types`" - ): - _ = DataladAOMICID1000(types="Scooby-Doo") + for t in types: + assert t in out diff --git a/junifer/datagrabber/aomic/tests/test_piop1.py b/junifer/datagrabber/aomic/tests/test_piop1.py index dc33a0adc..3326666ea 100644 --- a/junifer/datagrabber/aomic/tests/test_piop1.py +++ b/junifer/datagrabber/aomic/tests/test_piop1.py @@ -7,143 +7,126 @@ # Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Optional import pytest +from pydantic import HttpUrl from junifer.datagrabber import DataladAOMICPIOP1 -URI = "https://gin.g-node.org/juaml/datalad-example-aomicpiop1" +URI = HttpUrl("https://gin.g-node.org/juaml/datalad-example-aomicpiop1") @pytest.mark.parametrize( "type_, nested_types, tasks, space", [ ( - "BOLD", - ["confounds", "mask", "reference"], - None, - "MNI152NLin2009cAsym", - ), - ("BOLD", ["confounds", "mask", "reference"], None, "native"), - ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["anticipation"], "MNI152NLin2009cAsym", ), ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["emomatching", "faces"], "MNI152NLin2009cAsym", ), ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["restingstate"], "MNI152NLin2009cAsym", ), ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["workingmemory", "gstroop"], "MNI152NLin2009cAsym", ), ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["anticipation", "faces", "restingstate"], "MNI152NLin2009cAsym", ), - ("T1w", ["mask"], None, "MNI152NLin2009cAsym"), - ("T1w", ["mask"], None, "native"), - ("VBM_CSF", None, None, "MNI152NLin2009cAsym"), - ("VBM_CSF", None, None, "native"), - ("VBM_GM", None, None, "MNI152NLin2009cAsym"), - ("VBM_GM", None, None, "native"), - ("VBM_WM", None, None, "MNI152NLin2009cAsym"), - ("VBM_WM", None, None, "native"), - ("DWI", None, None, "MNI152NLin2009cAsym"), - ("FreeSurfer", None, None, "MNI152NLin2009cAsym"), + (["T1w"], ["mask"], ["restingstate"], "MNI152NLin2009cAsym"), + (["T1w"], ["mask"], ["restingstate"], "native"), + (["VBM_CSF"], None, ["restingstate"], "MNI152NLin2009cAsym"), + (["VBM_CSF"], None, ["restingstate"], "native"), + (["VBM_GM"], None, ["restingstate"], "MNI152NLin2009cAsym"), + (["VBM_GM"], None, ["restingstate"], "native"), + (["VBM_WM"], None, ["restingstate"], "MNI152NLin2009cAsym"), + (["VBM_WM"], None, ["restingstate"], "native"), + (["DWI"], None, ["restingstate"], "MNI152NLin2009cAsym"), + (["FreeSurfer"], None, ["restingstate"], "MNI152NLin2009cAsym"), ], ) def test_DataladAOMICPIOP1( - type_: str, + type_: list[str], nested_types: Optional[list[str]], - tasks: Optional[list[str]], + tasks: list[str], space: str, ) -> None: """Test DataladAOMICPIOP1 DataGrabber. Parameters ---------- - type_ : str + type_ : list of str The parametrized type. nested_types : list of str or None The parametrized nested types. - tasks : list of str or None + tasks : list of str The parametrized task values. space: str The parametrized space. """ - dg = DataladAOMICPIOP1(types=type_, tasks=tasks, space=space) - # Set URI to Gin - dg.uri = URI - + dg = DataladAOMICPIOP1(uri=URI, types=type_, tasks=tasks, space=space) with dg: - # Get all elements - all_elements = dg.get_elements() - # Get test element - test_element = all_elements[0] - # Get test element data - out = dg[test_element] - # Get all elements all_elements = dg.get_elements() - # Get test element test_element = all_elements[0] - # Get test element data out = dg[test_element] # Assert data type - assert type_ in out - # Check task name if BOLD - if type_ == "BOLD" and tasks is not None: - # Depending on task 'acquisition is different' - task_acqs = { - "anticipation": "seq", - "emomatching": "seq", - "faces": "mb3", - "gstroop": "seq", - "restingstate": "mb3", - "workingmemory": "seq", - } - assert task_acqs[test_element[1]] in out[type_]["path"].name - assert out[type_]["path"].exists() - assert out[type_]["path"].is_file() - # Asserts data type metadata - assert "meta" in out[type_] - meta = out[type_]["meta"] - assert "element" in meta - assert "subject" in meta["element"] - assert test_element[0] == meta["element"]["subject"] - # Assert nested data type if not None - if nested_types is not None: - for nested_type in nested_types: - assert out[type_][nested_type]["path"].exists() - assert out[type_][nested_type]["path"].is_file() + for t in type_: + assert t in out + # Check task name if BOLD + if t == "BOLD": + # Depending on task 'acquisition is different' + task_acqs = { + "anticipation": "seq", + "emomatching": "seq", + "faces": "mb3", + "gstroop": "seq", + "restingstate": "mb3", + "workingmemory": "seq", + } + assert task_acqs[test_element[1]] in out[t]["path"].name + assert out[t]["path"].exists() + assert out[t]["path"].is_file() + # Asserts data type metadata + assert "meta" in out[t] + meta = out[t]["meta"] + assert "element" in meta + assert "subject" in meta["element"] + assert test_element[0] == meta["element"]["subject"] + # Assert nested data type if not None + if nested_types is not None: + for nested_type in nested_types: + assert out[t][nested_type]["path"].exists() + assert out[t][nested_type]["path"].is_file() @pytest.mark.parametrize( "types", [ - "BOLD", - "T1w", - "VBM_CSF", - "VBM_GM", - "VBM_WM", - "DWI", + ["BOLD"], + ["T1w"], + ["VBM_CSF"], + ["VBM_GM"], + ["VBM_WM"], + ["DWI"], ["BOLD", "VBM_CSF"], ["T1w", "VBM_CSF"], ["VBM_GM", "VBM_WM"], @@ -151,50 +134,21 @@ def test_DataladAOMICPIOP1( ], ) def test_DataladAOMICPIOP1_partial_data_access( - types: Union[str, list[str]], + types: list[str], ) -> None: """Test DataladAOMICPIOP1 DataGrabber partial data access. Parameters ---------- - types : str or list of str + types : list of str The parametrized types. """ - dg = DataladAOMICPIOP1(types=types) - # Set URI to Gin - dg.uri = URI - + dg = DataladAOMICPIOP1(uri=URI, types=types) with dg: - # Get all elements all_elements = dg.get_elements() - # Get test element test_element = all_elements[0] - # Get test element data out = dg[test_element] # Assert data type - if isinstance(types, list): - for type_ in types: - assert type_ in out - else: - assert types in out - - -def test_DataladAOMICPIOP1_incorrect_data_type() -> None: - """Test DataladAOMICPIOP1 DataGrabber incorrect data type.""" - with pytest.raises( - ValueError, match="`patterns` must contain all `types`" - ): - _ = DataladAOMICPIOP1(types="Ceres") - - -def test_DataladAOMICPIOP1_invalid_tasks(): - """Test DataladAOMICIDPIOP1 DataGrabber invalid tasks.""" - with pytest.raises( - ValueError, - match=( - "thisisnotarealtask is not a valid task in " - "the AOMIC PIOP1 dataset!" - ), - ): - DataladAOMICPIOP1(tasks="thisisnotarealtask") + for t in types: + assert t in out diff --git a/junifer/datagrabber/aomic/tests/test_piop2.py b/junifer/datagrabber/aomic/tests/test_piop2.py index dc5454511..38158a2b0 100644 --- a/junifer/datagrabber/aomic/tests/test_piop2.py +++ b/junifer/datagrabber/aomic/tests/test_piop2.py @@ -7,122 +7,111 @@ # Synchon Mandal # License: AGPL -from typing import Optional, Union +from typing import Optional import pytest +from pydantic import HttpUrl from junifer.datagrabber import DataladAOMICPIOP2 -URI = "https://gin.g-node.org/juaml/datalad-example-aomicpiop2" +URI = HttpUrl("https://gin.g-node.org/juaml/datalad-example-aomicpiop2") @pytest.mark.parametrize( "type_, nested_types, tasks, space", [ ( - "BOLD", - ["confounds", "mask", "reference"], - None, - "MNI152NLin2009cAsym", - ), - ("BOLD", ["confounds", "mask", "reference"], None, "native"), - ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["restingstate"], "MNI152NLin2009cAsym", ), ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["restingstate", "stopsignal"], "MNI152NLin2009cAsym", ), ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["workingmemory", "stopsignal"], "MNI152NLin2009cAsym", ), ( - "BOLD", + ["BOLD"], ["confounds", "mask", "reference"], ["workingmemory"], "MNI152NLin2009cAsym", ), - ("T1w", ["mask"], None, "MNI152NLin2009cAsym"), - ("T1w", ["mask"], None, "native"), - ("VBM_CSF", None, None, "MNI152NLin2009cAsym"), - ("VBM_CSF", None, None, "native"), - ("VBM_GM", None, None, "MNI152NLin2009cAsym"), - ("VBM_GM", None, None, "native"), - ("VBM_WM", None, None, "MNI152NLin2009cAsym"), - ("VBM_WM", None, None, "native"), - ("DWI", None, None, "MNI152NLin2009cAsym"), - ("FreeSurfer", None, None, "MNI152NLin2009cAsym"), + (["T1w"], ["mask"], ["restingstate"], "MNI152NLin2009cAsym"), + (["T1w"], ["mask"], ["restingstate"], "native"), + (["VBM_CSF"], None, ["restingstate"], "MNI152NLin2009cAsym"), + (["VBM_CSF"], None, ["restingstate"], "native"), + (["VBM_GM"], None, ["restingstate"], "MNI152NLin2009cAsym"), + (["VBM_GM"], None, ["restingstate"], "native"), + (["VBM_WM"], None, ["restingstate"], "MNI152NLin2009cAsym"), + (["VBM_WM"], None, ["restingstate"], "native"), + (["DWI"], None, ["restingstate"], "MNI152NLin2009cAsym"), + (["FreeSurfer"], None, ["restingstate"], "MNI152NLin2009cAsym"), ], ) def test_DataladAOMICPIOP2( - type_: str, + type_: list[str], nested_types: Optional[list[str]], - tasks: Optional[list[str]], + tasks: list[str], space: str, ) -> None: """Test DataladAOMICPIOP2 DataGrabber. Parameters ---------- - type_ : str + type_ : list of str The parametrized type. nested_types : list of str or None The parametrized nested types. - tasks : list of str or None + tasks : list of str The parametrized task values. space: str The parametrized space. """ - dg = DataladAOMICPIOP2(types=type_, tasks=tasks, space=space) - # Set URI to Gin - dg.uri = URI - + dg = DataladAOMICPIOP2(uri=URI, types=type_, tasks=tasks, space=space) with dg: - # Get all elements all_elements = dg.get_elements() - # Get test element test_element = all_elements[0] - # Get test element data out = dg[test_element] # Assert data type - assert type_ in out - # Check task name if BOLD - if type_ == "BOLD" and tasks is not None: - assert test_element[1] in out[type_]["path"].name - assert out[type_]["path"].exists() - assert out[type_]["path"].is_file() - # Asserts data type metadata - assert "meta" in out[type_] - meta = out[type_]["meta"] - assert "element" in meta - assert "subject" in meta["element"] - assert test_element[0] == meta["element"]["subject"] - # Assert nested data type if not None - if nested_types is not None: - for nested_type in nested_types: - assert out[type_][nested_type]["path"].exists() - assert out[type_][nested_type]["path"].is_file() + for t in type_: + assert t in out + # Check task name if BOLD + if t == "BOLD": + assert test_element[1] in out[t]["path"].name + assert out[t]["path"].exists() + assert out[t]["path"].is_file() + # Asserts data type metadata + assert "meta" in out[t] + meta = out[t]["meta"] + assert "element" in meta + assert "subject" in meta["element"] + assert test_element[0] == meta["element"]["subject"] + # Assert nested data type if not None + if nested_types is not None: + for nested_type in nested_types: + assert out[t][nested_type]["path"].exists() + assert out[t][nested_type]["path"].is_file() @pytest.mark.parametrize( "types", [ - "BOLD", - "T1w", - "VBM_CSF", - "VBM_GM", - "VBM_WM", - "DWI", + ["BOLD"], + ["T1w"], + ["VBM_CSF"], + ["VBM_GM"], + ["VBM_WM"], + ["DWI"], ["BOLD", "VBM_CSF"], ["T1w", "VBM_CSF"], ["VBM_GM", "VBM_WM"], @@ -130,50 +119,21 @@ def test_DataladAOMICPIOP2( ], ) def test_DataladAOMICPIOP2_partial_data_access( - types: Union[str, list[str]], + types: list[str], ) -> None: """Test DataladAOMICPIOP2 DataGrabber partial data access. Parameters ---------- - types : str or list of str + types : list of str The parametrized types. """ - dg = DataladAOMICPIOP2(types=types) - # Set URI to Gin - dg.uri = URI - + dg = DataladAOMICPIOP2(uri=URI, types=types) with dg: - # Get all elements all_elements = dg.get_elements() - # Get test element test_element = all_elements[0] - # Get test element data out = dg[test_element] # Assert data type - if isinstance(types, list): - for type_ in types: - assert type_ in out - else: - assert types in out - - -def test_DataladAOMICPIOP2_incorrect_data_type() -> None: - """Test DataladAOMICPIOP2 DataGrabber incorrect data type.""" - with pytest.raises( - ValueError, match="`patterns` must contain all `types`" - ): - _ = DataladAOMICPIOP2(types="Vesta") - - -def test_DataladAOMICPIOP2_invalid_tasks(): - """Test DataladAOMICIDPIOP2 DataGrabber invalid tasks.""" - with pytest.raises( - ValueError, - match=( - "thisisnotarealtask is not a valid task in " - "the AOMIC PIOP2 dataset!" - ), - ): - DataladAOMICPIOP2(tasks="thisisnotarealtask") + for t_ in types: + assert t_ in out diff --git a/junifer/datagrabber/hcp1200/tests/test_hcp1200.py b/junifer/datagrabber/hcp1200/tests/test_hcp1200.py index 2ecb8c276..011545f63 100644 --- a/junifer/datagrabber/hcp1200/tests/test_hcp1200.py +++ b/junifer/datagrabber/hcp1200/tests/test_hcp1200.py @@ -7,65 +7,63 @@ import tempfile from collections.abc import Iterable from pathlib import Path -from typing import Optional import pytest +from pydantic import HttpUrl from junifer.datagrabber import HCP1200, DataladHCP1200 -from junifer.utils import configure_logging - - -URI = "https://gin.g-node.org/juaml/datalad-example-hcp1200" +from junifer.utils import config, configure_logging @pytest.fixture(scope="module") def hcpdg() -> Iterable[DataladHCP1200]: """Return a HCP1200 DataGrabber.""" tmpdir = Path(tempfile.gettempdir()) - dg = DataladHCP1200(datadir=tmpdir / "datadir") - # Set URI to Gin - dg.uri = URI - # Set correct root directory - dg._rootdir = "." + config.set(key="datagrabber.skipidcheck", val=True) + dg = DataladHCP1200( + uri=HttpUrl("https://gin.g-node.org/juaml/datalad-example-hcp1200"), + datadir=tmpdir / "hcp1200_test", + rootdir=Path("."), + ) with dg: for t_elem in dg.get_elements(): dg[t_elem] yield dg - shutil.rmtree(tmpdir / "datadir", ignore_errors=True) + config.set(key="datagrabber.skipidcheck", val=False) + shutil.rmtree(tmpdir / "hcp1200_test", ignore_errors=True) @pytest.mark.parametrize( "tasks, phase_encodings, ica_fix, expected_path_name", [ - (None, None, False, "rfMRI_REST1_LR.nii.gz"), - ("REST1", "LR", False, "rfMRI_REST1_LR.nii.gz"), - ("REST1", "RL", False, "rfMRI_REST1_RL.nii.gz"), - ("REST2", "LR", False, "rfMRI_REST2_LR.nii.gz"), - ("REST2", "RL", False, "rfMRI_REST2_RL.nii.gz"), - ("SOCIAL", "LR", False, "tfMRI_SOCIAL_LR.nii.gz"), - ("SOCIAL", "RL", False, "tfMRI_SOCIAL_RL.nii.gz"), - ("WM", "LR", False, "tfMRI_WM_LR.nii.gz"), - ("WM", "RL", False, "tfMRI_WM_RL.nii.gz"), - ("RELATIONAL", "LR", False, "tfMRI_RELATIONAL_LR.nii.gz"), - ("RELATIONAL", "RL", False, "tfMRI_RELATIONAL_RL.nii.gz"), - ("EMOTION", "LR", False, "tfMRI_EMOTION_LR.nii.gz"), - ("EMOTION", "RL", False, "tfMRI_EMOTION_RL.nii.gz"), - ("LANGUAGE", "LR", False, "tfMRI_LANGUAGE_LR.nii.gz"), - ("LANGUAGE", "RL", False, "tfMRI_LANGUAGE_RL.nii.gz"), - ("GAMBLING", "LR", False, "tfMRI_GAMBLING_LR.nii.gz"), - ("GAMBLING", "RL", False, "tfMRI_GAMBLING_RL.nii.gz"), - ("MOTOR", "LR", False, "tfMRI_MOTOR_LR.nii.gz"), - ("MOTOR", "RL", False, "tfMRI_MOTOR_RL.nii.gz"), - ("REST1", "LR", True, "rfMRI_REST1_LR_hp2000_clean.nii.gz"), - ("REST1", "RL", True, "rfMRI_REST1_RL_hp2000_clean.nii.gz"), - ("REST2", "LR", True, "rfMRI_REST2_LR_hp2000_clean.nii.gz"), - ("REST2", "RL", True, "rfMRI_REST2_RL_hp2000_clean.nii.gz"), + (["REST1"], ["LR"], False, "rfMRI_REST1_LR.nii.gz"), + (["REST1"], ["RL"], False, "rfMRI_REST1_RL.nii.gz"), + (["REST2"], ["LR"], False, "rfMRI_REST2_LR.nii.gz"), + (["REST2"], ["RL"], False, "rfMRI_REST2_RL.nii.gz"), + (["SOCIAL"], ["LR"], False, "tfMRI_SOCIAL_LR.nii.gz"), + (["SOCIAL"], ["RL"], False, "tfMRI_SOCIAL_RL.nii.gz"), + (["WM"], ["LR"], False, "tfMRI_WM_LR.nii.gz"), + (["WM"], ["RL"], False, "tfMRI_WM_RL.nii.gz"), + (["RELATIONAL"], ["LR"], False, "tfMRI_RELATIONAL_LR.nii.gz"), + (["RELATIONAL"], ["RL"], False, "tfMRI_RELATIONAL_RL.nii.gz"), + (["EMOTION"], ["LR"], False, "tfMRI_EMOTION_LR.nii.gz"), + (["EMOTION"], ["RL"], False, "tfMRI_EMOTION_RL.nii.gz"), + (["LANGUAGE"], ["LR"], False, "tfMRI_LANGUAGE_LR.nii.gz"), + (["LANGUAGE"], ["RL"], False, "tfMRI_LANGUAGE_RL.nii.gz"), + (["GAMBLING"], ["LR"], False, "tfMRI_GAMBLING_LR.nii.gz"), + (["GAMBLING"], ["RL"], False, "tfMRI_GAMBLING_RL.nii.gz"), + (["MOTOR"], ["LR"], False, "tfMRI_MOTOR_LR.nii.gz"), + (["MOTOR"], ["RL"], False, "tfMRI_MOTOR_RL.nii.gz"), + (["REST1"], ["LR"], True, "rfMRI_REST1_LR_hp2000_clean.nii.gz"), + (["REST1"], ["RL"], True, "rfMRI_REST1_RL_hp2000_clean.nii.gz"), + (["REST2"], ["LR"], True, "rfMRI_REST2_LR_hp2000_clean.nii.gz"), + (["REST2"], ["RL"], True, "rfMRI_REST2_RL_hp2000_clean.nii.gz"), ], ) def test_HCP1200( hcpdg: DataladHCP1200, - tasks: Optional[str], - phase_encodings: Optional[str], + tasks: list[str], + phase_encodings: list[str], ica_fix: bool, expected_path_name: str, ) -> None: @@ -76,9 +74,9 @@ def test_HCP1200( hcpdg : DataladHCP1200 The Datalad version of the DataGrabber with the first subject already cloned. - tasks : str + tasks : list of str The parametrized tasks. - phase_encodings : str + phase_encodings : list of str The parametrized phase encodings. ica_fix : bool The parametrized ICA-FIX flag. @@ -118,30 +116,30 @@ def test_HCP1200( @pytest.mark.parametrize( "tasks, phase_encodings", [ - ("REST1", "LR"), - ("REST1", "RL"), - ("REST2", "LR"), - ("REST2", "RL"), - ("SOCIAL", "LR"), - ("SOCIAL", "RL"), - ("WM", "LR"), - ("WM", "RL"), - ("RELATIONAL", "LR"), - ("RELATIONAL", "RL"), - ("EMOTION", "LR"), - ("EMOTION", "RL"), - ("LANGUAGE", "LR"), - ("LANGUAGE", "RL"), - ("GAMBLING", "LR"), - ("GAMBLING", "RL"), - ("MOTOR", "LR"), - ("MOTOR", "RL"), + (["REST1"], ["LR"]), + (["REST1"], ["RL"]), + (["REST2"], ["LR"]), + (["REST2"], ["RL"]), + (["SOCIAL"], ["LR"]), + (["SOCIAL"], ["RL"]), + (["WM"], ["LR"]), + (["WM"], ["RL"]), + (["RELATIONAL"], ["LR"]), + (["RELATIONAL"], ["RL"]), + (["EMOTION"], ["LR"]), + (["EMOTION"], ["RL"]), + (["LANGUAGE"], ["LR"]), + (["LANGUAGE"], ["RL"]), + (["GAMBLING"], ["LR"]), + (["GAMBLING"], ["RL"]), + (["MOTOR"], ["LR"]), + (["MOTOR"], ["RL"]), ], ) def test_HCP1200_single_access( hcpdg: DataladHCP1200, - tasks: Optional[str], - phase_encodings: Optional[str], + tasks: list[str], + phase_encodings: list[str], ) -> None: """Test HCP1200 DataGrabber single access. @@ -150,9 +148,9 @@ def test_HCP1200_single_access( hcpdg : DataladHCP1200 The Datalad version of the DataGrabber with the first subject already cloned. - tasks : str + tasks : list of str The parametrized tasks. - phase_encodings : str + phase_encodings : list of str The parametrized phase encodings. """ @@ -167,21 +165,12 @@ def test_HCP1200_single_access( all_elements = dg.get_elements() # Check only specified task and phase encoding are found for element in all_elements: - assert element[1] == tasks - assert element[2] == phase_encodings + assert element[1] == tasks[0] + assert element[2] == phase_encodings[0] -@pytest.mark.parametrize( - "tasks, phase_encodings", - [ - (["REST1", "REST2"], ["LR", "RL"]), - (["REST1", "REST2"], None), - ], -) def test_HCP1200_multi_access( hcpdg: DataladHCP1200, - tasks: Optional[str], - phase_encodings: Optional[str], ) -> None: """Test HCP1200 DataGrabber multiple access. @@ -190,17 +179,13 @@ def test_HCP1200_multi_access( hcpdg : DataladHCP1200 The Datalad version of the DataGrabber with the first subject already cloned. - tasks : str - The parametrized tasks. - phase_encodings : str - The parametrized phase encodings. """ configure_logging(level="DEBUG") dg = HCP1200( datadir=hcpdg.datadir, - tasks=tasks, - phase_encodings=phase_encodings, + tasks=["REST1", "REST2"], + phase_encodings=["LR", "RL"], ) with dg: # Get all elements @@ -226,7 +211,7 @@ def test_HCP1200_multi_access_task_simple( configure_logging(level="DEBUG") dg = HCP1200( datadir=hcpdg.datadir, - tasks="REST1", + tasks=["REST1"], phase_encodings=["LR", "RL"], ) with dg: @@ -254,7 +239,7 @@ def test_HCP1200_multi_access_phase_simple( dg = HCP1200( datadir=hcpdg.datadir, tasks=["REST1", "REST2"], - phase_encodings="LR", + phase_encodings=["LR"], ) with dg: # Get all elements @@ -265,70 +250,6 @@ def test_HCP1200_multi_access_phase_simple( assert element[2] == "LR" -@pytest.mark.parametrize( - "tasks, phase_encodings", - [ - ("FOO", ["LR", "RL"]), - ("FOO", "RL"), - (["FOO", "BAR"], ["LR", "RL"]), - (["FOO", "BAR"], "LR"), - ], -) -def test_HCP1200_incorrect_access_task( - tasks: Optional[str], - phase_encodings: Optional[str], -) -> None: - """Test HCP1200 DataGrabber incorrect access for task. - - Parameters - ---------- - tasks : str - The parametrized tasks. - phase_encodings : str - The parametrized phase encodings. - - """ - configure_logging(level="DEBUG") - with pytest.raises(ValueError, match="not a valid HCP-YA fMRI task input"): - _ = HCP1200( - datadir=".", - tasks=tasks, - phase_encodings=phase_encodings, - ) - - -@pytest.mark.parametrize( - "tasks, phase_encodings", - [ - ("REST1", ["FOO", "BAR"]), - ("REST1", "FOO"), - (["REST1", "REST2"], ["FOO", "BAR"]), - (["REST1", "REST2"], "BAR"), - ], -) -def test_HCP1200_incorrect_access_phase( - tasks: Optional[str], - phase_encodings: Optional[str], -) -> None: - """Test HCP1200 DataGrabber incorrect access for phase. - - Parameters - ---------- - tasks : str - The parametrized tasks. - phase_encodings : str - The parametrized phase encodings. - - """ - configure_logging(level="DEBUG") - with pytest.raises(ValueError, match="not a valid HCP-YA phase encoding"): - _ = HCP1200( - datadir=".", - tasks=tasks, - phase_encodings=phase_encodings, - ) - - def test_HCP1200_elements( hcpdg: DataladHCP1200, ) -> None: @@ -344,8 +265,8 @@ def test_HCP1200_elements( configure_logging(level="DEBUG") dg = HCP1200( datadir=hcpdg.datadir, - tasks="REST1", - phase_encodings="LR", + tasks=["REST1"], + phase_encodings=["LR"], ) with dg: # Get all elements @@ -364,23 +285,23 @@ def test_HCP1200_elements( @pytest.mark.parametrize( "tasks, ica_fix", [ - ("SOCIAL", True), - ("WM", True), - ("RELATIONAL", True), - ("EMOTION", True), - ("LANGUAGE", True), - ("GAMBLING", True), - ("MOTOR", True), + (["SOCIAL"], True), + (["WM"], True), + (["RELATIONAL"], True), + (["EMOTION"], True), + (["LANGUAGE"], True), + (["GAMBLING"], True), + (["MOTOR"], True), ], ) def test_HCP1200_incorrect_access_icafix( - tasks: Optional[str], ica_fix: bool + tasks: list[str], ica_fix: bool ) -> None: """Test HCP1200 DataGrabber incorrect access for icafix. Parameters ---------- - tasks : str + tasks : list of str The parametrized tasks. ica_fix : bool The parametrized ICA-FIX flag. @@ -389,7 +310,7 @@ def test_HCP1200_incorrect_access_icafix( configure_logging(level="DEBUG") with pytest.raises(ValueError, match="is only available for"): _ = HCP1200( - datadir=".", + datadir=Path("."), tasks=tasks, ica_fix=ica_fix, ) diff --git a/junifer/datagrabber/tests/test_datalad_base.py b/junifer/datagrabber/tests/test_datalad_base.py index f3e5e4661..0c5756a58 100644 --- a/junifer/datagrabber/tests/test_datalad_base.py +++ b/junifer/datagrabber/tests/test_datalad_base.py @@ -9,7 +9,7 @@ import datalad.api as dl import pytest -from junifer.datagrabber import DataladDataGrabber +from junifer.datagrabber import DataladDataGrabber, DataType from junifer.utils import config @@ -38,23 +38,18 @@ def concrete_datagrabber() -> type[DataladDataGrabber]: """ - class MyDataGrabber(DataladDataGrabber): # type: ignore - def __init__(self, datadir, uri): - super().__init__( - datadir=datadir, - rootdir="example_bids", - uri=uri, - types=["T1w", "BOLD"], - ) + class MyDataGrabber(DataladDataGrabber): + types: list[DataType] = [DataType.T1w, DataType.BOLD] # noqa: RUF012 + rootdir: Path = Path("example_bids") def get_item(self, subject): out = { "T1w": { - "path": self.datadir + "path": self.fulldir / f"{subject}/anat/{subject}_T1w.nii.gz" }, "BOLD": { - "path": self.datadir + "path": self.fulldir / f"{subject}/func/{subject}_task-rest_bold.nii.gz" }, } @@ -91,7 +86,7 @@ def test_DataladDataGrabber_install_errors( # Files are not there assert datadir.exists() is False # Clone dataset - dl.clone(uri, datadir) # type: ignore + dl.clone(uri, datadir) dg = concrete_datagrabber(datadir=datadir, uri=uri2) with pytest.raises(ValueError, match=r"different ID"): with dg: @@ -160,7 +155,6 @@ def test_DataladDataGrabber_clone_cleanup( assert "datagrabber" in meta assert "datalad_dirty" in meta["datagrabber"] assert meta["datagrabber"]["datalad_dirty"] is False - assert hasattr(dg, "_got_files") is False assert datadir.exists() is True assert elem1_bold.is_file() is True assert elem1_bold.is_symlink() is True @@ -185,8 +179,8 @@ def test_DataladDataGrabber_clone_create_cleanup( # Clone whole dataset uri = _testing_dataset["example_bids"]["uri"] - with concrete_datagrabber(datadir=None, uri=uri) as dg: - datadir = dg._tmpdir / "datadir" + with concrete_datagrabber(uri=uri) as dg: + datadir = dg._repodir elem1_bold = ( datadir / "example_bids/sub-01/func/sub-01_task-rest_bold.nii.gz" ) @@ -206,7 +200,6 @@ def test_DataladDataGrabber_clone_create_cleanup( assert "datagrabber" in meta assert "datalad_dirty" in meta["datagrabber"] assert meta["datagrabber"]["datalad_dirty"] is False - assert hasattr(dg, "_got_files") is False assert datadir.exists() is True assert elem1_bold.is_file() is True assert elem1_bold.is_symlink() is True @@ -246,7 +239,7 @@ def test_DataladDataGrabber_previously_cloned( assert elem1_t1w.exists() is False # Clone dataset - dl.clone(uri, datadir, result_renderer="disabled") # type: ignore + dl.clone(uri, datadir, result_renderer="disabled") # Files are there, but are empty symbolic links assert datadir.exists() is True @@ -316,7 +309,7 @@ def test_DataladDataGrabber_previously_cloned_and_get( assert elem1_t1w.exists() is False # Clone dataset - dl.clone(uri, datadir, result_renderer="disabled") # type: ignore + dl.clone(uri, datadir, result_renderer="disabled") # Files are there, but are empty symbolic links assert datadir.exists() is True @@ -325,9 +318,7 @@ def test_DataladDataGrabber_previously_cloned_and_get( assert elem1_t1w.is_symlink() is True assert elem1_t1w.is_file() is False - dl.get( # type: ignore - elem1_t1w, dataset=datadir, result_renderer="disabled" - ) + dl.get(elem1_t1w, dataset=datadir, result_renderer="disabled") assert elem1_bold.is_symlink() is True assert elem1_bold.is_file() is False @@ -399,7 +390,7 @@ def test_DataladDataGrabber_previously_cloned_and_get_dirty( assert elem1_t1w.exists() is False # Clone dataset - dl.clone(uri, datadir, result_renderer="disabled") # type: ignore + dl.clone(uri, datadir, result_renderer="disabled") # Files are there, but are empty symbolic links assert datadir.exists() is True @@ -408,9 +399,7 @@ def test_DataladDataGrabber_previously_cloned_and_get_dirty( assert elem1_t1w.is_symlink() is True assert elem1_t1w.is_file() is False - dl.get( # type: ignore - elem1_t1w, dataset=datadir, result_renderer="disabled" - ) + dl.get(elem1_t1w, dataset=datadir, result_renderer="disabled") assert elem1_bold.is_symlink() is True assert elem1_bold.is_file() is False diff --git a/junifer/datagrabber/tests/test_dmcc13_benchmark.py b/junifer/datagrabber/tests/test_dmcc13_benchmark.py index d3e124c90..594aab089 100644 --- a/junifer/datagrabber/tests/test_dmcc13_benchmark.py +++ b/junifer/datagrabber/tests/test_dmcc13_benchmark.py @@ -3,104 +3,98 @@ # Authors: Synchon Mandal # License: AGPL -from typing import Optional, Union - import pytest +from pydantic import HttpUrl -from junifer.datagrabber import DMCC13Benchmark +from junifer.datagrabber import DataType, DMCC13Benchmark -URI = "https://gin.g-node.org/synchon/datalad-example-dmcc13-benchmark" +URI = HttpUrl( + "https://gin.g-node.org/synchon/datalad-example-dmcc13-benchmark" +) @pytest.mark.parametrize( "sessions, tasks, phase_encodings, runs, native_t1w", [ - (None, None, None, None, False), - ("ses-wave1bas", "Rest", "AP", "1", False), - ("ses-wave1bas", "Axcpt", "AP", "1", False), - ("ses-wave1bas", "Cuedts", "AP", "1", False), - ("ses-wave1bas", "Stern", "AP", "1", False), - ("ses-wave1bas", "Stroop", "AP", "1", False), - ("ses-wave1bas", "Rest", "PA", "2", False), - ("ses-wave1bas", "Axcpt", "PA", "2", False), - ("ses-wave1bas", "Cuedts", "PA", "2", False), - ("ses-wave1bas", "Stern", "PA", "2", False), - ("ses-wave1bas", "Stroop", "PA", "2", False), - ("ses-wave1bas", "Rest", "AP", "1", True), - ("ses-wave1bas", "Axcpt", "AP", "1", True), - ("ses-wave1bas", "Cuedts", "AP", "1", True), - ("ses-wave1bas", "Stern", "AP", "1", True), - ("ses-wave1bas", "Stroop", "AP", "1", True), - ("ses-wave1bas", "Rest", "PA", "2", True), - ("ses-wave1bas", "Axcpt", "PA", "2", True), - ("ses-wave1bas", "Cuedts", "PA", "2", True), - ("ses-wave1bas", "Stern", "PA", "2", True), - ("ses-wave1bas", "Stroop", "PA", "2", True), - ("ses-wave1pro", "Rest", "AP", "1", False), - ("ses-wave1pro", "Rest", "PA", "2", False), - ("ses-wave1pro", "Rest", "AP", "1", True), - ("ses-wave1pro", "Rest", "PA", "2", True), - ("ses-wave1rea", "Rest", "AP", "1", False), - ("ses-wave1rea", "Rest", "PA", "2", False), - ("ses-wave1rea", "Rest", "AP", "1", True), - ("ses-wave1rea", "Rest", "PA", "2", True), + (["ses-wave1bas"], ["Rest"], ["AP"], ["1"], False), + (["ses-wave1bas"], ["Axcpt"], ["AP"], ["1"], False), + (["ses-wave1bas"], ["Cuedts"], ["AP"], ["1"], False), + (["ses-wave1bas"], ["Stern"], ["AP"], ["1"], False), + (["ses-wave1bas"], ["Stroop"], ["AP"], ["1"], False), + (["ses-wave1bas"], ["Rest"], ["PA"], ["2"], False), + (["ses-wave1bas"], ["Axcpt"], ["PA"], ["2"], False), + (["ses-wave1bas"], ["Cuedts"], ["PA"], ["2"], False), + (["ses-wave1bas"], ["Stern"], ["PA"], ["2"], False), + (["ses-wave1bas"], ["Stroop"], ["PA"], ["2"], False), + (["ses-wave1bas"], ["Rest"], ["AP"], ["1"], True), + (["ses-wave1bas"], ["Axcpt"], ["AP"], ["1"], True), + (["ses-wave1bas"], ["Cuedts"], ["AP"], ["1"], True), + (["ses-wave1bas"], ["Stern"], ["AP"], ["1"], True), + (["ses-wave1bas"], ["Stroop"], ["AP"], ["1"], True), + (["ses-wave1bas"], ["Rest"], ["PA"], ["2"], True), + (["ses-wave1bas"], ["Axcpt"], ["PA"], ["2"], True), + (["ses-wave1bas"], ["Cuedts"], ["PA"], ["2"], True), + (["ses-wave1bas"], ["Stern"], ["PA"], ["2"], True), + (["ses-wave1bas"], ["Stroop"], ["PA"], ["2"], True), + (["ses-wave1pro"], ["Rest"], ["AP"], ["1"], False), + (["ses-wave1pro"], ["Rest"], ["PA"], ["2"], False), + (["ses-wave1pro"], ["Rest"], ["AP"], ["1"], True), + (["ses-wave1pro"], ["Rest"], ["PA"], ["2"], True), + (["ses-wave1rea"], ["Rest"], ["AP"], ["1"], False), + (["ses-wave1rea"], ["Rest"], ["PA"], ["2"], False), + (["ses-wave1rea"], ["Rest"], ["AP"], ["1"], True), + (["ses-wave1rea"], ["Rest"], ["PA"], ["2"], True), ], ) def test_DMCC13Benchmark( - sessions: Optional[str], - tasks: Optional[str], - phase_encodings: Optional[str], - runs: Optional[str], + sessions: list[str], + tasks: list[str], + phase_encodings: list[str], + runs: list[str], native_t1w: bool, ) -> None: """Test DMCC13Benchmark DataGrabber. Parameters ---------- - sessions : str or None + sessions : list of str The parametrized session values. - tasks : str or None + tasks : list of str The parametrized task values. - phase_encodings : str or None + phase_encodings : list of str The parametrized phase encoding values. - runs : str or None + runs : list of str The parametrized run values. native_t1w : bool The parametrized values for fetching native T1w. """ dg = DMCC13Benchmark( + uri=URI, sessions=sessions, tasks=tasks, phase_encodings=phase_encodings, runs=runs, native_t1w=native_t1w, ) - # Set URI to Gin - dg.uri = URI - with dg: - # Get all elements all_elements = dg.get_elements() - # Get test element test_element = all_elements[0] - # Get test element's access values _, ses, task, phase, run = test_element - # Access data out = dg[("sub-01", ses, task, phase, run)] # Available data types data_types = [ - "BOLD", - "VBM_CSF", - "VBM_GM", - "VBM_WM", - "T1w", + DataType.BOLD, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + DataType.T1w, ] # Add Warp if native T1w is accessed if native_t1w: - data_types.append("Warp") + data_types.append(DataType.Warp) # Data type file name formats data_file_names = [ @@ -131,9 +125,9 @@ def test_DMCC13Benchmark( for data_type, data_file_name in zip(data_types, data_file_names): # Assert data type - assert data_type in out + assert data_type in out.keys() # Conditional for Warp - if data_type == "Warp": + if data_type is DataType.Warp: for idx, fname in enumerate(data_file_name): # Assert data file path exists assert out[data_type][idx]["path"].exists() @@ -199,16 +193,16 @@ def test_DMCC13Benchmark( @pytest.mark.parametrize( "types, native_t1w", [ - ("BOLD", True), - ("BOLD", False), - ("T1w", True), - ("T1w", False), - ("VBM_CSF", True), - ("VBM_CSF", False), - ("VBM_GM", True), - ("VBM_GM", False), - ("VBM_WM", True), - ("VBM_WM", False), + (["BOLD"], True), + (["BOLD"], False), + (["T1w"], True), + (["T1w"], False), + (["VBM_CSF"], True), + (["VBM_CSF"], False), + (["VBM_GM"], True), + (["VBM_GM"], False), + (["VBM_WM"], True), + (["VBM_WM"], False), (["BOLD", "VBM_CSF"], True), (["BOLD", "VBM_CSF"], False), (["T1w", "VBM_CSF"], True), @@ -218,79 +212,29 @@ def test_DMCC13Benchmark( ], ) def test_DMCC13Benchmark_partial_data_access( - types: Union[str, list[str]], + types: list[str], native_t1w: bool, ) -> None: """Test DMCC13Benchmark DataGrabber partial data access. Parameters ---------- - types : str or list of str + types : list of str The parametrized types. native_t1w : bool The parametrized values for fetching native T1w. """ - dg = DMCC13Benchmark(types=types, native_t1w=native_t1w) - # Set URI to Gin - dg.uri = URI - + dg = DMCC13Benchmark( + uri=URI, + types=types, + native_t1w=native_t1w, + ) with dg: - # Get all elements all_elements = dg.get_elements() - # Get test element test_element = all_elements[0] - # Get test element's access values _, ses, task, phase, run = test_element - # Access data out = dg[("sub-01", ses, task, phase, run)] # Assert data type - if isinstance(types, list): - for type_ in types: - assert type_ in out - else: - assert types in out - - -def test_DMCC13Benchmark_incorrect_data_type() -> None: - """Test DMCC13Benchmark DataGrabber incorrect data type.""" - with pytest.raises( - ValueError, match="`patterns` must contain all `types`" - ): - _ = DMCC13Benchmark(types="Orcus") - - -def test_DMCC13Benchmark_invalid_sessions(): - """Test DMCC13Benchmark DataGrabber invalid sessions.""" - with pytest.raises( - ValueError, - match=("phonyses is not a valid session in the DMCC dataset"), - ): - DMCC13Benchmark(sessions="phonyses") - - -def test_DMCC13Benchmark_invalid_tasks(): - """Test DMCC13Benchmark DataGrabber invalid tasks.""" - with pytest.raises( - ValueError, - match=("thisisnotarealtask is not a valid task in the DMCC dataset"), - ): - DMCC13Benchmark(tasks="thisisnotarealtask") - - -def test_DMCC13Benchmark_phase_encodings(): - """Test DMCC13Benchmark DataGrabber invalid phase encodings.""" - with pytest.raises( - ValueError, - match=("moonphase is not a valid phase encoding in the DMCC dataset"), - ): - DMCC13Benchmark(phase_encodings="moonphase") - - -def test_DMCC13Benchmark_runs(): - """Test DMCC13Benchmark DataGrabber invalid runs.""" - with pytest.raises( - ValueError, - match=("cerebralrun is not a valid run in the DMCC dataset"), - ): - DMCC13Benchmark(runs="cerebralrun") + for type_ in types: + assert type_ in out diff --git a/junifer/datagrabber/tests/test_multiple.py b/junifer/datagrabber/tests/test_multiple.py index 78afd2889..56fa66927 100644 --- a/junifer/datagrabber/tests/test_multiple.py +++ b/junifer/datagrabber/tests/test_multiple.py @@ -3,7 +3,10 @@ # Authors: Federico Raimondo # License: AGPL +from pathlib import Path + import pytest +from pydantic import HttpUrl from junifer.datagrabber import MultipleDataGrabber, PatternDataladDataGrabber @@ -22,8 +25,8 @@ def test_MultipleDataGrabber() -> None: """Test MultipleDataGrabber.""" - repo_uri = _testing_dataset["example_bids_ses"]["uri"] - rootdir = "example_bids_ses" + repo_uri = HttpUrl(_testing_dataset["example_bids_ses"]["uri"]) + rootdir = Path("example_bids_ses") replacements = ["subject", "session"] dg1 = PatternDataladDataGrabber( @@ -93,7 +96,7 @@ def test_MultipleDataGrabber() -> None: replacements=replacements, ) - dg = MultipleDataGrabber([dg1, dg2]) + dg = MultipleDataGrabber(datagrabbers=[dg1, dg2]) types = dg.get_types() assert "T1w" in types @@ -129,12 +132,12 @@ def test_MultipleDataGrabber() -> None: def test_MultipleDataGrabber_no_intersection() -> None: """Test MultipleDataGrabber without intersection (0 elements).""" - rootdir = "example_bids_ses" + rootdir = Path("example_bids_ses") replacements = ["subject", "session"] dg1 = PatternDataladDataGrabber( rootdir=rootdir, - uri=_testing_dataset["example_bids"]["uri"], + uri=HttpUrl(_testing_dataset["example_bids"]["uri"]), types=["T1w", "Warp"], patterns={ "T1w": { @@ -171,7 +174,7 @@ def test_MultipleDataGrabber_no_intersection() -> None: dg2 = PatternDataladDataGrabber( rootdir=rootdir, - uri=_testing_dataset["example_bids_ses"]["uri"], + uri=HttpUrl(_testing_dataset["example_bids_ses"]["uri"]), types=["BOLD"], patterns={ "BOLD": { @@ -185,7 +188,7 @@ def test_MultipleDataGrabber_no_intersection() -> None: replacements=replacements, ) - dg = MultipleDataGrabber([dg1, dg2]) + dg = MultipleDataGrabber(datagrabbers=[dg1, dg2]) expected_subs = set() with dg: subs = list(dg) @@ -195,8 +198,8 @@ def test_MultipleDataGrabber_no_intersection() -> None: def test_MultipleDataGrabber_get_item() -> None: """Test MultipleDataGrabber get_item() error.""" dg1 = PatternDataladDataGrabber( - rootdir="example_bids_ses", - uri=_testing_dataset["example_bids"]["uri"], + rootdir=Path("example_bids_ses"), + uri=HttpUrl(_testing_dataset["example_bids"]["uri"]), types=["T1w"], patterns={ "T1w": { @@ -209,18 +212,18 @@ def test_MultipleDataGrabber_get_item() -> None: replacements=["subject", "session"], ) - dg = MultipleDataGrabber([dg1]) + dg = MultipleDataGrabber(datagrabbers=[dg1]) with pytest.raises(NotImplementedError): - dg.get_item(subject="sub-01") # type: ignore + dg.get_item(subject="sub-01") def test_MultipleDataGrabber_validation() -> None: """Test MultipleDataGrabber init validation.""" - rootdir = "example_bids_ses" + rootdir = Path("example_bids_ses") dg1 = PatternDataladDataGrabber( rootdir=rootdir, - uri=_testing_dataset["example_bids"]["uri"], + uri=HttpUrl(_testing_dataset["example_bids"]["uri"]), types=["T1w"], patterns={ "T1w": { @@ -235,7 +238,7 @@ def test_MultipleDataGrabber_validation() -> None: dg2 = PatternDataladDataGrabber( rootdir=rootdir, - uri=_testing_dataset["example_bids_ses"]["uri"], + uri=HttpUrl(_testing_dataset["example_bids_ses"]["uri"]), types=["BOLD"], patterns={ "BOLD": { @@ -247,16 +250,16 @@ def test_MultipleDataGrabber_validation() -> None: ) with pytest.raises(RuntimeError, match="have different element keys"): - MultipleDataGrabber([dg1, dg2]) + MultipleDataGrabber(datagrabbers=[dg1, dg2]) with pytest.raises(RuntimeError, match="have overlapping mandatory"): - MultipleDataGrabber([dg1, dg1]) + MultipleDataGrabber(datagrabbers=[dg1, dg1]) def test_MultipleDataGrabber_partial_pattern() -> None: """Test MultipleDataGrabber partial pattern.""" - repo_uri = _testing_dataset["example_bids_ses"]["uri"] - rootdir = "example_bids_ses" + repo_uri = HttpUrl(_testing_dataset["example_bids_ses"]["uri"]) + rootdir = Path("example_bids_ses") replacements = ["subject", "session"] dg1 = PatternDataladDataGrabber( @@ -295,7 +298,7 @@ def test_MultipleDataGrabber_partial_pattern() -> None: partial_pattern_ok=True, ) - dg = MultipleDataGrabber([dg1, dg2]) + dg = MultipleDataGrabber(datagrabbers=[dg1, dg2]) types = dg.get_types() assert "BOLD" in types diff --git a/junifer/datagrabber/tests/test_pattern.py b/junifer/datagrabber/tests/test_pattern.py index 5cbb135bf..801f8982b 100644 --- a/junifer/datagrabber/tests/test_pattern.py +++ b/junifer/datagrabber/tests/test_pattern.py @@ -117,7 +117,7 @@ def test_PatternDataGrabber(tmp_path: Path) -> None: """ datagrabber_first = PatternDataGrabber( - datadir="/tmp/data", + datadir=Path("/tmp/data"), types=["BOLD", "T1w"], patterns={ "BOLD": { @@ -129,7 +129,7 @@ def test_PatternDataGrabber(tmp_path: Path) -> None: "space": "native", }, }, - replacements="subject", + replacements=["subject"], ) assert datagrabber_first.datadir == Path("/tmp/data") assert set(datagrabber_first.types) == {"T1w", "BOLD"} @@ -279,22 +279,3 @@ def test_PatternDataGrabber_unix_path_expansion(tmp_path: Path) -> None: # Check paths are found assert set(out["FreeSurfer"].keys()) == {"path", "aseg", "meta"} assert list(out["FreeSurfer"]["aseg"].keys()) == ["path"] - - -def test_PatternDataGrabber_confounds_format_error_on_init() -> None: - """Test PatterDataGrabber confounds format error on initialisation.""" - with pytest.raises( - ValueError, match="Invalid value for `confounds_format`" - ): - PatternDataGrabber( - types=["BOLD"], - patterns={ - "BOLD": { - "pattern": "func/{subject}.nii", - "space": "MNI152NLin6Asym", - }, - }, - replacements=["subject"], - datadir="/tmp", - confounds_format="foobar", - ) diff --git a/junifer/datagrabber/tests/test_pattern_datalad.py b/junifer/datagrabber/tests/test_pattern_datalad.py index c4529b1d8..f7839c1ae 100644 --- a/junifer/datagrabber/tests/test_pattern_datalad.py +++ b/junifer/datagrabber/tests/test_pattern_datalad.py @@ -7,9 +7,9 @@ from pathlib import Path -import pytest +from pydantic import HttpUrl -from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datagrabber import DataType, PatternDataladDataGrabber _testing_dataset = { @@ -26,45 +26,25 @@ } -def test_bids_PatternDataladDataGrabber_missing_uri() -> None: - """Test check of missing URI in PatternDataladDataGrabber.""" - with pytest.raises(ValueError, match=r"`uri` must be provided"): - PatternDataladDataGrabber( - datadir=None, - types=[], - patterns={}, - replacements=[], - ) - - def test_bids_PatternDataladDataGrabber() -> None: """Test subject-based BIDS PatternDataladDataGrabber.""" - # Define types - types = ["T1w", "BOLD"] - # Define patterns - patterns = { - "T1w": { - "pattern": "{subject}/anat/{subject}_T1w.nii.gz", - "space": "MNI152NLin6Asym", - }, - "BOLD": { - "pattern": "{subject}/func/{subject}_task-rest_bold.nii.gz", - "space": "MNI152NLin6Asym", - }, - } - # Define replacements - replacements = ["subject"] - - repo_uri = _testing_dataset["example_bids"]["uri"] - rootdir = "example_bids" repo_commit = _testing_dataset["example_bids"]["commit"] - + repo_uri = _testing_dataset["example_bids"]["uri"] with PatternDataladDataGrabber( - rootdir=rootdir, - uri=repo_uri, - types=types, - patterns=patterns, - replacements=replacements, + uri=HttpUrl(repo_uri), + types=[DataType.T1w, DataType.BOLD], + patterns={ + "T1w": { + "pattern": "{subject}/anat/{subject}_T1w.nii.gz", + "space": "MNI152NLin6Asym", + }, + "BOLD": { + "pattern": "{subject}/func/{subject}_task-rest_bold.nii.gz", + "space": "MNI152NLin6Asym", + }, + }, + replacements=["subject"], + rootdir=Path("example_bids"), ) as dg: subs = list(dg) expected_subs = [f"sub-{i:02d}" for i in range(1, 10)] @@ -74,11 +54,11 @@ def test_bids_PatternDataladDataGrabber() -> None: t_sub = dg[elem] assert "path" in t_sub["T1w"] assert t_sub["T1w"]["path"] == ( - dg.datadir / f"{elem}/anat/{elem}_T1w.nii.gz" + dg.fulldir / f"{elem}/anat/{elem}_T1w.nii.gz" ) assert "path" in t_sub["BOLD"] assert t_sub["BOLD"]["path"] == ( - dg.datadir / f"{elem}/func/{elem}_task-rest_bold.nii.gz" + dg.fulldir / f"{elem}/func/{elem}_task-rest_bold.nii.gz" ) assert "meta" in t_sub["BOLD"] @@ -88,7 +68,7 @@ def test_bids_PatternDataladDataGrabber() -> None: assert "class" in dg_meta assert dg_meta["class"] == "PatternDataladDataGrabber" assert "uri" in dg_meta - assert dg_meta["uri"] == repo_uri + assert str(dg_meta["uri"]) == repo_uri assert "datalad_commit_id" in dg_meta assert dg_meta["datalad_commit_id"] == repo_commit @@ -98,80 +78,64 @@ def test_bids_PatternDataladDataGrabber() -> None: def test_bids_PatternDataladDataGrabber_datadir() -> None: """Test PatternDataladDataGrabber with a datadir set to a relative path.""" - # Define patterns - patterns = { - "T1w": { - "pattern": "{subject}/anat/{subject}_T*w.nii.gz", - "space": "MNI152NLin6Asym", - }, - "BOLD": { - "pattern": "{subject}/func/{subject}_task-rest_*.nii.gz", - "space": "MNI152NLin6Asym", - }, - } - # Define datadir - datadir = "dataset" # use string and not absolute path + datadir = Path("dataset") # use string and not absolute path with PatternDataladDataGrabber( - uri=_testing_dataset["example_bids"]["uri"], - types=["T1w", "BOLD"], - patterns=patterns, - datadir=datadir, - rootdir="example_bids", + uri=HttpUrl(_testing_dataset["example_bids"]["uri"]), + types=[DataType.T1w, DataType.BOLD], + patterns={ + "T1w": { + "pattern": "{subject}/anat/{subject}_T*w.nii.gz", + "space": "MNI152NLin6Asym", + }, + "BOLD": { + "pattern": "{subject}/func/{subject}_task-rest_*.nii.gz", + "space": "MNI152NLin6Asym", + }, + }, replacements=["subject"], + datadir=datadir, + rootdir=Path("example_bids"), ) as dg: - assert dg.datadir == Path(datadir) / "example_bids" + assert dg.fulldir == Path(datadir) / "example_bids" for elem in dg: t_sub = dg[elem] assert "path" in t_sub["T1w"] assert t_sub["T1w"]["path"] == ( - dg.datadir / f"{elem}/anat/{elem}_T1w.nii.gz" + dg.fulldir / f"{elem}/anat/{elem}_T1w.nii.gz" ) assert "path" in t_sub["BOLD"] assert t_sub["BOLD"]["path"] == ( - dg.datadir / f"{elem}/func/{elem}_task-rest_bold.nii.gz" + dg.fulldir / f"{elem}/func/{elem}_task-rest_bold.nii.gz" ) def test_bids_PatternDataladDataGrabber_session(): """Test a subject and session-based BIDS PatternDataladDataGrabber.""" - types = ["T1w", "BOLD"] - patterns = { - "T1w": { - "pattern": ( - "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz" - ), - "space": "MNI152NLin6Asym", - }, - "BOLD": { - "pattern": ( - "{subject}/{session}/func/" - "{subject}_{session}_task-rest_bold.nii.gz" - ), - "space": "MNI152NLin6Asym", - }, - } - replacements = ["subject", "session"] - - # Check error - with pytest.raises(ValueError, match=r"`uri` must be provided"): - PatternDataladDataGrabber( - datadir=None, - types=types, - patterns=patterns, - replacements=replacements, - ) - # Set parameters - repo_uri = _testing_dataset["example_bids_ses"]["uri"] - rootdir = "example_bids_ses" - + repo_uri = HttpUrl(_testing_dataset["example_bids_ses"]["uri"]) + rootdir = Path("example_bids_ses") + replacements = ["subject", "session"] # With T1W and bold, only 2 sessions are available with PatternDataladDataGrabber( - rootdir=rootdir, uri=repo_uri, - types=types, - patterns=patterns, + types=[DataType.T1w, DataType.BOLD], + patterns={ + "T1w": { + "pattern": ( + "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz" + ), + "space": "MNI152NLin6Asym", + }, + "BOLD": { + "pattern": ( + "{subject}/{session}/func/" + "{subject}_{session}_task-rest_bold.nii.gz" + ), + "space": "MNI152NLin6Asym", + }, + }, replacements=replacements, + rootdir=rootdir, ) as dg: subs = list(dg.get_elements()) expected_subs = [ @@ -182,21 +146,19 @@ def test_bids_PatternDataladDataGrabber_session(): assert set(subs) == set(expected_subs) # Test with a different T1w only, it should have 3 sessions - types = ["T1w"] - patterns = { - "T1w": { - "pattern": ( - "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz" - ), - "space": "MNI152NLin6Asym", - }, - } with PatternDataladDataGrabber( - rootdir=rootdir, uri=repo_uri, - types=types, - patterns=patterns, + types=[DataType.T1w], + patterns={ + "T1w": { + "pattern": ( + "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz" + ), + "space": "MNI152NLin6Asym", + }, + }, replacements=replacements, + rootdir=rootdir, ) as dg: subs = list(dg) expected_subs = [ From ebde81f5a46af229c4f8c49a30a04ac2d3ee8126 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:40:16 +0100 Subject: [PATCH 54/99] update: allow DataType enum to be extended --- junifer/datagrabber/pattern_validation_mixin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/junifer/datagrabber/pattern_validation_mixin.py b/junifer/datagrabber/pattern_validation_mixin.py index 77ba2d6e7..612347563 100644 --- a/junifer/datagrabber/pattern_validation_mixin.py +++ b/junifer/datagrabber/pattern_validation_mixin.py @@ -6,6 +6,9 @@ from collections.abc import Iterator, MutableMapping from typing import TypedDict +from aenum import extend_enum + +from ..datagrabber import DataType from ..typing import DataGrabberPatterns from ..utils import logger, raise_error, warn_with_log @@ -196,6 +199,7 @@ def register_data_type(name: str, schema: DataTypeSchema) -> None: """ DataTypeManager()[name] = schema + extend_enum(DataType, name, name) class PatternValidationMixin: From 0331743a2f30e4d4407b9a8a7a4576a53aec3234 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:41:17 +0100 Subject: [PATCH 55/99] chore: improve docstrings for PatternValidationMixin --- junifer/datagrabber/pattern_validation_mixin.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/junifer/datagrabber/pattern_validation_mixin.py b/junifer/datagrabber/pattern_validation_mixin.py index 612347563..f9ec97de6 100644 --- a/junifer/datagrabber/pattern_validation_mixin.py +++ b/junifer/datagrabber/pattern_validation_mixin.py @@ -194,7 +194,7 @@ def register_data_type(name: str, schema: DataTypeSchema) -> None: ---------- name : str The data type name. - schema : DataTypeSchema + schema : ``DataTypeSchema`` The data type schema. """ @@ -238,8 +238,8 @@ def _validate_replacements( ---------- replacements : list of str The replacements to validate. - patterns : dict - The patterns to validate replacements against. + patterns : ``DataGrabberPatterns`` + The patterns to validate ``replacements`` against. partial_pattern_ok : bool Whether to raise error if partial pattern for a data type is found. @@ -403,11 +403,11 @@ def validate_patterns( Parameters ---------- - types : list of str - The data types to check patterns of. + types : list of :enum:`.DataType` + The data type(s) to check patterns of. replacements : list of str - The replacements to be replaced in the patterns. - patterns : dict + The replacements to be replaced in the ``patterns``. + patterns : ``DataGrabberPatterns`` The patterns to validate. partial_pattern_ok : bool, optional Whether to raise error if partial pattern for a data type is found. From 27c479d50fb370e794e4d1d68ab72abbea6d35a5 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:41:51 +0100 Subject: [PATCH 56/99] chore: fix test for TemporalSNRBase --- junifer/markers/temporal_snr/tests/test_temporal_snr_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/junifer/markers/temporal_snr/tests/test_temporal_snr_base.py b/junifer/markers/temporal_snr/tests/test_temporal_snr_base.py index 396dd6300..ac9617ce6 100644 --- a/junifer/markers/temporal_snr/tests/test_temporal_snr_base.py +++ b/junifer/markers/temporal_snr/tests/test_temporal_snr_base.py @@ -6,10 +6,10 @@ import pytest # done to keep line length 79 -import junifer.markers.temporal_snr as tsnr +from junifer.markers.temporal_snr import temporal_snr_base as tsb def test_base_temporal_snr_marker_abstractness() -> None: """Test TemporalSNRBase is an abstract base class.""" with pytest.raises(TypeError, match="abstract"): - tsnr.temporal_snr_base.TemporalSNRBase() + tsb.TemporalSNRBase() From 686a5f2b9155a6e9b09ce18fcefab3e8b46e6006 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:42:11 +0100 Subject: [PATCH 57/99] chore: reorder import for pipeline --- junifer/pipeline/__init__.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/junifer/pipeline/__init__.pyi b/junifer/pipeline/__init__.pyi index aa973e18f..8265c2a01 100644 --- a/junifer/pipeline/__init__.pyi +++ b/junifer/pipeline/__init__.pyi @@ -4,11 +4,11 @@ __all__ = [ "BaseDataDumpAsset", "DataObjectDumper", "ExtDep", + "MarkerCollection", "PipelineComponentRegistry", "PipelineStepMixin", "UpdateMetaMixin", "WorkDirManager", - "MarkerCollection", ] from ._data_object_dumper import ( @@ -17,9 +17,9 @@ from ._data_object_dumper import ( BaseDataDumpAsset, DataObjectDumper, ) +from .marker_collection import MarkerCollection from .pipeline_component_registry import PipelineComponentRegistry from .pipeline_step_mixin import PipelineStepMixin from .update_meta_mixin import UpdateMetaMixin from .utils import ExtDep from .workdir_manager import WorkDirManager -from .marker_collection import MarkerCollection From 169a3d908047c30435fa90b47a86b86d612dce33 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:43:09 +0100 Subject: [PATCH 58/99] chore: lint --- junifer/markers/complexity/complexity_base.py | 1 - junifer/storage/hdf5.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/junifer/markers/complexity/complexity_base.py b/junifer/markers/complexity/complexity_base.py index 90a08ffe9..f1fbb528f 100644 --- a/junifer/markers/complexity/complexity_base.py +++ b/junifer/markers/complexity/complexity_base.py @@ -68,7 +68,6 @@ class ComplexityBase(BaseMarker): masks: Optional[list[Union[dict, str]]] = None on: list[Literal[DataType.BOLD]] = [DataType.BOLD] # noqa: RUF012 - @abstractmethod def compute_complexity( self, diff --git a/junifer/storage/hdf5.py b/junifer/storage/hdf5.py index dd2314d8a..70514c343 100644 --- a/junifer/storage/hdf5.py +++ b/junifer/storage/hdf5.py @@ -359,7 +359,7 @@ def read( if feature_md5: logger.debug( f"Validating feature MD5 '{feature_md5}' in metadata " - f"for: {self.uri.resolve()} ..." # type: ignore + f"for: {self.uri.resolve()} ..." ) # Validate MD5 if feature_md5 in metadata: @@ -1057,7 +1057,7 @@ def collect(self) -> None: # Run loop to collect data per feature per file logger.info( - f"Collecting data from {self.uri.parent}/*_{self.uri.name}" # type: ignore + f"Collecting data from {self.uri.parent}/*_{self.uri.name}" ) logger.info(f"Will collect {len(elements_per_feature_md5)} features.") From fe687e2b5c7dbf4fec1e6c5c1b919b0ee19bdbb6 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:43:24 +0100 Subject: [PATCH 59/99] update: adapt maps_datagrabber fixture --- junifer/conftest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/junifer/conftest.py b/junifer/conftest.py index 28e9dcec1..a7fcbdbdb 100644 --- a/junifer/conftest.py +++ b/junifer/conftest.py @@ -7,8 +7,9 @@ from pathlib import Path import pytest +from pydantic import HttpUrl -from junifer.datagrabber import PatternDataladDataGrabber +from junifer.datagrabber import DataType, PatternDataladDataGrabber from junifer.utils.singleton import Singleton @@ -40,8 +41,8 @@ def maps_datagrabber(tmp_path: Path) -> PatternDataladDataGrabber: """ dg = PatternDataladDataGrabber( - uri="https://github.com/OpenNeuroDatasets/ds005226.git", - types=["BOLD"], + uri=HttpUrl("https://github.com/OpenNeuroDatasets/ds005226.git"), + types=[DataType.BOLD], patterns={ "BOLD": { "pattern": ( From 44430cdf6789b8fd7c04bd242c86dc48f09307a3 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:44:08 +0100 Subject: [PATCH 60/99] update: remove redundant tests in PatternValidationMixin --- .../datagrabber/pattern_validation_mixin.py | 40 ------------------- 1 file changed, 40 deletions(-) diff --git a/junifer/datagrabber/pattern_validation_mixin.py b/junifer/datagrabber/pattern_validation_mixin.py index f9ec97de6..729574284 100644 --- a/junifer/datagrabber/pattern_validation_mixin.py +++ b/junifer/datagrabber/pattern_validation_mixin.py @@ -205,27 +205,6 @@ def register_data_type(name: str, schema: DataTypeSchema) -> None: class PatternValidationMixin: """Mixin class for pattern validation.""" - def _validate_types(self, types: list[str]) -> None: - """Validate the types. - - Parameters - ---------- - types : list of str - The data types to validate. - - Raises - ------ - TypeError - If ``types`` is not a list or if the values are not string. - - """ - if not isinstance(types, list): - raise_error(msg="`types` must be a list", klass=TypeError) - if any(not isinstance(x, str) for x in types): - raise_error( - msg="`types` must be a list of strings", klass=TypeError - ) - def _validate_replacements( self, replacements: list[str], @@ -245,8 +224,6 @@ def _validate_replacements( Raises ------ - TypeError - If ``replacements`` is not a list or if the values are not string. ValueError If a value in ``replacements`` is not part of a data type pattern and ``partial_pattern_ok=False`` or @@ -260,15 +237,6 @@ def _validate_replacements( and ``partial_pattern_ok=True``. """ - if not isinstance(replacements, list): - raise_error(msg="`replacements` must be a list.", klass=TypeError) - - if any(not isinstance(x, str) for x in replacements): - raise_error( - msg="`replacements` must be a list of strings", - klass=TypeError, - ) - # Make a list of all patterns recursively all_patterns = [] for dtype_val in patterns.values(): @@ -416,8 +384,6 @@ def validate_patterns( Raises ------ - TypeError - If ``patterns`` is not a dictionary. ValueError If length of ``types`` and ``patterns`` are different or if ``patterns`` is missing entries from ``types`` or @@ -425,12 +391,6 @@ def validate_patterns( if data type pattern key contains '*' as value. """ - # Validate types - self._validate_types(types=types) - - # Validate patterns - if not isinstance(patterns, dict): - raise_error(msg="`patterns` must be a dict", klass=TypeError) # Unequal length of objects if len(types) > len(patterns): raise_error( From e005afdbecff9e90a75c38fb8bbd63c288295331 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:44:51 +0100 Subject: [PATCH 61/99] chore: improve tests for PatternValidationMixin --- .../tests/test_pattern_validation_mixin.py | 44 ------------------- 1 file changed, 44 deletions(-) diff --git a/junifer/datagrabber/tests/test_pattern_validation_mixin.py b/junifer/datagrabber/tests/test_pattern_validation_mixin.py index 83da5f5a5..94d6b06a2 100644 --- a/junifer/datagrabber/tests/test_pattern_validation_mixin.py +++ b/junifer/datagrabber/tests/test_pattern_validation_mixin.py @@ -118,26 +118,6 @@ def test_register_data_type() -> None: @pytest.mark.parametrize( "types, replacements, patterns, expect", [ - ( - "wrong", - [], - {}, - pytest.raises(TypeError, match="`types` must be a list"), - ), - ( - [1], - [], - {}, - pytest.raises( - TypeError, match="`types` must be a list of strings" - ), - ), - ( - ["BOLD"], - [], - "wrong", - pytest.raises(TypeError, match="`patterns` must be a dict"), - ), ( ["T1w", "BOLD"], "", @@ -205,30 +185,6 @@ def test_register_data_type() -> None: }, pytest.raises(ValueError, match="following a replacement"), ), - ( - ["T1w"], - "wrong", - { - "T1w": { - "pattern": "{subject}/anat/{subject}_T1w.nii", - "space": "native", - }, - }, - pytest.raises(TypeError, match="`replacements` must be a list"), - ), - ( - ["T1w"], - [1], - { - "T1w": { - "pattern": "{subject}/anat/{subject}_T1w.nii", - "space": "native", - }, - }, - pytest.raises( - TypeError, match="`replacements` must be a list of strings" - ), - ), ( ["T1w", "BOLD"], ["subject", "session"], From 79372be0f5e851df0336cbbf6fcd5f65aa0d5497 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:46:36 +0100 Subject: [PATCH 62/99] update: improve data type registration test --- junifer/datagrabber/tests/test_pattern_validation_mixin.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/junifer/datagrabber/tests/test_pattern_validation_mixin.py b/junifer/datagrabber/tests/test_pattern_validation_mixin.py index 94d6b06a2..e78533faf 100644 --- a/junifer/datagrabber/tests/test_pattern_validation_mixin.py +++ b/junifer/datagrabber/tests/test_pattern_validation_mixin.py @@ -9,7 +9,8 @@ import pytest -from junifer.datagrabber.pattern_validation_mixin import ( +from junifer.datagrabber import ( + DataType, DataTypeManager, DataTypeSchema, PatternValidationMixin, @@ -111,8 +112,10 @@ def test_register_data_type() -> None: ) assert "dtype" in DataTypeManager() + assert "dtype" in list(DataType) _ = DataTypeManager().pop("dtype") - assert "dumb" not in DataTypeManager() + assert "dtype" not in DataTypeManager() + assert "dtype" in list(DataType) @pytest.mark.parametrize( From ada119f76419405ebac1d1ae973bf24b2f7b6d58 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:47:13 +0100 Subject: [PATCH 63/99] chore: improve type annotation and docstring for PatternValidationMixin and test --- junifer/datagrabber/pattern_validation_mixin.py | 2 +- junifer/datagrabber/tests/test_pattern_validation_mixin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/junifer/datagrabber/pattern_validation_mixin.py b/junifer/datagrabber/pattern_validation_mixin.py index 729574284..4a108d946 100644 --- a/junifer/datagrabber/pattern_validation_mixin.py +++ b/junifer/datagrabber/pattern_validation_mixin.py @@ -362,7 +362,7 @@ def _identify_stray_keys( def validate_patterns( self, - types: list[str], + types: list[DataType], replacements: list[str], patterns: DataGrabberPatterns, partial_pattern_ok: bool = False, diff --git a/junifer/datagrabber/tests/test_pattern_validation_mixin.py b/junifer/datagrabber/tests/test_pattern_validation_mixin.py index e78533faf..ddba6d580 100644 --- a/junifer/datagrabber/tests/test_pattern_validation_mixin.py +++ b/junifer/datagrabber/tests/test_pattern_validation_mixin.py @@ -81,7 +81,7 @@ def test_dtype_mgr(dtype: DataTypeSchema) -> None: Parameters ---------- - dtype : DataTypeSchema + dtype : ``DataTypeSchema`` The parametrized schema. """ From a6f703ec58447c6daf85f27ac275c732dbfc9565 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:49:59 +0100 Subject: [PATCH 64/99] chore: adapt juseless datagrabbers --- .../juseless/datagrabbers/aomic_id1000_vbm.py | 42 ++-- .../juseless/datagrabbers/camcan_vbm.py | 47 ++-- .../configs/juseless/datagrabbers/ixi_vbm.py | 84 +++----- junifer/configs/juseless/datagrabbers/ucla.py | 201 +++++++++--------- .../configs/juseless/datagrabbers/ukb_vbm.py | 44 ++-- 5 files changed, 185 insertions(+), 233 deletions(-) diff --git a/junifer/configs/juseless/datagrabbers/aomic_id1000_vbm.py b/junifer/configs/juseless/datagrabbers/aomic_id1000_vbm.py index 3bf1ad098..2d00f8683 100644 --- a/junifer/configs/juseless/datagrabbers/aomic_id1000_vbm.py +++ b/junifer/configs/juseless/datagrabbers/aomic_id1000_vbm.py @@ -4,11 +4,13 @@ # Synchon Mandal # License: AGPL -from pathlib import Path -from typing import Union +from typing import Literal + +from pydantic import HttpUrl from ....api.decorators import register_datagrabber -from ....datagrabber import PatternDataladDataGrabber +from ....datagrabber import DataType, PatternDataladDataGrabber +from ....typing import DataGrabberPatterns __all__ = ["JuselessDataladAOMICID1000VBM"] @@ -22,27 +24,19 @@ class JuselessDataladAOMICID1000VBM(PatternDataladDataGrabber): Parameters ---------- - datadir : str or pathlib.Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. """ - def __init__(self, datadir: Union[str, Path, None] = None) -> None: - uri = "https://gin.g-node.org/felixh/ds003097_ReproVBM" - types = ["VBM_GM"] - replacements = ["subject"] - patterns = { - "VBM_GM": { - "pattern": ("{subject}/mri/mwp1{subject}_run-2_T1w.nii.gz"), - "space": "IXI549Space", - }, - } - super().__init__( - types=types, - datadir=datadir, - uri=uri, - replacements=replacements, - patterns=patterns, - ) + uri: HttpUrl = HttpUrl("https://gin.g-node.org/felixh/ds003097_ReproVBM") + types: list[Literal[DataType.VBM_GM]] = [DataType.VBM_GM] # noqa: RUF012 + patterns: DataGrabberPatterns = { # noqa: RUF012 + "VBM_GM": { + "pattern": ("{subject}/mri/mwp1{subject}_run-2_T1w.nii.gz"), + "space": "IXI549Space", + }, + } + replacements: list[str] = ["subject"] # noqa: RUF012 diff --git a/junifer/configs/juseless/datagrabbers/camcan_vbm.py b/junifer/configs/juseless/datagrabbers/camcan_vbm.py index 782574b38..10c62277c 100644 --- a/junifer/configs/juseless/datagrabbers/camcan_vbm.py +++ b/junifer/configs/juseless/datagrabbers/camcan_vbm.py @@ -5,11 +5,13 @@ # Synchon Mandal # License: AGPL -from pathlib import Path -from typing import Union +from typing import Literal + +from pydantic import HttpUrl from ....api.decorators import register_datagrabber -from ....datagrabber import PatternDataladDataGrabber +from ....datagrabber import DataType, PatternDataladDataGrabber +from ....typing import DataGrabberPatterns __all__ = ["JuselessDataladCamCANVBM"] @@ -23,30 +25,21 @@ class JuselessDataladCamCANVBM(PatternDataladDataGrabber): Parameters ---------- - datadir : str or pathlib.Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. """ - def __init__(self, datadir: Union[str, Path, None] = None) -> None: - uri = ( - "ria+http://cat_12.5.ds.inm7.de" - "#a139b26a-8406-11ea-8f94-a0369f287950" - ) - types = ["VBM_GM"] - replacements = ["subject"] - patterns = { - "VBM_GM": { - "pattern": "{subject}/mri/m0wp1{subject}.nii.gz", - "space": "IXI549Space", - }, - } - super().__init__( - types=types, - datadir=datadir, - uri=uri, - replacements=replacements, - patterns=patterns, - ) + uri: HttpUrl = ( + "ria+http://cat_12.5.ds.inm7.de#a139b26a-8406-11ea-8f94-a0369f287950" + ) + types: list[Literal[DataType.VBM_GM]] = [DataType.VBM_GM] # noqa: RUF012 + patterns: DataGrabberPatterns = { # noqa: RUF012 + "VBM_GM": { + "pattern": "{subject}/mri/m0wp1{subject}.nii.gz", + "space": "IXI549Space", + }, + } + replacements: list[str] = ["subject"] # noqa: RUF012 diff --git a/junifer/configs/juseless/datagrabbers/ixi_vbm.py b/junifer/configs/juseless/datagrabbers/ixi_vbm.py index 713cde1fd..3aa63bf98 100644 --- a/junifer/configs/juseless/datagrabbers/ixi_vbm.py +++ b/junifer/configs/juseless/datagrabbers/ixi_vbm.py @@ -5,17 +5,27 @@ # Synchon Mandal # License: AGPL -from pathlib import Path -from typing import Union +from enum import Enum +from typing import Literal + +from pydantic import HttpUrl from ....api.decorators import register_datagrabber -from ....datagrabber import PatternDataladDataGrabber -from ....utils import raise_error +from ....datagrabber import DataType, PatternDataladDataGrabber +from ....typing import DataGrabberPatterns __all__ = ["JuselessDataladIXIVBM"] +class IXISite(str, Enum): + """Accepted IXI sites.""" + + Guys = "Guys" + HH = "HH" + IOP = "IOP" + + @register_datagrabber class JuselessDataladIXIVBM(PatternDataladDataGrabber): """Concrete implementation for Juseless IXI VBM data fetching. @@ -24,53 +34,25 @@ class JuselessDataladIXIVBM(PatternDataladDataGrabber): Parameters ---------- - datadir : str or pathlib.Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). - sites : {"Guys", "HH", "IOP"} or list of the options or None, optional - Which sites to access data from. If None, all available sites are - selected (default None). + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. + sites : list of :obj:`IXISite`, optional + IXI sites. + By default, all available sites are selected. """ - def __init__( - self, - datadir: Union[str, Path, None] = None, - sites: Union[str, list[str], None] = None, - ) -> None: - uri = ( - "ria+http://cat_12.5.ds.inm7.de" - "#b7107c52-8408-11ea-89c6-a0369f287950" - ) - types = ["VBM_GM"] - replacements = ["site", "subject"] - patterns = { - "VBM_GM": { - "pattern": ("{site}/{subject}/mri/m0wp1{subject}.nii.gz"), - "space": "IXI549Space", - }, - } - - # validate and/or transform 'site' input - all_sites = ["HH", "Guys", "IOP"] - if sites is None: - sites = all_sites - - if isinstance(sites, str): - sites = [sites] - - for s in sites: - if s not in all_sites: - raise_error( - f"{s} not a valid site in IXI VBM dataset!" - f"Available sites are {all_sites}" - ) - self.sites = sites - super().__init__( - types=types, - datadir=datadir, - uri=uri, - replacements=replacements, - patterns=patterns, - ) + uri: HttpUrl = ( + "ria+http://cat_12.5.ds.inm7.de#b7107c52-8408-11ea-89c6-a0369f287950" + ) + types: list[Literal[DataType.VBM_GM]] = [DataType.VBM_GM] # noqa: RUF012 + sites: list[IXISite] = [IXISite.Guys, IXISite.HH, IXISite.IOP] # noqa: RUF012 + patterns: DataGrabberPatterns = { # noqa: RUF012 + "VBM_GM": { + "pattern": ("{site}/{subject}/mri/m0wp1{subject}.nii.gz"), + "space": "IXI549Space", + }, + } + replacements: list[str] = ["site", "subject"] # noqa: RUF012 diff --git a/junifer/configs/juseless/datagrabbers/ucla.py b/junifer/configs/juseless/datagrabbers/ucla.py index abf7ccb99..886beacc6 100644 --- a/junifer/configs/juseless/datagrabbers/ucla.py +++ b/junifer/configs/juseless/datagrabbers/ucla.py @@ -4,17 +4,31 @@ # Leonard Sasse # License: AGPL +from enum import Enum from pathlib import Path -from typing import Union +from typing import Literal from ....api.decorators import register_datagrabber -from ....datagrabber import PatternDataGrabber -from ....utils import raise_error +from ....datagrabber import ConfoundsFormat, DataType, PatternDataGrabber +from ....typing import DataGrabberPatterns __all__ = ["JuselessUCLA"] +class UCLATask(str, Enum): + """Accepted UCLA tasks.""" + + REST = "rest" + BART = "bart" + BHT = "bht" + PAMENC = "pamenc" + PAMRET = "pamret" + SCAP = "scap" + TASKSWITCH = "taskswitch" + STOPSIGNAL = "stopsignal" + + @register_datagrabber class JuselessUCLA(PatternDataGrabber): """Concrete implementation for Juseless UCLA data fetching. @@ -23,118 +37,93 @@ class JuselessUCLA(PatternDataGrabber): Parameters ---------- - datadir : str or Path, optional + datadir : Path, optional The directory where the dataset is stored. (default "/data/project/psychosis_thalamus/data/fmriprep"). - types: {"BOLD", "T1w", "VBM_CSF", "VBM_GM", "VBM_WM"} or \ - list of the options, optional - UCLA data types. If None, all available data types are selected. - (default None). - tasks : {"rest", "bart", "bht", "pamenc", "pamret", \ - "scap", "taskswitch", "stopsignal"} or \ - list of the options or None, optional - UCLA task sessions. If None, all available task sessions are - selected (default None). + types: list of {``DataType.BOLD``, ``DataType.T1w``, \ + ``DataType.VBM_CSF``, ``DataType.VBM_GM``, ``DataType.VBM_WM``}, \ + optional + The data type(s) to grab. + tasks : list of ``UCLATask``, optional + UCLA task sessions. + By default, all available task are selected. """ - def __init__( - self, - datadir: Union[ - str, Path - ] = "/data/project/psychosis_thalamus/data/fmriprep", - types: Union[str, list[str], None] = None, - tasks: Union[str, list[str], None] = None, - ) -> None: - # Declare all tasks - all_tasks = [ - "rest", - "bart", - "bht", - "pamenc", - "pamret", - "scap", - "taskswitch", - "stopsignal", + # the commented out uri leads to new open neuro dataset which does + # NOT have preprocessed data + # uri = "https://github.com/OpenNeuroDatasets/ds000030.git" + datadir: Path = Path("/data/project/psychosis_thalamus/data/fmriprep") + types: list[ + Literal[ + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, ] - # Set default tasks - if tasks is None: - tasks = all_tasks - else: - # Convert single task into list - if isinstance(tasks, str): - tasks = [tasks] - # Verify valid tasks - for t in tasks: - if t not in all_tasks: - raise_error( - f"{t} is not a valid task in the UCLA dataset!" - ) - self.tasks = tasks - # The patterns - patterns = { - "BOLD": { - "pattern": ( - "{subject}/func/{subject}_task-{task}_bold_space-" - "MNI152NLin2009cAsym_preproc.nii.gz" - ), - "space": "MNI152NLin2009cAsym", - "confounds": { - "pattern": ( - "{subject}/func/{subject}_" - "task-{task}_bold_confounds.tsv" - ), - "space": "fmriprep", - }, - }, - "T1w": { - "pattern": ( - "{subject}/anat/{subject}_" - "T1w_space-MNI152NLin2009cAsym_preproc.nii.gz" - ), - "space": "MNI152NLin2009cAsym", - }, - "VBM_CSF": { - "pattern": ( - "{subject}/anat/{subject}_T1w_space-" - "MNI152NLin2009cAsym_class-CSF_probtissue.nii.gz" - ), - "space": "MNI152NLin2009cAsym", - }, - "VBM_GM": { - "pattern": ( - "{subject}/anat/{subject}_T1w_space-" - "MNI152NLin2009cAsym_class-GM_probtissue.nii.gz" - ), - "space": "MNI152NLin2009cAsym", - }, - "VBM_WM": { + ] = [ # noqa: RUF012 + DataType.BOLD, + DataType.T1w, + DataType.VBM_CSF, + DataType.VBM_GM, + DataType.VBM_WM, + ] + tasks: list[UCLATask] = [ # noqa: RUF012 + UCLATask.REST, + UCLATask.BART, + UCLATask.BHT, + UCLATask.PAMENC, + UCLATask.PAMRET, + UCLATask.SCAP, + UCLATask.TASKSWITCH, + UCLATask.STOPSIGNAL, + ] + patterns: DataGrabberPatterns = { # noqa: RUF012 + "BOLD": { + "pattern": ( + "{subject}/func/{subject}_task-{task}_bold_space-" + "MNI152NLin2009cAsym_preproc.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + "confounds": { "pattern": ( - "{subject}/anat/{subject}_T1w_space" - "-MNI152NLin2009cAsym_class-WM_probtissue.nii.gz" + "{subject}/func/{subject}_task-{task}_bold_confounds.tsv" ), - "space": "MNI152NLin2009cAsym", + "space": "fmriprep", }, - } - # Set default types - if types is None: - types = list(patterns.keys()) - # Convert single type into list - else: - if not isinstance(types, list): - types = [types] - # The replacements - replacements = ["subject", "task"] - # the commented out uri leads to new open neuro dataset which does - # NOT have preprocessed data - # uri = "https://github.com/OpenNeuroDatasets/ds000030.git" - super().__init__( - types=types, - datadir=datadir, - patterns=patterns, - replacements=replacements, - confounds_format="fmriprep", - ) + }, + "T1w": { + "pattern": ( + "{subject}/anat/{subject}_" + "T1w_space-MNI152NLin2009cAsym_preproc.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + }, + "VBM_CSF": { + "pattern": ( + "{subject}/anat/{subject}_T1w_space-" + "MNI152NLin2009cAsym_class-CSF_probtissue.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + }, + "VBM_GM": { + "pattern": ( + "{subject}/anat/{subject}_T1w_space-" + "MNI152NLin2009cAsym_class-GM_probtissue.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + }, + "VBM_WM": { + "pattern": ( + "{subject}/anat/{subject}_T1w_space" + "-MNI152NLin2009cAsym_class-WM_probtissue.nii.gz" + ), + "space": "MNI152NLin2009cAsym", + }, + } + replacements: list[str] = ["subject", "task"] # noqa: RUF012 + confounds_format: ConfoundsFormat = ConfoundsFormat.FMRIPrep def get_elements(self) -> list: """Implement fetching list of elements in the dataset. diff --git a/junifer/configs/juseless/datagrabbers/ukb_vbm.py b/junifer/configs/juseless/datagrabbers/ukb_vbm.py index cea8ac655..9dee937b3 100644 --- a/junifer/configs/juseless/datagrabbers/ukb_vbm.py +++ b/junifer/configs/juseless/datagrabbers/ukb_vbm.py @@ -6,10 +6,13 @@ # License: AGPL from pathlib import Path -from typing import Union +from typing import Literal + +from pydantic import HttpUrl from ....api.decorators import register_datagrabber -from ....datagrabber import PatternDataladDataGrabber +from ....datagrabber import DataType, PatternDataladDataGrabber +from ....typing import DataGrabberPatterns __all__ = ["JuselessDataladUKBVBM"] @@ -23,29 +26,20 @@ class JuselessDataladUKBVBM(PatternDataladDataGrabber): Parameters ---------- - datadir : str or pathlib.Path or None, optional - The directory where the datalad dataset will be cloned. If None, - the datalad dataset will be cloned into a temporary directory - (default None). + datadir : pathlib.Path, optional + That path where the datalad dataset will be cloned. + If not specified, the datalad dataset will be cloned into a temporary + directory. """ - def __init__(self, datadir: Union[str, Path, None] = None) -> None: - uri = "ria+http://ukb.ds.inm7.de#~cat_m0wp1" - rootdir = "m0wp1" - types = ["VBM_GM"] - replacements = ["subject", "session"] - patterns = { - "VBM_GM": { - "pattern": "m0wp1{subject}_ses-{session}_T1w.nii.gz", - "space": "IXI549Space", - }, - } - super().__init__( - types=types, - datadir=datadir, - uri=uri, - rootdir=rootdir, - replacements=replacements, - patterns=patterns, - ) + uri: HttpUrl = "ria+http://ukb.ds.inm7.de#~cat_m0wp1" + rootdir: Path = Path("m0wp1") + types: list[Literal[DataType.VBM_GM]] = [DataType.VBM_GM] # noqa: RUF012 + patterns: DataGrabberPatterns = { # noqa: RUF012 + "VBM_GM": { + "pattern": "m0wp1{subject}_ses-{session}_T1w.nii.gz", + "space": "IXI549Space", + }, + } + replacements: list[str] = ["subject", "session"] # noqa: RUF012 From 442b387d37b78454f4911b9ec737fbb9ad8f74c3 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:52:29 +0100 Subject: [PATCH 65/99] chore: adapt validate_input for BaseFeatureStorage test --- junifer/storage/tests/test_storage_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/junifer/storage/tests/test_storage_base.py b/junifer/storage/tests/test_storage_base.py index f30bdc180..bdf4a6f5c 100644 --- a/junifer/storage/tests/test_storage_base.py +++ b/junifer/storage/tests/test_storage_base.py @@ -62,10 +62,10 @@ def collect(self): assert st.single_output is True # Check validate with valid argument - st.validate(input_=["matrix"]) + st.validate_input(input_=["matrix"]) # Check validate with invalid argument with pytest.raises(ValueError): - st.validate(input_=["duck"]) + st.validate_input(input_=["duck"]) with pytest.raises(NotImplementedError): st.list_features() From 8e75f534617af5f3e2ec3f6706a0afe0e6ed2a22 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:52:56 +0100 Subject: [PATCH 66/99] chore: adapt validate_component for DefaultDataReader test --- junifer/datareader/tests/test_default_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/datareader/tests/test_default_reader.py b/junifer/datareader/tests/test_default_reader.py index 663a89bb4..838dc25e1 100644 --- a/junifer/datareader/tests/test_default_reader.py +++ b/junifer/datareader/tests/test_default_reader.py @@ -30,7 +30,7 @@ def test_DefaultDataReader_validation(type_) -> None: """ reader = DefaultDataReader() assert reader.validate_input(type_) == type_ - assert reader.validate(type_) == type_ + assert reader.validate_component(type_) == type_ def test_DefaultDataReader_meta() -> None: From 724f7ba484605fa2e6f6de7d5b1a1b42a098cbcf Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:54:19 +0100 Subject: [PATCH 67/99] update: adapt testing datagrabbers --- junifer/testing/datagrabbers.py | 68 ++++++++++++++++----------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/junifer/testing/datagrabbers.py b/junifer/testing/datagrabbers.py index 63703a9ab..4c45e7ca7 100644 --- a/junifer/testing/datagrabbers.py +++ b/junifer/testing/datagrabbers.py @@ -5,12 +5,14 @@ # License: AGPL import tempfile +from enum import Enum from pathlib import Path +from typing import Any import nibabel as nib from nilearn import datasets, image -from ..datagrabber.base import BaseDataGrabber +from ..datagrabber import BaseDataGrabber, DataType __all__ = [ @@ -27,12 +29,9 @@ class OasisVBMTestingDataGrabber(BaseDataGrabber): """ - def __init__(self) -> None: - # Create temporary directory - datadir = tempfile.mkdtemp() - # Define types - types = ["VBM_GM"] - super().__init__(types=types, datadir=datadir) + types: list[DataType] = [DataType.VBM_GM] # noqa: RUF012 + datadir: Path = Path(tempfile.mkdtemp()) + _dataset: Any = None def get_element_keys(self) -> list[str]: """Get element keys. @@ -98,12 +97,8 @@ class SPMAuditoryTestingDataGrabber(BaseDataGrabber): """ - def __init__(self) -> None: - # Create temporary directory - datadir = tempfile.mkdtemp() - # Define types - types = ["BOLD", "T1w"] # TODO: Check that they are T1w - super().__init__(types=types, datadir=datadir) + types: list[DataType] = [DataType.BOLD, DataType.T1w] # noqa: RUF012 + datadir: Path = Path(tempfile.mkdtemp()) def get_element_keys(self) -> list[str]: """Get element keys. @@ -143,8 +138,8 @@ def get_item(self, subject: str) -> dict[str, dict]: """ out = {} nilearn_data = datasets.fetch_spm_auditory(subject_id=subject) - fmri_img = image.concat_imgs(nilearn_data.func) # type: ignore - anat_img = image.concat_imgs(nilearn_data.anat) # type: ignore + fmri_img = image.concat_imgs(nilearn_data.func) + anat_img = image.concat_imgs(nilearn_data.anat) fmri_fname = self.datadir / f"{subject}_bold.nii.gz" anat_fname = self.datadir / f"{subject}_T1w.nii.gz" @@ -155,6 +150,20 @@ def get_item(self, subject: str) -> dict[str, dict]: return out +class PartlyCloudyAgeGroup(str, Enum): + """Age group to fetch. + + * ``Adult`` : fetch adults only (n=33, ages 18-39) + * ``Child`` : fetch children only (n=122, ages 3-12) + * ``Both`` : fetch full sample (n=155) + + """ + + Adult = "adult" + Child = "child" + Both = "both" + + class PartlyCloudyTestingDataGrabber(BaseDataGrabber): """DataGrabber for Partly Cloudy dataset. @@ -169,28 +178,15 @@ class PartlyCloudyTestingDataGrabber(BaseDataGrabber): purpose of having realistic examples. Depending on your research question, other confounds might be more appropriate. If False, returns all :term:`fMRIPrep` confounds (default True). - - age_group : {"adult", "child", "both"}, optional - Which age group to fetch: - - * ``adult`` : fetch adults only (n=33, ages 18-39) - * ``child`` : fetch children only (n=122, ages 3-12) - * ``both`` : fetch full sample (n=155) - - (default "both") + age_group : `PartlyCloudyAgeGroup`, optional + Age group to fetch (default PartlyCloudyAgeGroup.Both). """ - def __init__( - self, reduce_confounds: bool = True, age_group: str = "both" - ) -> None: - """Initialize the class.""" - datadir = tempfile.mkdtemp() - # Define types - types = ["BOLD"] - self.reduce_confounds = reduce_confounds - self.age_group = age_group - super().__init__(types=types, datadir=datadir) + types: list[DataType] = [DataType.BOLD] # noqa: RUF012 + datadir: Path = Path(tempfile.mkdtemp()) + reduce_confounds: bool = True + age_group: PartlyCloudyAgeGroup = PartlyCloudyAgeGroup.Both def __enter__(self) -> "PartlyCloudyTestingDataGrabber": """Implement context entry. @@ -203,7 +199,9 @@ def __enter__(self) -> "PartlyCloudyTestingDataGrabber": self._dataset = datasets.fetch_development_fmri( n_subjects=10, reduce_confounds=self.reduce_confounds, - age_group=self.age_group, + age_group=self.age_group.value + if isinstance(self.age_group, Enum) + else self.age_group, ) return self From 3a8944d2ef9dbf11cec2087d8b74e9210356a71f Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 10:54:35 +0100 Subject: [PATCH 68/99] chore: update tests for testing datagrabbers --- .../testing/tests/test_testing_registry.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/junifer/testing/tests/test_testing_registry.py b/junifer/testing/tests/test_testing_registry.py index 1760dd149..cf04faa88 100644 --- a/junifer/testing/tests/test_testing_registry.py +++ b/junifer/testing/tests/test_testing_registry.py @@ -5,12 +5,35 @@ # License: AGPL from junifer.pipeline import PipelineComponentRegistry +from junifer.testing.datagrabbers import ( + OasisVBMTestingDataGrabber, + PartlyCloudyTestingDataGrabber, + SPMAuditoryTestingDataGrabber, +) def test_testing_registry() -> None: """Test testing registry.""" + for dg in [ + OasisVBMTestingDataGrabber, + SPMAuditoryTestingDataGrabber, + PartlyCloudyTestingDataGrabber, + ]: + PipelineComponentRegistry().register( + step="datagrabber", + klass=dg, + ) assert { "OasisVBMTestingDataGrabber", "SPMAuditoryTestingDataGrabber", "PartlyCloudyTestingDataGrabber", }.issubset(set(PipelineComponentRegistry().step_components("datagrabber"))) + for dg in [ + OasisVBMTestingDataGrabber, + SPMAuditoryTestingDataGrabber, + PartlyCloudyTestingDataGrabber, + ]: + PipelineComponentRegistry().deregister( + step="datagrabber", + klass=dg, + ) From 7fabbf0051cbfc6b31c5af22b0bcbc1711e066d7 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 11:09:19 +0100 Subject: [PATCH 69/99] update: improve tests for SQLiteFeatureStorage --- junifer/storage/tests/test_sqlite.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/junifer/storage/tests/test_sqlite.py b/junifer/storage/tests/test_sqlite.py index ce205c08c..6500a709f 100644 --- a/junifer/storage/tests/test_sqlite.py +++ b/junifer/storage/tests/test_sqlite.py @@ -14,7 +14,7 @@ from pandas.testing import assert_frame_equal from sqlalchemy import create_engine -from junifer.storage.sqlite import SQLiteFeatureStorage +from junifer.storage import MatrixKind, SQLiteFeatureStorage from junifer.storage.utils import element_to_prefix, process_meta @@ -503,16 +503,6 @@ def test_store_matrix(tmp_path: Path) -> None: read_df = storage.read_df(feature_md5=feature_md5) assert list(read_df.columns) == stored_names - with pytest.raises(ValueError, match="Invalid kind"): - storage.store_matrix( - meta_md5=meta_md5, - element=element_to_store, - data=data, - row_names=row_names, - col_names=col_names, - matrix_kind="wrong", - ) - with pytest.raises(ValueError, match="non-square"): storage.store_matrix( meta_md5=meta_md5, @@ -520,7 +510,7 @@ def test_store_matrix(tmp_path: Path) -> None: data=data, row_names=row_names, col_names=col_names, - matrix_kind="triu", + matrix_kind=MatrixKind.UpperTriangle, ) with pytest.raises(ValueError, match="cannot be False"): @@ -530,7 +520,7 @@ def test_store_matrix(tmp_path: Path) -> None: data=data, row_names=row_names, col_names=col_names, - matrix_kind="full", + matrix_kind=MatrixKind.Full, diagonal=False, ) @@ -550,7 +540,7 @@ def test_store_matrix(tmp_path: Path) -> None: data=data, row_names=row_names, col_names=col_names, - matrix_kind="triu", + matrix_kind=MatrixKind.UpperTriangle, ) stored_names = [ @@ -584,7 +574,7 @@ def test_store_matrix(tmp_path: Path) -> None: data=data, row_names=row_names, col_names=col_names, - matrix_kind="triu", + matrix_kind=MatrixKind.UpperTriangle, diagonal=False, ) @@ -619,7 +609,7 @@ def test_store_matrix(tmp_path: Path) -> None: data=data, row_names=row_names, col_names=col_names, - matrix_kind="tril", + matrix_kind=MatrixKind.LowerTriangle, ) stored_names = [ @@ -653,7 +643,7 @@ def test_store_matrix(tmp_path: Path) -> None: data=data, row_names=row_names, col_names=col_names, - matrix_kind="tril", + matrix_kind=MatrixKind.LowerTriangle, diagonal=False, ) stored_names = [ From e29162411a83b8d2e4f8a8cf5dcfe0d2670e55aa Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 11:31:12 +0100 Subject: [PATCH 70/99] update: move MatrixKind and fix tests for storage.utils --- junifer/storage/__init__.pyi | 3 ++- junifer/storage/_types.py | 17 +++++++++++++++++ junifer/storage/base.py | 11 ++--------- junifer/storage/tests/test_utils.py | 10 ---------- junifer/storage/utils.py | 21 +++++++++++++-------- 5 files changed, 34 insertions(+), 28 deletions(-) create mode 100644 junifer/storage/_types.py diff --git a/junifer/storage/__init__.pyi b/junifer/storage/__init__.pyi index 20e8f0524..8eddae24f 100644 --- a/junifer/storage/__init__.pyi +++ b/junifer/storage/__init__.pyi @@ -8,7 +8,8 @@ __all__ = [ "Upsert", ] -from .base import BaseFeatureStorage, MatrixKind, StorageType +from ._types import MatrixKind +from .base import BaseFeatureStorage, StorageType from .hdf5 import HDF5FeatureStorage from .pandas_base import PandasBaseFeatureStorage from .sqlite import SQLiteFeatureStorage, Upsert diff --git a/junifer/storage/_types.py b/junifer/storage/_types.py new file mode 100644 index 000000000..c3dee376d --- /dev/null +++ b/junifer/storage/_types.py @@ -0,0 +1,17 @@ +"""Provide common types for feature storage.""" + +# Authors: Synchon Mandal +# License: AGPL + +from enum import Enum + + +__all__ = ["MatrixKind"] + + +class MatrixKind(str, Enum): + """Accepted matrix kind value.""" + + UpperTriangle = "triu" + LowerTriangle = "tril" + Full = "full" diff --git a/junifer/storage/base.py b/junifer/storage/base.py index 6c55ec55b..42add2f7b 100644 --- a/junifer/storage/base.py +++ b/junifer/storage/base.py @@ -15,10 +15,11 @@ from pydantic import BaseModel, ConfigDict from ..utils import logger, raise_error +from ._types import MatrixKind from .utils import process_meta -__all__ = ["BaseFeatureStorage", "MatrixKind", "StorageType"] +__all__ = ["BaseFeatureStorage", "StorageType"] class StorageType(str, Enum): @@ -31,14 +32,6 @@ class StorageType(str, Enum): ScalarTable = "scalar_table" -class MatrixKind(str, Enum): - """Accepted matrix kind value.""" - - UpperTriangle = "triu" - LowerTriangle = "tril" - Full = "full" - - class BaseFeatureStorage(BaseModel, ABC): """Abstract base class for feature storage. diff --git a/junifer/storage/tests/test_utils.py b/junifer/storage/tests/test_utils.py index a9555fb4b..876e300ee 100644 --- a/junifer/storage/tests/test_utils.py +++ b/junifer/storage/tests/test_utils.py @@ -259,16 +259,6 @@ def test_element_to_prefix_invalid_type() -> None: @pytest.mark.parametrize( "params, err_msg", [ - ( - { - "matrix_kind": "half", - "diagonal": True, - "data_shape": (1, 1), - "row_names_len": 1, - "col_names_len": 1, - }, - "Invalid kind", - ), ( { "matrix_kind": "full", diff --git a/junifer/storage/utils.py b/junifer/storage/utils.py index 0ea795f61..b5d6f219d 100644 --- a/junifer/storage/utils.py +++ b/junifer/storage/utils.py @@ -8,15 +8,12 @@ import json from collections.abc import Sequence from importlib.metadata import PackageNotFoundError, version -from typing import TYPE_CHECKING import numpy as np +from pydantic import validate_call from ..utils.logging import logger, raise_error - - -if TYPE_CHECKING: - from .base import MatrixKind +from ._types import MatrixKind __all__ = [ @@ -168,8 +165,9 @@ def element_to_prefix(element: dict) -> str: return f"{prefix}_" +@validate_call def store_matrix_checks( - matrix_kind: "MatrixKind", + matrix_kind: MatrixKind, diagonal: bool, data_shape: tuple[int, int], row_names_len: int, @@ -276,7 +274,7 @@ def matrix_to_vector( data: np.ndarray, col_names: Sequence[str], row_names: Sequence[str], - matrix_kind: "MatrixKind", + matrix_kind: MatrixKind, diagonal: bool, ) -> tuple[np.ndarray, list[str]]: """Convert matrix to vector based on parameters. @@ -301,6 +299,11 @@ def matrix_to_vector( list of str The column labels. + Raises + ------ + ValueError + If the matrix kind is invalid. + """ # Prepare data indexing based on matrix kind if matrix_kind == "triu": @@ -309,11 +312,13 @@ def matrix_to_vector( elif matrix_kind == "tril": k = 0 if diagonal is True else -1 data_idx = np.tril_indices(data.shape[0], k=k) - else: # full + elif matrix_kind == "full": # full data_idx = ( np.repeat(np.arange(data.shape[0]), data.shape[1]), np.tile(np.arange(data.shape[1]), data.shape[0]), ) + else: + raise_error(f"Invalid matrix kind: {matrix_kind}") # Subset data as 1D flat_data = data[data_idx] # Generate flat 1D row X column names From d66c757d19a310b0019cb4bf1fb7e2f890184321 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 11:31:46 +0100 Subject: [PATCH 71/99] chore: update types for juseless datagrabbers --- junifer/configs/juseless/datagrabbers/__init__.pyi | 3 ++- junifer/configs/juseless/datagrabbers/ixi_vbm.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/junifer/configs/juseless/datagrabbers/__init__.pyi b/junifer/configs/juseless/datagrabbers/__init__.pyi index 4026e7450..28efa1a9c 100644 --- a/junifer/configs/juseless/datagrabbers/__init__.pyi +++ b/junifer/configs/juseless/datagrabbers/__init__.pyi @@ -2,12 +2,13 @@ __all__ = [ "JuselessDataladAOMICID1000VBM", "JuselessDataladCamCANVBM", "JuselessDataladIXIVBM", + "IXISite", "JuselessUCLA", "JuselessDataladUKBVBM", ] from .aomic_id1000_vbm import JuselessDataladAOMICID1000VBM from .camcan_vbm import JuselessDataladCamCANVBM -from .ixi_vbm import JuselessDataladIXIVBM +from .ixi_vbm import JuselessDataladIXIVBM, IXISite from .ucla import JuselessUCLA from .ukb_vbm import JuselessDataladUKBVBM diff --git a/junifer/configs/juseless/datagrabbers/ixi_vbm.py b/junifer/configs/juseless/datagrabbers/ixi_vbm.py index 3aa63bf98..34680a786 100644 --- a/junifer/configs/juseless/datagrabbers/ixi_vbm.py +++ b/junifer/configs/juseless/datagrabbers/ixi_vbm.py @@ -15,7 +15,7 @@ from ....typing import DataGrabberPatterns -__all__ = ["JuselessDataladIXIVBM"] +__all__ = ["IXISite", "JuselessDataladIXIVBM"] class IXISite(str, Enum): From a433443a91390f1c87be911f93fd581ad4d3dc73 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 11:32:53 +0100 Subject: [PATCH 72/99] chore: update type for HDF5FeatureStorage.store_matrix --- junifer/storage/hdf5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/storage/hdf5.py b/junifer/storage/hdf5.py index 70514c343..242c2213a 100644 --- a/junifer/storage/hdf5.py +++ b/junifer/storage/hdf5.py @@ -790,7 +790,7 @@ def _store_data( def store_matrix( self, meta_md5: str, - element: dict[str, str], + element: dict, data: np.ndarray, col_names: Optional[Sequence[str]] = None, row_names: Optional[Sequence[str]] = None, From 66264411ae171c05701e2bf03ee50fa921632c88 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 11:33:36 +0100 Subject: [PATCH 73/99] update: use pydantic model dump for UpdateMetaMixin --- junifer/pipeline/update_meta_mixin.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/junifer/pipeline/update_meta_mixin.py b/junifer/pipeline/update_meta_mixin.py index 29b905e9e..30d4359f9 100644 --- a/junifer/pipeline/update_meta_mixin.py +++ b/junifer/pipeline/update_meta_mixin.py @@ -32,10 +32,8 @@ def update_meta( t_meta = {} # Set class name for the step t_meta["class"] = self.__class__.__name__ - # Add object variables to metadata if name doesn't start with "_" - for k, v in vars(self).items(): - if not k.startswith("_"): - t_meta[k] = v + # Dump model + t_meta.update(self.model_dump(mode="json")) # Conditional for list dtype vals like Warp if not isinstance(input, list): input = [input] From d3156eaba8da7caf6d127d960813d2eeea440b2a Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 11:38:29 +0100 Subject: [PATCH 74/99] fix: remove datadir from datagrabber meta when hashing for storage --- junifer/storage/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/junifer/storage/utils.py b/junifer/storage/utils.py index b5d6f219d..12346c8c2 100644 --- a/junifer/storage/utils.py +++ b/junifer/storage/utils.py @@ -78,6 +78,9 @@ def _meta_hash(meta: dict) -> str: meta["dependencies"] = { dep: get_dependency_version(dep) for dep in meta["dependencies"] } + # Remove datadir from datagrabber + if meta.get("datagrabber"): + _ = meta.get("datagrabber").pop("datadir", None) meta_md5 = hashlib.md5( json.dumps(meta, sort_keys=True).encode("utf-8") ).hexdigest() From a66b6d543afa2f0ef2572daed0c9015be10cde9f Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:52:57 +0100 Subject: [PATCH 75/99] update: add new enums for queue_context --- junifer/api/queue_context/__init__.pyi | 12 ++++-- .../queue_context/queue_context_adapter.py | 43 ++++++++++++++++++- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/junifer/api/queue_context/__init__.pyi b/junifer/api/queue_context/__init__.pyi index 95a7e692e..1026c80ca 100644 --- a/junifer/api/queue_context/__init__.pyi +++ b/junifer/api/queue_context/__init__.pyi @@ -1,5 +1,11 @@ -__all__ = ["QueueContextAdapter", "HTCondorAdapter", "GnuParallelLocalAdapter"] +__all__ = [ + "QueueContextAdapter", + "EnvKind", + "EnvShell", + "QueueContextEnv", + "HTCondorAdapter", + "GnuParallelLocalAdapter", +] -from .queue_context_adapter import QueueContextAdapter -from .htcondor_adapter import HTCondorAdapter +from .queue_context_adapter import QueueContextAdapter, EnvKind, EnvShell, QueueContextEnv from .gnu_parallel_local_adapter import GnuParallelLocalAdapter diff --git a/junifer/api/queue_context/queue_context_adapter.py b/junifer/api/queue_context/queue_context_adapter.py index fe8315b4c..d7058aa66 100644 --- a/junifer/api/queue_context/queue_context_adapter.py +++ b/junifer/api/queue_context/queue_context_adapter.py @@ -3,15 +3,49 @@ # Authors: Synchon Mandal # License: AGPL +import sys + + +if sys.version_info < (3, 12): # pragma: no cover + from typing_extensions import TypedDict +else: + from typing import TypedDict + from abc import ABC, abstractmethod +from enum import Enum + +from pydantic import BaseModel, ConfigDict from ...utils import raise_error -__all__ = ["QueueContextAdapter"] +__all__ = ["EnvKind", "EnvShell", "QueueContextAdapter", "QueueContextEnv"] + + +class EnvKind(str, Enum): + """Accepted Python environment kind.""" + + Venv = "venv" + Conda = "conda" + Local = "local" -class QueueContextAdapter(ABC): +class EnvShell(str, Enum): + """Accepted environment shell.""" + + Bash = "bash" + Zsh = "zsh" + + +class QueueContextEnv(TypedDict, total=False): + """Accepted environment configuration for queue context.""" + + name: str + kind: EnvKind + shell: EnvShell + + +class QueueContextAdapter(BaseModel, ABC): """Abstract base class for queue context adapter. For every interface that is required, one needs to provide a concrete @@ -19,6 +53,11 @@ class QueueContextAdapter(ABC): """ + model_config = ConfigDict( + use_enum_values=True, + extra="allow", + ) + @abstractmethod def pre_run(self) -> str: """Return pre-run commands.""" From 6f29d6f56f9951859972e40bf0ba5d7a6107317b Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:54:38 +0100 Subject: [PATCH 76/99] update: make GnuParallelLocalAdapter pydantic model --- .../gnu_parallel_local_adapter.py | 184 +++++++----------- .../tests/test_gnu_parallel_local_adapter.py | 30 +-- 2 files changed, 73 insertions(+), 141 deletions(-) diff --git a/junifer/api/queue_context/gnu_parallel_local_adapter.py b/junifer/api/queue_context/gnu_parallel_local_adapter.py index fbb7fadf7..4867922ab 100644 --- a/junifer/api/queue_context/gnu_parallel_local_adapter.py +++ b/junifer/api/queue_context/gnu_parallel_local_adapter.py @@ -6,11 +6,16 @@ import shutil import textwrap from pathlib import Path -from typing import Optional +from typing import Any, Optional from ...typing import Elements -from ...utils import logger, make_executable, raise_error, run_ext_cmd -from .queue_context_adapter import QueueContextAdapter +from ...utils import logger, make_executable, run_ext_cmd +from .queue_context_adapter import ( + EnvKind, + EnvShell, + QueueContextAdapter, + QueueContextEnv, +) __all__ = ["GnuParallelLocalAdapter"] @@ -27,14 +32,14 @@ class GnuParallelLocalAdapter(QueueContextAdapter): The path to the job directory. yaml_config_path : pathlib.Path The path to the YAML config file. - elements : list of str or tuple + elements : Elements Element(s) to process. Will be used to index the DataGrabber. - pre_run : str or None, optional + pre_run_cmds : str or None, optional Extra shell commands to source before the run (default None). - pre_collect : str or None, optional - Extra bash commands to source before the collect (default None). - env : dict, optional - The Python environment configuration. If None, will run without a + pre_collect_cmds : str or None, optional + Extra shell commands to source before the collect (default None). + env : :class:`.QueueContextEnv` or None, optional + The environment configuration. If None, will run without a virtual environment of any kind (default None). verbose : str, optional The level of verbosity (default "info"). @@ -44,12 +49,6 @@ class GnuParallelLocalAdapter(QueueContextAdapter): submit : bool, optional Whether to submit the jobs (default False). - Raises - ------ - ValueError - If ``env.kind`` is invalid or - if ``env.shell`` is invalid. - See Also -------- QueueContextAdapter : @@ -59,87 +58,42 @@ class GnuParallelLocalAdapter(QueueContextAdapter): """ - def __init__( - self, - job_name: str, - job_dir: Path, - yaml_config_path: Path, - elements: Elements, - pre_run: Optional[str] = None, - pre_collect: Optional[str] = None, - env: Optional[dict[str, str]] = None, - verbose: str = "info", - verbose_datalad: Optional[str] = None, - submit: bool = False, - ) -> None: - """Initialize the class.""" - self._job_name = job_name - self._job_dir = job_dir - self._yaml_config_path = yaml_config_path - self._elements = elements - self._pre_run = pre_run - self._pre_collect = pre_collect - self._check_env(env) - self._verbose = verbose - self._verbose_datalad = verbose_datalad - self._submit = submit - - self._log_dir = self._job_dir / "logs" - self._pre_run_path = self._job_dir / "pre_run.sh" - self._pre_collect_path = self._job_dir / "pre_collect.sh" - self._run_path = self._job_dir / f"run_{self._job_name}.sh" - self._collect_path = self._job_dir / f"collect_{self._job_name}.sh" - self._run_joblog_path = self._job_dir / f"run_{self._job_name}_joblog" - self._elements_file_path = self._job_dir / "elements" - - def _check_env(self, env: Optional[dict[str, str]]) -> None: - """Check value of env parameter on init. - - Parameters - ---------- - env : dict or None - The value of env parameter. - - Raises - ------ - ValueError - If ``env.kind`` is invalid. - - """ - # Set env related variables - if env is None: - env = {"kind": "local"} - # Check env kind - valid_env_kinds = ["conda", "venv", "local"] - if env["kind"] not in valid_env_kinds: - raise_error( - f"Invalid value for `env.kind`: {env['kind']}, " - f"must be one of {valid_env_kinds}" + job_name: str + job_dir: Path + yaml_config_path: Path + elements: Elements + pre_run_cmds: Optional[str] = None + pre_collect_cmds: Optional[str] = None + env: Optional[QueueContextEnv] = None + verbose: str = "info" + verbose_datalad: Optional[str] = None + submit: bool = False + + def model_post_init(self, context: Any): # noqa: D102 + if self.env is None: + self.env = QueueContextEnv( + kind=EnvKind.Local, shell=EnvShell.Bash, name="" ) + if self.env["kind"] == EnvKind.Local: + # No virtual environment + self._executable = "junifer" + self._arguments = "" else: - # Check shell - shell = env.get("shell", "bash") - valid_shells = ["bash", "zsh"] - if shell not in valid_shells: - raise_error( - f"Invalid value for `env.shell`: {shell}, " - f"must be one of {valid_shells}" - ) - self._shell = shell - # Set variables - if env["kind"] == "local": - # No virtual environment - self._executable = "junifer" - self._arguments = "" - else: - self._executable = f"run_{env['kind']}.{self._shell}" - self._arguments = f"{env['name']} junifer" - self._exec_path = self._job_dir / self._executable - - def elements(self) -> str: + self._executable = f"run_{self.env['kind']}.{self.env['shell']}" + self._arguments = f"{self.env['name']} junifer" + self._exec_path = self.job_dir / self._executable + self._log_dir = self.job_dir / "logs" + self._pre_run_path = self.job_dir / "pre_run.sh" + self._pre_collect_path = self.job_dir / "pre_collect.sh" + self._run_path = self.job_dir / f"run_{self.job_name}.sh" + self._collect_path = self.job_dir / f"collect_{self.job_name}.sh" + self._run_joblog_path = self.job_dir / f"run_{self.job_name}_joblog" + self._elements_file_path = self.job_dir / "elements" + + def elements_to_run(self) -> str: """Return elements to run.""" elements_to_run = [] - for element in self._elements: + for element in self.elements: # Stringify elements if tuple for operation str_element = ( ",".join(element) if isinstance(element, tuple) else element @@ -151,23 +105,23 @@ def elements(self) -> str: def pre_run(self) -> str: """Return pre-run commands.""" fixed = ( - f"#!/usr/bin/env {self._shell}\n\n" + f"#!/usr/bin/env {self.env['shell']}\n\n" "# This script is auto-generated by junifer.\n\n" "# Force datalad to run in non-interactive mode\n" "DATALAD_UI_INTERACTIVE=false\n" ) - var = self._pre_run or "" + var = self.pre_run_cmds or "" return fixed + "\n" + var def run(self) -> str: """Return run commands.""" - verbose_args = f"--verbose {self._verbose}" - if self._verbose_datalad: + verbose_args = f"--verbose {self.verbose}" + if self.verbose_datalad: verbose_args = ( - f"{verbose_args} --verbose-datalad {self._verbose_datalad}" + f"{verbose_args} --verbose-datalad {self.verbose_datalad}" ) return ( - f"#!/usr/bin/env {self._shell}\n\n" + f"#!/usr/bin/env {self.env['shell']}\n\n" "# This script is auto-generated by junifer.\n\n" "# Run pre_run.sh\n" f"sh {self._pre_run_path.resolve()!s}\n\n" @@ -177,9 +131,9 @@ def run(self) -> str: "--delay 60 " # wait 1 min before next job is spawned f"--results {self._log_dir} " f"--arg-file {self._elements_file_path.resolve()!s} " - f"{self._job_dir.resolve()!s}/{self._executable} " + f"{self.job_dir.resolve()!s}/{self._executable} " f"{self._arguments} run " - f"{self._yaml_config_path.resolve()!s} " + f"{self.yaml_config_path.resolve()!s} " f"{verbose_args} " f"--element" ) @@ -187,28 +141,28 @@ def run(self) -> str: def pre_collect(self) -> str: """Return pre-collect commands.""" fixed = ( - f"#!/usr/bin/env {self._shell}\n\n" + f"#!/usr/bin/env {self.env['shell']}\n\n" "# This script is auto-generated by junifer.\n" ) - var = self._pre_collect or "" + var = self.pre_collect_cmds or "" return fixed + "\n" + var def collect(self) -> str: """Return collect commands.""" - verbose_args = f"--verbose {self._verbose}" - if self._verbose_datalad: + verbose_args = f"--verbose {self.verbose}" + if self.verbose_datalad: verbose_args = ( - f"{verbose_args} --verbose-datalad {self._verbose_datalad}" + f"{verbose_args} --verbose-datalad {self.verbose_datalad}" ) return ( - f"#!/usr/bin/env {self._shell}\n\n" + f"#!/usr/bin/env {self.env['shell']}\n\n" "# This script is auto-generated by junifer.\n\n" "# Run pre_collect.sh\n" f"sh {self._pre_collect_path.resolve()!s}\n\n" "# Run `junifer collect`\n" - f"{self._job_dir.resolve()!s}/{self._executable} " + f"{self.job_dir.resolve()!s}/{self._executable} " f"{self._arguments} collect " - f"{self._yaml_config_path.resolve()!s} " + f"{self.yaml_config_path.resolve()!s} " f"{verbose_args}" ) @@ -231,17 +185,19 @@ def prepare(self) -> None: f"{self._elements_file_path.resolve()!s}" ) self._elements_file_path.touch() - self._elements_file_path.write_text(textwrap.dedent(self.elements())) + self._elements_file_path.write_text( + textwrap.dedent(self.elements_to_run()) + ) # Create pre run logger.info( - f"Writing {self._pre_run_path.name} to {self._job_dir.resolve()!s}" + f"Writing {self._pre_run_path.name} to {self.job_dir.resolve()!s}" ) self._pre_run_path.touch() self._pre_run_path.write_text(textwrap.dedent(self.pre_run())) make_executable(self._pre_run_path) # Create run logger.info( - f"Writing {self._run_path.name} to {self._job_dir.resolve()!s}" + f"Writing {self._run_path.name} to {self.job_dir.resolve()!s}" ) self._run_path.touch() self._run_path.write_text(textwrap.dedent(self.run())) @@ -249,14 +205,14 @@ def prepare(self) -> None: # Create pre collect logger.info( f"Writing {self._pre_collect_path.name} to " - f"{self._job_dir.resolve()!s}" + f"{self.job_dir.resolve()!s}" ) self._pre_collect_path.touch() self._pre_collect_path.write_text(textwrap.dedent(self.pre_collect())) make_executable(self._pre_collect_path) # Create collect logger.info( - f"Writing {self._collect_path.name} to {self._job_dir.resolve()!s}" + f"Writing {self._collect_path.name} to {self.job_dir.resolve()!s}" ) self._collect_path.touch() self._collect_path.write_text(textwrap.dedent(self.collect())) @@ -264,7 +220,7 @@ def prepare(self) -> None: # Submit if required run_cmd = f"sh {self._run_path.resolve()!s}" collect_cmd = f"sh {self._collect_path.resolve()!s}" - if self._submit: + if self.submit: logger.info( "Shell scripts created, the following will be run:\n" f"{run_cmd}\n" diff --git a/junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py b/junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py index 17dca2805..e712b7363 100644 --- a/junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py +++ b/junifer/api/queue_context/tests/test_gnu_parallel_local_adapter.py @@ -12,30 +12,6 @@ from junifer.api.queue_context import GnuParallelLocalAdapter -def test_GnuParallelLocalAdapter_env_kind_error() -> None: - """Test error for invalid env kind.""" - with pytest.raises(ValueError, match=r"Invalid value for `env.kind`"): - GnuParallelLocalAdapter( - job_name="check_env_kind", - job_dir=Path("."), - yaml_config_path=Path("."), - elements=["sub01"], - env={"kind": "jambalaya"}, - ) - - -def test_GnuParallelLocalAdapter_env_shell_error() -> None: - """Test error for invalid env shell.""" - with pytest.raises(ValueError, match=r"Invalid value for `env.shell`"): - GnuParallelLocalAdapter( - job_name="check_env_shell", - job_dir=Path("."), - yaml_config_path=Path("."), - elements=["sub01"], - env={"kind": "conda", "shell": "fish"}, - ) - - @pytest.mark.parametrize( "elements, expected_text", [ @@ -63,7 +39,7 @@ def test_GnuParallelLocalAdapter_elements( yaml_config_path=Path("."), elements=elements, ) - assert expected_text in adapter.elements() + assert expected_text in adapter.elements_to_run() @pytest.mark.parametrize( @@ -98,7 +74,7 @@ def test_GnuParallelLocalAdapter_pre_run( yaml_config_path=Path("."), elements=["sub01"], env={"kind": "conda", "name": "junifer", "shell": shell}, - pre_run=pre_run, + pre_run_cmds=pre_run, ) assert shell in adapter.pre_run() assert expected_text in adapter.pre_run() @@ -136,7 +112,7 @@ def test_GnuParallelLocalAdapter_pre_collect( yaml_config_path=Path("."), elements=["sub01"], env={"kind": "venv", "name": "junifer", "shell": shell}, - pre_collect=pre_collect, + pre_collect_cmds=pre_collect, ) assert shell in adapter.pre_collect() assert expected_text in adapter.pre_collect() From bb93f7083ac715498fe7e3a8ec2090db2e47b657 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:56:53 +0100 Subject: [PATCH 77/99] update: make HTCondorAdapter pydantic model --- junifer/api/queue_context/__init__.pyi | 2 + junifer/api/queue_context/htcondor_adapter.py | 276 +++++++----------- .../tests/test_htcondor_adapter.py | 44 +-- 3 files changed, 109 insertions(+), 213 deletions(-) diff --git a/junifer/api/queue_context/__init__.pyi b/junifer/api/queue_context/__init__.pyi index 1026c80ca..9dc63ae80 100644 --- a/junifer/api/queue_context/__init__.pyi +++ b/junifer/api/queue_context/__init__.pyi @@ -4,8 +4,10 @@ __all__ = [ "EnvShell", "QueueContextEnv", "HTCondorAdapter", + "HTCondorCollect", "GnuParallelLocalAdapter", ] from .queue_context_adapter import QueueContextAdapter, EnvKind, EnvShell, QueueContextEnv +from .htcondor_adapter import HTCondorAdapter, HTCondorCollect from .gnu_parallel_local_adapter import GnuParallelLocalAdapter diff --git a/junifer/api/queue_context/htcondor_adapter.py b/junifer/api/queue_context/htcondor_adapter.py index 6715b6dd9..ac68b0d8f 100644 --- a/junifer/api/queue_context/htcondor_adapter.py +++ b/junifer/api/queue_context/htcondor_adapter.py @@ -5,15 +5,37 @@ import shutil import textwrap +from enum import Enum from pathlib import Path -from typing import Optional +from typing import Any, Optional from ...typing import Elements -from ...utils import logger, make_executable, raise_error, run_ext_cmd -from .queue_context_adapter import QueueContextAdapter +from ...utils import logger, make_executable, run_ext_cmd +from .queue_context_adapter import ( + EnvKind, + EnvShell, + QueueContextAdapter, + QueueContextEnv, +) -__all__ = ["HTCondorAdapter"] +__all__ = ["HTCondorAdapter", "HTCondorCollect"] + + +class HTCondorCollect(str, Enum): + """Accepted HTCondor collect commands. + + * ``"yes"``: Submit "collect" task and run even if some of the jobs + fail. + * ``"on_success_only"``: Submit "collect" task and run only if all jobs + succeed. + * ``"no"``: Do not submit "collect" task. + + """ + + Yes = "yes" + No = "no" + OnSuccessOnly = "on_success_only" class HTCondorAdapter(QueueContextAdapter): @@ -27,14 +49,14 @@ class HTCondorAdapter(QueueContextAdapter): The path to the job directory. yaml_config_path : pathlib.Path The path to the YAML config file. - elements : list of str or tuple + elements : Elements Element(s) to process. Will be used to index the DataGrabber. - pre_run : str or None, optional - Extra bash commands to source before the run (default None). - pre_collect : str or None, optional - Extra bash commands to source before the collect (default None). - env : dict, optional - The Python environment configuration. If None, will run without a + pre_run_cmds : str or None, optional + Extra shell commands to source before the run (default None). + pre_collect_cmds : str or None, optional + Extra shell commands to source before the collect (default None). + env : :class:`.QueueContextEnv` or None, optional + The environment configuration. If None, will run without a virtual environment of any kind (default None). verbose : str, optional The level of verbosity (default "info"). @@ -49,25 +71,12 @@ class HTCondorAdapter(QueueContextAdapter): The size of disk (HDD or SSD) to use (default "1G"). extra_preamble : str or None, optional Extra commands to pass to HTCondor (default None). - collect : {"yes", "on_success_only", "no"}, optional + collect_task : :class:`.HTCondorCollect`, optional Whether to submit "collect" task for junifer (default "yes"). - Valid options are: - - * "yes": Submit "collect" task and run even if some of the jobs - fail. - * "on_success_only": Submit "collect" task and run only if all jobs - succeed. - * "no": Do not submit "collect" task. - submit : bool, optional Whether to submit the jobs. In any case, .dag files will be created for submission (default False). - Raises - ------ - ValueError - If ``collect`` is invalid or if ``env`` is invalid. - See Also -------- QueueContextAdapter : @@ -77,144 +86,65 @@ class HTCondorAdapter(QueueContextAdapter): """ - def __init__( - self, - job_name: str, - job_dir: Path, - yaml_config_path: Path, - elements: Elements, - pre_run: Optional[str] = None, - pre_collect: Optional[str] = None, - env: Optional[dict[str, str]] = None, - verbose: str = "info", - verbose_datalad: Optional[str] = None, - cpus: int = 1, - mem: str = "8G", - disk: str = "1G", - extra_preamble: Optional[str] = None, - collect: str = "yes", - submit: bool = False, - ) -> None: - """Initialize the class.""" - self._job_name = job_name - self._job_dir = job_dir - self._yaml_config_path = yaml_config_path - self._elements = elements - self._pre_run = pre_run - self._pre_collect = pre_collect - self._check_env(env) - self._verbose = verbose - self._verbose_datalad = verbose_datalad - self._cpus = cpus - self._mem = mem - self._disk = disk - self._extra_preamble = extra_preamble - self._collect = self._check_collect(collect) - self._submit = submit - - self._log_dir = self._job_dir / "logs" - self._pre_run_path = self._job_dir / "pre_run.sh" - self._pre_collect_path = self._job_dir / "pre_collect.sh" - self._submit_run_path = self._job_dir / f"run_{self._job_name}.submit" - self._submit_collect_path = ( - self._job_dir / f"collect_{self._job_name}.submit" - ) - self._dag_path = self._job_dir / f"{self._job_name}.dag" - - def _check_env(self, env: Optional[dict[str, str]]) -> None: - """Check value of env parameter on init. - - Parameters - ---------- - env : dict or None - The value of env parameter. - - Raises - ------ - ValueError - If ``env.kind`` is invalid or - if ``env.shell`` is invalid. - - """ - # Set env related variables - if env is None: - env = {"kind": "local"} - # Check env kind - valid_env_kinds = ["conda", "venv", "local"] - if env["kind"] not in valid_env_kinds: - raise_error( - f"Invalid value for `env.kind`: {env['kind']}, " - f"must be one of {valid_env_kinds}" - ) - else: - # Check shell - shell = env.get("shell", "bash") - valid_shells = ["bash", "zsh"] - if shell not in valid_shells: - raise_error( - f"Invalid value for `env.shell`: {shell}, " - f"must be one of {valid_shells}" - ) - self._shell = shell - # Set variables - if env["kind"] == "local": - # No virtual environment - self._executable = "junifer" - self._arguments = "" - else: - self._executable = f"run_{env['kind']}.{self._shell}" - self._arguments = f"{env['name']} junifer" - self._exec_path = self._job_dir / self._executable - - def _check_collect(self, collect: str) -> str: - """Check value of collect parameter on init. - - Parameters - ---------- - collect : str - The value of collect parameter. - - Returns - ------- - str - The checked value of collect parameter. - - Raises - ------ - ValueError - If ``collect`` is invalid. - - """ - valid_options = ["yes", "no", "on_success_only"] - if collect not in valid_options: - raise_error( - f"Invalid value for `collect`: {collect}, " - f"must be one of {valid_options}" + job_name: str + job_dir: Path + yaml_config_path: Path + elements: Elements + pre_run_cmds: Optional[str] = None + pre_collect_cmds: Optional[str] = None + env: Optional[QueueContextEnv] = None + verbose: str = "info" + verbose_datalad: Optional[str] = None + cpus: int = 1 + mem: str = "8G" + disk: str = "1G" + extra_preamble: Optional[str] = None + collect_task: HTCondorCollect = HTCondorCollect.Yes + submit: bool = False + + def model_post_init(self, context: Any): # noqa: D102 + if self.env is None: + self.env = QueueContextEnv( + kind=EnvKind.Local, shell=EnvShell.Bash, name="" ) + if self.env["kind"] == EnvKind.Local: + # No virtual environment + self._executable = "junifer" + self._arguments = "" else: - return collect + self._executable = f"run_{self.env['kind']}.{self.env['shell']}" + self._arguments = f"{self.env['name']} junifer" + self._exec_path = self.job_dir / self._executable + self._log_dir = self.job_dir / "logs" + self._pre_run_path = self.job_dir / "pre_run.sh" + self._pre_collect_path = self.job_dir / "pre_collect.sh" + self._submit_run_path = self.job_dir / f"run_{self.job_name}.submit" + self._submit_collect_path = ( + self.job_dir / f"collect_{self.job_name}.submit" + ) + self._dag_path = self.job_dir / f"{self.job_name}.dag" def pre_run(self) -> str: """Return pre-run commands.""" fixed = ( - f"#!/usr/bin/env {self._shell}\n\n" + f"#!/usr/bin/env {self.env['shell']}\n\n" "# This script is auto-generated by junifer.\n\n" "# Force datalad to run in non-interactive mode\n" "DATALAD_UI_INTERACTIVE=false\n" ) - var = self._pre_run or "" + var = self.pre_run_cmds or "" return fixed + "\n" + var def run(self) -> str: """Return run commands.""" - verbose_args = f"--verbose {self._verbose} " - if self._verbose_datalad is not None: + verbose_args = f"--verbose {self.verbose} " + if self.verbose_datalad is not None: verbose_args = ( - f"{verbose_args} --verbose-datalad {self._verbose_datalad} " + f"{verbose_args} --verbose-datalad {self.verbose_datalad} " ) junifer_run_args = ( "run " - f"{self._yaml_config_path.resolve()!s} " + f"{self.yaml_config_path.resolve()!s} " f"{verbose_args}" "--element $(element)" ) @@ -227,11 +157,11 @@ def run(self) -> str: "universe = vanilla\n" "getenv = True\n\n" "# Resources\n" - f"request_cpus = {self._cpus}\n" - f"request_memory = {self._mem}\n" - f"request_disk = {self._disk}\n\n" + f"request_cpus = {self.cpus}\n" + f"request_memory = {self.mem}\n" + f"request_disk = {self.disk}\n\n" "# Executable\n" - f"initial_dir = {self._job_dir.resolve()!s}\n" + f"initial_dir = {self.job_dir.resolve()!s}\n" f"executable = $(initial_dir)/{self._executable}\n" f"transfer_executable = False\n\n" f"arguments = {self._arguments} {junifer_run_args}\n\n" @@ -240,31 +170,31 @@ def run(self) -> str: f"output = {log_dir_prefix}.out\n" f"error = {log_dir_prefix}.err\n" ) - var = self._extra_preamble or "" + var = self.extra_preamble or "" return fixed + "\n" + var + "\n" + "queue" def pre_collect(self) -> str: """Return pre-collect commands.""" fixed = ( - f"#!/usr/bin/env {self._shell}\n\n" + f"#!/usr/bin/env {self.env['shell']}\n\n" "# This script is auto-generated by junifer.\n" ) - var = self._pre_collect or "" + var = self.pre_collect_cmds or "" # Add commands if collect="yes" - if self._collect == "yes": + if self.collect_task == "yes": var += 'if [ "${1}" == "4" ]; then\n exit 1\nfi\n' return fixed + "\n" + var def collect(self) -> str: """Return collect commands.""" - verbose_args = f"--verbose {self._verbose} " - if self._verbose_datalad is not None: + verbose_args = f"--verbose {self.verbose} " + if self.verbose_datalad is not None: verbose_args = ( - f"{verbose_args} --verbose-datalad {self._verbose_datalad} " + f"{verbose_args} --verbose-datalad {self.verbose_datalad} " ) junifer_collect_args = ( - f"collect {self._yaml_config_path.resolve()!s} {verbose_args}" + f"collect {self.yaml_config_path.resolve()!s} {verbose_args}" ) log_dir_prefix = f"{self._log_dir.resolve()!s}/junifer_collect" fixed = ( @@ -273,11 +203,11 @@ def collect(self) -> str: "universe = vanilla\n" "getenv = True\n\n" "# Resources\n" - f"request_cpus = {self._cpus}\n" - f"request_memory = {self._mem}\n" - f"request_disk = {self._disk}\n\n" + f"request_cpus = {self.cpus}\n" + f"request_memory = {self.mem}\n" + f"request_disk = {self.disk}\n\n" "# Executable\n" - f"initial_dir = {self._job_dir.resolve()!s}\n" + f"initial_dir = {self.job_dir.resolve()!s}\n" f"executable = $(initial_dir)/{self._executable}\n" "transfer_executable = False\n\n" f"arguments = {self._arguments} {junifer_collect_args}\n\n" @@ -286,13 +216,13 @@ def collect(self) -> str: f"output = {log_dir_prefix}.out\n" f"error = {log_dir_prefix}.err\n" ) - var = self._extra_preamble or "" + var = self.extra_preamble or "" return fixed + "\n" + var + "\n" + "queue" def dag(self) -> str: """Return HTCondor DAG commands.""" fixed = "" - for idx, element in enumerate(self._elements): + for idx, element in enumerate(self.elements): # Stringify elements if tuple for operation str_element = ( ",".join(element) if isinstance(element, tuple) else element @@ -307,15 +237,15 @@ def dag(self) -> str: f'log_element="{log_element}"\n\n' # double quoted ) var = "" - if self._collect == "yes": + if self.collect_task == "yes": var += ( f"FINAL collect {self._submit_collect_path}\n" f"SCRIPT PRE collect {self._pre_collect_path.as_posix()} " "$DAG_STATUS\n" ) - elif self._collect == "on_success_only": + elif self.collect_task == "on_success_only": var += f"JOB collect {self._submit_collect_path}\nPARENT " - for idx, _ in enumerate(self._elements): + for idx, _ in enumerate(self.elements): var += f"run{idx} " var += "CHILD collect\n" @@ -326,7 +256,7 @@ def prepare(self) -> None: logger.info("Creating HTCondor job") # Create logs logger.info( - f"Creating logs directory under {self._job_dir.resolve()!s}" + f"Creating logs directory under {self.job_dir.resolve()!s}" ) self._log_dir.mkdir(exist_ok=True, parents=True) # Copy executable if not local @@ -341,7 +271,7 @@ def prepare(self) -> None: make_executable(self._exec_path) # Create pre run logger.info( - f"Writing {self._pre_run_path.name} to {self._job_dir.resolve()!s}" + f"Writing {self._pre_run_path.name} to {self.job_dir.resolve()!s}" ) self._pre_run_path.touch() self._pre_run_path.write_text(textwrap.dedent(self.pre_run())) @@ -349,14 +279,14 @@ def prepare(self) -> None: # Create run logger.debug( f"Writing {self._submit_run_path.name} to " - f"{self._job_dir.resolve()!s}" + f"{self.job_dir.resolve()!s}" ) self._submit_run_path.touch() self._submit_run_path.write_text(textwrap.dedent(self.run())) # Create pre collect logger.info( f"Writing {self._pre_collect_path.name} to " - f"{self._job_dir.resolve()!s}" + f"{self.job_dir.resolve()!s}" ) self._pre_collect_path.touch() self._pre_collect_path.write_text(textwrap.dedent(self.pre_collect())) @@ -364,13 +294,13 @@ def prepare(self) -> None: # Create collect logger.debug( f"Writing {self._submit_collect_path.name} to " - f"{self._job_dir.resolve()!s}" + f"{self.job_dir.resolve()!s}" ) self._submit_collect_path.touch() self._submit_collect_path.write_text(textwrap.dedent(self.collect())) # Create DAG logger.debug( - f"Writing {self._dag_path.name} to {self._job_dir.resolve()!s}" + f"Writing {self._dag_path.name} to {self.job_dir.resolve()!s}" ) self._dag_path.touch() self._dag_path.write_text(textwrap.dedent(self.dag())) @@ -380,7 +310,7 @@ def prepare(self) -> None: "-include_env HOME", f"{self._dag_path.resolve()!s}", ] - if self._submit: + if self.submit: run_ext_cmd(name="condor_submit_dag", cmd=condor_submit_dag_cmd) else: logger.info( diff --git a/junifer/api/queue_context/tests/test_htcondor_adapter.py b/junifer/api/queue_context/tests/test_htcondor_adapter.py index 4322813df..f554d517e 100644 --- a/junifer/api/queue_context/tests/test_htcondor_adapter.py +++ b/junifer/api/queue_context/tests/test_htcondor_adapter.py @@ -12,42 +12,6 @@ from junifer.api.queue_context import HTCondorAdapter -def test_HTCondorAdapter_env_kind_error() -> None: - """Test error for invalid env kind.""" - with pytest.raises(ValueError, match=r"Invalid value for `env.kind`"): - HTCondorAdapter( - job_name="check_env_kind", - job_dir=Path("."), - yaml_config_path=Path("."), - elements=["sub01"], - env={"kind": "jambalaya"}, - ) - - -def test_HTCondorAdapter_env_shell_error() -> None: - """Test error for invalid env shell.""" - with pytest.raises(ValueError, match=r"Invalid value for `env.shell`"): - HTCondorAdapter( - job_name="check_env_shell", - job_dir=Path("."), - yaml_config_path=Path("."), - elements=["sub01"], - env={"kind": "conda", "shell": "fish"}, - ) - - -def test_HTCondorAdapter_collect_error() -> None: - """Test error for invalid collect option.""" - with pytest.raises(ValueError, match=r"Invalid value for `collect`"): - HTCondorAdapter( - job_name="check_collect", - job_dir=Path("."), - yaml_config_path=Path("."), - elements=["sub01"], - collect="off", - ) - - @pytest.mark.parametrize( "pre_run, expected_text, shell", [ @@ -80,7 +44,7 @@ def test_HTCondorAdapter_pre_run( yaml_config_path=Path("."), elements=["sub01"], env={"kind": "conda", "name": "junifer", "shell": shell}, - pre_run=pre_run, + pre_run_cmds=pre_run, ) assert shell in adapter.pre_run() assert expected_text in adapter.pre_run() @@ -125,8 +89,8 @@ def test_HTCondorAdapter_pre_collect( yaml_config_path=Path("."), elements=["sub01"], env={"kind": "venv", "name": "junifer", "shell": shell}, - pre_collect=pre_collect, - collect=collect, + pre_collect_cmds=pre_collect, + collect_task=collect, ) assert shell in adapter.pre_collect() assert expected_text in adapter.pre_collect() @@ -200,7 +164,7 @@ def test_HTCondor_dag( job_dir=Path("."), yaml_config_path=Path("."), elements=elements, - collect=collect, + collect_task=collect, ) assert expected_text in adapter.dag() From e7d070362fb198cd3b4b7cb758ebbffccfd4e7a5 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:57:21 +0100 Subject: [PATCH 78/99] chore: update docstring for QueueContextAdapter --- junifer/api/queue_context/queue_context_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/api/queue_context/queue_context_adapter.py b/junifer/api/queue_context/queue_context_adapter.py index d7058aa66..71dcb1a7b 100644 --- a/junifer/api/queue_context/queue_context_adapter.py +++ b/junifer/api/queue_context/queue_context_adapter.py @@ -48,7 +48,7 @@ class QueueContextEnv(TypedDict, total=False): class QueueContextAdapter(BaseModel, ABC): """Abstract base class for queue context adapter. - For every interface that is required, one needs to provide a concrete + For every queue context, one needs to provide a concrete implementation of this abstract class. """ From 88b6bcdbc1d67e314a10d377c5c8cdb8fdda7771 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:58:20 +0100 Subject: [PATCH 79/99] chore: make configs import lazy --- junifer/configs/__init__.py | 5 +++++ junifer/configs/__init__.pyi | 0 junifer/configs/py.typed | 0 3 files changed, 5 insertions(+) create mode 100644 junifer/configs/__init__.pyi create mode 100644 junifer/configs/py.typed diff --git a/junifer/configs/__init__.py b/junifer/configs/__init__.py index 3109f2b5b..e2b5f6564 100644 --- a/junifer/configs/__init__.py +++ b/junifer/configs/__init__.py @@ -2,3 +2,8 @@ # Authors: Federico Raimondo # License: AGPL + +import lazy_loader as lazy + + +__getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__) diff --git a/junifer/configs/__init__.pyi b/junifer/configs/__init__.pyi new file mode 100644 index 000000000..e69de29bb diff --git a/junifer/configs/py.typed b/junifer/configs/py.typed new file mode 100644 index 000000000..e69de29bb From 1c8acfacedea36a306aa9de932d12843fcf77c6b Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:58:56 +0100 Subject: [PATCH 80/99] chore: make external import lazy --- junifer/external/__init__.py | 5 +++++ junifer/external/__init__.pyi | 0 junifer/external/py.typed | 0 3 files changed, 5 insertions(+) create mode 100644 junifer/external/__init__.pyi create mode 100644 junifer/external/py.typed diff --git a/junifer/external/__init__.py b/junifer/external/__init__.py index c7a5b28f1..5fcc463eb 100644 --- a/junifer/external/__init__.py +++ b/junifer/external/__init__.py @@ -2,3 +2,8 @@ # Authors: Synchon Mandal # License: AGPL + +import lazy_loader as lazy + + +__getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__) diff --git a/junifer/external/__init__.pyi b/junifer/external/__init__.pyi new file mode 100644 index 000000000..e69de29bb diff --git a/junifer/external/py.typed b/junifer/external/py.typed new file mode 100644 index 000000000..e69de29bb From 9286894df81d18b41c5c3d53efa0db713c345b93 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:59:15 +0100 Subject: [PATCH 81/99] chore: add py.typed to typing --- junifer/typing/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 junifer/typing/py.typed diff --git a/junifer/typing/py.typed b/junifer/typing/py.typed new file mode 100644 index 000000000..e69de29bb From fe5fd2b91cfaff85def463560fd49924c3b07d59 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:59:41 +0100 Subject: [PATCH 82/99] chore: lean up tests in configs --- .../juseless/datagrabbers/tests/test_ixi_vbm.py | 7 ------- .../juseless/datagrabbers/tests/test_ucla.py | 16 ---------------- 2 files changed, 23 deletions(-) diff --git a/junifer/configs/juseless/datagrabbers/tests/test_ixi_vbm.py b/junifer/configs/juseless/datagrabbers/tests/test_ixi_vbm.py index 11aa3b05e..95bbb0375 100644 --- a/junifer/configs/juseless/datagrabbers/tests/test_ixi_vbm.py +++ b/junifer/configs/juseless/datagrabbers/tests/test_ixi_vbm.py @@ -31,10 +31,3 @@ def test_JuselessDataladIXIVBM() -> None: out["VBM_GM"]["path"].name == f"m0wp1sub-{test_element[1]}.nii.gz" ) assert out["VBM_GM"]["path"].exists() - - -def test_JuselessDataladIXIVBM_invalid_site() -> None: - """Test JuselessDataladIXIVBM with invalid site.""" - with pytest.raises(ValueError, match="notavalidsite not a valid site"): - with JuselessDataladIXIVBM(sites="notavalidsite"): - pass diff --git a/junifer/configs/juseless/datagrabbers/tests/test_ucla.py b/junifer/configs/juseless/datagrabbers/tests/test_ucla.py index f6befdd5c..18dee8d24 100644 --- a/junifer/configs/juseless/datagrabbers/tests/test_ucla.py +++ b/junifer/configs/juseless/datagrabbers/tests/test_ucla.py @@ -80,14 +80,6 @@ def test_JuselessUCLA_partial_data_access( assert types in out -def test_JuselessUCLA_incorrect_data_type() -> None: - """Test JuselessUCLA DataGrabber incorrect data type.""" - with pytest.raises( - ValueError, match="`patterns` must contain all `types`" - ): - _ = JuselessUCLA(types="Eunomia") - - @pytest.mark.parametrize( "tasks", [None, "rest", ["rest", "stopsignal"]], @@ -125,11 +117,3 @@ def test_JuselessUCLA_task_params(tasks: Optional[str]) -> None: else: for el in all_elements: assert el[1] in ["rest", "stopsignal"] - - -def test_JuselessUCLA_invalid_tasks() -> None: - """Test JuselessUCLA with invalid task parameters.""" - with pytest.raises( - ValueError, match="invalid is not a valid task in the UCLA" - ): - JuselessUCLA(tasks="invalid") From 467d5cda77c6ced7b4325bebfe6a57d7cce0bf77 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 13:59:55 +0100 Subject: [PATCH 83/99] chore: fix method in DataDispatcher --- junifer/data/_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/junifer/data/_dispatch.py b/junifer/data/_dispatch.py index 99305dfff..6ace2d7a8 100644 --- a/junifer/data/_dispatch.py +++ b/junifer/data/_dispatch.py @@ -95,7 +95,7 @@ def __setitem__( # Update global self._registries[key] = value - def popitem(): + def popitem(self): """Not implemented.""" pass From 53cbd1db83fc8638efa0030f2a3e6fa33e21142d Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 14:05:03 +0100 Subject: [PATCH 84/99] chore: improve docstrings --- junifer/api/queue_context/gnu_parallel_local_adapter.py | 2 +- junifer/api/queue_context/htcondor_adapter.py | 2 +- junifer/datagrabber/base.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/junifer/api/queue_context/gnu_parallel_local_adapter.py b/junifer/api/queue_context/gnu_parallel_local_adapter.py index 4867922ab..16de556ed 100644 --- a/junifer/api/queue_context/gnu_parallel_local_adapter.py +++ b/junifer/api/queue_context/gnu_parallel_local_adapter.py @@ -32,7 +32,7 @@ class GnuParallelLocalAdapter(QueueContextAdapter): The path to the job directory. yaml_config_path : pathlib.Path The path to the YAML config file. - elements : Elements + elements : ``Elements`` Element(s) to process. Will be used to index the DataGrabber. pre_run_cmds : str or None, optional Extra shell commands to source before the run (default None). diff --git a/junifer/api/queue_context/htcondor_adapter.py b/junifer/api/queue_context/htcondor_adapter.py index ac68b0d8f..f895eb3a8 100644 --- a/junifer/api/queue_context/htcondor_adapter.py +++ b/junifer/api/queue_context/htcondor_adapter.py @@ -49,7 +49,7 @@ class HTCondorAdapter(QueueContextAdapter): The path to the job directory. yaml_config_path : pathlib.Path The path to the YAML config file. - elements : Elements + elements : ``Elements`` Element(s) to process. Will be used to index the DataGrabber. pre_run_cmds : str or None, optional Extra shell commands to source before the run (default None). diff --git a/junifer/datagrabber/base.py b/junifer/datagrabber/base.py index 029a50358..bcdaffaf6 100644 --- a/junifer/datagrabber/base.py +++ b/junifer/datagrabber/base.py @@ -164,7 +164,7 @@ def filter(self, selection: Elements) -> Iterator: Parameters ---------- - selection : `Elements` + selection : ``Elements`` The list of partial or complete element selectors to filter using. Yields @@ -179,7 +179,7 @@ def filter_func(element: Element) -> bool: Parameters ---------- - element : `Elements` + element : ``Elements`` The element to be filtered. Returns From 7afbd172faeacc5ad7f7ed85b6eb0c3027982f89 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 14:20:37 +0100 Subject: [PATCH 85/99] update: add pydantic.validate_call to get_aggfunc_by_name --- junifer/stats.py | 65 ++++++++++++++++++++----------------- junifer/tests/test_stats.py | 2 -- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/junifer/stats.py b/junifer/stats.py index 61a35bdca..94e204b0f 100644 --- a/junifer/stats.py +++ b/junifer/stats.py @@ -4,37 +4,58 @@ # Synchon Mandal # License: AGPL +from enum import Enum from typing import Any, Callable, Optional import numpy as np +from pydantic import validate_call from scipy.stats import mode, trim_mean from scipy.stats.mstats import winsorize from .utils import logger, raise_error -__all__ = ["count", "get_aggfunc_by_name", "select", "winsorized_mean"] +__all__ = [ + "AggFunc", + "count", + "get_aggfunc_by_name", + "select", + "winsorized_mean", +] +class AggFunc(str, Enum): + """Accepted aggregation function names. + + * ``mean`` -> :func:`numpy.mean` + * ``winsorized_mean`` -> :func:`scipy.stats.mstats.winsorize` + * ``trim_mean`` -> :func:`scipy.stats.trim_mean` + * ``mode`` -> :func:`scipy.stats.mode` + * ``std`` -> :func:`numpy.std` + * ``count`` -> :func:`.count` + * ``select`` -> :func:`.select` + + """ + + Mean = "mean" + WinsorizedMean = "winsorized_mean" + TrimMean = "trim_mean" + Mode = "mode" + Std = "std" + Count = "count" + Select = "select" + + +@validate_call def get_aggfunc_by_name( - name: str, func_params: Optional[dict[str, Any]] = None + name: AggFunc, func_params: Optional[dict[str, Any]] = None ) -> Callable: """Get an aggregation function by its name. Parameters ---------- - name : str - Name to identify the function. Currently supported names and - corresponding functions are: - - * ``mean`` -> :func:`numpy.mean` - * ``winsorized_mean`` -> :func:`scipy.stats.mstats.winsorize` - * ``trim_mean`` -> :func:`scipy.stats.trim_mean` - * ``mode`` -> :func:`scipy.stats.mode` - * ``std`` -> :func:`numpy.std` - * ``count`` -> :func:`.count` - * ``select`` -> :func:`.select` - + name : :enum:`.AggFunc` + Aggregation function name. func_params : dict, optional Parameters to pass to the function. E.g. for ``winsorized_mean``: ``func_params = {'limits': [0.1, 0.1]}`` @@ -48,16 +69,6 @@ def get_aggfunc_by_name( """ from functools import partial # local import to avoid sphinx error - # check validity of names - _valid_func_names = { - "winsorized_mean", - "mean", - "std", - "trim_mean", - "count", - "select", - "mode", - } if func_params is None: func_params = {} # apply functions @@ -101,11 +112,7 @@ def get_aggfunc_by_name( func = partial(select, **func_params) elif name == "mode": func = partial(mode, **func_params) - else: - raise_error( - f"Function {name} unknown. Please provide any of " - f"{_valid_func_names}" - ) + return func diff --git a/junifer/tests/test_stats.py b/junifer/tests/test_stats.py index a0f9f7828..cc8bfac41 100644 --- a/junifer/tests/test_stats.py +++ b/junifer/tests/test_stats.py @@ -41,8 +41,6 @@ def test_get_aggfunc_by_name(name: str, params: Optional[dict]) -> None: def test_get_aggfunc_by_name_errors() -> None: """Test aggregation function retrieval using wrong name.""" - with pytest.raises(ValueError, match=r"unknown. Please provide any of"): - get_aggfunc_by_name(name="invalid", func_params=None) with pytest.raises(ValueError, match=r"list of limits"): get_aggfunc_by_name(name="winsorized_mean", func_params=None) From 46dac91e9d9a0812ebd8731253ca5d732999562d Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 14:38:27 +0100 Subject: [PATCH 86/99] docs: update docs and add deps --- docs/api/onthefly.rst | 2 +- docs/builtin.rst | 6 +++--- docs/conf.py | 5 +++++ docs/whats_new.rst | 2 +- pyproject.toml | 2 ++ 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/api/onthefly.rst b/docs/api/onthefly.rst index e97f2f555..6974788d3 100644 --- a/docs/api/onthefly.rst +++ b/docs/api/onthefly.rst @@ -5,5 +5,5 @@ On-the-fly :members: :imported-members: -.. automodule:: junifer.onthefly.brainprint +.. automodule:: junifer.onthefly._brainprint :members: diff --git a/docs/builtin.rst b/docs/builtin.rst index 10683255c..48db25860 100644 --- a/docs/builtin.rst +++ b/docs/builtin.rst @@ -158,16 +158,16 @@ Available | (subject-native or other template spaces) - Done - 0.0.4 - * - ``Smoothing`` + * - :class:`.Smoothing` - | Apply smoothing to data, particularly useful when dealing with | ``fMRIPrep``-ed data - In Progress - :gh:`161` - * - ``TemporalSlicer`` + * - :class:`.TemporalSlicer` - Slice ``BOLD`` data temporally - | Done - :gh:`443` - * - ``TemporalFilter`` + * - :class:`.TemporalFilter` - Filter (clean) ``BOLD`` data temporally - | Done - :gh:`432` diff --git a/docs/conf.py b/docs/conf.py index e4353cd37..b4a16bafa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -65,6 +65,8 @@ "sphinx_copybutton", # copy button for code blocks "sphinxcontrib.mermaid", # mermaid support "sphinxcontrib.towncrier.ext", # towncrier fragment support + "sphinxcontrib.autodoc_pydantic", # autodoc support for pydantic models + "enum_tools.autoenum", # enum support ] if use_multiversion: @@ -97,6 +99,8 @@ ("py:class", "pipeline.Pipeline"), # nilearn ("py:obj", "neurokit2.*"), # ignore neurokit2 ("py:obj", "datalad.*"), # ignore datalad + ("py:obj", "junifer.*"), # ignore junifer internal + ("py:class", "annotated_types.*") # ignore pydantic annotated types ] # -- Options for HTML output ------------------------------------------------- @@ -154,6 +158,7 @@ "pandas": ("https://pandas.pydata.org/pandas-docs/dev", None), # "sqlalchemy": ("https://docs.sqlalchemy.org/en/20/", None), "scipy": ("https://docs.scipy.org/doc/scipy/", None), + "pydantic": ("https://docs.pydantic.dev/latest/", None), } # -- sphinx.ext.extlinks configuration --------------------------------------- diff --git a/docs/whats_new.rst b/docs/whats_new.rst index 1f9f64f04..d5f4fe202 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -146,7 +146,7 @@ Features ^^^^^^^^ - Introduce :func:`.normalize` and :func:`.reweight` functions for downstream - BrainPrint analysis in :mod:`.onthefly.brainprint` by `Synchon Mandal`_ + BrainPrint analysis in :mod:`.onthefly._brainprint` by `Synchon Mandal`_ (:gh:`354`) - Introduce :class:`junifer.pipeline.PipelineComponentRegistry` to centralise pipeline component management by `Synchon Mandal`_ (:gh:`362`) diff --git a/pyproject.toml b/pyproject.toml index 544ae4223..64061d1ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,8 @@ docs = [ "sphinxcontrib-mermaid>=0.8.1,<0.10", "sphinxcontrib-towncrier==0.4.0a0", "setuptools-scm>=8", + "autodoc_pydantic>=2.0.0", + "enum-tools[sphinx]>=0.13.0,<0.14.0", ] ################ From 6de05aa162ba077cc0b8007fadd9baf4694493e9 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 14:38:52 +0100 Subject: [PATCH 87/99] chore: add missing imports --- junifer/configs/juseless/datagrabbers/__init__.pyi | 3 ++- junifer/configs/juseless/datagrabbers/ixi_vbm.py | 2 +- junifer/configs/juseless/datagrabbers/ucla.py | 2 +- junifer/testing/datagrabbers.py | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/junifer/configs/juseless/datagrabbers/__init__.pyi b/junifer/configs/juseless/datagrabbers/__init__.pyi index 28efa1a9c..2964baee7 100644 --- a/junifer/configs/juseless/datagrabbers/__init__.pyi +++ b/junifer/configs/juseless/datagrabbers/__init__.pyi @@ -4,11 +4,12 @@ __all__ = [ "JuselessDataladIXIVBM", "IXISite", "JuselessUCLA", + "UCLATask", "JuselessDataladUKBVBM", ] from .aomic_id1000_vbm import JuselessDataladAOMICID1000VBM from .camcan_vbm import JuselessDataladCamCANVBM from .ixi_vbm import JuselessDataladIXIVBM, IXISite -from .ucla import JuselessUCLA +from .ucla import JuselessUCLA, UCLATask from .ukb_vbm import JuselessDataladUKBVBM diff --git a/junifer/configs/juseless/datagrabbers/ixi_vbm.py b/junifer/configs/juseless/datagrabbers/ixi_vbm.py index 34680a786..01c936062 100644 --- a/junifer/configs/juseless/datagrabbers/ixi_vbm.py +++ b/junifer/configs/juseless/datagrabbers/ixi_vbm.py @@ -38,7 +38,7 @@ class JuselessDataladIXIVBM(PatternDataladDataGrabber): That path where the datalad dataset will be cloned. If not specified, the datalad dataset will be cloned into a temporary directory. - sites : list of :obj:`IXISite`, optional + sites : list of :enum:`.IXISite`, optional IXI sites. By default, all available sites are selected. diff --git a/junifer/configs/juseless/datagrabbers/ucla.py b/junifer/configs/juseless/datagrabbers/ucla.py index 886beacc6..e82eac8b6 100644 --- a/junifer/configs/juseless/datagrabbers/ucla.py +++ b/junifer/configs/juseless/datagrabbers/ucla.py @@ -13,7 +13,7 @@ from ....typing import DataGrabberPatterns -__all__ = ["JuselessUCLA"] +__all__ = ["JuselessUCLA", "UCLATask"] class UCLATask(str, Enum): diff --git a/junifer/testing/datagrabbers.py b/junifer/testing/datagrabbers.py index 4c45e7ca7..ab20cc035 100644 --- a/junifer/testing/datagrabbers.py +++ b/junifer/testing/datagrabbers.py @@ -17,6 +17,7 @@ __all__ = [ "OasisVBMTestingDataGrabber", + "PartlyCloudyAgeGroup", "PartlyCloudyTestingDataGrabber", "SPMAuditoryTestingDataGrabber", ] @@ -178,7 +179,7 @@ class PartlyCloudyTestingDataGrabber(BaseDataGrabber): purpose of having realistic examples. Depending on your research question, other confounds might be more appropriate. If False, returns all :term:`fMRIPrep` confounds (default True). - age_group : `PartlyCloudyAgeGroup`, optional + age_group : :enum:`.PartlyCloudyAgeGroup`, optional Age group to fetch (default PartlyCloudyAgeGroup.Both). """ From 8609c335fbe874ed9793494c8b07de749c36d4ca Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 14:39:23 +0100 Subject: [PATCH 88/99] chore: add typing_extensions to deps --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 64061d1ca..ae64af64b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "structlog>=25.0.0,<26.0.0", "pydantic>=2.11.4", "aenum>=3.1.0,<3.2.0", + "typing_extensions>=4.15.0,<4.16.0; python_version<'3.12'", ] dynamic = ["version"] From db6b17804aa1bfc9b09a7c756ff704f5b37c7572 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 17 Nov 2025 14:39:43 +0100 Subject: [PATCH 89/99] docs: update extending docs --- docs/extending/datagrabber.rst | 117 ++++++++++++++------------------ docs/extending/dependencies.rst | 20 +++--- docs/extending/marker.rst | 78 ++++++++++----------- docs/extending/preprocessor.rst | 41 +++++------ 4 files changed, 110 insertions(+), 146 deletions(-) diff --git a/docs/extending/datagrabber.rst b/docs/extending/datagrabber.rst index f76303f35..46395eb04 100644 --- a/docs/extending/datagrabber.rst +++ b/docs/extending/datagrabber.rst @@ -140,29 +140,24 @@ With the variables defined above, we can create our DataGrabber and name it from pathlib import Path - from junifer.datagrabber import PatternDataGrabber + from junifer.datagrabber import PatternDataGrabber, DataType + from junifer.typing import DataGrabberPatterns class ExampleBIDSDataGrabber(PatternDataGrabber): - def __init__(self, datadir: str | Path) -> None: - types = ["T1w", "BOLD"] - patterns = { - "T1w": { - "pattern": "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz", - "space": "native", - }, - "BOLD": { - "pattern": "{subject}/{session}/func/{subject}_{session}_task-rest_bold.nii.gz", - "space": "MNI152NLin6Asym", - }, - } - replacements = ["subject", "session"] - super().__init__( - datadir=datadir, - types=types, - patterns=patterns, - replacements=replacements, - ) + + types: list[DataType] = [DataType.T1w, DataType.BOLD] + patterns: DataGrabberPatterns = { + "T1w": { + "pattern": "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz", + "space": "native", + }, + "BOLD": { + "pattern": "{subject}/{session}/func/{subject}_{session}_task-rest_bold.nii.gz", + "space": "MNI152NLin6Asym", + }, + } + replacements: list[str] = ["subject", "session"] Our DataGrabber is ready to be used by ``junifer``. However, it is still unknown to the library. We need to register it in the library. To do so, we need to @@ -175,29 +170,24 @@ use the :func:`.register_datagrabber` decorator. from junifer.api.decorators import register_datagrabber from junifer.datagrabber import PatternDataGrabber + from junifer.typing import DataGrabberPatterns @register_datagrabber class ExampleBIDSDataGrabber(PatternDataGrabber): - def __init__(self, datadir: str | Path) -> None: - types = ["T1w", "BOLD"] - patterns = { - "T1w": { - "pattern": "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz", - "space": "native", - }, - "BOLD": { - "pattern": "{subject}/{session}/func/{subject}_{session}_task-rest_bold.nii.gz", - "space": "MNI152NLin6Asym", - }, - } - replacements = ["subject", "session"] - super().__init__( - datadir=datadir, - types=types, - patterns=patterns, - replacements=replacements, - ) + + types: list[DataType] = [DataType.T1w, DataType.BOLD] + patterns: DataGrabberPatterns = { + "T1w": { + "pattern": "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz", + "space": "native", + }, + "BOLD": { + "pattern": "{subject}/{session}/func/{subject}_{session}_task-rest_bold.nii.gz", + "space": "MNI152NLin6Asym", + }, + } + replacements: list[str] = ["subject", "session"] Now, we can use our DataGrabber in ``junifer``, by setting the ``datagrabber`` @@ -259,35 +249,30 @@ And we can create our DataGrabber: .. code-block:: python + from pathlib import Path + from junifer.api.decorators import register_datagrabber from junifer.datagrabber import PatternDataladDataGrabber + from pydantic import HttpUrl @register_datagrabber class ExampleBIDSDataGrabber(PatternDataladDataGrabber): - def __init__(self) -> None: - types = ["T1w", "BOLD"] - patterns = { - "T1w": { - "pattern": "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz", - "space": "native", - }, - "BOLD": { - "pattern": "{subject}/{session}/func/{subject}_{session}_task-rest_bold.nii.gz", - "space": "MNI152NLin6Asym", - }, - } - replacements = ["subject", "session"] - uri = "https://gin.g-node.org/juaml/datalad-example-bids" - rootdir = "example_bids_ses" - super().__init__( - datadir=None, - uri=uri, - rootdir=rootdir, - types=types, - patterns=patterns, - replacements=replacements, - ) + + uri: HttpUrl = HttpUrl("https://gin.g-node.org/juaml/datalad-example-bids") + types: list[DataType] = [DataType.T1w, DataType.BOLD] + patterns: DataGrabberPatterns = { + "T1w": { + "pattern": "{subject}/{session}/anat/{subject}_{session}_T1w.nii.gz", + "space": "native", + }, + "BOLD": { + "pattern": "{subject}/{session}/func/{subject}_{session}_task-rest_bold.nii.gz", + "space": "MNI152NLin6Asym", + }, + } + replacements: list[str] = ["subject", "session"] + rootdir: Path = Path("example_bids_ses") This approach can be used directly from the YAML, like so: @@ -376,8 +361,8 @@ need to implement the following methods: .. note:: - The ``__init__`` method could also be implemented, but it is not mandatory. - This is required if the DataGrabber requires any extra parameter. + If the DataGrabber requires any extra parameter, they could be defined as + class attributes. We will now implement our BIDS example with this method. @@ -494,8 +479,8 @@ more information about the format of the confounds file. Thus, the ``BOLD.confounds`` element is a dictionary with the following keys: - ``path``: the path to the confounds file. -- ``format``: the format of the confounds file. Currently, this can be either - ``fmriprep`` or ``adhoc``. +- ``format``: the format of the confounds file. Check :enum:`.ConfoundsFormat` + for options. The ``fmriprep`` format corresponds to the format of the confounds files generated by `fMRIPrep`_. The ``adhoc`` format corresponds to a format that is diff --git a/docs/extending/dependencies.rst b/docs/extending/dependencies.rst index 48132f990..659583621 100644 --- a/docs/extending/dependencies.rst +++ b/docs/extending/dependencies.rst @@ -46,7 +46,7 @@ by having a class attribute like so: _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "afni", + "name": ExtDep.AFNI, "commands": ["3dReHo", "3dAFNItoNIFTI"], }, ] @@ -55,7 +55,7 @@ The above example is taken from the class which computes regional homogeneity (ReHo) using AFNI. The general pattern is that you need to have the value of ``_EXT_DEPENDENCIES`` as a list of dictionary with two keys: -* ``name`` (str) : lowercased name of the toolbox +* ``name`` (:enum:`.ExtDep`) : name of the toolbox * ``commands`` (list of str) : actual names of the commands you need to use This is simple but powerful as we will see in the following sub-sections. @@ -81,11 +81,11 @@ that it shows the problem a bit better and how we solve it: _CONDITIONAL_DEPENDENCIES: ClassVar[ConditionalDependencies] = [ { "using": "fsl", - "depends_on": FSLWarper, + "depends_on": [FSLWarper], }, { "using": "ants", - "depends_on": ANTsWarper, + "depends_on": [ANTSWarper], }, { "using": "auto", @@ -93,18 +93,16 @@ that it shows the problem a bit better and how we solve it: }, ] - def __init__( - self, using: str, reference: str, on: Union[List[str], str] - ) -> None: - # validation and setting up - ... + using: str + reference: str + on: List[DataType] Here, you see a new class attribute ``_CONDITIONAL_DEPENDENCIES`` which is a list of dictionaries with two keys: * ``using`` (str) : lowercased name of the toolbox -* ``depends_on`` (object or list of objects) : a class or list of classes which \ +* ``depends_on`` (list of objects) : list of classes which \ implements the particular tool's use It is mandatory to have the ``using`` positional argument in the constructor in @@ -128,7 +126,7 @@ similar. ``FSLWarper`` looks like this (only the relevant part is shown here): _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ { - "name": "fsl", + "name": ExtDep.FSL, "commands": ["flirt", "applywarp"], }, ] diff --git a/docs/extending/marker.rst b/docs/extending/marker.rst index d73c6e821..d65d150f4 100644 --- a/docs/extending/marker.rst +++ b/docs/extending/marker.rst @@ -14,7 +14,7 @@ Most of the functionality of a ``junifer`` Marker has been taken care by the :class:`.BaseMarker` class. Thus, only a few methods and class attributes are required: -#. ``__init__``: The initialisation method, where the Marker is configured. +#. (optional) ``validate_marker_params``: The method to perform logical validation of parameters (if required). #. ``compute``: The method that given the data, computes the Marker. As an example, we will develop a ``ParcelMean`` Marker, a Marker that first @@ -29,8 +29,8 @@ Step 1: Configure input and output This step is quite simple: we need to define the input and output of the Marker. Based on the current :ref:`data types `, we can have ``BOLD``, ``VBM_WM`` and ``VBM_GM`` as valid inputs. The output of the Marker depends on -the input. For ``BOLD``, it will be ``timeseries``, while for the rest of the -inputs, it will be ``vector``. Thus, we have a class attribute like so: +the input. For ``BOLD``, it will be ``Timeseries``, while for the rest of the +inputs, it will be ``Vector``. Thus, we have a class attribute like so: .. code-block:: python @@ -38,14 +38,14 @@ inputs, it will be ``vector``. Thus, we have a class attribute like so: # You can have multiple features for one data type, # each feature having same or different storage type _MARKER_INOUT_MAPPINGS = { - "BOLD": { - "parcel_mean": "timeseries", + DataType.BOLD: { + "parcel_mean": StorageType.Timeseries, }, - "VBM_WM": { - "parcel_mean": "vector", + DataType.VBM_WM: { + "parcel_mean": StorageType.Vector, }, - "VBM_GM": { - "parcel_mean": "vector", + DataType.VBM_GM: { + "parcel_mean": StorageType.Vector, }, } @@ -57,13 +57,11 @@ Step 2: Initialise the Marker In this step we need to define the parameters of the Marker the user can provide to configure how the Marker will behave. -The parameters of the Marker are defined in the ``__init__`` method. The -:class:`.BaseMarker` class requires two optional parameters: +The parameters of the Marker are defined as class attributes. The +:class:`.BaseMarker` class defines two optional parameters: -1. ``name``: the name of the Marker. This is used to identify the Marker in the - configuration file. -2. ``on``: a list or string with the data types that the Marker will be applied - to. +1. ``name``: the name of the Marker. This is used to identify the Marker in the configuration file. +2. ``on``: a list of :enum:`.DataType` with the data types that the Marker will be applied to. .. attention:: @@ -72,18 +70,11 @@ The parameters of the Marker are defined in the ``__init__`` method. The JSON format, and JSON only supports these types. In this example, only parameter required for the computation is the name of the -parcellation to use. Thus, we can define the ``__init__`` method as follows: +parcellation to use. Thus, we can define as follows: .. code-block:: python - def __init__( - self, - parcellation: str, - on: str | list[str] | None = None, - name: str | None = None, - ) -> None: - self.parcellation = parcellation - super().__init__(on=on, name=name) + parcellation: str .. caution:: @@ -121,7 +112,7 @@ and the values would be a dictionary of storage type specific key-value pairs. To simplify the ``store`` method, define keys of the dictionary based on the corresponding store functions in the :ref:`storage types `. - For example, if the output is a ``vector``, the keys of the dictionary should + For example, if the output is a ``Vector``, the keys of the dictionary should be ``data`` and ``col_names``. .. code-block:: python @@ -142,7 +133,7 @@ and the values would be a dictionary of storage type specific key-value pairs. # Get the parcellation tailored for the target t_parcellation, t_labels, _ = get_parcellation( - name=self.parcellation_name, + name=self.parcellation, target_data=input, extra_input=extra_input, ) @@ -195,7 +186,9 @@ Finally, we need to register the Marker using the ``@register_marker`` decorator from junifer.api.decorators import register_marker from junifer.data import get_parcellation + from junifer.datagrabber import DataType from junifer.markers import BaseMarker + from junifer.storage import StorageType from junifer.typing import Dependencies, MarkerInOutMappings from nilearn.maskers import NiftiLabelsMasker @@ -206,25 +199,18 @@ Finally, we need to register the Marker using the ``@register_marker`` decorator _DEPENDENCIES: ClassVar[Dependencies] = {"nilearn", "numpy"} _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { - "BOLD": { - "parcel_mean": "timeseries", + DataType.BOLD: { + "parcel_mean": StorageType.Timeseries, }, - "VBM_WM": { - "parcel_mean": "vector", + DataType.VBM_WM: { + "parcel_mean": StorageType.Vector, }, - "VBM_GM": { - "parcel_mean": "vector", + DataType.VBM_GM: { + "parcel_mean": StorageType.Vector, }, } - def __init__( - self, - parcellation: str, - on: str | list[str] | None = None, - name: str | None = None, - ) -> None: - self.parcellation = parcellation - super().__init__(on=on, name=name) + parcellation: str def compute( self, @@ -236,7 +222,7 @@ Finally, we need to register the Marker using the ``@register_marker`` decorator # Get the parcellation tailored for the target t_parcellation, t_labels, _ = get_parcellation( - name=self.parcellation_name, + name=self.parcellation, target_data=input, extra_input=extra_input, ) @@ -280,9 +266,13 @@ Template for a custom Marker # TODO: add the input-output mappings _MARKER_INOUT_MAPPINGS = {} - def __init__(self, on=None, name=None): - # TODO: add marker-specific parameters - super().__init__(on=on, name=name) + # TODO: define marker-specific parameters + + # optional + def validate_marker_params(self): + # TODO: add validation logic for marker parameters + pass def compute(self, input, extra_input): # TODO: compute the marker and create the output dictionary + return {} diff --git a/docs/extending/preprocessor.rst b/docs/extending/preprocessor.rst index 7dea78df7..47756dbd4 100644 --- a/docs/extending/preprocessor.rst +++ b/docs/extending/preprocessor.rst @@ -16,8 +16,7 @@ own Preprocessor. While implementing your own Preprocessor, you need to always inherit from :class:`.BasePreprocessor` and implement a few methods and class attributes: -#. ``__init__``: The initialisation method, where the Preprocessor is - configured. +#. (optional) ``validate_preprocessor_params``: The method to perform logical validation of parameters (if required). #. ``preprocess``: The method that given the data, preprocesses the data. As an example, we will develop a ``NilearnSmoothing`` Preprocessor, which @@ -36,15 +35,15 @@ For input we can accept ``T1w``, ``T2w`` and ``BOLD`` .. code-block:: python - _VALID_DATA_TYPES = ["T1w", "T2w", "BOLD"] + _VALID_DATA_TYPES = [DataType.T1w, DataType.T2w, DataType.BOLD] .. _extending_preprocessors_init: Step 2: Initialise the Preprocessor ----------------------------------- -Now we need to define our Preprocessor class' constructor which is also how -you configure it. Our class will have the following arguments: +Now we need to define our Preprocessor class' parameters. +Our class will have the following arguments: 1. ``fwhm``: The smoothing strength as a full-width at half maximum (in millimetres). Since we depend on :func:`nilearn.image.smooth_img`, we @@ -59,6 +58,8 @@ you configure it. Our class will have the following arguments: are allowed as parameters. This is because the parameters are stored in JSON format, and JSON only supports these types. +As :class:`.BasePreprocessor` already defines ``on``, we can define the other: + .. code-block:: python from typing import Literal @@ -68,15 +69,7 @@ you configure it. Our class will have the following arguments: ... - - def __init__( - self, - fwhm: int | float | ArrayLike | Literal["fast"] | None, - on: str | list[str] | None = None, - ) -> None: - self.fwhm = fwhm - super().__init__(on=on) - + fwhm: int | float | ArrayLike | Literal["fast"] | None ... @@ -154,6 +147,7 @@ decorator and our final code should look like this: from typing import Any, ClassVar, Literal from junifer.api.decorators import register_preprocessor + from junifer.datagrabber import DataType from junifer.preprocess import BasePreprocessor from nilearn import image as nimg @@ -165,15 +159,9 @@ decorator and our final code should look like this: _DEPENDENCIES = {"nilearn"} - _VALID_DATA_TYPES: ClassVar[Sequence[str]] = ["T1w", "T2w", "BOLD"] + _VALID_DATA_TYPES: ClassVar[Sequence[DataType]] = [DataType.T1w, DataType.T2w, DataType.BOLD] - def __init__( - self, - fwhm: int | float | ArrayLike | Literal["fast"] | None, - on: str | list[str] | None = None, - ) -> None: - self.fwhm = fwhm - super().__init__(on=on) + fwhm: int | float | ArrayLike | Literal["fast"] | None def preprocess( self, @@ -204,9 +192,12 @@ Template for a custom Preprocessor # TODO: add the inputs _VALID_DATA_TYPES = [] - def __init__(self, on=None): - # TODO: add preprocessor-specific parameters - super().__init__(on=on) + # TODO: define preprocessor-specific parameters + + # optional + def validate_preprocessor_params(self): + # TODO: add validation logic for preprocessor parameters + pass def preprocess(self, input, extra_input): # TODO: add the preprocessor logic From d354e58252e0fd1642a511f36c0090e704272e57 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 09:05:15 +0100 Subject: [PATCH 90/99] chore: fix tests for pipeline --- .../pipeline/tests/test_marker_collection.py | 22 +++++++++---------- .../tests/test_pipeline_step_mixin.py | 21 ------------------ .../pipeline/tests/test_update_meta_mixin.py | 5 +++-- 3 files changed, 14 insertions(+), 34 deletions(-) diff --git a/junifer/pipeline/tests/test_marker_collection.py b/junifer/pipeline/tests/test_marker_collection.py index 9a96893f9..79b4fa92c 100644 --- a/junifer/pipeline/tests/test_marker_collection.py +++ b/junifer/pipeline/tests/test_marker_collection.py @@ -26,12 +26,12 @@ def test_marker_collection_incorrect_markers() -> None: """Test incorrect markers for MarkerCollection.""" wrong_markers = [ ParcelAggregation( - parcellation="Schaefer100x7", + parcellation=["Schaefer100x7"], method="mean", name="gmd_schaefer100x7_mean", ), ParcelAggregation( - parcellation="Schaefer100x7", + parcellation=["Schaefer100x7"], method="mean", name="gmd_schaefer100x7_mean", ), @@ -44,17 +44,17 @@ def test_marker_collection() -> None: """Test MarkerCollection.""" markers = [ ParcelAggregation( - parcellation="TianxS2x3TxMNInonlinear2009cAsym", + parcellation=["TianxS2x3TxMNInonlinear2009cAsym"], method="mean", name="tian_mean", ), ParcelAggregation( - parcellation="TianxS2x3TxMNInonlinear2009cAsym", + parcellation=["TianxS2x3TxMNInonlinear2009cAsym"], method="std", name="tian_std", ), ParcelAggregation( - parcellation="TianxS2x3TxMNInonlinear2009cAsym", + parcellation=["TianxS2x3TxMNInonlinear2009cAsym"], method="trim_mean", method_params={"proportiontocut": 0.1}, name="tian_trim_mean90", @@ -115,12 +115,12 @@ def test_marker_collection_with_preprocessing() -> None: """Test MarkerCollection with preprocessing.""" markers = [ FunctionalConnectivityParcels( - parcellation="Schaefer100x17", + parcellation=["Schaefer100x17"], agg_method="mean", name="Schaefer100x17_mean_FC", ), FunctionalConnectivityParcels( - parcellation="TianxS2x3TxMNInonlinear2009cAsym", + parcellation=["TianxS2x3TxMNInonlinear2009cAsym"], agg_method="mean", name="TianxS2x3TxMNInonlinear2009cAsym_mean_FC", ), @@ -150,17 +150,17 @@ def test_marker_collection_storage(tmp_path: Path) -> None: """ markers = [ ParcelAggregation( - parcellation="TianxS2x3TxMNInonlinear2009cAsym", + parcellation=["TianxS2x3TxMNInonlinear2009cAsym"], method="mean", name="tian_mean", ), ParcelAggregation( - parcellation="TianxS2x3TxMNInonlinear2009cAsym", + parcellation=["TianxS2x3TxMNInonlinear2009cAsym"], method="std", name="tian_std", ), ParcelAggregation( - parcellation="TianxS2x3TxMNInonlinear2009cAsym", + parcellation=["TianxS2x3TxMNInonlinear2009cAsym"], method="trim_mean", method_params={"proportiontocut": 0.1}, name="tian_trim_mean90", @@ -170,7 +170,7 @@ def test_marker_collection_storage(tmp_path: Path) -> None: dg = PartlyCloudyTestingDataGrabber() # Setup storage storage = SQLiteFeatureStorage( - tmp_path / "test_marker_collection_storage.sqlite" + uri=tmp_path / "test_marker_collection_storage.sqlite" ) mc = MarkerCollection( markers=markers, # type: ignore diff --git a/junifer/pipeline/tests/test_pipeline_step_mixin.py b/junifer/pipeline/tests/test_pipeline_step_mixin.py index 1b3b63084..d7b3329cb 100644 --- a/junifer/pipeline/tests/test_pipeline_step_mixin.py +++ b/junifer/pipeline/tests/test_pipeline_step_mixin.py @@ -127,27 +127,6 @@ def _fit_transform(self, input: dict[str, dict]) -> dict[str, dict]: mixer.fit_transform({}) -def test_PipelineStepMixin_incorrect_ext_dependencies() -> None: - """Test fit-transform with incorrect external dependencies.""" - - class IncorrectMixer(PipelineStepMixin): - """Test class for validation.""" - - _EXT_DEPENDENCIES: ClassVar[ExternalDependencies] = [ - {"name": "foobar", "optional": True} - ] - - def validate_input(self, input: list[str]) -> list[str]: - return input - - def _fit_transform(self, input: dict[str, dict]) -> dict[str, dict]: - return {"input": input} - - mixer = IncorrectMixer() - with pytest.raises(ValueError, match="Invalid value"): - mixer.fit_transform({}) - - def test_PipelineStepMixin_correct_conditional_dependencies() -> None: """Test fit-transform with correct conditional dependencies.""" diff --git a/junifer/pipeline/tests/test_update_meta_mixin.py b/junifer/pipeline/tests/test_update_meta_mixin.py index 2aed1478c..2f4b596c6 100644 --- a/junifer/pipeline/tests/test_update_meta_mixin.py +++ b/junifer/pipeline/tests/test_update_meta_mixin.py @@ -7,8 +7,9 @@ from typing import Union import pytest +from pydantic import BaseModel -from junifer.pipeline.update_meta_mixin import UpdateMetaMixin +from junifer.pipeline import UpdateMetaMixin @pytest.mark.parametrize( @@ -41,7 +42,7 @@ def test_UpdateMetaMixin( """ - class TestUpdateMetaMixin(UpdateMetaMixin): + class TestUpdateMetaMixin(BaseModel, UpdateMetaMixin): """Test UpdateMetaMixin.""" _DEPENDENCIES = dependencies From eb8762a41628d7c7a67846aec0c826990f345b9b Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 09:05:49 +0100 Subject: [PATCH 91/99] chore: fix tests and data for cli --- junifer/cli/tests/data/gmd_mean.yaml | 2 +- junifer/cli/tests/data/gmd_mean_htcondor.yaml | 3 ++- junifer/cli/tests/data/partly_cloudy_agg_mean_tian.yml | 2 +- junifer/cli/tests/test_cli_utils.py | 3 +++ 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/junifer/cli/tests/data/gmd_mean.yaml b/junifer/cli/tests/data/gmd_mean.yaml index 4ee9319cf..2779719df 100644 --- a/junifer/cli/tests/data/gmd_mean.yaml +++ b/junifer/cli/tests/data/gmd_mean.yaml @@ -7,7 +7,7 @@ elements: [1, 2] markers: - name: Schaefer1000x7_Mean kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: mean storage: kind: SQLiteFeatureStorage diff --git a/junifer/cli/tests/data/gmd_mean_htcondor.yaml b/junifer/cli/tests/data/gmd_mean_htcondor.yaml index 892efb48d..e1c722b4e 100644 --- a/junifer/cli/tests/data/gmd_mean_htcondor.yaml +++ b/junifer/cli/tests/data/gmd_mean_htcondor.yaml @@ -6,7 +6,7 @@ datagrabber: markers: - name: Schaefer1000x7_Mean kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: mean storage: kind: SQLiteFeatureStorage @@ -17,4 +17,5 @@ queue: env: kind: conda name: junifer + shell: bash mem: 8G diff --git a/junifer/cli/tests/data/partly_cloudy_agg_mean_tian.yml b/junifer/cli/tests/data/partly_cloudy_agg_mean_tian.yml index da4ee0448..be38c03e3 100644 --- a/junifer/cli/tests/data/partly_cloudy_agg_mean_tian.yml +++ b/junifer/cli/tests/data/partly_cloudy_agg_mean_tian.yml @@ -7,7 +7,7 @@ datagrabber: markers: - kind: ParcelAggregation - parcellation: TianxS1x3TxMNInonlinear2009cAsym + parcellation: [TianxS1x3TxMNInonlinear2009cAsym] method: mean name: tian-s1-3T_mean diff --git a/junifer/cli/tests/test_cli_utils.py b/junifer/cli/tests/test_cli_utils.py index 1f47cc2d6..7fbb516a4 100644 --- a/junifer/cli/tests/test_cli_utils.py +++ b/junifer/cli/tests/test_cli_utils.py @@ -34,6 +34,7 @@ def test_get_dependency_information_short() -> None: """Test short version of _get_dependency_information().""" dependency_information = _get_dependency_information(long_=False) dependency_list = [ + "aenum", "click", "numpy", "scipy", @@ -50,6 +51,8 @@ def test_get_dependency_information_short() -> None: "looseversion", "junifer_data", "structlog", + "pydantic", + "typing_extensions", ] if sys.version_info < (3, 11): From f1a064cd2de4bb683fda42a9f0aa325ea0a41afb Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 09:06:13 +0100 Subject: [PATCH 92/99] chore: fix tests for api --- junifer/api/tests/test_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/junifer/api/tests/test_functions.py b/junifer/api/tests/test_functions.py index 5dc0e2c59..1f4fc6eae 100644 --- a/junifer/api/tests/test_functions.py +++ b/junifer/api/tests/test_functions.py @@ -62,19 +62,19 @@ def datagrabber() -> dict[str, str]: @pytest.fixture -def markers() -> list[dict[str, str]]: +def markers() -> list[dict[str, Union[list[str], str]]]: """Return markers as a list of dictionary.""" return [ { "name": "tian-s1-3T_mean", "kind": "ParcelAggregation", - "parcellation": "TianxS1x3TxMNInonlinear2009cAsym", + "parcellation": ["TianxS1x3TxMNInonlinear2009cAsym"], "method": "mean", }, { "name": "tian-s1-3T_std", "kind": "ParcelAggregation", - "parcellation": "TianxS1x3TxMNInonlinear2009cAsym", + "parcellation": ["TianxS1x3TxMNInonlinear2009cAsym"], "method": "std", }, ] From 97f84c6b9b735805c8aab4546d9e988d37c6cff0 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 09:06:43 +0100 Subject: [PATCH 93/99] update: improve constraints for QueueContextEnv --- junifer/api/queue_context/queue_context_adapter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/junifer/api/queue_context/queue_context_adapter.py b/junifer/api/queue_context/queue_context_adapter.py index 71dcb1a7b..311395e45 100644 --- a/junifer/api/queue_context/queue_context_adapter.py +++ b/junifer/api/queue_context/queue_context_adapter.py @@ -4,6 +4,7 @@ # License: AGPL import sys +from typing import Required if sys.version_info < (3, 12): # pragma: no cover @@ -40,9 +41,9 @@ class EnvShell(str, Enum): class QueueContextEnv(TypedDict, total=False): """Accepted environment configuration for queue context.""" + kind: Required[EnvKind] name: str - kind: EnvKind - shell: EnvShell + shell: Required[EnvShell] class QueueContextAdapter(BaseModel, ABC): From 5d02165ff0187e946a345649ce7fe556a0f84d87 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 09:07:39 +0100 Subject: [PATCH 94/99] update: add guard for env.name in context adapters --- junifer/api/queue_context/gnu_parallel_local_adapter.py | 4 +++- junifer/api/queue_context/htcondor_adapter.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/junifer/api/queue_context/gnu_parallel_local_adapter.py b/junifer/api/queue_context/gnu_parallel_local_adapter.py index 16de556ed..9c828e2c0 100644 --- a/junifer/api/queue_context/gnu_parallel_local_adapter.py +++ b/junifer/api/queue_context/gnu_parallel_local_adapter.py @@ -9,7 +9,7 @@ from typing import Any, Optional from ...typing import Elements -from ...utils import logger, make_executable, run_ext_cmd +from ...utils import logger, make_executable, raise_error, run_ext_cmd from .queue_context_adapter import ( EnvKind, EnvShell, @@ -79,6 +79,8 @@ def model_post_init(self, context: Any): # noqa: D102 self._executable = "junifer" self._arguments = "" else: + if self.env["name"] is None: + raise_error("`env.name` is required") self._executable = f"run_{self.env['kind']}.{self.env['shell']}" self._arguments = f"{self.env['name']} junifer" self._exec_path = self.job_dir / self._executable diff --git a/junifer/api/queue_context/htcondor_adapter.py b/junifer/api/queue_context/htcondor_adapter.py index f895eb3a8..06c368c5c 100644 --- a/junifer/api/queue_context/htcondor_adapter.py +++ b/junifer/api/queue_context/htcondor_adapter.py @@ -10,7 +10,7 @@ from typing import Any, Optional from ...typing import Elements -from ...utils import logger, make_executable, run_ext_cmd +from ...utils import logger, make_executable, raise_error, run_ext_cmd from .queue_context_adapter import ( EnvKind, EnvShell, @@ -112,6 +112,8 @@ def model_post_init(self, context: Any): # noqa: D102 self._executable = "junifer" self._arguments = "" else: + if self.env["name"] is None: + raise_error("`env.name` is required") self._executable = f"run_{self.env['kind']}.{self.env['shell']}" self._arguments = f"{self.env['name']} junifer" self._exec_path = self.job_dir / self._executable From 1b1cf17da635094e2a639407b8777fbca6310fbd Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 09:23:28 +0100 Subject: [PATCH 95/99] docs: update examples --- examples/run_compute_parcel_mean.py | 11 +++++++++-- examples/run_datagrabber_bids_datalad.py | 10 ++++++---- examples/run_ets_rss_marker.py | 4 ++-- examples/run_junifer_julearn.py | 6 +++--- examples/run_run_gmd_mean.py | 6 +++--- examples/yamls/gmd_mean.yaml | 6 +++--- examples/yamls/gmd_mean_htcondor.yaml | 2 +- examples/yamls/partly_cloudy_agg_mean_tian.yml | 2 +- examples/yamls/ukb_gmd_mean.yaml | 6 +++--- 9 files changed, 31 insertions(+), 22 deletions(-) diff --git a/examples/run_compute_parcel_mean.py b/examples/run_compute_parcel_mean.py index f2aa546a5..5ad37f39f 100644 --- a/examples/run_compute_parcel_mean.py +++ b/examples/run_compute_parcel_mean.py @@ -15,8 +15,10 @@ OasisVBMTestingDataGrabber, SPMAuditoryTestingDataGrabber, ) +from junifer.datagrabber import DataType from junifer.datareader import DefaultDataReader from junifer.markers import ParcelAggregation +from junifer.stats import AggFunc from junifer.utils import configure_logging @@ -32,7 +34,10 @@ # Read the element element_data = DefaultDataReader().fit_transform(dg[element]) # Initialize marker - marker = ParcelAggregation(parcellation="Schaefer100x7", method="mean") + marker = ParcelAggregation( + parcellation=["Schaefer100x7"], + method=AggFunc.Mean, + ) # Compute feature feature = marker.fit_transform(element_data) # Print the output @@ -48,7 +53,9 @@ element_data = DefaultDataReader().fit_transform(dg[element]) # Initialize marker marker = ParcelAggregation( - parcellation="Schaefer100x7", method="mean", on="BOLD" + parcellation=["Schaefer100x7"], + method=AggFunc.Mean, + on=[DataType.BOLD], ) # Compute feature feature = marker.fit_transform(element_data) diff --git a/examples/run_datagrabber_bids_datalad.py b/examples/run_datagrabber_bids_datalad.py index ce1488636..7f5be5624 100644 --- a/examples/run_datagrabber_bids_datalad.py +++ b/examples/run_datagrabber_bids_datalad.py @@ -10,8 +10,10 @@ License: BSD 3 clause """ -from junifer.datagrabber import PatternDataladDataGrabber +from pathlib import Path +from junifer.datagrabber import DataType, PatternDataladDataGrabber from junifer.utils import configure_logging +from pydantic import HttpUrl ############################################################################### @@ -23,7 +25,7 @@ # The BIDS DataGrabber requires three parameters: the types of data we want, # the specific pattern that matches each type, and the variables that will be # replaced in the patterns. -types = ["T1w", "BOLD"] +types = [DataType.T1w, DataType.BOLD] patterns = { "T1w": { "pattern": "{subject}/anat/{subject}_T1w.nii.gz", @@ -38,8 +40,8 @@ ############################################################################### # Additionally, a datalad-based DataGrabber requires the URI of the remote # sibling and the location of the dataset within the remote sibling. -repo_uri = "https://gin.g-node.org/juaml/datalad-example-bids" -rootdir = "example_bids" +repo_uri = HttpUrl("https://gin.g-node.org/juaml/datalad-example-bids") +rootdir = Path("example_bids") ############################################################################### # Now we can use the DataGrabber within a `with` context. diff --git a/examples/run_ets_rss_marker.py b/examples/run_ets_rss_marker.py index 53dcb4049..f39e44c14 100644 --- a/examples/run_ets_rss_marker.py +++ b/examples/run_ets_rss_marker.py @@ -36,12 +36,12 @@ { "name": "Schaefer100x17_RSSETS", "kind": "RSSETSMarker", - "parcellation": "Schaefer100x17", + "parcellation": ["Schaefer100x17"], }, { "name": "Schaefer200x17_RSSETS", "kind": "RSSETSMarker", - "parcellation": "Schaefer200x17", + "parcellation": ["Schaefer200x17"], }, ] diff --git a/examples/run_junifer_julearn.py b/examples/run_junifer_julearn.py index ad55c1d1b..56539ad51 100644 --- a/examples/run_junifer_julearn.py +++ b/examples/run_junifer_julearn.py @@ -20,7 +20,7 @@ import junifer.testing.registry # noqa: F401 from junifer.api import collect, run -from junifer.storage.sqlite import SQLiteFeatureStorage +from junifer.storage import SQLiteFeatureStorage from junifer.utils import configure_logging @@ -36,14 +36,14 @@ { "name": "Schaefer100x17_TrimMean80", "kind": "ParcelAggregation", - "parcellation": "Schaefer100x17", + "parcellation": ["Schaefer100x17"], "method": "trim_mean", "method_params": {"proportiontocut": 0.2}, }, { "name": "Schaefer200x17_Mean", "kind": "ParcelAggregation", - "parcellation": "Schaefer200x17", + "parcellation": ["Schaefer200x17"], "method": "mean", }, ] diff --git a/examples/run_run_gmd_mean.py b/examples/run_run_gmd_mean.py index 277ba8f42..cbe953730 100644 --- a/examples/run_run_gmd_mean.py +++ b/examples/run_run_gmd_mean.py @@ -20,20 +20,20 @@ { "name": "Schaefer1000x7_TrimMean80", "kind": "ParcelAggregation", - "parcellation": "Schaefer1000x7", + "parcellation": ["Schaefer1000x7"], "method": "trim_mean", "method_params": {"proportiontocut": 0.2}, }, { "name": "Schaefer1000x7_Mean", "kind": "ParcelAggregation", - "parcellation": "Schaefer1000x7", + "parcellation": ["Schaefer1000x7"], "method": "mean", }, { "name": "Schaefer1000x7_Std", "kind": "ParcelAggregation", - "parcellation": "Schaefer1000x7", + "parcellation": ["Schaefer1000x7"], "method": "std", }, ] diff --git a/examples/yamls/gmd_mean.yaml b/examples/yamls/gmd_mean.yaml index 9b944b407..c6bff0d20 100644 --- a/examples/yamls/gmd_mean.yaml +++ b/examples/yamls/gmd_mean.yaml @@ -7,17 +7,17 @@ elements: markers: - name: Schaefer1000x7_TrimMean80 kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: trim_mean method_params: proportiontocut: 0.2 - name: Schaefer1000x7_Mean kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: mean - name: Schaefer1000x7_Std kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: std storage: kind: SQLiteFeatureStorage diff --git a/examples/yamls/gmd_mean_htcondor.yaml b/examples/yamls/gmd_mean_htcondor.yaml index f29be473c..d12ba61f4 100644 --- a/examples/yamls/gmd_mean_htcondor.yaml +++ b/examples/yamls/gmd_mean_htcondor.yaml @@ -6,7 +6,7 @@ datagrabber: markers: - name: Schaefer1000x7_Mean kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: mean storage: kind: SQLiteFeatureStorage diff --git a/examples/yamls/partly_cloudy_agg_mean_tian.yml b/examples/yamls/partly_cloudy_agg_mean_tian.yml index da4ee0448..be38c03e3 100644 --- a/examples/yamls/partly_cloudy_agg_mean_tian.yml +++ b/examples/yamls/partly_cloudy_agg_mean_tian.yml @@ -7,7 +7,7 @@ datagrabber: markers: - kind: ParcelAggregation - parcellation: TianxS1x3TxMNInonlinear2009cAsym + parcellation: [TianxS1x3TxMNInonlinear2009cAsym] method: mean name: tian-s1-3T_mean diff --git a/examples/yamls/ukb_gmd_mean.yaml b/examples/yamls/ukb_gmd_mean.yaml index 921410a1a..f715452df 100644 --- a/examples/yamls/ukb_gmd_mean.yaml +++ b/examples/yamls/ukb_gmd_mean.yaml @@ -7,17 +7,17 @@ elements: markers: - name: Schaefer1000x7_TrimMean80 kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: trim_mean method_params: proportiontocut: 0.2 - name: Schaefer1000x7_Mean kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: mean - name: Schaefer1000x7_Std kind: ParcelAggregation - parcellation: Schaefer1000x7 + parcellation: [Schaefer1000x7] method: std storage: kind: SQLiteFeatureStorage From 8af38994b4eedf768c4e7820853e793e8a1f48ed Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 10:28:23 +0100 Subject: [PATCH 96/99] docs: update extending and using docs --- docs/extending/marker.rst | 14 ++++++++------ docs/extending/masks.rst | 2 +- docs/extending/parcellations.rst | 2 +- docs/using/codeless.rst | 8 ++++---- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/docs/extending/marker.rst b/docs/extending/marker.rst index d65d150f4..303b32868 100644 --- a/docs/extending/marker.rst +++ b/docs/extending/marker.rst @@ -119,7 +119,7 @@ and the values would be a dictionary of storage type specific key-value pairs. from typing import Any - from junifer.data import get_parcellation + from junifer.data import get_data from nilearn.maskers import NiftiLabelsMasker @@ -132,8 +132,9 @@ and the values would be a dictionary of storage type specific key-value pairs. data = input["data"] # Get the parcellation tailored for the target - t_parcellation, t_labels, _ = get_parcellation( - name=self.parcellation, + t_parcellation, t_labels, _ = get_data( + kind="parcellation", + name=[self.parcellation], target_data=input, extra_input=extra_input, ) @@ -185,7 +186,7 @@ Finally, we need to register the Marker using the ``@register_marker`` decorator from typing import Any, ClassVar from junifer.api.decorators import register_marker - from junifer.data import get_parcellation + from junifer.data import get_data from junifer.datagrabber import DataType from junifer.markers import BaseMarker from junifer.storage import StorageType @@ -221,8 +222,9 @@ Finally, we need to register the Marker using the ``@register_marker`` decorator data = input["data"] # Get the parcellation tailored for the target - t_parcellation, t_labels, _ = get_parcellation( - name=self.parcellation, + t_parcellation, t_labels, _ = get_data( + kind="parcellation", + name=[self.parcellation], target_data=input, extra_input=extra_input, ) diff --git a/docs/extending/masks.rst b/docs/extending/masks.rst index 05763f9a1..df54be372 100644 --- a/docs/extending/masks.rst +++ b/docs/extending/masks.rst @@ -86,7 +86,7 @@ mask as an argument. For example: markers: - name: CustomMaskParcelAggregation_mean kind: ParcelAggregation - parcellation: Schaefer200x17 + parcellation: [Schaefer200x17] method: mean masks: "my_custom_mask" diff --git a/docs/extending/parcellations.rst b/docs/extending/parcellations.rst index f49b300b4..a6d76a47d 100644 --- a/docs/extending/parcellations.rst +++ b/docs/extending/parcellations.rst @@ -124,7 +124,7 @@ parcellation when registering it. For example, we can add a markers: - name: CustomParcellation_mean kind: ParcelAggregation - parcellation: my_custom_parcellation + parcellation: [] method: mean Now, you can simply use this YAML file to run your pipeline. diff --git a/docs/using/codeless.rst b/docs/using/codeless.rst index 0a075c4fe..8dbe209f8 100644 --- a/docs/using/codeless.rst +++ b/docs/using/codeless.rst @@ -171,11 +171,11 @@ For the ``Oasis VBM Testing dataset`` example, we want to compute the mean markers: - name: Schaefer100x7_mean kind: ParcelAggregation - parcellation: Schaefer100x7 + parcellation: [Schaefer100x7] method: mean - name: Schaefer200x7_mean kind: ParcelAggregation - parcellation: Schaefer200x7 + parcellation: [Schaefer200x7] method: mean - name: DMNBuckner_5mm_mean kind: SphereAggregation @@ -219,11 +219,11 @@ looks like: markers: - name: Schaefer100x7_mean kind: ParcelAggregation - parcellation: Schaefer100x7 + parcellation: [Schaefer100x7] method: mean - name: Schaefer200x7_mean kind: ParcelAggregation - parcellation: Schaefer200x7 + parcellation: [Schaefer200x7] method: mean - name: DMNBuckner_5mm_mean kind: SphereAggregation From d9061c3157dab4bcdb22cd64d0eeae3f3b09ea2b Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 10:37:19 +0100 Subject: [PATCH 97/99] chore: add changelog 364.enh --- docs/changes/newsfragments/364.enh | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/changes/newsfragments/364.enh diff --git a/docs/changes/newsfragments/364.enh b/docs/changes/newsfragments/364.enh new file mode 100644 index 000000000..be3654b9a --- /dev/null +++ b/docs/changes/newsfragments/364.enh @@ -0,0 +1 @@ +Adopt Pydantic for user-facing core objects to perform automatic validation using typing annotations by `Synchon Mandal`_ From 5590869b50abe15f171445c829424d55203c5326 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 14:06:25 +0100 Subject: [PATCH 98/99] fix: conditional import for Required in queue_context_adapter.py --- junifer/api/queue_context/queue_context_adapter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/junifer/api/queue_context/queue_context_adapter.py b/junifer/api/queue_context/queue_context_adapter.py index 311395e45..a5bc48c5a 100644 --- a/junifer/api/queue_context/queue_context_adapter.py +++ b/junifer/api/queue_context/queue_context_adapter.py @@ -4,7 +4,6 @@ # License: AGPL import sys -from typing import Required if sys.version_info < (3, 12): # pragma: no cover @@ -12,6 +11,11 @@ else: from typing import TypedDict +if sys.version_info < (3, 11): # pragma: no cover + from typing_extensions import Required +else: + from typing import Required + from abc import ABC, abstractmethod from enum import Enum From df9cc955954466f9d0de15f4ff4a8cd48a87e671 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Tue, 18 Nov 2025 17:19:37 +0100 Subject: [PATCH 99/99] fix: conditional import for TypedDict --- junifer/datagrabber/pattern_validation_mixin.py | 10 +++++++++- .../preprocess/confounds/fmriprep_confound_remover.py | 1 - junifer/typing/_typing.py | 9 ++++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/junifer/datagrabber/pattern_validation_mixin.py b/junifer/datagrabber/pattern_validation_mixin.py index 4a108d946..eadd0e675 100644 --- a/junifer/datagrabber/pattern_validation_mixin.py +++ b/junifer/datagrabber/pattern_validation_mixin.py @@ -3,8 +3,16 @@ # Authors: Synchon Mandal # License: AGPL +import sys + + +if sys.version_info < (3, 12): # pragma: no cover + from typing_extensions import TypedDict +else: + from typing import TypedDict + + from collections.abc import Iterator, MutableMapping -from typing import TypedDict from aenum import extend_enum diff --git a/junifer/preprocess/confounds/fmriprep_confound_remover.py b/junifer/preprocess/confounds/fmriprep_confound_remover.py index 8ae721dfc..5eae6e41b 100644 --- a/junifer/preprocess/confounds/fmriprep_confound_remover.py +++ b/junifer/preprocess/confounds/fmriprep_confound_remover.py @@ -19,7 +19,6 @@ Any, ClassVar, Optional, - TypedDict, Union, ) diff --git a/junifer/typing/_typing.py b/junifer/typing/_typing.py index a53b3181a..eb89727fa 100644 --- a/junifer/typing/_typing.py +++ b/junifer/typing/_typing.py @@ -3,10 +3,17 @@ # Authors: Synchon Mandal # License: AGPL +import sys + + +if sys.version_info < (3, 12): # pragma: no cover + from typing_extensions import TypedDict +else: + from typing import TypedDict + from collections.abc import Sequence from typing import ( TYPE_CHECKING, - TypedDict, Union, )