Skip to content

Commit b69cedd

Browse files
authored
Merge pull request #439 from jdkent/fix/resampling
[FIX,REF] start changing how to handle resampling
2 parents c221825 + a3bf504 commit b69cedd

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

nimare/base.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from abc import ABCMeta, abstractmethod
1010
from collections import defaultdict
1111

12+
import nibabel as nb
1213
import numpy as np
14+
from nilearn._utils.niimg_conversions import _check_same_fov
15+
from nilearn.image import concat_imgs, resample_to_img
1316

1417
from .results import MetaResult
1518
from .utils import get_masker
@@ -283,14 +286,44 @@ def __init__(self, *args, **kwargs):
283286
mask = get_masker(mask)
284287
self.masker = mask
285288

289+
self.resample = kwargs.get("resample", False)
290+
291+
# defaults for resampling images (nilearn's defaults do not work well)
292+
self._resample_kwargs = {"clip": True, "interpolation": "linear"}
293+
self._resample_kwargs.update(
294+
{k.split("resample__")[1]: v for k, v in kwargs.items() if k.startswith("resample__")}
295+
)
296+
286297
def _preprocess_input(self, dataset):
287298
"""Preprocess inputs to the Estimator from the Dataset as needed."""
288299
masker = self.masker or dataset.masker
300+
301+
mask_img = masker.mask_img or masker.labels_img
302+
if isinstance(mask_img, str):
303+
mask_img = nb.load(mask_img)
304+
289305
for name, (type_, _) in self._required_inputs.items():
290306
if type_ == "image":
307+
# If no resampling is requested, check if resampling is required
308+
if not self.resample:
309+
check_imgs = {img: nb.load(img) for img in self.inputs_[name]}
310+
_check_same_fov(**check_imgs, reference_masker=mask_img, raise_error=True)
311+
imgs = list(check_imgs.values())
312+
else:
313+
# resampling will only occur if shape/affines are different
314+
# making this harmless if all img shapes/affines are the same
315+
# as the reference
316+
imgs = [
317+
resample_to_img(nb.load(img), mask_img, **self._resample_kwargs)
318+
for img in self.inputs_[name]
319+
]
320+
321+
# input to NiFtiLabelsMasker must be 4d
322+
img4d = concat_imgs(imgs, ensure_ndim=4)
323+
291324
# Mask required input images using either the dataset's mask or
292325
# the estimator's.
293-
temp_arr = masker.transform(self.inputs_[name])
326+
temp_arr = masker.transform(img4d)
294327

295328
# An intermediate step to mask out bad voxels. Can be dropped
296329
# once PyMARE is able to handle masked arrays or missing data.

nimare/tests/conftest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pytest
88
import nibabel as nib
9+
import numpy as np
10+
from nilearn.image import resample_img
911

1012
import nimare
1113
from nimare.tests.utils import get_test_data_path
@@ -79,3 +81,43 @@ def mni_mask():
7981
return nib.load(
8082
os.path.join(get_resource_path(), "templates", "MNI152_2x2x2_brainmask.nii.gz")
8183
)
84+
85+
86+
@pytest.fixture(scope="session")
87+
def testdata_ibma_resample(tmp_path_factory):
88+
tmpdir = tmp_path_factory.mktemp("testdata_ibma_resample")
89+
90+
# Load dataset
91+
dset_file = os.path.join(get_test_data_path(), "test_pain_dataset.json")
92+
dset_dir = os.path.join(get_test_data_path(), "test_pain_dataset")
93+
mask_file = os.path.join(dset_dir, "mask.nii.gz")
94+
dset = nimare.dataset.Dataset(dset_file, mask=mask_file)
95+
dset.update_path(dset_dir)
96+
97+
# create reproducible random number generator for resampling
98+
rng = np.random.default_rng(seed=123)
99+
# Move image contents of Dataset to temporary directory
100+
for c in dset.images.columns:
101+
if c.endswith("__relative"):
102+
continue
103+
for f in dset.images[c].values:
104+
if (f is None) or not os.path.isfile(f):
105+
continue
106+
new_f = f.replace(
107+
dset_dir.rstrip(os.path.sep), str(tmpdir.absolute()).rstrip(os.path.sep)
108+
)
109+
dirname = os.path.dirname(new_f)
110+
if not os.path.isdir(dirname):
111+
os.makedirs(dirname)
112+
# create random affine to make images different shapes
113+
affine = np.eye(3)
114+
np.fill_diagonal(affine, rng.choice([1, 2, 3]))
115+
img = resample_img(
116+
nib.load(f),
117+
target_affine=affine,
118+
interpolation="linear",
119+
clip=True,
120+
)
121+
nib.save(img, new_f)
122+
dset.update_path(tmpdir)
123+
return dset

nimare/tests/test_meta_ibma.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
Test nimare.meta.ibma (image-based meta-analytic estimators).
33
"""
44
import os.path as op
5+
from contextlib import ExitStack as does_not_raise
56

67
from nilearn.input_data import NiftiLabelsMasker
8+
import pytest
79

810
import nimare
911
from nimare.correct import FDRCorrector, FWECorrector
@@ -135,3 +137,24 @@ def test_ibma_with_custom_masker(testdata_ibma):
135137
meta.fit(testdata_ibma)
136138
assert isinstance(meta.results, nimare.results.MetaResult)
137139
assert meta.results.maps["z"].shape == (5,)
140+
141+
142+
@pytest.mark.parametrize(
143+
"resample,resample_kwargs,expectation",
144+
[
145+
(False, {}, pytest.raises(ValueError)),
146+
(None, {}, pytest.raises(ValueError)),
147+
(True, {}, does_not_raise()),
148+
(
149+
True,
150+
{"resample__clip": False, "resample__interpolation": "continuous"},
151+
does_not_raise(),
152+
),
153+
],
154+
)
155+
def test_ibma_resampling(testdata_ibma_resample, resample, resample_kwargs, expectation):
156+
meta = ibma.Fishers(resample=resample, **resample_kwargs)
157+
with expectation:
158+
meta.fit(testdata_ibma_resample)
159+
if isinstance(expectation, does_not_raise):
160+
assert isinstance(meta.results, nimare.results.MetaResult)

0 commit comments

Comments
 (0)