From a7fc0db7dd716e80a3d52be4ae6b570a3fbecbc3 Mon Sep 17 00:00:00 2001 From: JosePizarro3 Date: Tue, 5 Mar 2024 09:44:40 +0100 Subject: [PATCH] Split resolve_chemical_symbol and resolve_atomic_number in AtomsState Added testing for TestAtomsState --- simulationdataschema/atoms_state.py | 34 ++++++++++++++++++------ tests/test_atoms_state.py | 40 +++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/simulationdataschema/atoms_state.py b/simulationdataschema/atoms_state.py index 5584f913..0ed6f53b 100644 --- a/simulationdataschema/atoms_state.py +++ b/simulationdataschema/atoms_state.py @@ -627,31 +627,49 @@ class AtomsState(ArchiveSection): sub_section=HubbardInteractions.m_def, repeats=False ) - def resolve_chemical_symbol_and_number(self, logger: BoundLogger) -> None: + def resolve_chemical_symbol(self, logger: BoundLogger) -> Optional[str]: """ - Resolves the chemical symbol from the atomic number and viceversa. + Resolves the `chemical_symbol` from the `atomic_number`. Args: logger (BoundLogger): The logger to log messages. + + Returns: + (Optional[str]): The resolved `chemical_symbol`. """ - f = lambda x: tuple(map(bool, x)) - if f((self.chemical_symbol, self.atomic_number)) == f((None, not None)): + if self.atomic_number is not None: try: - self.chemical_symbol = ase.data.chemical_symbols[self.atomic_number] + return ase.data.chemical_symbols[self.atomic_number] except IndexError: logger.error( 'The `AtomsState.atomic_number` is out of range of the periodic table.' ) - elif f((self.chemical_symbol, self.atomic_number)) == f((not None, None)): + return None + + def resolve_atomic_number(self, logger: BoundLogger) -> Optional[int]: + """ + Resolves the `atomic_number` from the `chemical_symbol`. + + Args: + logger (BoundLogger): The logger to log messages. + + Returns: + (Optional[int]): The resolved `atomic_number`. + """ + if self.chemical_symbol is not None: try: - self.atomic_number = ase.data.atomic_numbers[self.chemical_symbol] + return ase.data.atomic_numbers[self.chemical_symbol] except IndexError: logger.error( 'The `AtomsState.chemical_symbol` is not recognized in the periodic table.' ) + return None def normalize(self, archive, logger) -> None: super().normalize(archive, logger) # Get chemical_symbol from atomic_number and viceversa - self.resolve_chemical_symbol_and_number(logger) + if self.chemical_symbol is None: + self.chemical_symbol = self.resolve_chemical_symbol(logger) + if self.atomic_number is None: + self.atomic_number = self.resolve_atomic_number(logger) diff --git a/tests/test_atoms_state.py b/tests/test_atoms_state.py index f037e3e7..9f30e2bb 100644 --- a/tests/test_atoms_state.py +++ b/tests/test_atoms_state.py @@ -344,3 +344,43 @@ def test_normalize(self, hubbard_interactions): assert np.isclose(hubbard_interactions.u_effective.to('eV').magnitude, 1.0) assert np.isclose(hubbard_interactions.u_interaction.to('eV').magnitude, 3.0) assert hubbard_interactions.slater_integrals is None + + +class TestAtomsState: + """ + Tests the `AtomsState` class defined in atoms_state.py. + """ + + logger = logging.getLogger(__name__) + + @staticmethod + def add_element_information(atom_state, quantity_name, value) -> None: + setattr(atom_state, quantity_name, value) + + @pytest.fixture(autouse=True) + def atom_state(self) -> AtomsState: + return AtomsState() + + @pytest.mark.parametrize( + 'chemical_symbol, atomic_number', + [ + ('Fe', 26), + ('H', 1), + ('Cu', 29), + ('O', 8), + ], + ) + def test_chemical_symbol_and_atomic_number( + self, atom_state, chemical_symbol, atomic_number + ): + """ + Test the `chemical_symbol` and `atomic_number` resolution for the `AtomsState` section. + """ + # Testing `chemical_symbol` + self.add_element_information(atom_state, 'chemical_symbol', chemical_symbol) + resolved_atomic_number = atom_state.resolve_atomic_number(self.logger) + assert resolved_atomic_number == atomic_number + # Testing `atomic_number` + self.add_element_information(atom_state, 'atomic_number', atomic_number) + resolved_chemical_symbol = atom_state.resolve_chemical_symbol(self.logger) + assert resolved_chemical_symbol == chemical_symbol