diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index 11b21566a2..4f18dc0559 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -1244,12 +1244,9 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary section.setChart(pStart, pEnd) if boundary_set and not extruded: - renumbering, (constrainedStart, constrainedEnd) = plex_renumbering(dm, - mesh._entity_classes, reordering=mesh._default_reordering, boundary_set=boundary_set) + renumbering = plex_renumbering(dm, mesh._entity_classes, reordering=mesh._default_reordering, boundary_set=boundary_set) else: renumbering = mesh._dm_renumbering - constrainedStart = -1 - constrainedEnd = -1 CHKERR(PetscSectionSetPermutation(section.sec, renumbering.iset)) for i in range(dimension + 1): @@ -2495,7 +2492,7 @@ def plex_renumbering(PETSc.DM plex, perm_is.setType("general") CHKERR(ISGeneralSetIndices(perm_is.iset, pEnd - pStart, perm, PETSC_OWN_POINTER)) - return perm_is, (lidx[1], lidx[3]) + return perm_is @cython.boundscheck(False) @cython.wraparound(False) @@ -3345,25 +3342,38 @@ def make_global_numbering(PETSc.Section lsec, PETSc.Section gsec): :arg lsec: Section describing local dof layout and numbers. :arg gsec: Section describing global dof layout and numbers.""" cdef: - PetscInt c, p, pStart, pEnd, dof, cdof, loff, goff + PetscInt c, cc, p, pStart, pEnd, dof, cdof, loff, goff, max_dof np.ndarray val + PetscInt *dof_array = NULL val = np.empty(lsec.getStorageSize(), dtype=IntType) pStart, pEnd = lsec.getChart() - + CHKERR(PetscSectionGetMaxDof(lsec.sec, &max_dof)) + #CHKERR(PetscMalloc1(max_dof, &dof_array)) for p in range(pStart, pEnd): CHKERR(PetscSectionGetDof(lsec.sec, p, &dof)) CHKERR(PetscSectionGetConstraintDof(lsec.sec, p, &cdof)) if dof > 0: CHKERR(PetscSectionGetOffset(lsec.sec, p, &loff)) CHKERR(PetscSectionGetOffset(gsec.sec, p, &goff)) + print(dof, cdof, loff, goff) + goff = cabs(goff) if cdof > 0: + CHKERR(PetscSectionGetConstraintIndices(lsec.sec, p, &dof_array)) + for c in range(dof): + val[loff + c] = -2 + for c in range(cdof): + val[loff + dof_array[c]] = -1 + cc = 0 for c in range(dof): - val[loff + c] = -1 + if val[loff + c] < -1: + print(c, cc) + val[loff + c] = goff + cc + cc += 1 else: - goff = cabs(goff) for c in range(dof): val[loff + c] = goff + c + #CHKERR(PetscFree(dof_array)) return val diff --git a/firedrake/cython/petschdr.pxi b/firedrake/cython/petschdr.pxi index 55786e7184..9a0bff609d 100644 --- a/firedrake/cython/petschdr.pxi +++ b/firedrake/cython/petschdr.pxi @@ -97,6 +97,7 @@ cdef extern from "petscis.h" nogil: int PetscSectionGetConstraintDof(PETSc.PetscSection,PetscInt,PetscInt*) int PetscSectionSetConstraintDof(PETSc.PetscSection,PetscInt,PetscInt) int PetscSectionSetConstraintIndices(PETSc.PetscSection,PetscInt, PetscInt[]) + int PetscSectionGetConstraintIndices(PETSc.PetscSection,PetscInt, const PetscInt**) int PetscSectionGetMaxDof(PETSc.PetscSection,PetscInt*) int PetscSectionSetPermutation(PETSc.PetscSection,PETSc.PetscIS) int ISGetIndices(PETSc.PetscIS,PetscInt*[]) diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 53f69e92ce..28e56515e1 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -16,7 +16,8 @@ from pyop2 import op2, mpi from firedrake import dmhooks, utils -from firedrake.functionspacedata import get_shared_data, create_element +from firedrake import extrusion_utils as eutils +from firedrake.functionspacedata import get_shared_data, create_element, get_node_set from firedrake.petsc import PETSc @@ -888,7 +889,23 @@ def __init__(self, function_space, name=None, boundary_set=frozenset()): [str(i) for i in self.boundary_set]))) def set_shared_data(self): - sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set) + if False:#self._mesh.cell_set._extruded: + sdata = get_shared_data(self._mesh, self.ufl_element(), None) + finat_element = create_element(self.ufl_element()) + real_tensorproduct = eutils.is_real_tensor_product_element(finat_element) + entity_dofs = finat_element.entity_dofs() + nodes_per_entity = tuple(self._mesh.make_dofs_per_plex_entity(entity_dofs)) + key = (nodes_per_entity, real_tensorproduct, self.boundary_set) + node_set = get_node_set(self._mesh, key) + # Get constrained global section. + gsec = node_set.halo.dm.getGlobalSection() + # Set unconstrained local section. + node_set.halo.dm.setLocalSection(sdata.node_set.halo.dm.getLocalSection()) + # Set constrained global section before anyone calls dm.getGlobalSection(). + node_set.halo.dm.setGlobalSection(gsec) + sdata.node_set = node_set + else: + sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set) self._shared_data = sdata self.node_set = sdata.node_set r"""A :class:`pyop2.types.set.Set` representing the function space nodes.""" diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 906547ae72..0c3f37b3bc 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -1245,7 +1245,7 @@ def _renumber_entities(self, reorder): else: # No reordering reordering = None - return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering)[0] + return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering) @utils.cached_property def cell_closure(self): @@ -1979,7 +1979,7 @@ def _renumber_entities(self, reorder): perm_is.setIndices(perm) return perm_is else: - return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None)[0] + return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None) @utils.cached_property # TODO: Recalculate if mesh moves def cell_closure(self): diff --git a/tests/firedrake/regression/test_restricted_function_space.py b/tests/firedrake/regression/test_restricted_function_space.py index cae361cebf..190e597ab2 100644 --- a/tests/firedrake/regression/test_restricted_function_space.py +++ b/tests/firedrake/regression/test_restricted_function_space.py @@ -207,8 +207,20 @@ def test_restricted_mixed_spaces(i, j): assert errornorm(w.subfunctions[1], w2.subfunctions[1]) < 1.e-12 +@pytest.mark.parallel(nprocs=2) def test_restricted_function_space_extrusion(): mesh = UnitIntervalMesh(2) - extm = ExtrudedMesh(mesh, 2) + extm = ExtrudedMesh(mesh, 1) V = FunctionSpace(extm, "CG", 2) V_res = RestrictedFunctionSpace(V, boundary_set=["bottom"]) + #mesh._dm_renumbering.view() + #mesh.topology_dm.viewFromOptions("-dm_view") + #V_res._shared_data.node_set.halo.dm.getLocalSection().view() + #V_res._shared_data.node_set.halo.dm.getGlobalSection().view() + #V._shared_data.node_set.halo.dm.getGlobalSection().view() + lgmap = V_res.topological.local_to_global_map(None) + if mesh.comm.rank == 0: + lgmap_expected = [-1, 0, 1, -1, 2, 3, -1, 8, 9, -1, 4, 5, -1, 6, 7] + else: + lgmap_expected = [-1, 4, 5, -1, 6, 7, -1, 8, 9, -1, 0, 1, -1, 2, 3] + assert np.allclose(lgmap, lgmap_expected)