Skip to content

Commit

Permalink
Merge pull request #44 from ifd3f/pydantic-2-update
Browse files Browse the repository at this point in the history
Update code and environment to be Pydantic 2.0 compatible
  • Loading branch information
ALescoulie authored Nov 8, 2023
2 parents ffc3bd0 + 5e86d5a commit a2a5707
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 144 deletions.
6 changes: 3 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

name: mdsapt
channels:
- psi4
- conda-forge
- conda-forge/label/libint_dev # needed for psi4 1.8.1
- defaults
dependencies:
# This must be specified FIRST to avoid packaging errors.
- psi4>=1.6.1,<1.7
- psi4=1.8.1=*_2 # pydantic update, for more info see https://github.com/psi4/psi4/issues/2991

- mdanalysis>=2.2.0,<2.3
- mdanalysis
- click
- numpy
- openmm
Expand Down
158 changes: 47 additions & 111 deletions mdsapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,15 @@
# pylint: disable=no-self-argument

import dataclasses
from dataclasses import dataclass
from enum import Enum
from os import PathLike
import os
from pathlib import Path
from typing import List, Dict, Tuple, Literal, Optional, \
from typing import Annotated, List, Dict, Tuple, Literal, Optional, \
Union, Any, Set, Iterable

import logging

import pydantic
from pydantic import BaseModel, conint, Field, root_validator, \
from pydantic import BaseModel, Field, model_validator, \
FilePath, ValidationError, DirectoryPath
import yaml

Expand Down Expand Up @@ -58,7 +55,7 @@ class SysLimitsConfig(BaseModel):
"""
Resource limits for your system.
"""
ncpus: conint(ge=1)
ncpus: Annotated[int, Field(strict=True, ge=1)]
memory: str


Expand All @@ -83,8 +80,7 @@ class SimulationConfig(BaseModel):
charge_guesser: ChargeGuesser


@dataclass
class TopologySelection:
class TopologySelection(BaseModel):
"""
A configuration item for selecting a single topology. To successfully import a topology,
it must be supported by MDAnalysis.
Expand All @@ -97,31 +93,18 @@ class TopologySelection:
.. seealso::
`List of topology formats that MDAnalysis supports <https://docs.mdanalysis.org/1.1.1/documentation_pages/topology/init.html>`_
"""
class _TopologySelection(BaseModel):
path: FilePath
topology_format: Optional[str]
charge_overrides: Dict[int, int] = Field(default_factory=dict)

path: Path
topology_format: Optional[str] = None
charge_overrides: Dict[int, int] = dataclasses.field(default_factory=dict)

@model_validator(mode='before')
@classmethod
def __get_validators__(cls):
yield cls._validate

@classmethod
def _validate(cls, values):
"""
Validates the topology. You should not call this directly.
"""
result = pydantic.parse_obj_as(Union[FilePath, cls._TopologySelection], values)
def _accept_bare_string(cls, data: Any) -> Any:
try:
path = Path(result)
return {'path': Path(data), 'charge_overrides': {}}
except TypeError:
return TopologySelection(path=result.path, topology_format=result.topology_format,
charge_overrides=result.charge_overrides)
return TopologySelection(path=path)
return data

def create_universe(self, *coordinates: Any, **kwargs) -> mda.Universe:
"""Create a universe based on this topology and the given arguments.."""
Expand All @@ -138,17 +121,18 @@ class RangeFrameSelection(BaseModel):
stop: the last frame to use, inclusive.
step: step between frames.
"""
start: Optional[conint(ge=0)]
stop: Optional[conint(ge=0)]
step: Optional[conint(ge=1)] = 1
start: Optional[Annotated[int, Field(strict=True, ge=0)]]
stop: Optional[Annotated[int, Field(strict=True, ge=0)]]
step: Optional[Annotated[int, Field(strict=True, ge=1)]] = 1

@root_validator()
def _check_start_before_stop(cls, values: Dict[str, int]) -> Dict[str, int]:
@model_validator(mode='after')
def _check_start_before_stop(self) -> 'RangeFrameSelection':
"""
Ensures that a valid range is selected for frame iteration.
"""
assert values['start'] <= values['stop'], "start must be before stop"
return values
if self.start is not None and self.stop is not None and self.start > self.stop:
raise ValueError('Start must be before stop')
return self


class TrajectoryAnalysisConfig(BaseModel):
Expand All @@ -171,40 +155,11 @@ class TrajectoryAnalysisConfig(BaseModel):
type: Literal['trajectory']
topology: TopologySelection
trajectories: List[FilePath]
pairs: List[Tuple[conint(ge=0), conint(ge=0)]]
pairs: List[Tuple[Annotated[int, Field(strict=True, ge=0)], Annotated[int, Field(strict=True, ge=0)]
]]
frames: RangeFrameSelection
output: str

# noinspection PyMethodParameters
@root_validator
def check_valid_md_system(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validates that setting work with the selected MD system
"""
errors: List[str] = []

topology: TopologySelection = values['topology']
trajectories: List[FilePath] = values['trajectories']
ag_pair: List[Tuple[conint(ge=0), conint(ge=0)]] = values['pairs']
frames: RangeFrameSelection = values['frames']

try:
unv = topology.create_universe([str(p) for p in trajectories])
except OSError as err:
raise ValueError("Error while creating the universe") from err

missing_selections = get_invalid_residue_selections({r for p in ag_pair for r in p}, unv)
if len(missing_selections) > 0:
errors.append(f'Selected residues are missing from topology: {missing_selections}')

trajlen: int = len(unv.trajectory)
if trajlen <= frames.stop:
errors.append(f'Stop {frames.stop} exceeds trajectory length {trajlen}.')

if len(errors) > 0:
raise ValidationError([errors], cls)
return values

def create_universe(self, **universe_kwargs) -> mda.Universe:
"""
Loads a universe from the given topology and trajectory
Expand All @@ -226,35 +181,33 @@ def get_invalid_residue_selections(residues: Iterable[int], unv: mda.Universe) -
]


DockingElement = Union[Literal['L'], conint(ge=-1)]
DockingElement = Union[Literal['L'], Annotated[int, Field(strict=True, ge=-1)]
]
"""
A single element to analyze in docking.
The literal 'L' specifies the ligand, whereas an integer specifies the protein residue number.
"""


class TopologyGroupSelection(BaseModel):
"""
A selection of a group of topologies.
TopologyGroupSelection = Union[DirectoryPath, List[TopologySelection]]
"""
A selection of a group of topologies.
In a YAML config, this may either be a path to a flat directory full of topologies
or a list of :obj:`TopologySelection`s.
"""
__root__: Union[DirectoryPath, List[TopologySelection]]
In a YAML config, this may either be a path to a flat directory full of topologies
or a list of :obj:`TopologySelection`s.
"""

def get_individual_topologies(self) -> List[TopologySelection]:
"""
It
"""
if isinstance(self.__root__, list):
return self.__root__

return [
TopologySelection(path=f)
for f in self.__root__.iterdir()
if f.is_file()
]
def get_individual_topologies(sel: TopologyGroupSelection) -> List[TopologySelection]:
if isinstance(sel, list):
return sel

return [
TopologySelection(path=f)
for f in sel.iterdir()
if f.is_file()
]


# noinspection PyMethodParameters
Expand Down Expand Up @@ -288,34 +241,17 @@ class DockingAnalysisConfig(BaseModel):
"""
type: Literal['docking']
pairs: List[Tuple[DockingElement, DockingElement]]
combined_topologies: Optional[TopologyGroupSelection]
protein: Optional[TopologySelection]
ligands: Optional[TopologyGroupSelection]
combined_topologies: Optional[TopologyGroupSelection] = None
protein: Optional[TopologySelection] = None
ligands: Optional[TopologyGroupSelection] = None
output: str

@root_validator
def _check_valid_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validates that the provided settings are valid.
"""
errors: List[str] = []

pairs: List[Tuple[DockingElement, DockingElement]] = values['pairs']
protein_selections: Set[int] = {
i for pair in pairs for i in pair if i != 'L'
}
ens: Ensemble = cls._build_ensemble(combined_topologies=values.get('combined_topologies'),
protein=values.get('protein'),
ligands=values.get('ligands'))
missing_selections: List[int] = []

for v in ens.values():
missing_selections += get_invalid_residue_selections(protein_selections, v)

if len(missing_selections) > 0:
errors.append(f'Selected residues are missing from topology: {missing_selections}')

return values
@model_validator(mode='after')
def ensure_presence_of_args(self) -> 'DockingAnalysisConfig':
provided_args = (self.combined_topologies is not None, self.protein is not None, self.ligands is not None)
if provided_args not in [(True, False, False), (False, True, True)]:
raise ValueError('Must provide `protein` and `ligands` keys, or only `combined_topologies`')
return self

def build_ensemble(self) -> Ensemble:
return self._build_ensemble(combined_topologies=self.combined_topologies,
Expand All @@ -333,12 +269,12 @@ def _build_ensemble(
"""Fails if the wrong types of arguments are provided."""
if combined_topologies is not None and (protein, ligands) == (None, None):
return Ensemble.build_from_files(
[top.path for top in combined_topologies.get_individual_topologies()]
[top.path for top in get_individual_topologies(combined_topologies)]
)

if combined_topologies is None and None not in (protein, ligands):
ens: Ensemble = Ensemble.build_from_files([top.path for top
in ligands.get_individual_topologies()])
in get_individual_topologies(ligands)])
protein_sys: mda.Universe = mda.Universe(str(protein.path))
protein_mol: mda.AtomGroup = protein_sys.select_atoms("protein")
ens = ens.merge(protein_mol)
Expand Down Expand Up @@ -370,7 +306,7 @@ def load_from_yaml_file(path: Union[str, PathLike]) -> Config:
"""
with Path(path).open('r', encoding='utf8') as file:
try:
return Config(**yaml.safe_load(file))
return Config.model_validate(yaml.safe_load(file))
except ValidationError as err:
logger.exception("Error while loading config from %r", path)
raise err
36 changes: 6 additions & 30 deletions mdsapt/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,6 @@ def test_frame_range_selection() -> None:
RangeFrameSelection(**frame_range)


@pytest.mark.parametrize('key,var', [
('trajectories', [f'{resources_dir}/test_read_error.dcd']),
('frames', {'start': 1, 'stop': 120}),
('pairs', [(250, 251)])
])
def test_traj_analysis_config(key: str, var: Any) -> None:
"""
Tests TrajectoryAnalysis config validation with different errors
"""
traj_analysis_dict: Dict[str, Any] = dict(
type='trajectory',
topology=f'{resources_dir}/testtop.psf',
trajectories=[f'{resources_dir}/testtraj.dcd'],
pairs=[(132, 152), (34, 152)],
frames={'start': 1, 'stop': 4},
output=True
)

traj_analysis_dict[key] = var

with pytest.raises(ValidationError):
TrajectoryAnalysisConfig(**traj_analysis_dict)


def test_traj_sel() -> None:
"""
Test getting set of selections
Expand All @@ -63,9 +39,9 @@ def test_traj_sel() -> None:
trajectories=[f'{resources_dir}/testtraj.dcd'],
pairs=[(132, 152), (34, 152)],
frames={'start': 1, 'stop': 4},
output=True)
output='something.csv')

cfg: TrajectoryAnalysisConfig = TrajectoryAnalysisConfig(**traj_analysis_dict)
cfg: TrajectoryAnalysisConfig = TrajectoryAnalysisConfig.model_validate(traj_analysis_dict)
assert {34, 132, 152} == cfg.get_selections()


Expand Down Expand Up @@ -142,7 +118,7 @@ def test_topology_selection_parses_from_string() -> None:
Test topology selection parses
"""
data: str = str(resources_dir / 'test_input.yaml')
result = pydantic.parse_obj_as(TopologySelection, data)
result = TopologySelection.model_validate(data)

assert result.path == resources_dir / 'test_input.yaml'
assert result.charge_overrides == {}
Expand All @@ -153,7 +129,7 @@ def test_topology_selection_parses_from_obj_with_overrides() -> None:
Test alternative topology selection
"""
data = {'path': str(resources_dir / 'test_input.yaml'), 'charge_overrides': {'13': 3}}
result = pydantic.parse_obj_as(TopologySelection, data)
result = TopologySelection.model_validate(data)

assert result.path == resources_dir / 'test_input.yaml'
assert result.charge_overrides == {13: 3}
Expand All @@ -164,7 +140,7 @@ def test_topology_selection_parses_from_obj_without_overrides() -> None:
Tests no charge overrides topology selection
"""
data = {'path': str(resources_dir / 'test_input.yaml')}
result = pydantic.parse_obj_as(TopologySelection, data)
result = TopologySelection.model_validate(data)

assert result.path == resources_dir / 'test_input.yaml'
assert result.charge_overrides == {}
Expand Down Expand Up @@ -203,7 +179,7 @@ def test_seperated_docking_dir() -> None:
output=mktemp(),
)

cfg: DockingAnalysisConfig = DockingAnalysisConfig(**config_dict)
cfg: DockingAnalysisConfig = DockingAnalysisConfig.model_validate(config_dict)
ens: Ensemble = cfg.build_ensemble()
assert len(ens) == 7

Expand Down

0 comments on commit a2a5707

Please sign in to comment.