diff --git a/firedrake/functionspace.py b/firedrake/functionspace.py index 29546acdf1..ca67c2ad38 100644 --- a/firedrake/functionspace.py +++ b/firedrake/functionspace.py @@ -282,9 +282,9 @@ def rec(eles): spaces = tuple(s.topological for s in flatten(spaces)) # Error checking for space in spaces: - if type(space) in (impl.FunctionSpace, impl.RealFunctionSpace): + if type(space) in (impl.FunctionSpace, impl.RealFunctionSpace, impl.RestrictedFunctionSpace): continue - elif type(space) is impl.ProxyFunctionSpace: + elif type(space) in (impl.ProxyFunctionSpace, impl.ProxyRestrictedFunctionSpace): if space.component is not None: raise ValueError("Can't make mixed space with %s" % space) continue @@ -296,6 +296,7 @@ def rec(eles): new = cls.create(new, mesh) return new + @PETSc.Log.EventDecorator("CreateFunctionSpace") def RestrictedFunctionSpace(function_space, name=None, boundary_set=[]): """Create a :class:`.RestrictedFunctionSpace`. diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index d63f84d9aa..3136302975 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -799,6 +799,74 @@ def collapse(self): return FunctionSpace(self.mesh(), self.ufl_element()) +class RestrictedFunctionSpace(FunctionSpace): + def __init__(self, function_space, name=None, boundary_set=frozenset()): + label = "" + for boundary_domain in boundary_set: + label += str(boundary_domain) + self.boundary_set = frozenset(boundary_set) + super().__init__(function_space._mesh.topology, + function_space.ufl_element(), function_space.name) + self._label = label + self._ufl_function_space = ufl.FunctionSpace(function_space._mesh.ufl_mesh(), + function_space.ufl_element(), + label=self._label) + self.function_space = function_space + self.name = name or (function_space.name or "Restricted" + "_" + + "_".join(sorted( + [str(i) for i in self.boundary_set]))) + + def set_shared_data(self): + sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set) + self._shared_data = sdata + self.node_set = sdata.node_set + r"""A :class:`pyop2.types.set.Set` representing the function space nodes.""" + self.dof_dset = op2.DataSet(self.node_set, self.shape or 1, + name="%s_nodes_dset" % self.name) + r"""A :class:`pyop2.types.dataset.DataSet` representing the function space + degrees of freedom.""" + self.finat_element = create_element(self.ufl_element()) + # Used for reconstruction of mixed/component spaces. + # sdata carries real_tensorproduct. + self.real_tensorproduct = sdata.real_tensorproduct + self.extruded = sdata.extruded + self.offset = sdata.offset + self.offset_quotient = sdata.offset_quotient + self.cell_boundary_masks = sdata.cell_boundary_masks + self.interior_facet_boundary_masks = sdata.interior_facet_boundary_masks + self.global_numbering = sdata.global_numbering + + def __eq__(self, other): + if not isinstance(other, RestrictedFunctionSpace): + return False + return self.function_space == other.function_space and \ + self.boundary_set == other.boundary_set + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.function_space.mesh(), self.function_space.dof_dset, + self.function_space.ufl_element(), self.boundary_set)) + + def __repr__(self): + return self.__class__.__name__ + "(%r, name=%r, boundary_set=%r)" % ( + str(self.function_space), self.name, self.boundary_set) + + def __str__(self): + return self.__repr__() + + @utils.cached_property + def dof_count(self): + node_count = self.node_count + for sub_domain in self.boundary_set: + node_count -= len(self._shared_data.boundary_nodes(self, sub_domain)) + return node_count*self.value_size + + def local_to_global_map(self, bcs, lgmap=None): + return lgmap or self.dof_dset.lgmap + + class MixedFunctionSpace(object): r"""A function space on a mixed finite element. @@ -822,6 +890,7 @@ def __init__(self, spaces, name=None): self._ufl_function_space = ufl.FunctionSpace(mesh.ufl_mesh(), finat.ufl.MixedElement(*[s.ufl_element() for s in spaces])) self.name = name or "_".join(str(s.name) for s in spaces) + self._label = "_".join(str(s._label) for s in spaces) self._subspaces = {} self._mesh = mesh self.comm = mesh.comm @@ -1055,6 +1124,54 @@ def make_dat(self, *args, **kwargs): return super(ProxyFunctionSpace, self).make_dat(*args, **kwargs) +class ProxyRestrictedFunctionSpace(RestrictedFunctionSpace): + r"""A :class:`RestrictedFunctionSpace` that one can attach extra properties to. + + :arg function_space: The function space to be restricted. + :arg name: The name of the restricted function space. + :arg boundary_set: The boundary domains on which boundary conditions will be specified + + .. warning:: + + Users should not build a :class:`ProxyRestrictedFunctionSpace` directly, + it is mostly used as an internal implementation detail. + """ + def __new__(cls, function_space, name=None, boundary_set=frozenset()): + topology = function_space._mesh.topology + self = super(ProxyRestrictedFunctionSpace, cls).__new__(cls) + if function_space._mesh is not topology: + return WithGeometry.create(self, function_space._mesh) + else: + return self + + def __repr__(self): + return "%sProxyRestrictedFunctionSpace(%r, name=%r, boundary_set=%r, index=%r, component=%r)" % \ + (str(self.identifier).capitalize(), + str(self.function_space), + self.name, + self.boundary_set, + self.index, + self.component) + + def __str__(self): + return self.__repr__() + + identifier = None + r"""An optional identifier, for debugging purposes.""" + + no_dats = False + r"""Can this proxy make :class:`pyop2.types.dat.Dat` objects""" + + def make_dat(self, *args, **kwargs): + r"""Create a :class:`pyop2.types.dat.Dat`. + + :raises ValueError: if :attr:`no_dats` is ``True``. + """ + if self.no_dats: + raise ValueError("Can't build Function on %s function space" % self.identifier) + return super(ProxyRestrictedFunctionSpace, self).make_dat(*args, **kwargs) + + def IndexedFunctionSpace(index, space, parent): r"""Build a new FunctionSpace that remembers it is a particular subspace of a :class:`MixedFunctionSpace`. @@ -1067,6 +1184,8 @@ def IndexedFunctionSpace(index, space, parent): """ if space.ufl_element().family() == "Real": new = RealFunctionSpace(space.mesh(), space.ufl_element(), name=space.name) + elif len(space.boundary_set) > 0: + new = ProxyRestrictedFunctionSpace(space.function_space, name=space.name, boundary_set=space.boundary_set) else: new = ProxyFunctionSpace(space.mesh(), space.ufl_element(), name=space.name) new.index = index @@ -1160,74 +1279,6 @@ def local_to_global_map(self, bcs, lgmap=None): return None -class RestrictedFunctionSpace(FunctionSpace): - def __init__(self, function_space, name=None, boundary_set=frozenset()): - label = "" - for boundary_domain in boundary_set: - label += str(boundary_domain) - self.boundary_set = frozenset(boundary_set) - super().__init__(function_space._mesh.topology, - function_space.ufl_element(), function_space.name) - self._label = label - self._ufl_function_space = ufl.FunctionSpace(function_space._mesh.ufl_mesh(), - function_space.ufl_element(), - label=self._label) - self.function_space = function_space - self.name = name or (function_space.name or "Restricted" + "_" - + "_".join(sorted( - [str(i) for i in self.boundary_set]))) - - def set_shared_data(self): - sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set) - self._shared_data = sdata - self.node_set = sdata.node_set - r"""A :class:`pyop2.types.set.Set` representing the function space nodes.""" - self.dof_dset = op2.DataSet(self.node_set, self.shape or 1, - name="%s_nodes_dset" % self.name) - r"""A :class:`pyop2.types.dataset.DataSet` representing the function space - degrees of freedom.""" - self.finat_element = create_element(self.ufl_element()) - # Used for reconstruction of mixed/component spaces. - # sdata carries real_tensorproduct. - self.real_tensorproduct = sdata.real_tensorproduct - self.extruded = sdata.extruded - self.offset = sdata.offset - self.offset_quotient = sdata.offset_quotient - self.cell_boundary_masks = sdata.cell_boundary_masks - self.interior_facet_boundary_masks = sdata.interior_facet_boundary_masks - self.global_numbering = sdata.global_numbering - - def __eq__(self, other): - if not isinstance(other, RestrictedFunctionSpace): - return False - return self.function_space == other.function_space and \ - self.boundary_set == other.boundary_set - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash((self.function_space.mesh(), self.function_space.dof_dset, - self.function_space.ufl_element(), self.boundary_set)) - - def __repr__(self): - return self.__class__.__name__ + "(%r, name=%r, boundary_set=%r)" % ( - str(self.function_space), self.name, self.boundary_set) - - def __str__(self): - return self.__repr__() - - @utils.cached_property - def dof_count(self): - node_count = self.node_count - for sub_domain in self.boundary_set: - node_count -= len(self._shared_data.boundary_nodes(self, sub_domain)) - return node_count*self.value_size - - def local_to_global_map(self, bcs, lgmap=None): - return lgmap or self.dof_dset.lgmap - - @dataclass class FunctionSpaceCargo: """Helper class carrying data for a :class:`WithGeometryBase`.