diff --git a/simulationdataschema/atoms_state.py b/simulationdataschema/atoms_state.py index 73fcb8f4..26eb549c 100644 --- a/simulationdataschema/atoms_state.py +++ b/simulationdataschema/atoms_state.py @@ -294,9 +294,8 @@ def normalize(self, archive, logger) -> None: setattr(self, f'{quantum_name}_quantum_{quantum_type}', quantity) # Resolve the degeneracy - self.degeneracy = ( - self.resolve_degeneracy() if self.degeneracy is None else self.degeneracy - ) + if self.degeneracy is None: + self.degeneracy = self.resolve_degeneracy() class CoreHole(ArchiveSection): diff --git a/tests/test_atoms_state.py b/tests/test_atoms_state.py index b4c3ee1d..383dcfe3 100644 --- a/tests/test_atoms_state.py +++ b/tests/test_atoms_state.py @@ -37,6 +37,24 @@ class TestOrbitalsState: logger = logging.getLogger(__name__) + @staticmethod + def add_quantum_numbers(orbital_state, quantum_name, quantum_type, value) -> None: + """Adds quantum numbers to the `OrbitalsState` section.""" + if quantum_name == 'ml': # l_quantum_number must be specified + orbital_state.l_quantum_number = 1 + setattr(orbital_state, f'{quantum_name}_quantum_{quantum_type}', value) + + @staticmethod + def add_state( + orbital_state, l_number, ml_number, ms_number, j_number, mj_number + ) -> None: + """Adds l and ml quantum numbers to the `OrbitalsState` section.""" + orbital_state.l_quantum_number = l_number + orbital_state.ml_quantum_number = ml_number + orbital_state.ms_quantum_number = ms_number + orbital_state.j_quantum_number = j_number + orbital_state.mj_quantum_number = mj_number + @pytest.fixture(autouse=True) def orbital_state(self) -> OrbitalsState: return OrbitalsState(n_quantum_number=2) @@ -71,13 +89,16 @@ def test_number_and_symbol( """ Test the number and symbol resolution for each of the quantum numbers defined in the parametrization. """ - if quantum_name == 'ml': # l_quantum_number must be specified - orbital_state.l_quantum_number = 1 - setattr(orbital_state, f'{quantum_name}_quantum_{quantum_type}', value) + # Adding quantum numbers to the `OrbitalsState` section + self.add_quantum_numbers(orbital_state, quantum_name, quantum_type, value) + + # Making sure that the `quantum_type` is assigned resolved_type = orbital_state.resolve_number_and_symbol( quantum_name, quantum_type, self.logger ) assert resolved_type == value + + # Resolving if the counter-type is assigned resolved_countertype = orbital_state.resolve_number_and_symbol( quantum_name, countertype, self.logger ) @@ -108,11 +129,14 @@ def test_degeneracy( """ Test the degeneracy of each of orbital states defined in the parametrization. """ - orbital_state.l_quantum_number = l_quantum_number - orbital_state.ml_quantum_number = ml_quantum_number - orbital_state.j_quantum_number = j_quantum_number - orbital_state.mj_quantum_number = mj_quantum_number - orbital_state.ms_quantum_number = ms_quantum_number + self.add_state( + orbital_state, + l_quantum_number, + ml_quantum_number, + ms_quantum_number, + j_quantum_number, + mj_quantum_number, + ) resolved_degeneracy = orbital_state.resolve_degeneracy() assert resolved_degeneracy == degeneracy @@ -120,9 +144,14 @@ def test_normalize(self, orbital_state): """ Test the normalization of the `OrbitalsState`. Inputs are defined as the quantities of the `OrbitalsState` section. """ + self.add_state(orbital_state, 2, -2, None, None, None) orbital_state.normalize(None, self.logger) - # assert orbital_state.degeneracy == 6 - # assert orbital_state.occupation == 0.0 + assert orbital_state.n_quantum_number == 2 + assert orbital_state.l_quantum_number == 2 + assert orbital_state.l_quantum_symbol == 'd' + assert orbital_state.ml_quantum_number == -2 + assert orbital_state.ml_quantum_symbol == 'xy' + assert orbital_state.degeneracy == 2 class TestCoreHole: @@ -202,22 +231,18 @@ class TestHubbardInteractions: logger = logging.getLogger(__name__) @staticmethod - def add_slater_interactions( - hubbard_interactions, slater_integrals - ) -> HubbardInteractions: + def add_slater_interactions(hubbard_interactions, slater_integrals) -> None: """Adds `slater_integrals` (in eV) to the `HubbardInteractions` section.""" if slater_integrals is not None: hubbard_interactions.slater_integrals = slater_integrals * ureg('eV') - return hubbard_interactions @staticmethod - def add_u_j(hubbard_interactions, u, j) -> HubbardInteractions: + def add_u_j(hubbard_interactions, u, j) -> None: """Adds `u_interaction` and `j_local_exchange_interaction` (in eV) to the `HubbardInteractions` section.""" if u is not None: hubbard_interactions.u_interaction = u * ureg('eV') if j is not None: hubbard_interactions.j_local_exchange_interaction = j * ureg('eV') - return hubbard_interactions @pytest.fixture(autouse=True) def hubbard_interactions(self) -> HubbardInteractions: