Skip to content

Commit

Permalink
first attempts at getting mixed restricted spaces to work
Browse files Browse the repository at this point in the history
  • Loading branch information
emmarothwell1 committed Feb 20, 2024
1 parent 438628f commit ed5a318
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 70 deletions.
5 changes: 3 additions & 2 deletions firedrake/functionspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def rec(eles):
spaces = tuple(s.topological for s in flatten(spaces))
# Error checking
for space in spaces:
if type(space) in (impl.FunctionSpace, impl.RealFunctionSpace):
if type(space) in (impl.FunctionSpace, impl.RealFunctionSpace, impl.RestrictedFunctionSpace):
continue
elif type(space) is impl.ProxyFunctionSpace:
elif type(space) in (impl.ProxyFunctionSpace, impl.ProxyRestrictedFunctionSpace):
if space.component is not None:
raise ValueError("Can't make mixed space with %s" % space)
continue
Expand All @@ -296,6 +296,7 @@ def rec(eles):
new = cls.create(new, mesh)
return new


@PETSc.Log.EventDecorator("CreateFunctionSpace")
def RestrictedFunctionSpace(function_space, name=None, boundary_set=[]):
"""Create a :class:`.RestrictedFunctionSpace`.
Expand Down
187 changes: 119 additions & 68 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,74 @@ def collapse(self):
return FunctionSpace(self.mesh(), self.ufl_element())


class RestrictedFunctionSpace(FunctionSpace):
def __init__(self, function_space, name=None, boundary_set=frozenset()):
label = ""
for boundary_domain in boundary_set:
label += str(boundary_domain)
self.boundary_set = frozenset(boundary_set)
super().__init__(function_space._mesh.topology,
function_space.ufl_element(), function_space.name)
self._label = label
self._ufl_function_space = ufl.FunctionSpace(function_space._mesh.ufl_mesh(),
function_space.ufl_element(),
label=self._label)
self.function_space = function_space
self.name = name or (function_space.name or "Restricted" + "_"
+ "_".join(sorted(
[str(i) for i in self.boundary_set])))

def set_shared_data(self):
sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set)
self._shared_data = sdata
self.node_set = sdata.node_set
r"""A :class:`pyop2.types.set.Set` representing the function space nodes."""
self.dof_dset = op2.DataSet(self.node_set, self.shape or 1,
name="%s_nodes_dset" % self.name)
r"""A :class:`pyop2.types.dataset.DataSet` representing the function space
degrees of freedom."""
self.finat_element = create_element(self.ufl_element())
# Used for reconstruction of mixed/component spaces.
# sdata carries real_tensorproduct.
self.real_tensorproduct = sdata.real_tensorproduct
self.extruded = sdata.extruded
self.offset = sdata.offset
self.offset_quotient = sdata.offset_quotient
self.cell_boundary_masks = sdata.cell_boundary_masks
self.interior_facet_boundary_masks = sdata.interior_facet_boundary_masks
self.global_numbering = sdata.global_numbering

def __eq__(self, other):
if not isinstance(other, RestrictedFunctionSpace):
return False
return self.function_space == other.function_space and \
self.boundary_set == other.boundary_set

def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
return hash((self.function_space.mesh(), self.function_space.dof_dset,
self.function_space.ufl_element(), self.boundary_set))

def __repr__(self):
return self.__class__.__name__ + "(%r, name=%r, boundary_set=%r)" % (
str(self.function_space), self.name, self.boundary_set)

def __str__(self):
return self.__repr__()

@utils.cached_property
def dof_count(self):
node_count = self.node_count
for sub_domain in self.boundary_set:
node_count -= len(self._shared_data.boundary_nodes(self, sub_domain))
return node_count*self.value_size

def local_to_global_map(self, bcs, lgmap=None):
return lgmap or self.dof_dset.lgmap


class MixedFunctionSpace(object):
r"""A function space on a mixed finite element.
Expand All @@ -822,6 +890,7 @@ def __init__(self, spaces, name=None):
self._ufl_function_space = ufl.FunctionSpace(mesh.ufl_mesh(),
finat.ufl.MixedElement(*[s.ufl_element() for s in spaces]))
self.name = name or "_".join(str(s.name) for s in spaces)
self._label = "_".join(str(s._label) for s in spaces)
self._subspaces = {}
self._mesh = mesh
self.comm = mesh.comm
Expand Down Expand Up @@ -1055,6 +1124,54 @@ def make_dat(self, *args, **kwargs):
return super(ProxyFunctionSpace, self).make_dat(*args, **kwargs)


class ProxyRestrictedFunctionSpace(RestrictedFunctionSpace):
r"""A :class:`RestrictedFunctionSpace` that one can attach extra properties to.
:arg function_space: The function space to be restricted.
:arg name: The name of the restricted function space.
:arg boundary_set: The boundary domains on which boundary conditions will be specified
.. warning::
Users should not build a :class:`ProxyRestrictedFunctionSpace` directly,
it is mostly used as an internal implementation detail.
"""
def __new__(cls, function_space, name=None, boundary_set=frozenset()):
topology = function_space._mesh.topology
self = super(ProxyRestrictedFunctionSpace, cls).__new__(cls)
if function_space._mesh is not topology:
return WithGeometry.create(self, function_space._mesh)
else:
return self

def __repr__(self):
return "%sProxyRestrictedFunctionSpace(%r, name=%r, boundary_set=%r, index=%r, component=%r)" % \
(str(self.identifier).capitalize(),
str(self.function_space),
self.name,
self.boundary_set,
self.index,
self.component)

def __str__(self):
return self.__repr__()

identifier = None
r"""An optional identifier, for debugging purposes."""

no_dats = False
r"""Can this proxy make :class:`pyop2.types.dat.Dat` objects"""

def make_dat(self, *args, **kwargs):
r"""Create a :class:`pyop2.types.dat.Dat`.
:raises ValueError: if :attr:`no_dats` is ``True``.
"""
if self.no_dats:
raise ValueError("Can't build Function on %s function space" % self.identifier)
return super(ProxyRestrictedFunctionSpace, self).make_dat(*args, **kwargs)


def IndexedFunctionSpace(index, space, parent):
r"""Build a new FunctionSpace that remembers it is a particular
subspace of a :class:`MixedFunctionSpace`.
Expand All @@ -1067,6 +1184,8 @@ def IndexedFunctionSpace(index, space, parent):
"""
if space.ufl_element().family() == "Real":
new = RealFunctionSpace(space.mesh(), space.ufl_element(), name=space.name)
elif len(space.boundary_set) > 0:
new = ProxyRestrictedFunctionSpace(space.function_space, name=space.name, boundary_set=space.boundary_set)
else:
new = ProxyFunctionSpace(space.mesh(), space.ufl_element(), name=space.name)
new.index = index
Expand Down Expand Up @@ -1160,74 +1279,6 @@ def local_to_global_map(self, bcs, lgmap=None):
return None


class RestrictedFunctionSpace(FunctionSpace):
def __init__(self, function_space, name=None, boundary_set=frozenset()):
label = ""
for boundary_domain in boundary_set:
label += str(boundary_domain)
self.boundary_set = frozenset(boundary_set)
super().__init__(function_space._mesh.topology,
function_space.ufl_element(), function_space.name)
self._label = label
self._ufl_function_space = ufl.FunctionSpace(function_space._mesh.ufl_mesh(),
function_space.ufl_element(),
label=self._label)
self.function_space = function_space
self.name = name or (function_space.name or "Restricted" + "_"
+ "_".join(sorted(
[str(i) for i in self.boundary_set])))

def set_shared_data(self):
sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set)
self._shared_data = sdata
self.node_set = sdata.node_set
r"""A :class:`pyop2.types.set.Set` representing the function space nodes."""
self.dof_dset = op2.DataSet(self.node_set, self.shape or 1,
name="%s_nodes_dset" % self.name)
r"""A :class:`pyop2.types.dataset.DataSet` representing the function space
degrees of freedom."""
self.finat_element = create_element(self.ufl_element())
# Used for reconstruction of mixed/component spaces.
# sdata carries real_tensorproduct.
self.real_tensorproduct = sdata.real_tensorproduct
self.extruded = sdata.extruded
self.offset = sdata.offset
self.offset_quotient = sdata.offset_quotient
self.cell_boundary_masks = sdata.cell_boundary_masks
self.interior_facet_boundary_masks = sdata.interior_facet_boundary_masks
self.global_numbering = sdata.global_numbering

def __eq__(self, other):
if not isinstance(other, RestrictedFunctionSpace):
return False
return self.function_space == other.function_space and \
self.boundary_set == other.boundary_set

def __ne__(self, other):
return not self.__eq__(other)

def __hash__(self):
return hash((self.function_space.mesh(), self.function_space.dof_dset,
self.function_space.ufl_element(), self.boundary_set))

def __repr__(self):
return self.__class__.__name__ + "(%r, name=%r, boundary_set=%r)" % (
str(self.function_space), self.name, self.boundary_set)

def __str__(self):
return self.__repr__()

@utils.cached_property
def dof_count(self):
node_count = self.node_count
for sub_domain in self.boundary_set:
node_count -= len(self._shared_data.boundary_nodes(self, sub_domain))
return node_count*self.value_size

def local_to_global_map(self, bcs, lgmap=None):
return lgmap or self.dof_dset.lgmap


@dataclass
class FunctionSpaceCargo:
"""Helper class carrying data for a :class:`WithGeometryBase`.
Expand Down

0 comments on commit ed5a318

Please sign in to comment.