Skip to content

Commit 585ae2d

Browse files
committed
Improvements acquisitions with MPI.
Proxy the detector MPI communicator to the instrument instance. When indexing the detectors (or the samplings) of an acquisition, change the communicator of the detectors (or the samplings) to MPI.COMM_SELF. Fix clearing the cached operator when indexing the acquisition. Add MPI-specific tests and execute them with mpirun in the CI.
1 parent 7d8cca6 commit 585ae2d

File tree

7 files changed

+190
-35
lines changed

7 files changed

+190
-35
lines changed

.github/workflows/build-test-publish.yml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,19 @@ jobs:
8888
CIBW_SKIP: "*-musllinux_*"
8989
CIBW_ARCHS: ${{ matrix.platform.arch }}
9090
CIBW_BEFORE_TEST_LINUX: |
91-
yum install -y openmpi-devel
92-
MPICC=/lib64/openmpi/bin/mpicc pip install mpi4py
91+
yum install -y openmpi-devel environment-modules
92+
source /usr/share/Modules/init/sh
93+
module load mpi
94+
pip install mpi4py
9395
CIBW_BEFORE_TEST_MACOS: brew install openmpi
9496
CIBW_TEST_EXTRAS: dev
95-
CIBW_TEST_COMMAND: pytest {package}/tests
97+
# pytest {package}/tests &&
98+
CIBW_TEST_COMMAND: |
99+
if [[ ${{ matrix.platform.os }} == ubuntu-20.04 ]]; then
100+
source /usr/share/Modules/init/sh
101+
module load mpi
102+
fi
103+
mpirun -np 6 --oversubscribe --allow-run-as-root pytest -m mpi --no-cov {package}/tests
96104
PYTHONFAULTHANDLER: "1"
97105

98106
- name: Build macosx_arm64

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ repos:
1313
- --all
1414

1515
- repo: https://github.com/PyCQA/isort
16-
rev: '5.10.1'
16+
rev: '5.12.0'
1717
hooks:
1818
- id: isort
1919
args:

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ addopts = "-ra --cov=pysimulators"
8282
testpaths = [
8383
"tests",
8484
]
85+
markers = [
86+
"mpi: mark tests to be run using mpirun.",
87+
]
8588

8689
[tool.setuptools_scm]
8790
version_scheme = "post-release"

src/pysimulators/acquisitions.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,23 +144,43 @@ def __getitem__(self, x):
144144
Restrict to the first 10 pixels of the scene:
145145
>>> new_acq = acq[..., :10]
146146
"""
147+
148+
def is_colon(x):
149+
return isinstance(x, slice) and x == slice(None)
150+
147151
out = copy(self)
148152
if not isinstance(x, tuple):
149-
out.instrument = self.instrument[x]
150-
return out
151-
if len(x) == 2 and x[0] is Ellipsis:
152-
x = Ellipsis, Ellipsis, x[1]
153-
if len(x) > 3:
153+
x = (x,)
154+
elif len(x) == 2 and x[0] is Ellipsis:
155+
x = slice(None), slice(None), x[1]
156+
elif len(x) > 3:
154157
raise ValueError('Invalid selection.')
155-
x = x + (3 - len(x)) * (Ellipsis,)
156-
if x[2] is not Ellipsis and (
157-
not isinstance(x[2], slice) or x[2] == slice(None)
158-
):
158+
159+
x = tuple(slice(None) if _ is Ellipsis else _ for _ in x)
160+
x = x + (3 - len(x)) * (slice(None),)
161+
162+
if all(is_colon(_) for _ in x):
163+
return out
164+
165+
if any(not is_colon(_) for _ in x):
159166
self._operator = None
160167
gc.collect()
168+
161169
out.instrument = self.instrument[x[0]]
170+
if not is_colon(x[0]):
171+
object.__setattr__(out.instrument.detector, 'comm', MPI.COMM_SELF)
162172
out.sampling = self.sampling[x[1]] # XXX FIX BLOCKS!!!
173+
if not is_colon(x[1]):
174+
object.__setattr__(out.sampling, 'comm', MPI.COMM_SELF)
163175
out.scene = self.scene[x[2]]
176+
177+
if not is_colon(x[0]) and not is_colon(x[1]):
178+
out.comm = MPI.COMM_SELF
179+
elif not is_colon(x[0]):
180+
out.comm = self.sampling.comm.Create_cart([1, self.sampling.comm.size])
181+
elif not is_colon(x[1]):
182+
out.comm = self.instrument.comm.Create_cart([self.instrument.comm.size, 1])
183+
164184
return out
165185

166186
def __str__(self):

src/pysimulators/instruments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def __iter__(self):
5757
def __len__(self):
5858
return len(self.detector)
5959

60+
@property
61+
def comm(self):
62+
return self.detector.comm
63+
6064
def pack(self, x):
6165
return self.detector.pack(x)
6266

tests/test_acquisitions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_get_noise1(shape, fknee):
3939
fsamp = 5
4040
sigma = 0.3
4141
scene = Scene(10)
42-
sampling = Sampling(2e4, period=1 / fsamp)
42+
sampling = Sampling(20_000, period=1 / fsamp)
4343
np.random.seed(0)
4444

4545
class MyAcquisition1(Acquisition):
@@ -58,7 +58,7 @@ def test_get_noise2(shape):
5858
fsamp = 5
5959
sigma = 0.3
6060
scene = Scene(10)
61-
sampling = Sampling(2e4, period=1 / fsamp)
61+
sampling = Sampling(20_000, period=1 / fsamp)
6262
freq = np.arange(6) / 6 * fsamp
6363
psd = np.array([0, 1, 1, 1, 1, 1], float) * sigma**2 / fsamp
6464

@@ -80,7 +80,7 @@ def test_get_noise3(shape):
8080
fsamp = 5
8181
sigma = 0.3
8282
scene = Scene(10)
83-
sampling = Sampling(2e4, period=1 / fsamp)
83+
sampling = Sampling(20_000, period=1 / fsamp)
8484
freq = np.arange(4) / 6 * fsamp
8585
psd = np.array([0, 2, 2, 1], float) * sigma**2 / fsamp
8686

tests/test_acquisitions_mpi.py

Lines changed: 139 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,144 @@
1+
from itertools import chain, product
2+
3+
import numpy as np
14
import pytest
25

3-
from pyoperators import MPI
6+
from pyoperators import MPI, MPIDistributionIdentityOperator
7+
from pyoperators.utils.testing import assert_same
48
from pysimulators import Acquisition, Instrument, PackedTable, Sampling, Scene
9+
from pysimulators.operators import ProjectionOperator
10+
from pysimulators.sparse import FSRMatrix
11+
12+
pytestmark = pytest.mark.mpi
13+
14+
RANK = MPI.COMM_WORLD.rank
15+
SIZE = MPI.COMM_WORLD.size
16+
NPROCS_INSTRUMENT = sorted(
17+
{int(n) for n in [1, SIZE / 3, SIZE / 2, SIZE] if int(n) == n}
18+
)
19+
NSCENE = 10
20+
NSAMPLING_GLOBAL = 100
21+
NDETECTOR_GLOBAL = 16
22+
SCENE = Scene(NSCENE)
23+
SAMPLING = Sampling(NSAMPLING_GLOBAL, period=1.0)
24+
INSTRUMENT = Instrument('', PackedTable(NDETECTOR_GLOBAL))
25+
26+
27+
class MyAcquisition(Acquisition):
28+
def get_projection_operator(self):
29+
dtype = [('index', int), ('value', float)]
30+
data = np.recarray((len(self.instrument), len(self.sampling), 1), dtype=dtype)
31+
for ilocal, iglobal in enumerate(self.instrument.detector.index):
32+
data[ilocal].value = iglobal
33+
data[ilocal, :, 0].index = [
34+
(iglobal + int(t)) % NSCENE for t in self.sampling.time
35+
]
36+
37+
matrix = FSRMatrix(
38+
(len(self.instrument) * len(self.sampling), NSCENE),
39+
data=data.reshape((-1, 1)),
40+
)
41+
42+
return ProjectionOperator(
43+
matrix, dtype=float, shapeout=(len(self.instrument), len(self.sampling))
44+
)
45+
46+
47+
def get_acquisition(comm, nprocs_instrument):
48+
return MyAcquisition(
49+
INSTRUMENT, SAMPLING, SCENE, comm=comm, nprocs_instrument=nprocs_instrument
50+
)
51+
52+
53+
@pytest.mark.parametrize('nprocs_instrument', NPROCS_INSTRUMENT)
54+
def test_communicators(nprocs_instrument):
55+
sky = SCENE.ones()
56+
nprocs_sampling = SIZE // nprocs_instrument
57+
serial_acq = get_acquisition(MPI.COMM_SELF, 1)
58+
assert serial_acq.comm.size == 1
59+
assert serial_acq.instrument.comm.size == 1
60+
assert serial_acq.sampling.comm.size == 1
61+
assert len(serial_acq.instrument) == NDETECTOR_GLOBAL
62+
assert len(serial_acq.sampling) == NSAMPLING_GLOBAL
63+
64+
parallel_acq = get_acquisition(MPI.COMM_WORLD, nprocs_instrument)
65+
assert parallel_acq.comm.size == SIZE
66+
assert parallel_acq.instrument.comm.size == nprocs_instrument
67+
assert parallel_acq.sampling.comm.size == nprocs_sampling
68+
assert (
69+
parallel_acq.instrument.comm.allreduce(len(parallel_acq.instrument))
70+
== NDETECTOR_GLOBAL
71+
)
72+
assert (
73+
parallel_acq.sampling.comm.allreduce(len(parallel_acq.sampling))
74+
== NSAMPLING_GLOBAL
75+
)
576

6-
rank = MPI.COMM_WORLD.rank
7-
size = MPI.COMM_WORLD.size
8-
9-
10-
def test():
11-
scene = Scene(1024)
12-
instrument = Instrument('instrument', PackedTable((32, 32)))
13-
sampling = Sampling(1000)
14-
acq = Acquisition(instrument, sampling, scene, nprocs_sampling=max(size // 2, 1))
15-
print(
16-
acq.comm.rank,
17-
acq.instrument.detector.comm.rank,
18-
'/',
19-
acq.instrument.detector.comm.size,
20-
acq.sampling.comm.rank,
21-
'/',
22-
acq.sampling.comm.size,
77+
serial_H = serial_acq.get_projection_operator()
78+
ref_tod = serial_H(sky)
79+
80+
parallel_H = (
81+
parallel_acq.get_projection_operator()
82+
* MPIDistributionIdentityOperator(parallel_acq.comm)
83+
)
84+
local_tod = parallel_H(sky)
85+
actual_tod = np.vstack(
86+
parallel_acq.instrument.comm.allgather(
87+
np.hstack(parallel_acq.sampling.comm.allgather(local_tod))
88+
)
2389
)
24-
pytest.xfail('the test is not finished.')
90+
assert_same(actual_tod, ref_tod, atol=20)
91+
92+
ref_backproj = serial_H.T(ref_tod)
93+
actual_backproj = parallel_H.T(local_tod)
94+
assert_same(actual_backproj, ref_backproj, atol=20)
95+
96+
97+
@pytest.mark.parametrize('nprocs_instrument', NPROCS_INSTRUMENT)
98+
@pytest.mark.parametrize(
99+
'selection',
100+
[
101+
Ellipsis,
102+
slice(None),
103+
]
104+
+ list(chain(*(product([slice(None), Ellipsis], repeat=n) for n in [1, 2, 3]))),
105+
)
106+
def test_communicators_getitem_all(nprocs_instrument, selection):
107+
acq = get_acquisition(MPI.COMM_WORLD, nprocs_instrument)
108+
assert acq.instrument.comm.size == nprocs_instrument
109+
assert acq.sampling.comm.size == MPI.COMM_WORLD.size / nprocs_instrument
110+
assert acq.comm.size == MPI.COMM_WORLD.size
111+
restricted_acq = acq[selection]
112+
assert restricted_acq.instrument.comm.size == nprocs_instrument
113+
assert restricted_acq.sampling.comm.size == MPI.COMM_WORLD.size / nprocs_instrument
114+
assert restricted_acq.comm.size == MPI.COMM_WORLD.size
115+
116+
117+
@pytest.mark.parametrize('nprocs_instrument', NPROCS_INSTRUMENT)
118+
@pytest.mark.parametrize('selection', [0, slice(None, 1), np.array])
119+
def test_communicators_getitem_instrument(nprocs_instrument, selection):
120+
acq = get_acquisition(MPI.COMM_WORLD, nprocs_instrument)
121+
if selection is np.array:
122+
selection = np.zeros(len(acq.instrument), bool)
123+
selection[0] = True
124+
restricted_acq = acq[selection]
125+
assert restricted_acq.instrument.comm.size == 1
126+
assert restricted_acq.sampling.comm.size == acq.sampling.comm.size
127+
assert restricted_acq.comm.size == acq.sampling.comm.size
128+
129+
130+
SELECTION_GETITEM_SAMPLING = np.zeros(NSAMPLING_GLOBAL, bool)
131+
SELECTION_GETITEM_SAMPLING[0] = True
132+
133+
134+
@pytest.mark.parametrize('nprocs_instrument', NPROCS_INSTRUMENT)
135+
@pytest.mark.parametrize('selection', [0, slice(None, 1), np.array])
136+
def test_communicators_getitem_sampling(nprocs_instrument, selection):
137+
acq = get_acquisition(MPI.COMM_WORLD, nprocs_instrument)
138+
if selection is np.array:
139+
selection = np.zeros(len(acq.sampling), bool)
140+
selection[0] = True
141+
restricted_acq = acq[:, selection]
142+
assert restricted_acq.instrument.comm.size == acq.instrument.comm.size
143+
assert restricted_acq.sampling.comm.size == 1
144+
assert restricted_acq.comm.size == acq.instrument.comm.size

0 commit comments

Comments
 (0)