From c1de3148ee29d416ce51ace11c381085cbe92a2d Mon Sep 17 00:00:00 2001 From: Emma Rothwell Date: Fri, 5 Jan 2024 10:41:07 +0000 Subject: [PATCH] plex_renumbering counting constrained nodes first. --- firedrake/cython/dmcommon.pyx | 45 ++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/firedrake/cython/dmcommon.pyx b/firedrake/cython/dmcommon.pyx index 3fcacbcf28..da35df068d 100644 --- a/firedrake/cython/dmcommon.pyx +++ b/firedrake/cython/dmcommon.pyx @@ -2337,6 +2337,7 @@ def plex_renumbering(PETSc.DM plex, PETSc.IS facet_is = None PETSc.IS perm_is = None PetscBT seen = NULL + PetscBT seen_boundary = NULL PetscBool has_point DMLabel labels[3] bint reorder = reordering is not None @@ -2346,6 +2347,8 @@ def plex_renumbering(PETSc.DM plex, get_height_stratum(plex.dm, 0, &cStart, &cEnd) CHKERR(PetscMalloc1(pEnd - pStart, &perm)) CHKERR(PetscBTCreate(pEnd - pStart, &seen)) + if boundary_set: + CHKERR(PetscBTCreate(pEnd - pStart, &seen_boundary)) ncells = np.zeros(3, dtype=IntType) # Get label pointers and label-specific array indices @@ -2355,12 +2358,10 @@ def plex_renumbering(PETSc.DM plex, for l in range(3): CHKERR(DMLabelCreateIndex(labels[l], pStart, pEnd)) entity_classes = entity_classes.astype(IntType) - lidx = np.zeros(4, dtype=IntType) - lidx[1] = sum(entity_classes[:, 0]) - lidx[2] = sum(entity_classes[:, 1]) - lidx[3] = lidx[2] - 1 # used here to count constrained owned dofs - # Get boundary points (if the boundary_set exists) + # Get boundary points (if the boundary_set exists) and count each type + constrained_core = 0 + constrained_owned = 0 boundary_points = np.array([]) if boundary_set: for marker in boundary_set: @@ -2373,18 +2374,36 @@ def plex_renumbering(PETSc.DM plex, if n == 0: continue points = plex.getStratumIS(label, marker).indices - boundary_points = np.concatenate((boundary_points, points)) - + for i in range(n): + p = points[i] + if not PetscBTLookup(seen_boundary, p): + for l in range(3): + CHKERR(DMLabelHasPoint(labels[l], p, &has_point)) + if has_point: + PetscBTSet(seen_boundary, p) + if l == 1: + constrained_owned += 1 + elif l == 0: + constrained_core += 1 + break + + # assign lists + lidx = np.zeros(3, dtype=IntType) + lidx[1] = sum(entity_classes[:, 0]) + lidx[2] = sum(entity_classes[:, 1]) + if boundary_set: + lidx[1] -= constrained_core + lidx = np.concatenate((lidx, np.array([lidx[2]], dtype=IntType))) + lidx[2] -= (constrained_core + constrained_owned) + for c in range(pStart, pEnd): if reorder: cell = reordering[c] else: cell = c - # We always re-order cell-wise so that we inherit any cache # coherency from the reordering provided by the Plex if cStart <= cell < cEnd: - # Get cell closure get_transitive_closure(plex.dm, cell, PETSC_TRUE, &nclosure, &closure) for ci in range(nclosure): @@ -2394,10 +2413,10 @@ def plex_renumbering(PETSc.DM plex, CHKERR(DMLabelHasPoint(labels[l], p, &has_point)) if has_point: PetscBTSet(seen, p) - if boundary_set and p in boundary_points and l == 1: + if boundary_set and PetscBTLookup(seen_boundary, p) and l <= 1: # push boundary point to end of constrained owned dofs perm[lidx[3]] = p - lidx[3] -= 1 + lidx[3] += 1 else: perm[lidx[l]] = p lidx[l] += 1 @@ -2409,6 +2428,10 @@ def plex_renumbering(PETSc.DM plex, CHKERR(DMLabelDestroyIndex(labels[c])) CHKERR(PetscBTDestroy(&seen)) + + if boundary_set: + CHKERR(PetscBTDestroy(&seen_boundary)) + perm_is = PETSc.IS().create(comm=plex.comm) perm_is.setType("general") CHKERR(ISGeneralSetIndices(perm_is.iset, pEnd - pStart,