Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fieldsplit: replace empty Forms with ZeroBaseForm #3947

Merged
merged 42 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a86614a
Restricted Cofunction RHS
pbrubeck Dec 11, 2024
aef7886
Fix BCs on Cofunction
pbrubeck Dec 11, 2024
8e5603e
LinearSolver: check function spaces
pbrubeck Dec 11, 2024
2dd4f76
assemble(form, zero_bc_nodes=True) as default
pbrubeck Dec 12, 2024
3c5e64f
Fix FunctionAssignBlock
pbrubeck Dec 12, 2024
e3449f5
Allow Cofunction.assign take in constants
pbrubeck Dec 12, 2024
b755c81
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 13, 2024
0b0296d
Merge branch 'pbrubeck/fix/restricted-cofunction' of github.com:fired…
pbrubeck Dec 13, 2024
40cf6d7
suggestion from code review
pbrubeck Dec 13, 2024
fe30b48
more suggestions from review
pbrubeck Dec 19, 2024
3d49f31
remove BaseFormAssembler test
pbrubeck Dec 19, 2024
6742374
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 20, 2024
df04f4b
only supply relevant kwargs to OneFormAssembler
pbrubeck Dec 20, 2024
474edb3
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 20, 2024
950e42d
Only interpolate the residual, not every cofunction in the RHS
pbrubeck Dec 21, 2024
a86f3f5
DROP BEFORE MERGE
pbrubeck Dec 21, 2024
337d087
Fix tests
pbrubeck Dec 21, 2024
ed34164
Fix adjoint utils
pbrubeck Dec 22, 2024
027ad37
More robust test for (unrestricted) Cofunction RHS
pbrubeck Dec 22, 2024
885958f
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 23, 2024
2286596
DO NOT MERGE
pbrubeck Dec 30, 2024
bb04bb0
Replace empty Jacobians with ZeroBaseForm
pbrubeck Jan 1, 2025
d82039d
Split Cofunction
pbrubeck Jan 2, 2025
af53302
Do not split off-diagonal blocks if we only want the diagonal
pbrubeck Jan 2, 2025
7f40504
Zero-simplify slate Tensors
pbrubeck Jan 3, 2025
b48c77c
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Jan 3, 2025
d68113f
set bcs directly on diagonal Cofunction
pbrubeck Jan 3, 2025
3d06fc5
ImplicitMatrixContext: handle empty action
pbrubeck Jan 3, 2025
6078f93
Only extract constants referenced in the kernel
pbrubeck Jan 4, 2025
5894b49
Adjoint: only skip expand_derivatives if necessary
pbrubeck Jan 4, 2025
d99ba50
style
pbrubeck Jan 4, 2025
d6bb7dd
EquationBC: do not reconstruct empty Forms
pbrubeck Jan 5, 2025
ed58467
lower degree for EquationBC tests
pbrubeck Jan 6, 2025
2a0c03b
style
pbrubeck Jan 6, 2025
934ff6f
FunctionSpace: multiindex returns subspace
pbrubeck Jan 9, 2025
8688f5e
Revert WithGeometry.__getitem__
pbrubeck Jan 10, 2025
7c2354e
Merge branch 'master' into pbrubeck/simplify-indexed
pbrubeck Jan 15, 2025
e99ce9a
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Jan 15, 2025
70a45fd
Merge branch 'master' into pbrubeck/simplify-indexed
pbrubeck Jan 15, 2025
605e52f
DROP BEFORE MERGE (2)
pbrubeck Jan 15, 2025
f3c4ef6
Do not zero a ZeroBaseForm
pbrubeck Jan 15, 2025
8f7ca9b
Update .github/workflows/build.yml
pbrubeck Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions demos/netgen/netgen_mesh.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ We will now show how to solve the Poisson problem on a high-order mesh, of order

bc = DirichletBC(V, 0.0, [1])
A = assemble(a, bcs=bc)
b = assemble(l)
bc.apply(b)
b = assemble(l, bcs=bc)
solve(A, sol, b, solver_parameters={"ksp_type": "cg", "pc_type": "lu"})

VTKFile("output/Sphere.pvd").write(sol)
Expand Down
6 changes: 3 additions & 3 deletions firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = None
for adj_input in adj_inputs:
if isconstant(c):
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
if self.function_space != self.parent_space:
vec = extract_bc_subvector(
Expand Down Expand Up @@ -88,11 +88,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# you can even use the Function outside its domain.
# For now we will just assume the FunctionSpace is the same for
# the BC and the Function.
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
r = extract_bc_subvector(
adj_value, c.function_space(), bc
)
).riesz_representation("l2")
if adj_output is None:
adj_output = r
else:
Expand Down
1 change: 1 addition & 0 deletions firedrake/adjoint_utils/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
)
diff_expr_assembled = firedrake.Function(adj_input_func.function_space())
diff_expr_assembled.interpolate(ufl.conj(diff_expr))
diff_expr_assembled = diff_expr_assembled.riesz_representation(riesz_map="l2")
adj_output = firedrake.Function(
R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func))
)
Expand Down
26 changes: 12 additions & 14 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,12 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):

def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
kwargs = self.assemble_kwargs.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
kwargs["bcs"] = bcs
dFdu = self._assemble_dFdu_adj(dFdu_adj_form, **kwargs)
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)

for bc in bcs:
bc.apply(dJdu)
bc.zero(dJdu)

adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
Expand All @@ -219,10 +217,8 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
return adj_sol, adj_sol_bdy

def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
adj_sol_bdy = firedrake.Function(
self.function_space.dual(), dJdu.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)).dat)
return adj_sol_bdy
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
return adj_sol_bdy.riesz_representation("l2")

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
Expand Down Expand Up @@ -264,8 +260,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
return dFdm

dFdm = -firedrake.derivative(F_form, c_rep, trial_function)
dFdm = firedrake.adjoint(dFdm)
dFdm = dFdm * adj_sol
if isinstance(dFdm, ufl.Form):
dFdm = firedrake.adjoint(dFdm)
dFdm = firedrake.action(dFdm, adj_sol)
else:
dFdm = dFdm(adj_sol)
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
return dFdm

Expand Down Expand Up @@ -654,9 +653,8 @@ def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
def _adjoint_solve(self, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
for bc in bcs:
bc.apply(dJdu)
for bc in self.bcs:
bc.zero(dJdu)

if (
self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
Expand Down Expand Up @@ -876,7 +874,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs):
self.add_dependency(bc, no_duplicates=True)

def apply_mixedmass(self, a):
b = firedrake.Function(self.target_space)
b = firedrake.Function(self.target_space.dual())
with a.dat.vec_ro as vsrc, b.dat.vec_wo as vrhs:
self.mixed_mass.mult(vsrc, vrhs)
return b
Expand Down
13 changes: 8 additions & 5 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
from firedrake.ufl_expr import derivative, adjoint
from ufl import replace


Expand All @@ -11,7 +12,6 @@ def _ad_annotate_init(init):
@no_annotations
@wraps(init)
def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u_restrict
Expand All @@ -20,10 +20,13 @@ def wrapper(self, *args, **kwargs):
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
dFdu = derivative(self.F, self.u_restrict)
try:
self._ad_adj_F = adjoint(dFdu)
except ValueError:
# Try again without expanding derivatives,
# as dFdu might have been simplied to an empty Form
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
except (TypeError, NotImplementedError):
self._ad_adj_F = None
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
Expand Down
75 changes: 44 additions & 31 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def assemble(expr, *args, **kwargs):
zero_bc_nodes : bool
If `True`, set the boundary condition nodes in the
output tensor to zero rather than to the values prescribed by the
boundary condition. Default is `False`.
boundary condition. Default is `True`.
diagonal : bool
If assembling a matrix is it diagonal?
weight : float
Expand Down Expand Up @@ -143,7 +143,6 @@ def get_assembler(form, *args, **kwargs):

"""
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
bcs = kwargs.get('bcs', None)
fc_params = kwargs.get('form_compiler_parameters', None)
if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed:
mat_type = kwargs.get('mat_type', None)
Expand All @@ -155,8 +154,13 @@ def get_assembler(form, *args, **kwargs):
if len(form.arguments()) == 0:
return ZeroFormAssembler(form, form_compiler_parameters=fc_params)
elif len(form.arguments()) == 1 or diagonal:
return OneFormAssembler(form, *args, bcs=bcs, form_compiler_parameters=fc_params, needs_zeroing=kwargs.get('needs_zeroing', True),
zero_bc_nodes=kwargs.get('zero_bc_nodes', False), diagonal=diagonal)
return OneFormAssembler(form, *args,
bcs=kwargs.get("bcs", None),
form_compiler_parameters=fc_params,
needs_zeroing=kwargs.get("needs_zeroing", True),
zero_bc_nodes=kwargs.get("zero_bc_nodes", True),
diagonal=diagonal,
weight=kwargs.get("weight", 1.0))
elif len(form.arguments()) == 2:
return TwoFormAssembler(form, *args, **kwargs)
else:
Expand Down Expand Up @@ -308,7 +312,7 @@ def __init__(self,
sub_mat_type=None,
options_prefix=None,
appctx=None,
zero_bc_nodes=False,
zero_bc_nodes=True,
diagonal=False,
weight=1.0,
allocation_integral_types=None):
Expand Down Expand Up @@ -381,6 +385,12 @@ def visitor(e, *operands):
visited = {}
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited)

# Apply BCs after assembly
rank = len(self._form.arguments())
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
for bc in self._bcs:
bc.zero(result)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
return tensor
Expand All @@ -405,8 +415,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
if rank == 0:
assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params)
elif rank == 1 or (rank == 2 and self._diagonal):
assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal)
assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
elif rank == 2:
assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
Expand Down Expand Up @@ -577,10 +587,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
@staticmethod
def update_tensor(assembled_base_form, tensor):
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
assembled_base_form.dat.copy(tensor.dat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.dat.zero()
else:
assembled_base_form.dat.copy(tensor.dat)
elif isinstance(tensor, matrix.MatrixBase):
# Uses the PETSc copy method.
assembled_base_form.petscmat.copy(tensor.petscmat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.petscmat.zeroEntries()
else:
assembled_base_form.petscmat.copy(tensor.petscmat)
else:
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))

Expand Down Expand Up @@ -807,9 +822,9 @@ def restructure_base_form(expr, visited=None):
return ufl.action(expr, ustar)

# -- Case (6) -- #
if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()):
# Return ufl.Sum
return sum([c for c in expr.components()])
if isinstance(expr, ufl.FormSum) and all(ufl.duals.is_dual(a.function_space()) for a in expr.arguments()):
# Return ufl.Sum if we are assembling a FormSum with Coarguments (a primal expression)
return sum(w*c for w, c in zip(expr.weights(), expr.components()))
return expr

@staticmethod
Expand Down Expand Up @@ -1138,7 +1153,7 @@ class OneFormAssembler(ParloopFormAssembler):

Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
1-form.

Notes
Expand All @@ -1149,14 +1164,15 @@ class OneFormAssembler(ParloopFormAssembler):

@classmethod
def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=True, diagonal=False, weight=1.0):
bcs = solving._extract_bcs(bcs)
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal, weight

@FormAssembler._skip_if_initialised
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=True, diagonal=False, weight=1.0):
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
self._weight = weight
self._diagonal = diagonal
self._zero_bc_nodes = zero_bc_nodes
if self._diagonal and any(isinstance(bc, EquationBCSplit) for bc in self._bcs):
Expand Down Expand Up @@ -1185,23 +1201,21 @@ def _apply_bc(self, tensor, bc):
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if not self._zero_bc_nodes:
tensor_func = tensor.riesz_representation(riesz_map="l2")
if self._diagonal:
bc.set(tensor_func, 1)
else:
bc.apply(tensor_func)
tensor.assign(tensor_func.riesz_representation(riesz_map="l2"))
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space():
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
Expand Down Expand Up @@ -2127,14 +2141,13 @@ def iter_active_coefficients(form, kinfo):

@staticmethod
def iter_constants(form, kinfo):
"""Yield the form constants"""
"""Yield the form constants referenced in ``kinfo``."""
if isinstance(form, slate.TensorBase):
for const in form.constants():
yield const
all_constants = form.constants()
else:
all_constants = extract_firedrake_constants(form)
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]

@staticmethod
def index_function_spaces(form, indices):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
return
rank = len(self.f.arguments())
splitter = ExtractSubBlock()
if rank == 1:
form = splitter.split(self.f, argument_indices=(row_field, ))
elif rank == 2:
form = splitter.split(self.f, argument_indices=(row_field, col_field))
form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank])
if isinstance(form, ufl.ZeroBaseForm) or form.empty():
# form is empty, do nothing
return
if u is not None:
form = firedrake.replace(form, {self.u: u})
if action_x is not None:
Expand Down
6 changes: 4 additions & 2 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ def assign(self, expr, subset=None, expr_from_assemble=False):
return self.assign(
assembled_expr, subset=subset,
expr_from_assemble=True)

raise ValueError('Cannot assign %s' % expr)
else:
from firedrake.assign import Assigner
Assigner(self, expr, subset).assign()
return self

def riesz_representation(self, riesz_map='L2', **solver_options):
"""Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map.
Expand Down
Loading
Loading