Skip to content

Commit

Permalink
Split resolve_chemical_symbol and resolve_atomic_number in AtomsState
Browse files Browse the repository at this point in the history
Added testing for TestAtomsState
  • Loading branch information
JosePizarro3 committed Mar 5, 2024
1 parent 9b96463 commit a7fc0db
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 8 deletions.
34 changes: 26 additions & 8 deletions simulationdataschema/atoms_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,31 +627,49 @@ class AtomsState(ArchiveSection):
sub_section=HubbardInteractions.m_def, repeats=False
)

def resolve_chemical_symbol_and_number(self, logger: BoundLogger) -> None:
def resolve_chemical_symbol(self, logger: BoundLogger) -> Optional[str]:
"""
Resolves the chemical symbol from the atomic number and viceversa.
Resolves the `chemical_symbol` from the `atomic_number`.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[str]): The resolved `chemical_symbol`.
"""
f = lambda x: tuple(map(bool, x))
if f((self.chemical_symbol, self.atomic_number)) == f((None, not None)):
if self.atomic_number is not None:
try:
self.chemical_symbol = ase.data.chemical_symbols[self.atomic_number]
return ase.data.chemical_symbols[self.atomic_number]
except IndexError:
logger.error(
'The `AtomsState.atomic_number` is out of range of the periodic table.'
)
elif f((self.chemical_symbol, self.atomic_number)) == f((not None, None)):
return None

def resolve_atomic_number(self, logger: BoundLogger) -> Optional[int]:
"""
Resolves the `atomic_number` from the `chemical_symbol`.
Args:
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[int]): The resolved `atomic_number`.
"""
if self.chemical_symbol is not None:
try:
self.atomic_number = ase.data.atomic_numbers[self.chemical_symbol]
return ase.data.atomic_numbers[self.chemical_symbol]
except IndexError:
logger.error(
'The `AtomsState.chemical_symbol` is not recognized in the periodic table.'
)
return None

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# Get chemical_symbol from atomic_number and viceversa
self.resolve_chemical_symbol_and_number(logger)
if self.chemical_symbol is None:
self.chemical_symbol = self.resolve_chemical_symbol(logger)
if self.atomic_number is None:
self.atomic_number = self.resolve_atomic_number(logger)
40 changes: 40 additions & 0 deletions tests/test_atoms_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,43 @@ def test_normalize(self, hubbard_interactions):
assert np.isclose(hubbard_interactions.u_effective.to('eV').magnitude, 1.0)
assert np.isclose(hubbard_interactions.u_interaction.to('eV').magnitude, 3.0)
assert hubbard_interactions.slater_integrals is None


class TestAtomsState:
"""
Tests the `AtomsState` class defined in atoms_state.py.
"""

logger = logging.getLogger(__name__)

@staticmethod
def add_element_information(atom_state, quantity_name, value) -> None:
setattr(atom_state, quantity_name, value)

@pytest.fixture(autouse=True)
def atom_state(self) -> AtomsState:
return AtomsState()

@pytest.mark.parametrize(
'chemical_symbol, atomic_number',
[
('Fe', 26),
('H', 1),
('Cu', 29),
('O', 8),
],
)
def test_chemical_symbol_and_atomic_number(
self, atom_state, chemical_symbol, atomic_number
):
"""
Test the `chemical_symbol` and `atomic_number` resolution for the `AtomsState` section.
"""
# Testing `chemical_symbol`
self.add_element_information(atom_state, 'chemical_symbol', chemical_symbol)
resolved_atomic_number = atom_state.resolve_atomic_number(self.logger)
assert resolved_atomic_number == atomic_number
# Testing `atomic_number`
self.add_element_information(atom_state, 'atomic_number', atomic_number)
resolved_chemical_symbol = atom_state.resolve_chemical_symbol(self.logger)
assert resolved_chemical_symbol == chemical_symbol

0 comments on commit a7fc0db

Please sign in to comment.