Skip to content

Commit

Permalink
Fix OrbitalsState testing
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Mar 4, 2024
1 parent 1dfbdfb commit ffa47bc
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
5 changes: 2 additions & 3 deletions simulationdataschema/atoms_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,8 @@ def normalize(self, archive, logger) -> None:
setattr(self, f'{quantum_name}_quantum_{quantum_type}', quantity)

# Resolve the degeneracy
self.degeneracy = (
self.resolve_degeneracy() if self.degeneracy is None else self.degeneracy
)
if self.degeneracy is None:
self.degeneracy = self.resolve_degeneracy()


class CoreHole(ArchiveSection):
Expand Down
57 changes: 41 additions & 16 deletions tests/test_atoms_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ class TestOrbitalsState:

logger = logging.getLogger(__name__)

@staticmethod
def add_quantum_numbers(orbital_state, quantum_name, quantum_type, value) -> None:
"""Adds quantum numbers to the `OrbitalsState` section."""
if quantum_name == 'ml': # l_quantum_number must be specified
orbital_state.l_quantum_number = 1
setattr(orbital_state, f'{quantum_name}_quantum_{quantum_type}', value)

@staticmethod
def add_state(
orbital_state, l_number, ml_number, ms_number, j_number, mj_number
) -> None:
"""Adds l and ml quantum numbers to the `OrbitalsState` section."""
orbital_state.l_quantum_number = l_number
orbital_state.ml_quantum_number = ml_number
orbital_state.ms_quantum_number = ms_number
orbital_state.j_quantum_number = j_number
orbital_state.mj_quantum_number = mj_number

@pytest.fixture(autouse=True)
def orbital_state(self) -> OrbitalsState:
return OrbitalsState(n_quantum_number=2)
Expand Down Expand Up @@ -71,13 +89,16 @@ def test_number_and_symbol(
"""
Test the number and symbol resolution for each of the quantum numbers defined in the parametrization.
"""
if quantum_name == 'ml': # l_quantum_number must be specified
orbital_state.l_quantum_number = 1
setattr(orbital_state, f'{quantum_name}_quantum_{quantum_type}', value)
# Adding quantum numbers to the `OrbitalsState` section
self.add_quantum_numbers(orbital_state, quantum_name, quantum_type, value)

# Making sure that the `quantum_type` is assigned
resolved_type = orbital_state.resolve_number_and_symbol(
quantum_name, quantum_type, self.logger
)
assert resolved_type == value

# Resolving if the counter-type is assigned
resolved_countertype = orbital_state.resolve_number_and_symbol(
quantum_name, countertype, self.logger
)
Expand Down Expand Up @@ -108,21 +129,29 @@ def test_degeneracy(
"""
Test the degeneracy of each of orbital states defined in the parametrization.
"""
orbital_state.l_quantum_number = l_quantum_number
orbital_state.ml_quantum_number = ml_quantum_number
orbital_state.j_quantum_number = j_quantum_number
orbital_state.mj_quantum_number = mj_quantum_number
orbital_state.ms_quantum_number = ms_quantum_number
self.add_state(
orbital_state,
l_quantum_number,
ml_quantum_number,
ms_quantum_number,
j_quantum_number,
mj_quantum_number,
)
resolved_degeneracy = orbital_state.resolve_degeneracy()
assert resolved_degeneracy == degeneracy

def test_normalize(self, orbital_state):
"""
Test the normalization of the `OrbitalsState`. Inputs are defined as the quantities of the `OrbitalsState` section.
"""
self.add_state(orbital_state, 2, -2, None, None, None)
orbital_state.normalize(None, self.logger)
# assert orbital_state.degeneracy == 6
# assert orbital_state.occupation == 0.0
assert orbital_state.n_quantum_number == 2
assert orbital_state.l_quantum_number == 2
assert orbital_state.l_quantum_symbol == 'd'
assert orbital_state.ml_quantum_number == -2
assert orbital_state.ml_quantum_symbol == 'xy'
assert orbital_state.degeneracy == 2


class TestCoreHole:
Expand Down Expand Up @@ -202,22 +231,18 @@ class TestHubbardInteractions:
logger = logging.getLogger(__name__)

@staticmethod
def add_slater_interactions(
hubbard_interactions, slater_integrals
) -> HubbardInteractions:
def add_slater_interactions(hubbard_interactions, slater_integrals) -> None:
"""Adds `slater_integrals` (in eV) to the `HubbardInteractions` section."""
if slater_integrals is not None:
hubbard_interactions.slater_integrals = slater_integrals * ureg('eV')
return hubbard_interactions

@staticmethod
def add_u_j(hubbard_interactions, u, j) -> HubbardInteractions:
def add_u_j(hubbard_interactions, u, j) -> None:
"""Adds `u_interaction` and `j_local_exchange_interaction` (in eV) to the `HubbardInteractions` section."""
if u is not None:
hubbard_interactions.u_interaction = u * ureg('eV')
if j is not None:
hubbard_interactions.j_local_exchange_interaction = j * ureg('eV')
return hubbard_interactions

@pytest.fixture(autouse=True)
def hubbard_interactions(self) -> HubbardInteractions:
Expand Down

0 comments on commit ffa47bc

Please sign in to comment.