Skip to content

Commit c0685ac

Browse files
authored
Merge pull request #4884 from tanishy7777/parallize_rdf
Parallelizes `MDAnalysis.analysis.InterRDF` and `MDAnalysis.analysis.InterRDF_s`
2 parents 519ac56 + 79601e6 commit c0685ac

File tree

5 files changed

+172
-45
lines changed

5 files changed

+172
-45
lines changed

package/CHANGELOG

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,13 @@ Enhancements
115115
(Issue #4677, PR #4729)
116116
* Enables parallelization for analysis.contacts.Contacts (Issue #4660)
117117
* Enable parallelization for analysis.nucleicacids.NucPairDist (Issue #4670)
118+
* Add check and warning for empty (all zero) coordinates in RDKit converter (PR #4824)
119+
* Added `precision` for XYZWriter (Issue #4775, PR #4771)
120+
* Parallelize `analysis.rdf.InterRDF` and `analysis.rdf.InterRDF` (Issue #4675)
121+
* Added `precision` for XYZWriter (Issue #4775, PR #4771)
118122
* Add check and warning for empty (all zero) coordinates in RDKit converter
119123
(PR #4824)
120-
* Added `precision` for XYZWriter (Issue #4775, PR #4771)
124+
* Added `precision` for XYZWriter (Issue #4775, PR #4771)
121125

122126
Changes
123127
* MDAnalysis.analysis.psa, MDAnalysis.analysis.waterdynamics and

package/MDAnalysis/analysis/rdf.py

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,43 @@
8080
import numpy as np
8181

8282
from ..lib import distances
83-
from .base import AnalysisBase
83+
from .base import AnalysisBase, ResultsGroup
84+
85+
86+
def nested_array_sum(arrs):
87+
r"""Custom aggregator for nested arrays
88+
89+
This function takes a nested list or tuple of NumPy arrays, flattens it
90+
into a single list, and aggregates the elements at alternating indices
91+
into two separate arrays. The first array accumulates elements at even
92+
indices, while the second accumulates elements at odd indices.
93+
94+
Parameters
95+
----------
96+
arrs : list
97+
List of arrays or nested lists of arrays
98+
99+
Returns
100+
-------
101+
list of ndarray
102+
A list containing two NumPy arrays:
103+
- The first array is the sum of all elements at even indices
104+
in the sum of flattened arrays.
105+
- The second array is the sum of all elements at odd indices
106+
in the sum of flattened arrays.
107+
"""
108+
109+
def flatten(arr):
110+
if isinstance(arr, (list, tuple)):
111+
return [item for sublist in arr for item in flatten(sublist)]
112+
return [arr]
113+
114+
flat = flatten(arrs)
115+
aggregated_arr = [np.zeros_like(flat[0]), np.zeros_like(flat[1])]
116+
for i in range(len(flat) // 2):
117+
aggregated_arr[0] += flat[2 * i] # 0, 2, 4, ...
118+
aggregated_arr[1] += flat[2 * i + 1] # 1, 3, 5, ...
119+
return aggregated_arr
84120

85121

86122
class InterRDF(AnalysisBase):
@@ -221,8 +257,23 @@ class InterRDF(AnalysisBase):
221257
Store results as attributes `bins`, `edges`, `rdf` and `count`
222258
of the `results` attribute of
223259
:class:`~MDAnalysis.analysis.AnalysisBase`.
260+
261+
.. versionchanged:: 2.9.0
262+
Enabled **parallel execution** with the ``multiprocessing`` and ``dask``
263+
backends; use the new method :meth:`get_supported_backends` to see all
264+
supported backends.
224265
"""
225266

267+
@classmethod
268+
def get_supported_backends(cls):
269+
return (
270+
"serial",
271+
"multiprocessing",
272+
"dask",
273+
)
274+
275+
_analysis_algorithm_is_parallelizable = True
276+
226277
def __init__(
227278
self,
228279
g1,
@@ -281,7 +332,7 @@ def _prepare(self):
281332

282333
if self.norm == "rdf":
283334
# Cumulative volume for rdf normalization
284-
self.volume_cum = 0
335+
self.results.volume_cum = 0
285336
# Set the max range to filter the search radius
286337
self._maxrange = self.rdf_settings["range"][1]
287338

@@ -311,7 +362,17 @@ def _single_frame(self):
311362
self.results.count += count
312363

313364
if self.norm == "rdf":
314-
self.volume_cum += self._ts.volume
365+
self.results.volume_cum += self._ts.volume
366+
367+
def _get_aggregator(self):
368+
return ResultsGroup(
369+
lookup={
370+
"count": ResultsGroup.ndarray_sum,
371+
"volume_cum": ResultsGroup.ndarray_sum,
372+
"bins": ResultsGroup.ndarray_sum,
373+
"edges": ResultsGroup.ndarray_mean,
374+
}
375+
)
315376

316377
def _conclude(self):
317378
norm = self.n_frames
@@ -333,6 +394,7 @@ def _conclude(self):
333394
N -= xA * xB * nblocks
334395

335396
# Average number density
397+
self.volume_cum = self.results.volume_cum
336398
box_vol = self.volume_cum / self.n_frames
337399
norm *= N / box_vol
338400

@@ -576,8 +638,32 @@ class InterRDF_s(AnalysisBase):
576638
Instead of `density=True` use `norm='density'`
577639
.. deprecated:: 2.3.0
578640
The `universe` parameter is superflous.
641+
.. versionchanged:: 2.9.0
642+
Enabled **parallel execution** with the ``multiprocessing`` and ``dask``
643+
backends; use the new method :meth:`get_supported_backends` to see all
644+
supported backends.
579645
"""
580646

647+
@classmethod
648+
def get_supported_backends(cls):
649+
return (
650+
"serial",
651+
"multiprocessing",
652+
"dask",
653+
)
654+
655+
_analysis_algorithm_is_parallelizable = True
656+
657+
def _get_aggregator(self):
658+
return ResultsGroup(
659+
lookup={
660+
"count": nested_array_sum,
661+
"volume_cum": ResultsGroup.ndarray_sum,
662+
"bins": ResultsGroup.ndarray_mean,
663+
"edges": ResultsGroup.ndarray_mean,
664+
}
665+
)
666+
581667
def __init__(
582668
self,
583669
u,
@@ -632,7 +718,7 @@ def _prepare(self):
632718

633719
if self.norm == "rdf":
634720
# Cumulative volume for rdf normalization
635-
self.volume_cum = 0
721+
self.results.volume_cum = 0
636722
self._maxrange = self.rdf_settings["range"][1]
637723

638724
def _single_frame(self):
@@ -650,7 +736,7 @@ def _single_frame(self):
650736
self.results.count[i][idx1, idx2, :] += count
651737

652738
if self.norm == "rdf":
653-
self.volume_cum += self._ts.volume
739+
self.results.volume_cum += self._ts.volume
654740

655741
def _conclude(self):
656742
norm = self.n_frames
@@ -661,6 +747,7 @@ def _conclude(self):
661747

662748
if self.norm == "rdf":
663749
# Average number density
750+
self.volume_cum = self.results.volume_cum
664751
norm *= 1 / (self.volume_cum / self.n_frames)
665752

666753
# Empty lists to restore indices, RDF

testsuite/MDAnalysisTests/analysis/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from MDAnalysis.analysis.density import DensityAnalysis
2020
from MDAnalysis.analysis.lineardensity import LinearDensity
2121
from MDAnalysis.analysis.polymer import PersistenceLength
22+
from MDAnalysis.analysis.rdf import InterRDF, InterRDF_s
2223
from MDAnalysis.lib.util import is_installed
2324

2425

@@ -194,3 +195,16 @@ def client_LinearDensity(request):
194195
@pytest.fixture(scope="module", params=params_for_cls(PersistenceLength))
195196
def client_PersistenceLength(request):
196197
return request.param
198+
199+
200+
# MDAnalysis.analysis.rdf
201+
202+
203+
@pytest.fixture(scope="module", params=params_for_cls(InterRDF))
204+
def client_InterRDF(request):
205+
return request.param
206+
207+
208+
@pytest.fixture(scope="module", params=params_for_cls(InterRDF_s))
209+
def client_InterRDF_s(request):
210+
return request.param

testsuite/MDAnalysisTests/analysis/test_rdf.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,83 +49,85 @@ def sels(u):
4949
return s1, s2
5050

5151

52-
def test_nbins(u):
52+
def test_nbins(u, client_InterRDF):
5353
s1 = u.atoms[:3]
5454
s2 = u.atoms[3:]
55-
rdf = InterRDF(s1, s2, nbins=412).run()
55+
rdf = InterRDF(s1, s2, nbins=412).run(**client_InterRDF)
5656

5757
assert len(rdf.results.bins) == 412
5858

5959

60-
def test_range(u):
60+
def test_range(u, client_InterRDF):
6161
s1 = u.atoms[:3]
6262
s2 = u.atoms[3:]
6363
rmin, rmax = 1.0, 13.0
64-
rdf = InterRDF(s1, s2, range=(rmin, rmax)).run()
64+
rdf = InterRDF(s1, s2, range=(rmin, rmax)).run(**client_InterRDF)
6565

6666
assert rdf.results.edges[0] == rmin
6767
assert rdf.results.edges[-1] == rmax
6868

6969

70-
def test_count_sum(sels):
70+
def test_count_sum(sels, client_InterRDF):
7171
# OW vs HW
7272
# should see 8 comparisons in count
7373
s1, s2 = sels
74-
rdf = InterRDF(s1, s2).run()
74+
rdf = InterRDF(s1, s2).run(**client_InterRDF)
7575
assert rdf.results.count.sum() == 8
7676

7777

78-
def test_count(sels):
78+
def test_count(sels, client_InterRDF):
7979
# should see two distances with 4 counts each
8080
s1, s2 = sels
81-
rdf = InterRDF(s1, s2).run()
81+
rdf = InterRDF(s1, s2).run(**client_InterRDF)
8282
assert len(rdf.results.count[rdf.results.count == 4]) == 2
8383

8484

85-
def test_double_run(sels):
85+
def test_double_run(sels, client_InterRDF):
8686
# running rdf twice should give the same result
8787
s1, s2 = sels
88-
rdf = InterRDF(s1, s2).run()
89-
rdf.run()
88+
rdf = InterRDF(s1, s2).run(**client_InterRDF)
89+
rdf.run(**client_InterRDF)
9090
assert len(rdf.results.count[rdf.results.count == 4]) == 2
9191

9292

93-
def test_exclusion(sels):
93+
def test_exclusion(sels, client_InterRDF):
9494
# should see two distances with 4 counts each
9595
s1, s2 = sels
96-
rdf = InterRDF(s1, s2, exclusion_block=(1, 2)).run()
96+
rdf = InterRDF(s1, s2, exclusion_block=(1, 2)).run(**client_InterRDF)
9797
assert rdf.results.count.sum() == 4
9898

9999

100100
@pytest.mark.parametrize(
101101
"attr, count", [("residue", 8), ("segment", 0), ("chain", 8)]
102102
)
103-
def test_ignore_same_residues(sels, attr, count):
103+
def test_ignore_same_residues(sels, attr, count, client_InterRDF):
104104
# should see two distances with 4 counts each
105105
s1, s2 = sels
106-
rdf = InterRDF(s2, s2, exclude_same=attr).run()
106+
rdf = InterRDF(s2, s2, exclude_same=attr).run(**client_InterRDF)
107107
assert rdf.rdf[0] == 0
108108
assert rdf.results.count.sum() == count
109109

110110

111-
def test_ignore_same_residues_fails(sels):
111+
def test_ignore_same_residues_fails(sels, client_InterRDF):
112112
s1, s2 = sels
113113
with pytest.raises(
114114
ValueError, match="The exclude_same argument to InterRDF must be"
115115
):
116-
InterRDF(s2, s2, exclude_same="unsupported").run()
116+
InterRDF(s2, s2, exclude_same="unsupported").run(**client_InterRDF)
117117

118118
with pytest.raises(
119119
ValueError,
120120
match="The exclude_same argument to InterRDF cannot be used with",
121121
):
122-
InterRDF(s2, s2, exclude_same="residue", exclusion_block=tuple()).run()
122+
InterRDF(s2, s2, exclude_same="residue", exclusion_block=tuple()).run(
123+
**client_InterRDF
124+
)
123125

124126

125127
@pytest.mark.parametrize("attr", ("rdf", "bins", "edges", "count"))
126-
def test_rdf_attr_warning(sels, attr):
128+
def test_rdf_attr_warning(sels, attr, client_InterRDF):
127129
s1, s2 = sels
128-
rdf = InterRDF(s1, s2).run()
130+
rdf = InterRDF(s1, s2).run(**client_InterRDF)
129131
wmsg = f"The `{attr}` attribute was deprecated in MDAnalysis 2.0.0"
130132
with pytest.warns(DeprecationWarning, match=wmsg):
131133
getattr(rdf, attr) is rdf.results[attr]
@@ -134,18 +136,18 @@ def test_rdf_attr_warning(sels, attr):
134136
@pytest.mark.parametrize(
135137
"norm, value", [("density", 1.956823), ("rdf", 244602.88385), ("none", 4)]
136138
)
137-
def test_norm(sels, norm, value):
139+
def test_norm(sels, norm, value, client_InterRDF):
138140
s1, s2 = sels
139-
rdf = InterRDF(s1, s2, norm=norm).run()
141+
rdf = InterRDF(s1, s2, norm=norm).run(**client_InterRDF)
140142
assert_allclose(max(rdf.results.rdf), value)
141143

142144

143145
@pytest.mark.parametrize(
144146
"norm, norm_required", [("Density", "density"), (None, "none")]
145147
)
146-
def test_norm_values(sels, norm, norm_required):
148+
def test_norm_values(sels, norm, norm_required, client_InterRDF):
147149
s1, s2 = sels
148-
rdf = InterRDF(s1, s2, norm=norm).run()
150+
rdf = InterRDF(s1, s2, norm=norm).run(**client_InterRDF)
149151
assert rdf.norm == norm_required
150152

151153

0 commit comments

Comments
 (0)