Skip to content

Commit 32577f7

Browse files
authored
Make torch optional (#836)
* Make `torch` optional * Update test_meta_cbmr.py * try import torch for cbmr models as well * Update test_meta_cbmr.py * Update test_meta_cbmr.py * Skip model function * Update test_meta_cbmr.py * Update test_meta_cbmr.py
1 parent 63f0a7e commit 32577f7

File tree

8 files changed

+129
-24
lines changed

8 files changed

+129
-24
lines changed

.github/workflows/testing.yml

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
python-version: ${{ matrix.python-version }}
5252
- name: 'Install NiMARE'
5353
shell: bash {0}
54-
run: pip install -e .[tests]
54+
run: pip install -e .[tests,cbmr]
5555
- name: 'Run tests'
5656
shell: bash {0}
5757
run: make unittest
@@ -83,7 +83,7 @@ jobs:
8383
python-version: 3.8
8484
- name: 'Install NiMARE'
8585
shell: bash {0}
86-
run: pip install -e .[minimum,tests]
86+
run: pip install -e .[minimum,tests,cbmr]
8787
- name: 'Run tests'
8888
shell: bash {0}
8989
run: make unittest
@@ -115,7 +115,7 @@ jobs:
115115
python-version: ${{ matrix.python-version }}
116116
- name: 'Install NiMARE'
117117
shell: bash {0}
118-
run: pip install -e .[tests]
118+
run: pip install -e .[tests,cbmr]
119119
- name: 'Run tests'
120120
shell: bash {0}
121121
run: make test_performance_estimators
@@ -147,7 +147,7 @@ jobs:
147147
python-version: ${{ matrix.python-version }}
148148
- name: 'Install NiMARE'
149149
shell: bash {0}
150-
run: pip install -e .[tests]
150+
run: pip install -e .[tests,cbmr]
151151
- name: 'Run tests'
152152
shell: bash {0}
153153
run: make test_performance_correctors
@@ -179,7 +179,7 @@ jobs:
179179
python-version: ${{ matrix.python-version }}
180180
- name: 'Install NiMARE'
181181
shell: bash {0}
182-
run: pip install -e .[tests]
182+
run: pip install -e .[tests,cbmr]
183183
- name: 'Run tests'
184184
shell: bash {0}
185185
run: make test_performance_smoke
@@ -190,9 +190,41 @@ jobs:
190190
path: coverage.xml
191191
if: success()
192192

193+
test_cbmr_importerror:
194+
name: CBMR ImportError tests
195+
needs: check_skip
196+
if: ${{ needs.check_skip.outputs.skip == 'false' }}
197+
runs-on: ${{ matrix.os }}
198+
strategy:
199+
fail-fast: false
200+
matrix:
201+
os: ["ubuntu-latest"]
202+
python-version: ["3.8"]
203+
defaults:
204+
run:
205+
shell: bash
206+
steps:
207+
- uses: actions/checkout@v2
208+
- name: 'Set up python'
209+
uses: actions/setup-python@v2
210+
with:
211+
python-version: ${{ matrix.python-version }}
212+
- name: 'Install NiMARE'
213+
shell: bash {0}
214+
run: pip install -e .[tests]
215+
- name: 'Run tests'
216+
shell: bash {0}
217+
run: make test_cbmr_importerror
218+
- name: Upload artifacts
219+
uses: actions/upload-artifact@v2
220+
with:
221+
name: cbmr_importerror
222+
path: coverage.xml
223+
if: success()
224+
193225
upload_to_codecov:
194226
name: Upload coverage
195-
needs: [run_unit_tests,run_unit_tests_with_minimum_dependencies,test_performance_estimators,test_performance_correctors,test_performance_smoke]
227+
needs: [run_unit_tests,run_unit_tests_with_minimum_dependencies,test_performance_estimators,test_performance_correctors,test_performance_smoke,test_cbmr_importerror]
196228
runs-on: "ubuntu-latest"
197229
steps:
198230
- name: Checkout

.readthedocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ python:
2020
path: .
2121
extra_requirements:
2222
- doc
23+
- cbmr

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ help:
99
@echo " test_performance_estimators to run performance tests on meta estimators"
1010
@echo " test_performance_correctors to run performance tests on correctors"
1111
@echo " test_performance_smoke to run performance smoke tests"
12+
@echo " test_cbmr_importerror to run cbmr importerror tests"
1213

1314
lint:
1415
@flake8 nimare
1516

1617
unittest:
17-
@py.test -m "not performance_estimators and not performance_correctors and not performance_smoke" --cov-append --cov-report=xml --cov=nimare nimare
18+
@py.test -m "not performance_estimators and not performance_correctors and not performance_smoke and not cbmr_importerror" --cov-append --cov-report=xml --cov=nimare nimare
1819

1920
test_performance_estimators:
2021
@py.test -m "performance_estimators" --cov-append --cov-report=xml --cov=nimare nimare
@@ -24,3 +25,6 @@ test_performance_correctors:
2425

2526
test_performance_smoke:
2627
@py.test -m "performance_smoke" --cov-append --cov-report=xml --cov=nimare nimare
28+
29+
test_cbmr_importerror:
30+
@py.test -m "cbmr_importerror" --cov-append --cov-report=xml --cov=nimare nimare

nimare/meta/cbmr.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Coordinate Based Meta Regression Methods."""
2+
23
import logging
34
import re
45
from functools import wraps
@@ -7,7 +8,13 @@
78
import numpy as np
89
import pandas as pd
910
import scipy
10-
import torch
11+
12+
try:
13+
import torch
14+
except ImportError as e:
15+
raise ImportError(
16+
"Torch is required to use `CBMR` classes. Install with `pip install 'nimare[cbmr]'`."
17+
) from e
1118

1219
from nimare import _version
1320
from nimare.diagnostics import FocusFilter
@@ -22,6 +29,8 @@
2229
class CBMREstimator(Estimator):
2330
"""Coordinate-based meta-regression with a spatial model.
2431
32+
.. versionadded:: 0.1.0
33+
2534
Parameters
2635
----------
2736
group_categories : :obj:`~str` or obj:`~list` or obj:`~None`, optional
@@ -408,6 +417,8 @@ def _fit(self, dataset):
408417
class CBMRInference(object):
409418
"""Statistical inference on outcomes of CBMR.
410419
420+
.. versionadded:: 0.1.0
421+
411422
(intensity estimation and study-level moderator regressors)
412423
413424
Parameters

nimare/meta/models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
import numpy as np
66
import pandas as pd
7-
import torch
7+
8+
try:
9+
import torch
10+
except ImportError as e:
11+
raise ImportError(
12+
"Torch is required to use `CBMR` models. Install with `pip install 'nimare[cbmr]'`."
13+
) from e
814

915
LGR = logging.getLogger(__name__)
1016

nimare/tests/test_meta_cbmr.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
"""Tests for CBMR meta-analytic methods."""
22
import logging
3+
import warnings
34

45
import pytest
5-
import torch
6+
7+
try:
8+
import torch
9+
except ImportError:
10+
warnings.warn("Torch not installed. CBMR tests will be skipped.")
11+
TORCH_INSTALLED = False
12+
else:
13+
TORCH_INSTALLED = True
14+
from nimare.meta import models
15+
from nimare.meta.cbmr import CBMREstimator, CBMRInference
616

717
import nimare
818
from nimare.correct import FDRCorrector, FWECorrector
9-
from nimare.meta import models
10-
from nimare.meta.cbmr import CBMREstimator, CBMRInference
1119
from nimare.transforms import StandardizeField
1220

1321
# numba has a lot of debug messages that are not useful for testing
@@ -16,17 +24,24 @@
1624
logging.getLogger("indexed_gzip").setLevel(logging.WARNING)
1725

1826

19-
@pytest.fixture(
20-
scope="session",
21-
params=[
22-
pytest.param(models.PoissonEstimator, id="Poisson"),
23-
pytest.param(models.NegativeBinomialEstimator, id="NegativeBinomial"),
24-
pytest.param(models.ClusteredNegativeBinomialEstimator, id="ClusteredNegativeBinomial"),
25-
],
26-
)
27-
def model(request):
28-
"""CBMR models."""
29-
return request.param
27+
if TORCH_INSTALLED:
28+
29+
@pytest.fixture(
30+
scope="session",
31+
params=[
32+
pytest.param(models.PoissonEstimator, id="Poisson"),
33+
pytest.param(models.NegativeBinomialEstimator, id="NegativeBinomial"),
34+
pytest.param(
35+
models.ClusteredNegativeBinomialEstimator, id="ClusteredNegativeBinomial"
36+
),
37+
],
38+
)
39+
def model(request):
40+
"""CBMR models."""
41+
return request.param
42+
43+
else:
44+
model = None
3045

3146

3247
@pytest.fixture(scope="session")
@@ -243,3 +258,37 @@ def test_StandardizeField(testdata_cbmr_simulated):
243258
assert dset.annotations["standardized_sample_sizes"].std() == pytest.approx(1.0, abs=1e-3)
244259
assert dset.annotations["standardized_avg_age"].mean() == pytest.approx(0.0, abs=1e-3)
245260
assert dset.annotations["standardized_avg_age"].std() == pytest.approx(1.0, abs=1e-3)
261+
262+
263+
@pytest.mark.cbmr_importerror
264+
def test_cbmr_importerror():
265+
"""Test that ImportErrors are raised when torch is not installed."""
266+
with pytest.raises(ImportError):
267+
from nimare.meta.cbmr import CBMREstimator
268+
269+
CBMREstimator()
270+
271+
with pytest.raises(ImportError):
272+
from nimare.meta.cbmr import CBMRInference
273+
274+
CBMRInference()
275+
276+
with pytest.raises(ImportError):
277+
from nimare.meta.models import GeneralLinearModelEstimator
278+
279+
GeneralLinearModelEstimator()
280+
281+
with pytest.raises(ImportError):
282+
from nimare.meta.models import PoissonEstimator
283+
284+
PoissonEstimator()
285+
286+
with pytest.raises(ImportError):
287+
from nimare.meta.models import NegativeBinomialEstimator
288+
289+
NegativeBinomialEstimator()
290+
291+
with pytest.raises(ImportError):
292+
from nimare.meta.models import ClusteredNegativeBinomialEstimator
293+
294+
ClusteredNegativeBinomialEstimator()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ markers = [
2828
"performance_smoke: mark smoke tests that measure performance",
2929
"performance_estimators: mark tests that measure estimator performance",
3030
"performance_correctors: mark tests that measure corrector performance",
31+
"cbmr_importerror: mark tests that should fail due to missing torch dependencies",
3132
]
3233

3334
[tool.isort]

setup.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,15 @@ install_requires =
5858
seaborn # nimare.reports
5959
sparse>=0.13.0 # for kernel transformers
6060
statsmodels!=0.13.2 # this version doesn't install properly
61-
torch>=2.0 # for cbmr models
6261
tqdm # progress bars throughout package
6362
packages = find:
6463
include_package_data = False
6564

6665
[options.extras_require]
6766
gzip =
6867
indexed_gzip>=1.4.0 # working with gzipped niftis
68+
cbmr =
69+
torch>=2.0 # for cbmr models
6970
doc =
7071
m2r
7172
matplotlib

0 commit comments

Comments
 (0)