Skip to content

Commit

Permalink
hex: enable interior facet integration
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Nov 5, 2024
1 parent f401d80 commit 89b86b4
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 0 deletions.
20 changes: 20 additions & 0 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,16 @@ def _as_global_kernel_arg_interior_facet(_, self):
return op2.DatKernelArg((2,))


@_as_global_kernel_arg.register(kernel_args.ExteriorFacetOrientationKernelArg)
def _as_global_kernel_arg_exterior_facet_orientation(_, self):
return op2.DatKernelArg((1,))


@_as_global_kernel_arg.register(kernel_args.InteriorFacetOrientationKernelArg)
def _as_global_kernel_arg_interior_facet_orientation(_, self):
return op2.DatKernelArg((2,))


@_as_global_kernel_arg.register(CellFacetKernelArg)
def _as_global_kernel_arg_cell_facet(_, self):
if self._mesh.extruded:
Expand Down Expand Up @@ -2053,6 +2063,16 @@ def _as_parloop_arg_interior_facet(_, self):
return op2.DatParloopArg(self._mesh.interior_facets.local_facet_dat)


@_as_parloop_arg.register(kernel_args.ExteriorFacetOrientationKernelArg)
def _as_parloop_arg_exterior_facet_orientation(_, self):
return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_orientation_dat)


@_as_parloop_arg.register(kernel_args.InteriorFacetOrientationKernelArg)
def _as_parloop_arg_interior_facet_orientation(_, self):
return op2.DatParloopArg(self._mesh.interior_facets.local_facet_orientation_dat)


@_as_parloop_arg.register(CellFacetKernelArg)
def _as_parloop_arg_cell_facet(_, self):
return op2.DatParloopArg(self._mesh.cell_to_facets)
Expand Down
40 changes: 40 additions & 0 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from firedrake.adjoint_utils import MeshGeometryMixin
from pyadjoint import stop_annotating
import gem

try:
import netgen
Expand All @@ -44,6 +45,7 @@
ngsPETSc = None
# Only for docstring
import mpi4py # noqa: F401
from tsfc.finatinterface import as_fiat_cell


__all__ = [
Expand Down Expand Up @@ -317,6 +319,44 @@ def facet_cell_map(self):
return op2.Map(self.set, self.mesh.cell_set, self._rank, self.facet_cell,
"facet_to_cell_map")

@utils.cached_property
def local_facet_orientation_dat(self):
"""Dat for the local facet orientations."""
dtype = gem.uint_type
# Make a map from cell to facet orientations.
fiat_cell = as_fiat_cell(self.mesh.ufl_cell())
topo = fiat_cell.topology
num_entities = [0]
for d in range(len(topo)):
num_entities.append(len(topo[d]))
offsets = np.cumsum(num_entities)
local_facet_start = offsets[-3]
local_facet_end = offsets[-2]
map_from_cell_to_facet_orientations = self.mesh.entity_orientations[:, local_facet_start:local_facet_end]
# Make output data;
# this is a map from an exterior/interior facet to the corresponding local facet orientation/orientations.
# Halo data are required by design, but not actually used.
# -- Reshape as (-1, self._rank) to uniformly handle exterior and interior facets.
data = np.empty_like(self.local_facet_dat.data_ro_with_halos).reshape((-1, self._rank))
data.fill(np.iinfo(dtype).max)
# Set local facet orientations on the block corresponding to the owned facets; i.e., data[:shape[0], :] below.
local_facets = self.local_facet_dat.data_ro # do not need halos.
# -- Reshape as (-1, self._rank) to uniformly handle exterior and interior facets.
local_facets = local_facets.reshape((-1, self._rank))
shape = local_facets.shape
map_from_owned_facet_to_cells = self.facet_cell[:shape[0], :]
data[:shape[0], :] = np.take_along_axis(
map_from_cell_to_facet_orientations[map_from_owned_facet_to_cells],
local_facets.reshape(shape + (1, )), # reshape as required by take_along_axis.
axis=2,
).reshape(shape)
return op2.Dat(
self.local_facet_dat.dataset,
data,
dtype,
f"{self.mesh.name}_{self.kind}_local_facet_orientation"
)


@PETSc.Log.EventDecorator()
def _from_gmsh(filename, comm=None):
Expand Down
3 changes: 3 additions & 0 deletions firedrake/mg/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def compile_element(expression, dual_space=None, parameters=None,

config = dict(interface=builder,
ufl_cell=cell,
integral_type="cell",
point_indices=(),
point_expr=point,
argument_multiindices=argument_multiindices,
Expand Down Expand Up @@ -537,6 +538,7 @@ def dg_injection_kernel(Vf, Vc, ncell):
integration_dim, entity_ids = lower_integral_type(Vfe.cell, "cell")
macro_cfg = dict(interface=macro_builder,
ufl_cell=Vf.ufl_cell(),
integral_type="cell",
integration_dim=integration_dim,
entity_ids=entity_ids,
index_cache=index_cache,
Expand Down Expand Up @@ -573,6 +575,7 @@ def dg_injection_kernel(Vf, Vc, ncell):

coarse_cfg = dict(interface=coarse_builder,
ufl_cell=Vc.ufl_cell(),
integral_type="cell",
integration_dim=integration_dim,
entity_ids=entity_ids,
index_cache=index_cache,
Expand Down
1 change: 1 addition & 0 deletions firedrake/pointeval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def compile_element(expression, coordinates, parameters=None):

config = dict(interface=builder,
ufl_cell=extract_unique_domain(coordinates).ufl_cell(),
integral_type="cell",
point_indices=(),
point_expr=point,
scalar_type=utils.ScalarType)
Expand Down
1 change: 1 addition & 0 deletions firedrake/pointquery_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def to_reference_coords_newton_step(ufl_coordinate_element, parameters, x0_dtype
context = tsfc.fem.GemPointContext(
interface=builder,
ufl_cell=cell,
integral_type="cell",
point_indices=(),
point_expr=point,
scalar_type=parameters["scalar_type"]
Expand Down
35 changes: 35 additions & 0 deletions tests/regression/test_integral_hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,38 @@ def test_integral_hex_exterior_facet(mesh_from_file, family):
x, y, z = SpatialCoordinate(mesh)
f = Function(V).interpolate(2 * x + 3 * y * y + 4 * z * z * z)
assert abs(assemble(f * ds) - (2 + 4 + 2 + 5 + 2 + 6)) < 1.e-10


@pytest.mark.parallel(nprocs=2)
@pytest.mark.parametrize('mesh_from_file', [False, True])
@pytest.mark.parametrize('family', ["Q", "DQ"])
def test_integral_hex_interior_facet(mesh_from_file, family):
if mesh_from_file:
mesh = Mesh(join(cwd, "..", "meshes", "cube_hex.msh"))
else:
mesh = UnitCubeMesh(2, 3, 5, hexahedral=True)
V = FunctionSpace(mesh, family, 3)
x, y, z = SpatialCoordinate(mesh)
f = Function(V).interpolate(2 * x + 3 * y * y + 4 * z * z * z)
assert assemble((f('+') - f('-'))**2 * dS)**0.5 < 1.e-14


@pytest.mark.parallel(nprocs=2)
@pytest.mark.parametrize('mesh_from_file', [False, True])
def test_integral_hex_interior_facet_solve(mesh_from_file):
if mesh_from_file:
mesh = Mesh(join(cwd, "..", "meshes", "cube_hex.msh"))
else:
mesh = UnitCubeMesh(2, 3, 5, hexahedral=True)
V = FunctionSpace(mesh, "Q", 1)
x, y, z = SpatialCoordinate(mesh)
f = Function(V).interpolate(2 * x + 3 * y + 5 * z)
u = TrialFunction(V)
v = TestFunction(V)
a = inner(u('+'), v('+')) * dS(degree=3)
L = inner(f('+'), v('-')) * dS(degree=3)
bc = DirichletBC(V, f, "on_boundary")
sol = Function(V)
solve(a == L, sol, bcs=[bc])
err = assemble((sol - f)**2 * dx)**0.5
assert err < 1.e-14

0 comments on commit 89b86b4

Please sign in to comment.