|
9 | 9 | from abc import ABCMeta, abstractmethod
|
10 | 10 | from collections import defaultdict
|
11 | 11 |
|
| 12 | +import nibabel as nb |
12 | 13 | import numpy as np
|
| 14 | +from nilearn._utils.niimg_conversions import _check_same_fov |
| 15 | +from nilearn.image import concat_imgs, resample_to_img |
13 | 16 |
|
14 | 17 | from .results import MetaResult
|
15 | 18 | from .utils import get_masker
|
@@ -283,14 +286,44 @@ def __init__(self, *args, **kwargs):
|
283 | 286 | mask = get_masker(mask)
|
284 | 287 | self.masker = mask
|
285 | 288 |
|
| 289 | + self.resample = kwargs.get("resample", False) |
| 290 | + |
| 291 | + # defaults for resampling images (nilearn's defaults do not work well) |
| 292 | + self._resample_kwargs = {"clip": True, "interpolation": "linear"} |
| 293 | + self._resample_kwargs.update( |
| 294 | + {k.split("resample__")[1]: v for k, v in kwargs.items() if k.startswith("resample__")} |
| 295 | + ) |
| 296 | + |
286 | 297 | def _preprocess_input(self, dataset):
|
287 | 298 | """Preprocess inputs to the Estimator from the Dataset as needed."""
|
288 | 299 | masker = self.masker or dataset.masker
|
| 300 | + |
| 301 | + mask_img = masker.mask_img or masker.labels_img |
| 302 | + if isinstance(mask_img, str): |
| 303 | + mask_img = nb.load(mask_img) |
| 304 | + |
289 | 305 | for name, (type_, _) in self._required_inputs.items():
|
290 | 306 | if type_ == "image":
|
| 307 | + # If no resampling is requested, check if resampling is required |
| 308 | + if not self.resample: |
| 309 | + check_imgs = {img: nb.load(img) for img in self.inputs_[name]} |
| 310 | + _check_same_fov(**check_imgs, reference_masker=mask_img, raise_error=True) |
| 311 | + imgs = list(check_imgs.values()) |
| 312 | + else: |
| 313 | + # resampling will only occur if shape/affines are different |
| 314 | + # making this harmless if all img shapes/affines are the same |
| 315 | + # as the reference |
| 316 | + imgs = [ |
| 317 | + resample_to_img(nb.load(img), mask_img, **self._resample_kwargs) |
| 318 | + for img in self.inputs_[name] |
| 319 | + ] |
| 320 | + |
| 321 | + # input to NiFtiLabelsMasker must be 4d |
| 322 | + img4d = concat_imgs(imgs, ensure_ndim=4) |
| 323 | + |
291 | 324 | # Mask required input images using either the dataset's mask or
|
292 | 325 | # the estimator's.
|
293 |
| - temp_arr = masker.transform(self.inputs_[name]) |
| 326 | + temp_arr = masker.transform(img4d) |
294 | 327 |
|
295 | 328 | # An intermediate step to mask out bad voxels. Can be dropped
|
296 | 329 | # once PyMARE is able to handle masked arrays or missing data.
|
|
0 commit comments