Skip to content

Commit

Permalink
Emmarothwell1/restricted _function_space: Implement a RestrictedFunct…
Browse files Browse the repository at this point in the history
…ionSpace class. (#3215)

* If `restrict=True`, solver replaces test and trial functions with those on the RestrictedFunctionSpace.

* If `restrict=True`, eigenproblem solver takes the constrained DoFs out of the system (default).

Co-authored-by: David A. Ham <[email protected]>
Co-authored-by: ksagiyam <[email protected]>
  • Loading branch information
3 people authored Apr 26, 2024
1 parent 486fa41 commit 02bdb5f
Show file tree
Hide file tree
Showing 19 changed files with 745 additions and 147 deletions.
9 changes: 6 additions & 3 deletions demos/eigenvalues_QG_basinmodes/qgbasinmodes.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Oceanic Basin Modes: Quasi-Geostrophic approach
.. rst-class:: emphasis

This tutorial was contributed by Christine Kaufhold and `Francis
Poulin <mailto:[email protected]>`__.
Poulin <mailto:[email protected]>`__. The tutorial was later updated by
Emma Rothwell to add in the restrict flag in the LinearEigenproblem.

As a continuation of the Quasi-Geostrophic (QG) model described in the other
tutorial, we will now see how we can use Firedrake to compute the spatial
Expand Down Expand Up @@ -138,12 +139,14 @@ We define the Test Function :math:`\phi` and the Trial Function

To build the weak formulation of our equation we need to build two PETSc
matrices in the form of a generalized eigenvalue problem,
:math:`A\psi = \lambda M\psi` ::
:math:`A\psi = \lambda M\psi`. This eigenproblem takes `restrict=True` to help
users to avoid convergence failures by removing eigenvalues on the
boundary, while preserving the original function space for the eigenmodes. ::

eigenproblem = LinearEigenproblem(
A=beta*phi*psi.dx(0)*dx,
M=-inner(grad(psi), grad(phi))*dx - F*psi*phi*dx,
bcs=bc)
bcs=bc, restrict=True)

Next we program our eigenvalue solver through the PETSc options system. The
first is specifying that we have an generalized eigenvalue problem that is
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@
\sphinxDUC{22C5}{$\cdot$}
\sphinxDUC{25A3}{$\boxdot$}
\sphinxDUC{03BB}{$\lambda$}
\sphinxDUC{0393}{$\Gamma$}
% Sphinx equivalent of
% \DeclareUnicodeCharacter{}{}
Expand Down
8 changes: 4 additions & 4 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u
self._ad_u = self.u_restrict
self._ad_bcs = self.bcs
self._ad_J = self.J
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u,
TrialFunction(self.u.function_space()))
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
except (TypeError, NotImplementedError):
self._ad_adj_F = None
Expand Down Expand Up @@ -130,7 +130,7 @@ def _ad_problem_clone(self, problem, dependencies):
_ad_count_map[J_replace_map[coeff]] = coeff.count()

nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map),
F_replace_map[problem.u],
F_replace_map[problem.u_restrict],
bcs=problem.bcs,
J=replace(problem.J, J_replace_map))
nlvp._ad_count_map_update(_ad_count_map)
Expand Down
38 changes: 21 additions & 17 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,15 +1631,15 @@ def _as_global_kernel_arg(self, tsfc_arg):
# TODO Make singledispatchmethod with Python 3.8
return _as_global_kernel_arg(tsfc_arg, self)

def _get_map_arg(self, finat_element):
def _get_map_arg(self, finat_element, boundary_set):
"""Get the appropriate map argument for the given FInAT element.
:arg finat_element: A FInAT element.
:returns: A :class:`op2.MapKernelArg` instance corresponding to
the given FInAT element. This function uses a cache to ensure
that PyOP2 knows when it can reuse maps.
"""
key = self._get_map_id(finat_element)
key = self._get_map_id(finat_element), boundary_set

try:
return self._map_arg_cache[key]
Expand Down Expand Up @@ -1679,28 +1679,28 @@ def _get_dim(self, finat_element):
else:
return (1,)

def _make_dat_global_kernel_arg(self, finat_element, index=None):
def _make_dat_global_kernel_arg(self, finat_element, boundary_set, index=None):
if isinstance(finat_element, finat.EnrichedElement) and finat_element.is_mixed:
assert index is None
subargs = tuple(self._make_dat_global_kernel_arg(subelem.element)
subargs = tuple(self._make_dat_global_kernel_arg(subelem.element, boundary_set)
for subelem in finat_element.elements)
return op2.MixedDatKernelArg(subargs)
else:
dim = self._get_dim(finat_element)
map_arg = self._get_map_arg(finat_element)
map_arg = self._get_map_arg(finat_element, boundary_set)
return op2.DatKernelArg(dim, map_arg, index)

def _make_mat_global_kernel_arg(self, relem, celem):
def _make_mat_global_kernel_arg(self, relem, celem, rbset, cbset):
if any(isinstance(e, finat.EnrichedElement) and e.is_mixed for e in {relem, celem}):
subargs = tuple(self._make_mat_global_kernel_arg(rel.element, cel.element)
subargs = tuple(self._make_mat_global_kernel_arg(rel.element, cel.element, rbset, cbset)
for rel, cel in product(relem.elements, celem.elements))
shape = len(relem.elements), len(celem.elements)
return op2.MixedMatKernelArg(subargs, shape)
else:
# PyOP2 matrix objects have scalar dims so we flatten them here
rdim = numpy.prod(self._get_dim(relem), dtype=int)
cdim = numpy.prod(self._get_dim(celem), dtype=int)
map_args = self._get_map_arg(relem), self._get_map_arg(celem)
map_args = self._get_map_arg(relem, rbset), self._get_map_arg(celem, cbset)
return op2.MatKernelArg((((rdim, cdim),),), map_args, unroll=self._unroll)

@staticmethod
Expand Down Expand Up @@ -1737,25 +1737,29 @@ def _as_global_kernel_arg_output(_, self):
if V.ufl_element().family() == "Real":
return op2.GlobalKernelArg((1,))
else:
return self._make_dat_global_kernel_arg(create_element(V.ufl_element()))
return self._make_dat_global_kernel_arg(create_element(V.ufl_element()), V.boundary_set)
elif rank == 2:
if all(V.ufl_element().family() == "Real" for V in Vs):
return op2.GlobalKernelArg((1,))
elif any(V.ufl_element().family() == "Real" for V in Vs):
el, = (create_element(V.ufl_element()) for V in Vs
if V.ufl_element().family() != "Real")
return self._make_dat_global_kernel_arg(el)
for V in Vs:
if V.ufl_element().family() != "Real":
el = create_element(V.ufl_element())
boundary_set = V.boundary_set
break
return self._make_dat_global_kernel_arg(el, boundary_set)
else:
rel, cel = (create_element(V.ufl_element()) for V in Vs)
return self._make_mat_global_kernel_arg(rel, cel)
rbset, cbset = (V.boundary_set for V in Vs)
return self._make_mat_global_kernel_arg(rel, cel, rbset, cbset)
else:
raise AssertionError


@_as_global_kernel_arg.register(kernel_args.CoordinatesKernelArg)
def _as_global_kernel_arg_coordinates(_, self):
finat_element = create_element(self._mesh.ufl_coordinate_element())
return self._make_dat_global_kernel_arg(finat_element)
return self._make_dat_global_kernel_arg(finat_element, self._mesh.coordinates.function_space().boundary_set)


@_as_global_kernel_arg.register(kernel_args.CoefficientKernelArg)
Expand All @@ -1773,7 +1777,7 @@ def _as_global_kernel_arg_coefficient(_, self):
return op2.GlobalKernelArg((ufl_element.value_size,))
else:
finat_element = create_element(ufl_element)
return self._make_dat_global_kernel_arg(finat_element, index)
return self._make_dat_global_kernel_arg(finat_element, V.boundary_set, index)


@_as_global_kernel_arg.register(kernel_args.ConstantKernelArg)
Expand All @@ -1788,7 +1792,7 @@ def _as_global_kernel_arg_cell_sizes(_, self):
# this mirrors tsfc.kernel_interface.firedrake_loopy.KernelBuilder.set_cell_sizes
ufl_element = finat.ufl.FiniteElement("P", self._mesh.ufl_cell(), 1)
finat_element = create_element(ufl_element)
return self._make_dat_global_kernel_arg(finat_element)
return self._make_dat_global_kernel_arg(finat_element, self._mesh.coordinates.function_space().boundary_set)


@_as_global_kernel_arg.register(kernel_args.ExteriorFacetKernelArg)
Expand All @@ -1815,7 +1819,7 @@ def _as_global_kernel_arg_cell_orientations(_, self):
# this mirrors firedrake.mesh.MeshGeometry.init_cell_orientations
ufl_element = finat.ufl.FiniteElement("DG", cell=self._mesh.ufl_cell(), degree=0)
finat_element = create_element(ufl_element)
return self._make_dat_global_kernel_arg(finat_element)
return self._make_dat_global_kernel_arg(finat_element, self._mesh.coordinates.function_space().boundary_set)


@_as_global_kernel_arg.register(LayerCountKernelArg)
Expand Down
2 changes: 2 additions & 0 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def __init__(self, V, g, sub_domain, method=None):
warnings.simplefilter('always', DeprecationWarning)
warnings.warn("Selecting a bcs method is deprecated. Only topological association is supported",
DeprecationWarning)
if len(V.boundary_set) and sub_domain not in V.boundary_set:
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
super().__init__(V, sub_domain)
if len(V) > 1:
raise ValueError("Cannot apply boundary conditions on mixed spaces directly.\n"
Expand Down
2 changes: 1 addition & 1 deletion firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,7 +1437,7 @@ def _get_dm_for_checkpointing(self, tV):
sd_key = self._get_shared_data_key_for_checkpointing(tV.mesh(), tV.ufl_element())
if isinstance(tV.ufl_element(), (finat.ufl.VectorElement, finat.ufl.TensorElement)):
nodes_per_entity, real_tensorproduct, block_size = sd_key
global_numbering = tV.mesh().create_section(nodes_per_entity, real_tensorproduct, block_size=block_size)
global_numbering, _ = tV.mesh().create_section(nodes_per_entity, real_tensorproduct, block_size=block_size)
topology_dm = tV.mesh().topology_dm
dm = PETSc.DMShell().create(tV.mesh()._comm)
dm.setPointSF(topology_dm.getPointSF())
Expand Down
Loading

0 comments on commit 02bdb5f

Please sign in to comment.