Skip to content

Commit

Permalink
Added typing for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Feb 19, 2024
1 parent 88c6240 commit 4ef8eca
Showing 1 changed file with 28 additions and 29 deletions.
57 changes: 28 additions & 29 deletions simulationdataschema/model_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#

import numpy as np
import pint
from typing import Optional
from ase.dft.kpoints import monkhorst_pack

from nomad.units import ureg
Expand Down Expand Up @@ -101,13 +103,6 @@ class Mesh(ArchiveSection):
""",
)

n_points = Quantity(
type=np.int32,
description="""
Total number of points in the mesh, accounting for the multiplicities.
""",
)

grid = Quantity(
type=np.int32,
shape=["dimensionality"],
Expand All @@ -128,27 +123,27 @@ class Mesh(ArchiveSection):
type=np.float64,
shape=["*"],
description="""
The amount of times the same point reappears. These are accounted for in `n_points`.
A value larger than 1, typically indicates a symmtery operation that was applied to the mesh.
The amount of times the same point reappears. A value larger than 1, typically indicates
a symmtery operation that was applied to the `Mesh`.
""",
)

# ! is this description correct?
weights = Quantity(
type=np.float64,
shape=["*"],
description="""
The frequency of times the same point reappears.
A value larger than 1, typically indicates a symmtery operation that was applied to the mesh.
The frequency of times the same point reappears. A value larger than 1, typically
indicates a symmtery operation that was applied to the mesh.
""",
)

def normalize(self, archive, logger):
def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

self.dimensionality = 3 if not self.dimensionality else self.dimensionality
if self.grid is None:
return
self.n_points = np.prod(self.grid) if not self.n_points else self.n_points


class LinePathSegment(ArchiveSection):
Expand All @@ -160,7 +155,7 @@ class LinePathSegment(ArchiveSection):
start_point = Quantity(
type=str,
description="""
Name of the hihg-symmetry starting point of the line path segment.
Name of the high-symmetry starting point of the line path segment.
""",
)

Expand All @@ -180,12 +175,15 @@ class LinePathSegment(ArchiveSection):

points = Quantity(
type=np.float64,
shape=["*", 3],
shape=["n_points", 3],
description="""
List of all the points in the line path segment.
""",
)

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)


class KMesh(Mesh):
"""
Expand Down Expand Up @@ -239,20 +237,21 @@ class KMesh(Mesh):

line_path_segments = SubSection(sub_section=LinePathSegment.m_def, repeats=True)

def get_k_line_density(self, reciprocal_lattice_vectors):
def get_k_line_density(
self, reciprocal_lattice_vectors: pint.Quantity
) -> Optional[np.float64]:
"""
Calculates the k-line density of the KMesh. This quantity is used to have an idea
of the precision of the KMesh sampling.
Gets the k-line density of the `KMesh`. This quantity is used as a precision measure
of the `KMesh` sampling.
Args:
reciprocal_lattice_vectors (np.array): Reciprocal lattice vectors of the
atomic cell.
reciprocal_lattice_vectors (np.array): Reciprocal lattice vectors of the atomic cell.
Returns:
(np.float64): The k-line density of the KMesh.
(np.float64): The k-line density of the `KMesh`.
"""
if reciprocal_lattice_vectors is None:
return
return None
if len(reciprocal_lattice_vectors) != 3 or len(self.grid) != 3:
return None

Expand All @@ -264,7 +263,7 @@ def get_k_line_density(self, reciprocal_lattice_vectors):
]
)

def normalize(self, archive, logger):
def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# If `grid` is not defined, we do not normalize the KMesh
Expand All @@ -281,12 +280,13 @@ def normalize(self, archive, logger):
pass # this is a quick workaround: k_mesh.grid should be symmetry reduced

# Calculate k_line_density for precision
model_system = self.m_xpath("m_parent.m_parent.model_system[-1]", dict=False)
if self.k_line_density is None and model_system is not None:
model_systems = self.m_xpath("m_parent.m_parent.model_system", dict=False)
if self.k_line_density is not None:
return
for model_system in model_systems:
if not model_system.is_representative:
logger.warning(
"The last ModelSystem was not found to be representative. We will not "
"extract k_line_density."
"The last ModelSystem was not found to be representative. We will not extract k_line_density."
)
return
if model_system.type != "bulk":
Expand All @@ -297,8 +297,7 @@ def normalize(self, archive, logger):
atomic_cell = model_system.atomic_cell
if atomic_cell is None:
logger.warning(
"Atomic cell was not found in the ModelSystem. We will not extract "
"k_line_density."
"Atomic cell was not found in the ModelSystem. We will not extract k_line_density."
)
return
ase_atoms = atomic_cell[0].to_ase_atoms(logger)
Expand Down

0 comments on commit 4ef8eca

Please sign in to comment.