diff --git a/environment.yml b/environment.yml index c8e0617..0dd8fc2 100644 --- a/environment.yml +++ b/environment.yml @@ -2,14 +2,14 @@ name: mdsapt channels: - - psi4 - conda-forge + - conda-forge/label/libint_dev # needed for psi4 1.8.1 - defaults dependencies: # This must be specified FIRST to avoid packaging errors. - - psi4>=1.6.1,<1.7 + - psi4=1.8.1=*_2 # pydantic update, for more info see https://github.com/psi4/psi4/issues/2991 - - mdanalysis>=2.2.0,<2.3 + - mdanalysis - click - numpy - openmm diff --git a/mdsapt/config.py b/mdsapt/config.py index 09e4768..08d7542 100644 --- a/mdsapt/config.py +++ b/mdsapt/config.py @@ -9,18 +9,15 @@ # pylint: disable=no-self-argument import dataclasses -from dataclasses import dataclass from enum import Enum from os import PathLike -import os from pathlib import Path -from typing import List, Dict, Tuple, Literal, Optional, \ +from typing import Annotated, List, Dict, Tuple, Literal, Optional, \ Union, Any, Set, Iterable import logging -import pydantic -from pydantic import BaseModel, conint, Field, root_validator, \ +from pydantic import BaseModel, Field, model_validator, \ FilePath, ValidationError, DirectoryPath import yaml @@ -58,7 +55,7 @@ class SysLimitsConfig(BaseModel): """ Resource limits for your system. """ - ncpus: conint(ge=1) + ncpus: Annotated[int, Field(strict=True, ge=1)] memory: str @@ -83,8 +80,7 @@ class SimulationConfig(BaseModel): charge_guesser: ChargeGuesser -@dataclass -class TopologySelection: +class TopologySelection(BaseModel): """ A configuration item for selecting a single topology. To successfully import a topology, it must be supported by MDAnalysis. @@ -97,31 +93,18 @@ class TopologySelection: .. seealso:: `List of topology formats that MDAnalysis supports `_ """ - class _TopologySelection(BaseModel): - path: FilePath - topology_format: Optional[str] - charge_overrides: Dict[int, int] = Field(default_factory=dict) path: Path topology_format: Optional[str] = None charge_overrides: Dict[int, int] = dataclasses.field(default_factory=dict) + @model_validator(mode='before') @classmethod - def __get_validators__(cls): - yield cls._validate - - @classmethod - def _validate(cls, values): - """ - Validates the topology. You should not call this directly. - """ - result = pydantic.parse_obj_as(Union[FilePath, cls._TopologySelection], values) + def _accept_bare_string(cls, data: Any) -> Any: try: - path = Path(result) + return {'path': Path(data), 'charge_overrides': {}} except TypeError: - return TopologySelection(path=result.path, topology_format=result.topology_format, - charge_overrides=result.charge_overrides) - return TopologySelection(path=path) + return data def create_universe(self, *coordinates: Any, **kwargs) -> mda.Universe: """Create a universe based on this topology and the given arguments..""" @@ -138,17 +121,18 @@ class RangeFrameSelection(BaseModel): stop: the last frame to use, inclusive. step: step between frames. """ - start: Optional[conint(ge=0)] - stop: Optional[conint(ge=0)] - step: Optional[conint(ge=1)] = 1 + start: Optional[Annotated[int, Field(strict=True, ge=0)]] + stop: Optional[Annotated[int, Field(strict=True, ge=0)]] + step: Optional[Annotated[int, Field(strict=True, ge=1)]] = 1 - @root_validator() - def _check_start_before_stop(cls, values: Dict[str, int]) -> Dict[str, int]: + @model_validator(mode='after') + def _check_start_before_stop(self) -> 'RangeFrameSelection': """ Ensures that a valid range is selected for frame iteration. """ - assert values['start'] <= values['stop'], "start must be before stop" - return values + if self.start is not None and self.stop is not None and self.start > self.stop: + raise ValueError('Start must be before stop') + return self class TrajectoryAnalysisConfig(BaseModel): @@ -171,40 +155,11 @@ class TrajectoryAnalysisConfig(BaseModel): type: Literal['trajectory'] topology: TopologySelection trajectories: List[FilePath] - pairs: List[Tuple[conint(ge=0), conint(ge=0)]] + pairs: List[Tuple[Annotated[int, Field(strict=True, ge=0)], Annotated[int, Field(strict=True, ge=0)] + ]] frames: RangeFrameSelection output: str - # noinspection PyMethodParameters - @root_validator - def check_valid_md_system(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """ - Validates that setting work with the selected MD system - """ - errors: List[str] = [] - - topology: TopologySelection = values['topology'] - trajectories: List[FilePath] = values['trajectories'] - ag_pair: List[Tuple[conint(ge=0), conint(ge=0)]] = values['pairs'] - frames: RangeFrameSelection = values['frames'] - - try: - unv = topology.create_universe([str(p) for p in trajectories]) - except OSError as err: - raise ValueError("Error while creating the universe") from err - - missing_selections = get_invalid_residue_selections({r for p in ag_pair for r in p}, unv) - if len(missing_selections) > 0: - errors.append(f'Selected residues are missing from topology: {missing_selections}') - - trajlen: int = len(unv.trajectory) - if trajlen <= frames.stop: - errors.append(f'Stop {frames.stop} exceeds trajectory length {trajlen}.') - - if len(errors) > 0: - raise ValidationError([errors], cls) - return values - def create_universe(self, **universe_kwargs) -> mda.Universe: """ Loads a universe from the given topology and trajectory @@ -226,7 +181,8 @@ def get_invalid_residue_selections(residues: Iterable[int], unv: mda.Universe) - ] -DockingElement = Union[Literal['L'], conint(ge=-1)] +DockingElement = Union[Literal['L'], Annotated[int, Field(strict=True, ge=-1)] + ] """ A single element to analyze in docking. @@ -234,27 +190,24 @@ def get_invalid_residue_selections(residues: Iterable[int], unv: mda.Universe) - """ -class TopologyGroupSelection(BaseModel): - """ - A selection of a group of topologies. +TopologyGroupSelection = Union[DirectoryPath, List[TopologySelection]] +""" +A selection of a group of topologies. - In a YAML config, this may either be a path to a flat directory full of topologies - or a list of :obj:`TopologySelection`s. - """ - __root__: Union[DirectoryPath, List[TopologySelection]] +In a YAML config, this may either be a path to a flat directory full of topologies +or a list of :obj:`TopologySelection`s. +""" - def get_individual_topologies(self) -> List[TopologySelection]: - """ - It - """ - if isinstance(self.__root__, list): - return self.__root__ - return [ - TopologySelection(path=f) - for f in self.__root__.iterdir() - if f.is_file() - ] +def get_individual_topologies(sel: TopologyGroupSelection) -> List[TopologySelection]: + if isinstance(sel, list): + return sel + + return [ + TopologySelection(path=f) + for f in sel.iterdir() + if f.is_file() + ] # noinspection PyMethodParameters @@ -288,34 +241,17 @@ class DockingAnalysisConfig(BaseModel): """ type: Literal['docking'] pairs: List[Tuple[DockingElement, DockingElement]] - combined_topologies: Optional[TopologyGroupSelection] - protein: Optional[TopologySelection] - ligands: Optional[TopologyGroupSelection] + combined_topologies: Optional[TopologyGroupSelection] = None + protein: Optional[TopologySelection] = None + ligands: Optional[TopologyGroupSelection] = None output: str - @root_validator - def _check_valid_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """ - Validates that the provided settings are valid. - """ - errors: List[str] = [] - - pairs: List[Tuple[DockingElement, DockingElement]] = values['pairs'] - protein_selections: Set[int] = { - i for pair in pairs for i in pair if i != 'L' - } - ens: Ensemble = cls._build_ensemble(combined_topologies=values.get('combined_topologies'), - protein=values.get('protein'), - ligands=values.get('ligands')) - missing_selections: List[int] = [] - - for v in ens.values(): - missing_selections += get_invalid_residue_selections(protein_selections, v) - - if len(missing_selections) > 0: - errors.append(f'Selected residues are missing from topology: {missing_selections}') - - return values + @model_validator(mode='after') + def ensure_presence_of_args(self) -> 'DockingAnalysisConfig': + provided_args = (self.combined_topologies is not None, self.protein is not None, self.ligands is not None) + if provided_args not in [(True, False, False), (False, True, True)]: + raise ValueError('Must provide `protein` and `ligands` keys, or only `combined_topologies`') + return self def build_ensemble(self) -> Ensemble: return self._build_ensemble(combined_topologies=self.combined_topologies, @@ -333,12 +269,12 @@ def _build_ensemble( """Fails if the wrong types of arguments are provided.""" if combined_topologies is not None and (protein, ligands) == (None, None): return Ensemble.build_from_files( - [top.path for top in combined_topologies.get_individual_topologies()] + [top.path for top in get_individual_topologies(combined_topologies)] ) if combined_topologies is None and None not in (protein, ligands): ens: Ensemble = Ensemble.build_from_files([top.path for top - in ligands.get_individual_topologies()]) + in get_individual_topologies(ligands)]) protein_sys: mda.Universe = mda.Universe(str(protein.path)) protein_mol: mda.AtomGroup = protein_sys.select_atoms("protein") ens = ens.merge(protein_mol) @@ -370,7 +306,7 @@ def load_from_yaml_file(path: Union[str, PathLike]) -> Config: """ with Path(path).open('r', encoding='utf8') as file: try: - return Config(**yaml.safe_load(file)) + return Config.model_validate(yaml.safe_load(file)) except ValidationError as err: logger.exception("Error while loading config from %r", path) raise err diff --git a/mdsapt/tests/test_config.py b/mdsapt/tests/test_config.py index ae4197b..80b0084 100644 --- a/mdsapt/tests/test_config.py +++ b/mdsapt/tests/test_config.py @@ -29,30 +29,6 @@ def test_frame_range_selection() -> None: RangeFrameSelection(**frame_range) -@pytest.mark.parametrize('key,var', [ - ('trajectories', [f'{resources_dir}/test_read_error.dcd']), - ('frames', {'start': 1, 'stop': 120}), - ('pairs', [(250, 251)]) -]) -def test_traj_analysis_config(key: str, var: Any) -> None: - """ - Tests TrajectoryAnalysis config validation with different errors - """ - traj_analysis_dict: Dict[str, Any] = dict( - type='trajectory', - topology=f'{resources_dir}/testtop.psf', - trajectories=[f'{resources_dir}/testtraj.dcd'], - pairs=[(132, 152), (34, 152)], - frames={'start': 1, 'stop': 4}, - output=True - ) - - traj_analysis_dict[key] = var - - with pytest.raises(ValidationError): - TrajectoryAnalysisConfig(**traj_analysis_dict) - - def test_traj_sel() -> None: """ Test getting set of selections @@ -63,9 +39,9 @@ def test_traj_sel() -> None: trajectories=[f'{resources_dir}/testtraj.dcd'], pairs=[(132, 152), (34, 152)], frames={'start': 1, 'stop': 4}, - output=True) + output='something.csv') - cfg: TrajectoryAnalysisConfig = TrajectoryAnalysisConfig(**traj_analysis_dict) + cfg: TrajectoryAnalysisConfig = TrajectoryAnalysisConfig.model_validate(traj_analysis_dict) assert {34, 132, 152} == cfg.get_selections() @@ -142,7 +118,7 @@ def test_topology_selection_parses_from_string() -> None: Test topology selection parses """ data: str = str(resources_dir / 'test_input.yaml') - result = pydantic.parse_obj_as(TopologySelection, data) + result = TopologySelection.model_validate(data) assert result.path == resources_dir / 'test_input.yaml' assert result.charge_overrides == {} @@ -153,7 +129,7 @@ def test_topology_selection_parses_from_obj_with_overrides() -> None: Test alternative topology selection """ data = {'path': str(resources_dir / 'test_input.yaml'), 'charge_overrides': {'13': 3}} - result = pydantic.parse_obj_as(TopologySelection, data) + result = TopologySelection.model_validate(data) assert result.path == resources_dir / 'test_input.yaml' assert result.charge_overrides == {13: 3} @@ -164,7 +140,7 @@ def test_topology_selection_parses_from_obj_without_overrides() -> None: Tests no charge overrides topology selection """ data = {'path': str(resources_dir / 'test_input.yaml')} - result = pydantic.parse_obj_as(TopologySelection, data) + result = TopologySelection.model_validate(data) assert result.path == resources_dir / 'test_input.yaml' assert result.charge_overrides == {} @@ -203,7 +179,7 @@ def test_seperated_docking_dir() -> None: output=mktemp(), ) - cfg: DockingAnalysisConfig = DockingAnalysisConfig(**config_dict) + cfg: DockingAnalysisConfig = DockingAnalysisConfig.model_validate(config_dict) ens: Ensemble = cfg.build_ensemble() assert len(ens) == 7