Skip to content

Commit

Permalink
Add 'validate' parameter to Mesh (#789)
Browse files Browse the repository at this point in the history
* Add 'validate' parameter to Mesh
useful to suppress validation even when loglevel<=DEBUG

* skip validation on temp mesh in from_meshio

* added a completion debug message to is_valid
  • Loading branch information
gatling-nrl authored Nov 16, 2021
1 parent c828ab6 commit ddae31d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
2 changes: 1 addition & 1 deletion skfem/io/meshio.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def from_meshio(m,
if meshio_type in v}

# create temporary mesh for matching boundary elements
mtmp = mesh_type(p, t)
mtmp = mesh_type(p, t, validate=False)
bnd_type = BOUNDARY_TYPE_MAPPING[meshio_type]

# parse boundaries from cell_sets
Expand Down
67 changes: 43 additions & 24 deletions skfem/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Mesh:
# However, some algorithms (e.g., adaptive refinement) require switching
# off this behaviour and, hence, this flag exists.
sort_t: bool = False
validate: bool = True # run validation check if log_level<=DEBUG

@property
def p(self):
Expand Down Expand Up @@ -186,7 +187,7 @@ def with_subdomains(self,
**({} if self._subdomains is None else self._subdomains),
**{name: self.elements_satisfying(test)
for name, test in subdomains.items()},
}
},
)

def _encode_point_data(self) -> Dict[str, List[ndarray]]:
Expand Down Expand Up @@ -418,7 +419,43 @@ def __post_init__(self):
self.t = np.ascontiguousarray(self.t)

# run validation
self.is_valid(debug=False)
if self.validate and logger.getEffectiveLevel() <= logging.DEBUG:
self.is_valid()

def is_valid(self, raise_=False) -> bool:
"""Perform expensive mesh validation.
Parameters
----------
raise_: raise an exception if the mesh is invalid.
Returns
-------
bool: True if the mesh is valid.
"""
logger.debug("Running mesh validation.")

# check that there are no duplicate points
tmp = np.ascontiguousarray(self.p.T)
p_unique = np.unique(tmp.view([('', tmp.dtype)] * tmp.shape[1]))
if self.p.shape[1] != p_unique.shape[0]:
msg = "Mesh contains duplicate vertices."
if raise_:
raise ValueError(msg)
logger.debug(msg)
return False

# check that all points are at least in some element
if len(np.setdiff1d(np.arange(self.p.shape[1]),
np.unique(self.t))) > 0:
msg = "Mesh contains a vertex not belonging to any element."
if raise_:
raise ValueError(msg)
logger.debug(msg)
return False

logger.debug("Mesh validation completed with no warnings.")
return True

def __rmatmul__(self, other):
out = self.__matmul__(other)
Expand All @@ -442,26 +479,6 @@ def __matmul__(self, other):
]
raise NotImplementedError

def is_valid(self, debug=True) -> bool:
"""Perform expensive mesh validation (if logging set to DEBUG)."""

if debug or logger.getEffectiveLevel() <= 10:
# check that there are no duplicate points
tmp = np.ascontiguousarray(self.p.T)
if self.p.shape[1] != np.unique(tmp.view([('', tmp.dtype)]
* tmp.shape[1])).shape[0]:
logger.debug("Mesh contains duplicate vertices.")
return False

# check that all points are at least in some element
if len(np.setdiff1d(np.arange(self.p.shape[1]),
np.unique(self.t))) > 0:
logger.debug("Mesh contains a vertex not belonging "
"to any element.")
return False

return True

def __add__(self, other):
"""Join two meshes."""
cls = type(self)
Expand Down Expand Up @@ -593,7 +610,7 @@ def from_mesh(cls, mesh, t: Optional[ndarray] = None):
@classmethod
def init_refdom(cls):
"""Initialize a mesh corresponding to the reference domain."""
return cls(cls.elem.refdom.p, cls.elem.refdom.t)
return cls(cls.elem.refdom.p, cls.elem.refdom.t, validate=False)

def refined(self, times_or_ix: Union[int, ndarray] = 1):
"""Return a refined mesh.
Expand Down Expand Up @@ -760,7 +777,9 @@ def _splitref(self, nrefs: int = 1):
for itr in range(len(x) - 1):
p = np.vstack((p, x[itr + 1].flatten()))

return cls(p, t)
# Always creates a mesh with duplicate verticies. Use validate=False to
# suppress confusing logger DEBUG messages.
return cls(p, t, validate=False)

@staticmethod
def build_entities(t, indices, sort=True):
Expand Down

0 comments on commit ddae31d

Please sign in to comment.