Skip to content

Commit

Permalink
Hopefully passing number of constrained nodes back up.
Browse files Browse the repository at this point in the history
  • Loading branch information
emmarothwell1 committed Jan 5, 2024
1 parent c1de314 commit cdde2d0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
13 changes: 9 additions & 4 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1232,10 +1232,11 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
get_chart(dm.dm, &pStart, &pEnd)
section.setChart(pStart, pEnd)
if boundary_set:
renumbering = plex_renumbering(dm, mesh._entity_classes, reordering=None,
renumbering, constrained_nodes = plex_renumbering(dm, mesh._entity_classes, reordering=None,
boundary_set=boundary_set)
else:
renumbering = mesh._dm_renumbering
constrained_nodes = 0
CHKERR(PetscSectionSetPermutation(section.sec, renumbering.iset))
dimension = get_topological_dimension(dm)
nodes = nodes_per_entity.reshape(dimension + 1, -1)
Expand Down Expand Up @@ -1290,7 +1291,7 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
dof_array[j] = j
CHKERR(PetscSectionSetConstraintIndices(section.sec, p, dof_array))
CHKERR(PetscFree(dof_array))
return section
return section, constrained_nodes


@cython.boundscheck(False)
Expand Down Expand Up @@ -2428,15 +2429,19 @@ def plex_renumbering(PETSc.DM plex,
CHKERR(DMLabelDestroyIndex(labels[c]))

CHKERR(PetscBTDestroy(&seen))

if boundary_set:
CHKERR(PetscBTDestroy(&seen_boundary))

perm_is = PETSc.IS().create(comm=plex.comm)
perm_is.setType("general")
CHKERR(ISGeneralSetIndices(perm_is.iset, pEnd - pStart,
perm, PETSC_OWN_POINTER))
return perm_is
constrained_nodes = constrained_core + constrained_owned
if boundary_set:
return perm_is, constrained_nodes
else:
return perm_is

@cython.boundscheck(False)
@cython.wraparound(False)
Expand Down
6 changes: 3 additions & 3 deletions firedrake/functionspacedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def get_node_set(mesh, key):
:returns: A :class:`pyop2.Set` for the function space nodes.
"""
nodes_per_entity, real_tensorproduct, _ = key
global_numbering = get_global_numbering(mesh, key)
global_numbering, constrained_nodes = get_global_numbering(mesh, key)
node_classes = mesh.node_classes(nodes_per_entity, real_tensorproduct=real_tensorproduct)
halo = halo_mod.Halo(mesh.topology_dm, global_numbering, comm=mesh.comm)
node_set = op2.Set(node_classes, halo=halo, comm=mesh.comm)
node_set = op2.Set(node_classes, halo=halo, comm=mesh.comm, constrained_nodes=constrained_nodes)
extruded = mesh.cell_set._extruded

assert global_numbering.getStorageSize() == node_set.total_size
Expand Down Expand Up @@ -425,7 +425,7 @@ def __init__(self, mesh, ufl_element, boundary_set=None):
# For non-scalar valued function spaces, there are multiple dofs per node.
key = (nodes_per_entity, real_tensorproduct, boundary_set)
# These are keyed only on nodes per topological entity.
global_numbering = get_global_numbering(mesh, key)
global_numbering, constrained_nodes = get_global_numbering(mesh, key)
node_set = get_node_set(mesh, key)

edofs_key = entity_dofs_key(entity_dofs)
Expand Down
6 changes: 3 additions & 3 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,16 +578,16 @@ def callback(self):
tdim = dmcommon.get_topological_dimension(self.topology_dm)
entity_dofs = np.zeros(tdim+1, dtype=IntType)
entity_dofs[-1] = 1
self._cell_numbering = self.create_section(entity_dofs)
self._cell_numbering = self.create_section(entity_dofs)[0]
if tdim == 0:
self._vertex_numbering = self._cell_numbering
else:
entity_dofs[:] = 0
entity_dofs[0] = 1
self._vertex_numbering = self.create_section(entity_dofs)
self._vertex_numbering = self.create_section(entity_dofs)[0]
entity_dofs[:] = 0
entity_dofs[-2] = 1
facet_numbering = self.create_section(entity_dofs)
facet_numbering = self.create_section(entity_dofs)[0]
self._facet_ordering = dmcommon.get_facet_ordering(self.topology_dm, facet_numbering)
self._callback = callback
self.name = name
Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
v2 = TestFunction(V_res)
restricted_form = u2 * v2 * dx

matrix_res = assemble(u2 * v2 * dx) # still fails here, 4x4 size but wrong 2x2 values (due to messing with plex_renumbering)
matrix_res = assemble(u2 * v2 * dx) # current work gives a number 11 SEGV error.
print(matrix_res.M.values)

matrix_normal_bcs = assemble(u * v * dx, bcs=[bc])
Expand Down

0 comments on commit cdde2d0

Please sign in to comment.