Skip to content

Commit

Permalink
Implementation of Parallelization to MDAnalysis.analysis.contacts (#…
Browse files Browse the repository at this point in the history
…4820)

* Fixes #4660
* summary of changes:
    - added backends and aggregators to Contacts in analysis.contacts
    - added private _get_box_func method because lambdas cannot be used for parallelization
    - added the client_Contacts in conftest.py
    - added client_Contacts in run() in test_contacts.py
* Update CHANGELOG
  • Loading branch information
talagayev authored Dec 17, 2024
1 parent 80b28c8 commit a3672f2
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 40 deletions.
2 changes: 2 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ Fixes
the function to prevent shared state. (Issue #4655)

Enhancements
* Enables parallelization for analysis.contacts.Contacts (Issue #4660)
* Enable parallelization for analysis.nucleicacids.NucPairDist (Issue #4670)
* Add check and warning for empty (all zero) coordinates in RDKit converter (PR #4824)
* Added `precision` for XYZWriter (Issue #4775, PR #4771)


Changes

Deprecations
Expand Down
48 changes: 41 additions & 7 deletions package/MDAnalysis/analysis/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def is_any_closer(r, r0, dist=2.5):
from MDAnalysis.lib.util import openany
from MDAnalysis.analysis.distances import distance_array
from MDAnalysis.core.groups import AtomGroup, UpdatingAtomGroup
from .base import AnalysisBase
from .base import AnalysisBase, ResultsGroup

logger = logging.getLogger("MDAnalysis.analysis.contacts")

Expand Down Expand Up @@ -376,8 +376,22 @@ class Contacts(AnalysisBase):
:class:`MDAnalysis.analysis.base.Results` instance.
.. versionchanged:: 2.2.0
:class:`Contacts` accepts both AtomGroup and string for `select`
.. versionchanged:: 2.9.0
Introduced :meth:`get_supported_backends` allowing
for parallel execution on :mod:`multiprocessing`
and :mod:`dask` backends.
"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return (
"serial",
"multiprocessing",
"dask",
)

def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
pbc=True, kwargs=None, **basekwargs):
"""
Expand Down Expand Up @@ -444,11 +458,8 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
self.r0 = []
self.initial_contacts = []

#get dimension of box if pbc set to True
if self.pbc:
self._get_box = lambda ts: ts.dimensions
else:
self._get_box = lambda ts: None
# get dimensions via partial for parallelization compatibility
self._get_box = functools.partial(self._get_box_func, pbc=self.pbc)

if isinstance(refgroup[0], AtomGroup):
refA, refB = refgroup
Expand All @@ -464,7 +475,6 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,

self.n_initial_contacts = self.initial_contacts[0].sum()


@staticmethod
def _get_atomgroup(u, sel):
select_error_message = ("selection must be either string or a "
Expand All @@ -480,6 +490,28 @@ def _get_atomgroup(u, sel):
else:
raise TypeError(select_error_message)

@staticmethod
def _get_box_func(ts, pbc):
"""Retrieve the dimensions of the simulation box based on PBC.
Parameters
----------
ts : Timestep
The current timestep of the simulation, which contains the
box dimensions.
pbc : bool
A flag indicating whether periodic boundary conditions (PBC)
are enabled. If `True`, the box dimensions are returned,
else returns `None`.
Returns
-------
box_dimensions : ndarray or None
The dimensions of the simulation box as a NumPy array if PBC
is True, else returns `None`.
"""
return ts.dimensions if pbc else None

def _prepare(self):
self.results.timeseries = np.empty((self.n_frames, len(self.r0)+1))

Expand All @@ -506,6 +538,8 @@ def timeseries(self):
warnings.warn(wmsg, DeprecationWarning)
return self.results.timeseries

def _get_aggregator(self):
return ResultsGroup(lookup={'timeseries': ResultsGroup.ndarray_vstack})

def _new_selections(u_orig, selections, frame):
"""create stand alone AGs from selections at frame"""
Expand Down
8 changes: 8 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
HydrogenBondAnalysis,
)
from MDAnalysis.analysis.nucleicacids import NucPairDist
from MDAnalysis.analysis.contacts import Contacts
from MDAnalysis.lib.util import is_installed


Expand Down Expand Up @@ -149,3 +150,10 @@ def client_HydrogenBondAnalysis(request):
@pytest.fixture(scope="module", params=params_for_cls(NucPairDist))
def client_NucPairDist(request):
return request.param


# MDAnalysis.analysis.contacts

@pytest.fixture(scope="module", params=params_for_cls(Contacts))
def client_Contacts(request):
return request.param
87 changes: 54 additions & 33 deletions testsuite/MDAnalysisTests/analysis/test_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def universe():
return mda.Universe(PSF, DCD)

def _run_Contacts(
self, universe,
start=None, stop=None, step=None, **kwargs
self, universe, client_Contacts, start=None,
stop=None, step=None, **kwargs
):
acidic = universe.select_atoms(self.sel_acidic)
basic = universe.select_atoms(self.sel_basic)
Expand All @@ -181,7 +181,8 @@ def _run_Contacts(
select=(self.sel_acidic, self.sel_basic),
refgroup=(acidic, basic),
radius=6.0,
**kwargs).run(start=start, stop=stop, step=step)
**kwargs
).run(**client_Contacts, start=start, stop=stop, step=step)

@pytest.mark.parametrize("seltxt", [sel_acidic, sel_basic])
def test_select_valid_types(self, universe, seltxt):
Expand All @@ -195,7 +196,7 @@ def test_select_valid_types(self, universe, seltxt):

assert ag_from_string == ag_from_ag

def test_contacts_selections(self, universe):
def test_contacts_selections(self, universe, client_Contacts):
"""Test if Contacts can take both string and AtomGroup as selections.
"""
aga = universe.select_atoms(self.sel_acidic)
Expand All @@ -210,8 +211,8 @@ def test_contacts_selections(self, universe):
refgroup=(aga, agb)
)

cag.run()
csel.run()
cag.run(**client_Contacts)
csel.run(**client_Contacts)

assert cag.grA == csel.grA
assert cag.grB == csel.grB
Expand All @@ -228,26 +229,31 @@ def test_select_wrong_types(self, universe, ag):
) as te:
contacts.Contacts._get_atomgroup(universe, ag)

def test_startframe(self, universe):
def test_startframe(self, universe, client_Contacts):
"""test_startframe: TestContactAnalysis1: start frame set to 0 (resolution of
Issue #624)
"""
CA1 = self._run_Contacts(universe)
CA1 = self._run_Contacts(universe, client_Contacts=client_Contacts)
assert len(CA1.results.timeseries) == universe.trajectory.n_frames

def test_end_zero(self, universe):
def test_end_zero(self, universe, client_Contacts):
"""test_end_zero: TestContactAnalysis1: stop frame 0 is not ignored"""
CA1 = self._run_Contacts(universe, stop=0)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts, stop=0
)
assert len(CA1.results.timeseries) == 0

def test_slicing(self, universe):
def test_slicing(self, universe, client_Contacts):
start, stop, step = 10, 30, 5
CA1 = self._run_Contacts(universe, start=start, stop=stop, step=step)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts,
start=start, stop=stop, step=step
)
frames = np.arange(universe.trajectory.n_frames)[start:stop:step]
assert len(CA1.results.timeseries) == len(frames)

def test_villin_folded(self):
def test_villin_folded(self, client_Contacts):
# one folded, one unfolded
f = mda.Universe(contacts_villin_folded)
u = mda.Universe(contacts_villin_unfolded)
Expand All @@ -259,12 +265,12 @@ def test_villin_folded(self):
select=(sel, sel),
refgroup=(grF, grF),
method="soft_cut")
q.run()
q.run(**client_Contacts)

results = soft_cut(f, u, sel, sel)
assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7)

def test_villin_unfolded(self):
def test_villin_unfolded(self, client_Contacts):
# both folded
f = mda.Universe(contacts_villin_folded)
u = mda.Universe(contacts_villin_folded)
Expand All @@ -276,13 +282,13 @@ def test_villin_unfolded(self):
select=(sel, sel),
refgroup=(grF, grF),
method="soft_cut")
q.run()
q.run(**client_Contacts)

results = soft_cut(f, u, sel, sel)
assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7)

def test_hard_cut_method(self, universe):
ca = self._run_Contacts(universe)
def test_hard_cut_method(self, universe, client_Contacts):
ca = self._run_Contacts(universe, client_Contacts=client_Contacts)
expected = [1., 0.58252427, 0.52427184, 0.55339806, 0.54368932,
0.54368932, 0.51456311, 0.46601942, 0.48543689, 0.52427184,
0.46601942, 0.58252427, 0.51456311, 0.48543689, 0.48543689,
Expand All @@ -306,7 +312,7 @@ def test_hard_cut_method(self, universe):
assert len(ca.results.timeseries) == len(expected)
assert_allclose(ca.results.timeseries[:, 1], expected, rtol=0, atol=1.5e-7)

def test_radius_cut_method(self, universe):
def test_radius_cut_method(self, universe, client_Contacts):
acidic = universe.select_atoms(self.sel_acidic)
basic = universe.select_atoms(self.sel_basic)
r = contacts.distance_array(acidic.positions, basic.positions)
Expand All @@ -316,15 +322,20 @@ def test_radius_cut_method(self, universe):
r = contacts.distance_array(acidic.positions, basic.positions)
expected.append(contacts.radius_cut_q(r[initial_contacts], None, radius=6.0))

ca = self._run_Contacts(universe, method='radius_cut')
ca = self._run_Contacts(
universe, client_Contacts=client_Contacts, method="radius_cut"
)
assert_array_equal(ca.results.timeseries[:, 1], expected)

@staticmethod
def _is_any_closer(r, r0, dist=2.5):
return np.any(r < dist)

def test_own_method(self, universe):
ca = self._run_Contacts(universe, method=self._is_any_closer)
def test_own_method(self, universe, client_Contacts):
ca = self._run_Contacts(
universe, client_Contacts=client_Contacts,
method=self._is_any_closer
)

bound_expected = [1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0.,
1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1.,
Expand All @@ -340,21 +351,28 @@ def test_own_method(self, universe):
def _weird_own_method(r, r0):
return 'aaa'

def test_own_method_no_array_cast(self, universe):
def test_own_method_no_array_cast(self, universe, client_Contacts):
with pytest.raises(ValueError):
self._run_Contacts(universe, method=self._weird_own_method, stop=2)

def test_non_callable_method(self, universe):
self._run_Contacts(
universe,
client_Contacts=client_Contacts,
method=self._weird_own_method,
stop=2,
)

def test_non_callable_method(self, universe, client_Contacts):
with pytest.raises(ValueError):
self._run_Contacts(universe, method=2, stop=2)
self._run_Contacts(
universe, client_Contacts=client_Contacts, method=2, stop=2
)

@pytest.mark.parametrize("pbc,expected", [
(True, [1., 0.43138152, 0.3989021, 0.43824337, 0.41948765,
0.42223239, 0.41354071, 0.43641354, 0.41216834, 0.38334858]),
(False, [1., 0.42327791, 0.39192399, 0.40950119, 0.40902613,
0.42470309, 0.41140143, 0.42897862, 0.41472684, 0.38574822])
])
def test_distance_box(self, pbc, expected):
def test_distance_box(self, pbc, expected, client_Contacts):
u = mda.Universe(TPR, XTC)
sel_basic = "(resname ARG LYS)"
sel_acidic = "(resname ASP GLU)"
Expand All @@ -363,13 +381,15 @@ def test_distance_box(self, pbc, expected):

r = contacts.Contacts(u, select=(sel_acidic, sel_basic),
refgroup=(acidic, basic), radius=6.0, pbc=pbc)
r.run()
r.run(**client_Contacts)
assert_allclose(r.results.timeseries[:, 1], expected,rtol=0, atol=1.5e-7)

def test_warn_deprecated_attr(self, universe):
def test_warn_deprecated_attr(self, universe, client_Contacts):
"""Test for warning message emitted on using deprecated `timeseries`
attribute"""
CA1 = self._run_Contacts(universe, stop=1)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts, stop=1
)
wmsg = "The `timeseries` attribute was deprecated in MDAnalysis"
with pytest.warns(DeprecationWarning, match=wmsg):
assert_equal(CA1.timeseries, CA1.results.timeseries)
Expand All @@ -385,10 +405,11 @@ def test_n_initial_contacts(self, datafiles, expected):
r = contacts.Contacts(u, select=select, refgroup=refgroup)
assert_equal(r.n_initial_contacts, expected)

def test_q1q2():

def test_q1q2(client_Contacts):
u = mda.Universe(PSF, DCD)
q1q2 = contacts.q1q2(u, 'name CA', radius=8)
q1q2.run()
q1q2.run(**client_Contacts)

q1_expected = [1., 0.98092643, 0.97366031, 0.97275204, 0.97002725,
0.97275204, 0.96276113, 0.96730245, 0.9582198, 0.96185286,
Expand Down

0 comments on commit a3672f2

Please sign in to comment.