Skip to content

Commit

Permalink
[fmt] lib (#4804)
Browse files Browse the repository at this point in the history
Co-authored-by: Egor Marin <[email protected]>
  • Loading branch information
RMeli and marinegor authored Dec 6, 2024
1 parent 4e903c7 commit 25e755f
Show file tree
Hide file tree
Showing 23 changed files with 3,576 additions and 2,063 deletions.
38 changes: 22 additions & 16 deletions package/MDAnalysis/lib/NeighborSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class AtomNeighborSearch(object):
:class:`~MDAnalysis.lib.distances.capped_distance`.
"""

def __init__(self, atom_group: AtomGroup,
box: Optional[npt.ArrayLike] = None) -> None:
def __init__(
self, atom_group: AtomGroup, box: Optional[npt.ArrayLike] = None
) -> None:
"""
Parameters
Expand All @@ -62,10 +63,9 @@ def __init__(self, atom_group: AtomGroup,
self._u = atom_group.universe
self._box = box

def search(self, atoms: AtomGroup,
radius: float,
level: str = 'A'
) -> Optional[Union[AtomGroup, ResidueGroup, SegmentGroup]]:
def search(
self, atoms: AtomGroup, radius: float, level: str = "A"
) -> Optional[Union[AtomGroup, ResidueGroup, SegmentGroup]]:
"""
Return all atoms/residues/segments that are within *radius* of the
atoms in *atoms*.
Expand Down Expand Up @@ -102,17 +102,21 @@ def search(self, atoms: AtomGroup,
except AttributeError:
# For atom, take the position attribute
position = atoms.position
pairs = capped_distance(position, self.atom_group.positions,
radius, box=self._box, return_distances=False)
pairs = capped_distance(
position,
self.atom_group.positions,
radius,
box=self._box,
return_distances=False,
)

if pairs.size > 0:
unique_idx = unique_int_1d(np.asarray(pairs[:, 1], dtype=np.intp))
return self._index2level(unique_idx, level)

def _index2level(self,
indices: List[int],
level: str
) -> Union[AtomGroup, ResidueGroup, SegmentGroup]:
def _index2level(
self, indices: List[int], level: str
) -> Union[AtomGroup, ResidueGroup, SegmentGroup]:
"""Convert list of atom_indices in a AtomGroup to either the
Atoms or segments/residues containing these atoms.
Expand All @@ -125,11 +129,13 @@ def _index2level(self,
*radius* of *atoms*.
"""
atomgroup = self.atom_group[indices]
if level == 'A':
if level == "A":
return atomgroup
elif level == 'R':
elif level == "R":
return atomgroup.residues
elif level == 'S':
elif level == "S":
return atomgroup.segments
else:
raise NotImplementedError('{0}: level not implemented'.format(level))
raise NotImplementedError(
"{0}: level not implemented".format(level)
)
21 changes: 16 additions & 5 deletions package/MDAnalysis/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,17 @@
================================================================
"""

__all__ = ['log', 'transformations', 'util', 'mdamath', 'distances',
'NeighborSearch', 'formats', 'pkdtree', 'nsgrid']
__all__ = [
"log",
"transformations",
"util",
"mdamath",
"distances",
"NeighborSearch",
"formats",
"pkdtree",
"nsgrid",
]

from . import log
from . import transformations
Expand All @@ -39,6 +48,8 @@
from . import formats
from . import pkdtree
from . import nsgrid
from .picklable_file_io import (FileIOPicklable,
BufferIOPicklable,
TextIOPicklable)
from .picklable_file_io import (
FileIOPicklable,
BufferIOPicklable,
TextIOPicklable,
)
25 changes: 13 additions & 12 deletions package/MDAnalysis/lib/_distopia.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@

# check for compatibility: currently needs to be >=0.2.0,<0.3.0 (issue
# #4740) No distopia.__version__ available so we have to do some probing.
needed_funcs = ['calc_bonds_no_box_float', 'calc_bonds_ortho_float']
needed_funcs = ["calc_bonds_no_box_float", "calc_bonds_ortho_float"]
has_distopia_020 = all([hasattr(distopia, func) for func in needed_funcs])
if not has_distopia_020:
warnings.warn("Install 'distopia>=0.2.0,<0.3.0' to be used with this "
"release of MDAnalysis. Your installed version of "
"distopia >=0.3.0 will NOT be used.",
category=RuntimeWarning)
warnings.warn(
"Install 'distopia>=0.2.0,<0.3.0' to be used with this "
"release of MDAnalysis. Your installed version of "
"distopia >=0.3.0 will NOT be used.",
category=RuntimeWarning,
)
del distopia
HAS_DISTOPIA = False

Expand All @@ -59,23 +61,22 @@
def calc_bond_distance_ortho(
coords1, coords2: np.ndarray, box: np.ndarray, results: np.ndarray
) -> None:
distopia.calc_bonds_ortho_float(
coords1, coords2, box[:3], results=results
)
distopia.calc_bonds_ortho_float(coords1, coords2, box[:3], results=results)
# upcast is currently required, change for 3.0, see #3927


def calc_bond_distance(
coords1: np.ndarray, coords2: np.ndarray, results: np.ndarray
) -> None:
distopia.calc_bonds_no_box_float(
coords1, coords2, results=results
)
distopia.calc_bonds_no_box_float(coords1, coords2, results=results)
# upcast is currently required, change for 3.0, see #3927


def calc_bond_distance_triclinic(
coords1: np.ndarray, coords2: np.ndarray, box: np.ndarray, results: np.ndarray
coords1: np.ndarray,
coords2: np.ndarray,
box: np.ndarray,
results: np.ndarray,
) -> None:
# redirect to serial backend
warnings.warn(
Expand Down
15 changes: 10 additions & 5 deletions package/MDAnalysis/lib/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,18 @@ def autocorrelation(list_of_sets, tau_max, window_step=1):
"""

# check types
if (type(list_of_sets) != list and len(list_of_sets) != 0) or type(list_of_sets[0]) != set:
raise TypeError("list_of_sets must be a one-dimensional list of sets") # pragma: no cover
if (type(list_of_sets) != list and len(list_of_sets) != 0) or type(
list_of_sets[0]
) != set:
raise TypeError(
"list_of_sets must be a one-dimensional list of sets"
) # pragma: no cover

# Check dimensions of parameters
if len(list_of_sets) < tau_max:
raise ValueError("tau_max cannot be greater than the length of list_of_sets") # pragma: no cover
raise ValueError(
"tau_max cannot be greater than the length of list_of_sets"
) # pragma: no cover

tau_timeseries = list(range(1, tau_max + 1))
timeseries_data = [[] for _ in range(tau_max)]
Expand All @@ -157,7 +163,7 @@ def autocorrelation(list_of_sets, tau_max, window_step=1):
break

# continuous: IDs that survive from t to t + tau and at every frame in between
Ntau = len(set.intersection(*list_of_sets[t:t + tau + 1]))
Ntau = len(set.intersection(*list_of_sets[t : t + tau + 1]))
timeseries_data[tau - 1].append(Ntau / float(Nt))

timeseries = [np.mean(x) for x in timeseries_data]
Expand Down Expand Up @@ -257,4 +263,3 @@ def correct_intermittency(list_of_sets, intermittency):

seen_frames_ago[element] = 0
return list_of_sets

Loading

0 comments on commit 25e755f

Please sign in to comment.