Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restricted Cofunction RHS #3922

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a86614a
Restricted Cofunction RHS
pbrubeck Dec 11, 2024
aef7886
Fix BCs on Cofunction
pbrubeck Dec 11, 2024
8e5603e
LinearSolver: check function spaces
pbrubeck Dec 11, 2024
2dd4f76
assemble(form, zero_bc_nodes=True) as default
pbrubeck Dec 12, 2024
3c5e64f
Fix FunctionAssignBlock
pbrubeck Dec 12, 2024
e3449f5
Allow Cofunction.assign take in constants
pbrubeck Dec 12, 2024
b755c81
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 13, 2024
0b0296d
Merge branch 'pbrubeck/fix/restricted-cofunction' of github.com:fired…
pbrubeck Dec 13, 2024
40cf6d7
suggestion from code review
pbrubeck Dec 13, 2024
fe30b48
more suggestions from review
pbrubeck Dec 19, 2024
3d49f31
remove BaseFormAssembler test
pbrubeck Dec 19, 2024
6742374
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 20, 2024
df04f4b
only supply relevant kwargs to OneFormAssembler
pbrubeck Dec 20, 2024
474edb3
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 20, 2024
950e42d
Only interpolate the residual, not every cofunction in the RHS
pbrubeck Dec 21, 2024
a86f3f5
DROP BEFORE MERGE
pbrubeck Dec 21, 2024
337d087
Fix tests
pbrubeck Dec 21, 2024
ed34164
Fix adjoint utils
pbrubeck Dec 22, 2024
027ad37
More robust test for (unrestricted) Cofunction RHS
pbrubeck Dec 22, 2024
885958f
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 23, 2024
b48c77c
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Jan 3, 2025
d68113f
set bcs directly on diagonal Cofunction
pbrubeck Jan 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch ufl pbrubeck/fix/formsum-weights \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
3 changes: 1 addition & 2 deletions demos/netgen/netgen_mesh.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ We will now show how to solve the Poisson problem on a high-order mesh, of order

bc = DirichletBC(V, 0.0, [1])
A = assemble(a, bcs=bc)
b = assemble(l)
bc.apply(b)
b = assemble(l, bcs=bc)
solve(A, sol, b, solver_parameters={"ksp_type": "cg", "pc_type": "lu"})

VTKFile("output/Sphere.pvd").write(sol)
Expand Down
6 changes: 3 additions & 3 deletions firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = None
for adj_input in adj_inputs:
if isconstant(c):
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
if self.function_space != self.parent_space:
vec = extract_bc_subvector(
Expand Down Expand Up @@ -88,11 +88,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# you can even use the Function outside its domain.
# For now we will just assume the FunctionSpace is the same for
# the BC and the Function.
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
r = extract_bc_subvector(
adj_value, c.function_space(), bc
)
).riesz_representation("l2")
if adj_output is None:
adj_output = r
else:
Expand Down
1 change: 1 addition & 0 deletions firedrake/adjoint_utils/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
)
diff_expr_assembled = firedrake.Function(adj_input_func.function_space())
diff_expr_assembled.interpolate(ufl.conj(diff_expr))
diff_expr_assembled = diff_expr_assembled.riesz_representation(riesz_map="l2")
adj_output = firedrake.Function(
R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func))
)
Expand Down
26 changes: 12 additions & 14 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,12 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):

def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
kwargs = self.assemble_kwargs.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
kwargs["bcs"] = bcs
dFdu = self._assemble_dFdu_adj(dFdu_adj_form, **kwargs)
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)

for bc in bcs:
bc.apply(dJdu)
bc.zero(dJdu)

adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
Expand All @@ -219,10 +217,8 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
return adj_sol, adj_sol_bdy

def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
adj_sol_bdy = firedrake.Function(
self.function_space.dual(), dJdu.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)).dat)
return adj_sol_bdy
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
return adj_sol_bdy.riesz_representation("l2")

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
Expand Down Expand Up @@ -264,8 +260,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
return dFdm

dFdm = -firedrake.derivative(F_form, c_rep, trial_function)
dFdm = firedrake.adjoint(dFdm)
dFdm = dFdm * adj_sol
if isinstance(dFdm, ufl.Form):
dFdm = firedrake.adjoint(dFdm)
dFdm = firedrake.action(dFdm, adj_sol)
else:
dFdm = dFdm(adj_sol)
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
return dFdm

Expand Down Expand Up @@ -654,9 +653,8 @@ def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
def _adjoint_solve(self, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
for bc in bcs:
bc.apply(dJdu)
for bc in self.bcs:
bc.zero(dJdu)

if (
self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
Expand Down Expand Up @@ -876,7 +874,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs):
self.add_dependency(bc, no_duplicates=True)

def apply_mixedmass(self, a):
b = firedrake.Function(self.target_space)
b = firedrake.Function(self.target_space.dual())
with a.dat.vec_ro as vsrc, b.dat.vec_wo as vrhs:
self.mixed_mass.mult(vsrc, vrhs)
return b
Expand Down
53 changes: 31 additions & 22 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def assemble(expr, *args, **kwargs):
zero_bc_nodes : bool
If `True`, set the boundary condition nodes in the
output tensor to zero rather than to the values prescribed by the
boundary condition. Default is `False`.
boundary condition. Default is `True`.
diagonal : bool
If assembling a matrix is it diagonal?
weight : float
Expand Down Expand Up @@ -143,7 +143,6 @@ def get_assembler(form, *args, **kwargs):

"""
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
bcs = kwargs.get('bcs', None)
fc_params = kwargs.get('form_compiler_parameters', None)
if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed:
mat_type = kwargs.get('mat_type', None)
Expand All @@ -155,8 +154,13 @@ def get_assembler(form, *args, **kwargs):
if len(form.arguments()) == 0:
return ZeroFormAssembler(form, form_compiler_parameters=fc_params)
elif len(form.arguments()) == 1 or diagonal:
return OneFormAssembler(form, *args, bcs=bcs, form_compiler_parameters=fc_params, needs_zeroing=kwargs.get('needs_zeroing', True),
zero_bc_nodes=kwargs.get('zero_bc_nodes', False), diagonal=diagonal)
return OneFormAssembler(form, *args,
bcs=kwargs.get("bcs", None),
form_compiler_parameters=fc_params,
needs_zeroing=kwargs.get("needs_zeroing", True),
zero_bc_nodes=kwargs.get("zero_bc_nodes", True),
diagonal=diagonal,
weight=kwargs.get("weight", 1.0))
elif len(form.arguments()) == 2:
return TwoFormAssembler(form, *args, **kwargs)
else:
Expand Down Expand Up @@ -308,7 +312,7 @@ def __init__(self,
sub_mat_type=None,
options_prefix=None,
appctx=None,
zero_bc_nodes=False,
zero_bc_nodes=True,
diagonal=False,
weight=1.0,
allocation_integral_types=None):
Expand Down Expand Up @@ -381,6 +385,12 @@ def visitor(e, *operands):
visited = {}
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited)

# Apply BCs after assembly
rank = len(self._form.arguments())
if rank == 1:
for bc in self._bcs:
bc.zero(result)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
return tensor
Expand All @@ -405,8 +415,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
if rank == 0:
assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params)
elif rank == 1 or (rank == 2 and self._diagonal):
assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal)
assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
elif rank == 2:
assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
Expand Down Expand Up @@ -807,9 +817,9 @@ def restructure_base_form(expr, visited=None):
return ufl.action(expr, ustar)

# -- Case (6) -- #
if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()):
# Return ufl.Sum
return sum([c for c in expr.components()])
if isinstance(expr, ufl.FormSum) and all(ufl.duals.is_dual(a.function_space()) for a in expr.arguments()):
# Return ufl.Sum if we are assembling a FormSum with Coarguments (a primal expression)
return sum(w*c for w, c in zip(expr.weights(), expr.components()))
return expr

@staticmethod
Expand Down Expand Up @@ -1149,14 +1159,15 @@ class OneFormAssembler(ParloopFormAssembler):

@classmethod
def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=True, diagonal=False, weight=1.0):
bcs = solving._extract_bcs(bcs)
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal, weight

@FormAssembler._skip_if_initialised
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=True, diagonal=False, weight=1.0):
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
self._weight = weight
self._diagonal = diagonal
self._zero_bc_nodes = zero_bc_nodes
if self._diagonal and any(isinstance(bc, EquationBCSplit) for bc in self._bcs):
Expand Down Expand Up @@ -1185,23 +1196,21 @@ def _apply_bc(self, tensor, bc):
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if not self._zero_bc_nodes:
tensor_func = tensor.riesz_representation(riesz_map="l2")
if self._diagonal:
bc.set(tensor_func, 1)
else:
bc.apply(tensor_func)
tensor.assign(tensor_func.riesz_representation(riesz_map="l2"))
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space():
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,10 @@ def assign(self, expr, subset=None, expr_from_assemble=False):
return self.assign(
assembled_expr, subset=subset,
expr_from_assemble=True)

raise ValueError('Cannot assign %s' % expr)
else:
from firedrake.assign import Assigner
Assigner(self, expr, subset).assign()
return self

def riesz_representation(self, riesz_map='L2', **solver_options):
"""Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map.
Expand Down
4 changes: 4 additions & 0 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ufl
import finat.ufl

from ufl.duals import is_dual, is_primal
from pyop2 import op2, mpi
from pyop2.utils import as_tuple

Expand Down Expand Up @@ -296,6 +297,9 @@ def restore_work_function(self, function):
cache[function] = False

def __eq__(self, other):
if is_primal(self) != is_primal(other) or \
is_dual(self) != is_dual(other):
return False
try:
return self.topological == other.topological and \
self.mesh() is other.mesh()
Expand Down
7 changes: 7 additions & 0 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ def solve(self, x, b):
if not isinstance(b, (function.Function, cofunction.Cofunction)):
raise TypeError("Provided RHS is a '%s', not a Function or Cofunction" % type(b).__name__)

# When solving `Ax = b`, with A: V x U -> R, or equivalently A: V -> U*,
# we need to make sure that x and b belong to V and U*, respectively.
if x.function_space() != self.trial_space:
raise ValueError(f"x must be a Function in {self.trial_space}.")
if b.function_space() != self.test_space.dual():
raise ValueError(f"b must be a Cofunction in {self.test_space.dual()}.")

if len(self.trial_space) > 1 and self.nullspace is not None:
self.nullspace._apply(self.trial_space.dof_dset.field_ises)
if len(self.test_space) > 1 and self.transpose_nullspace is not None:
Expand Down
6 changes: 2 additions & 4 deletions firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(self, a, row_bcs=[], col_bcs=[],
self._assemble_action = get_assembler(self.action,
bcs=self.bcs_action,
form_compiler_parameters=self.fc_params,
zero_bc_nodes=True).assemble
).assemble

# For assembling action(adjoint(f), self._y)
# Sorted list of equation bcs
Expand Down Expand Up @@ -183,11 +183,9 @@ def _assemble_diagonal(self):

def getDiagonal(self, mat, vec):
self._assemble_diagonal(tensor=self._diagonal)
diagonal_func = self._diagonal.riesz_representation(riesz_map="l2")
for bc in self.bcs:
# Operator is identity on boundary nodes
bc.set(diagonal_func, 1)
self._diagonal.assign(diagonal_func.riesz_representation(riesz_map="l2"))
bc.set(self._diagonal, 1)
with self._diagonal.dat.vec_ro as v:
v.copy(vec)

Expand Down
2 changes: 1 addition & 1 deletion firedrake/slate/static_condensation/scpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def initialize(self, pc):
r_expr = reduced_sys.rhs

# Construct the condensed right-hand side
self._assemble_Srhs = get_assembler(r_expr, bcs=bcs, zero_bc_nodes=True, form_compiler_parameters=self.cxt.fc_params).assemble
self._assemble_Srhs = get_assembler(r_expr, bcs=bcs, form_compiler_parameters=self.cxt.fc_params).assemble

# Allocate and set the condensed operator
form_assembler = get_assembler(S_expr, bcs=bcs, form_compiler_parameters=self.cxt.fc_params, mat_type=mat_type, options_prefix=prefix, appctx=self.get_appctx(pc))
Expand Down
2 changes: 1 addition & 1 deletion firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=True).assemble
).assemble

self._jacobian_assembled = False
self._splits = {}
Expand Down
12 changes: 8 additions & 4 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
DEFAULT_SNES_PARAMETERS
)
from firedrake.function import Function
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake.ufl_expr import TrialFunction, TestFunction, action
from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space
from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin
from ufl import replace
from firedrake.__future__ import interpolate
from ufl import replace, Form

__all__ = ["LinearVariationalProblem",
"LinearVariationalSolver",
Expand Down Expand Up @@ -91,8 +92,11 @@ def __init__(self, F, u, bcs=None, J=None,
bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
F_arg, = F.arguments()
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
if isinstance(F, Form):
F_arg, = F.arguments()
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
else:
self.F = action(replace(F, {self.u: self.u_restrict}), interpolate(v_res, V))
v_arg, u_arg = self.J.arguments()
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
if self.Jp:
Expand Down
7 changes: 3 additions & 4 deletions tests/firedrake/multigrid/test_poisson_gmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,11 @@ def test_baseform_coarsening(solver_type, mixed):
a_terms.append(inner(grad(u), grad(v)) * dx)
a = sum(a_terms)

assemble_bcs = lambda L: assemble(L, bcs=bcs, zero_bc_nodes=True)
# These are equivalent right-hand sides
sources = [sum(forms), # purely symbolic linear form
assemble_bcs(sum(forms)), # purely numerical cofunction
sum(assemble_bcs(form) for form in forms), # symbolic combination of numerical cofunctions
forms[0] + assemble_bcs(sum(forms[1:])), # symbolic plus numerical
assemble(sum(forms), bcs=bcs), # purely numerical cofunction
sum(assemble(form, bcs=bcs) for form in forms), # symbolic combination of numerical cofunctions
forms[0] + assemble(sum(forms[1:]), bcs=bcs), # symbolic plus numerical
]
solutions = []
for L in sources:
Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/regression/test_assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_one_form_assembler_cache(mesh):
assert len(L._cache[_FORM_CACHE_KEY]) == 3

# changing zero_bc_nodes should increase the cache size
assemble(L, zero_bc_nodes=True)
assemble(L, zero_bc_nodes=False)
assert len(L._cache[_FORM_CACHE_KEY]) == 4


Expand Down
Loading
Loading