Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restricted Cofunction RHS #3922

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a86614a
Restricted Cofunction RHS
pbrubeck Dec 11, 2024
aef7886
Fix BCs on Cofunction
pbrubeck Dec 11, 2024
8e5603e
LinearSolver: check function spaces
pbrubeck Dec 11, 2024
2dd4f76
assemble(form, zero_bc_nodes=True) as default
pbrubeck Dec 12, 2024
3c5e64f
Fix FunctionAssignBlock
pbrubeck Dec 12, 2024
e3449f5
Allow Cofunction.assign take in constants
pbrubeck Dec 12, 2024
b755c81
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 13, 2024
0b0296d
Merge branch 'pbrubeck/fix/restricted-cofunction' of github.com:fired…
pbrubeck Dec 13, 2024
40cf6d7
suggestion from code review
pbrubeck Dec 13, 2024
fe30b48
more suggestions from review
pbrubeck Dec 19, 2024
3d49f31
remove BaseFormAssembler test
pbrubeck Dec 19, 2024
6742374
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 20, 2024
df04f4b
only supply relevant kwargs to OneFormAssembler
pbrubeck Dec 20, 2024
474edb3
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 20, 2024
950e42d
Only interpolate the residual, not every cofunction in the RHS
pbrubeck Dec 21, 2024
a86f3f5
DROP BEFORE MERGE
pbrubeck Dec 21, 2024
337d087
Fix tests
pbrubeck Dec 21, 2024
ed34164
Fix adjoint utils
pbrubeck Dec 22, 2024
027ad37
More robust test for (unrestricted) Cofunction RHS
pbrubeck Dec 22, 2024
885958f
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Dec 23, 2024
b48c77c
Merge branch 'master' into pbrubeck/fix/restricted-cofunction
pbrubeck Jan 3, 2025
d68113f
set bcs directly on diagonal Cofunction
pbrubeck Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions demos/netgen/netgen_mesh.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ We will now show how to solve the Poisson problem on a high-order mesh, of order

bc = DirichletBC(V, 0.0, [1])
A = assemble(a, bcs=bc)
b = assemble(l)
bc.apply(b)
b = assemble(l, bcs=bc)
solve(A, sol, b, solver_parameters={"ksp_type": "cg", "pc_type": "lu"})

VTKFile("output/Sphere.pvd").write(sol)
Expand Down
6 changes: 3 additions & 3 deletions firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = None
for adj_input in adj_inputs:
if isconstant(c):
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
if self.function_space != self.parent_space:
vec = extract_bc_subvector(
Expand Down Expand Up @@ -88,11 +88,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# you can even use the Function outside its domain.
# For now we will just assume the FunctionSpace is the same for
# the BC and the Function.
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
r = extract_bc_subvector(
adj_value, c.function_space(), bc
)
).riesz_representation("l2")
if adj_output is None:
adj_output = r
else:
Expand Down
1 change: 1 addition & 0 deletions firedrake/adjoint_utils/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
)
diff_expr_assembled = firedrake.Function(adj_input_func.function_space())
diff_expr_assembled.interpolate(ufl.conj(diff_expr))
diff_expr_assembled = diff_expr_assembled.riesz_representation(riesz_map="l2")
adj_output = firedrake.Function(
R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func))
)
Expand Down
19 changes: 7 additions & 12 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,12 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):

def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
kwargs = self.assemble_kwargs.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
kwargs["bcs"] = bcs
dFdu = self._assemble_dFdu_adj(dFdu_adj_form, **kwargs)
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)

for bc in bcs:
bc.apply(dJdu)
bc.zero(dJdu)

adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
Expand All @@ -219,10 +217,8 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
return adj_sol, adj_sol_bdy

def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
adj_sol_bdy = firedrake.Function(
self.function_space.dual(), dJdu.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)).dat)
return adj_sol_bdy
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
return adj_sol_bdy.riesz_representation("l2")

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
Expand Down Expand Up @@ -654,9 +650,8 @@ def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
def _adjoint_solve(self, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
for bc in bcs:
bc.apply(dJdu)
for bc in self.bcs:
bc.zero(dJdu)

if (
self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
Expand Down Expand Up @@ -876,7 +871,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs):
self.add_dependency(bc, no_duplicates=True)

def apply_mixedmass(self, a):
b = firedrake.Function(self.target_space)
b = firedrake.Function(self.target_space.dual())
with a.dat.vec_ro as vsrc, b.dat.vec_wo as vrhs:
self.mixed_mass.mult(vsrc, vrhs)
return b
Expand Down
29 changes: 14 additions & 15 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def assemble(expr, *args, **kwargs):
zero_bc_nodes : bool
If `True`, set the boundary condition nodes in the
output tensor to zero rather than to the values prescribed by the
boundary condition. Default is `False`.
boundary condition. Default is `True`.
diagonal : bool
If assembling a matrix is it diagonal?
weight : float
Expand Down Expand Up @@ -156,7 +156,7 @@ def get_assembler(form, *args, **kwargs):
return ZeroFormAssembler(form, form_compiler_parameters=fc_params)
elif len(form.arguments()) == 1 or diagonal:
return OneFormAssembler(form, *args, bcs=bcs, form_compiler_parameters=fc_params, needs_zeroing=kwargs.get('needs_zeroing', True),
zero_bc_nodes=kwargs.get('zero_bc_nodes', False), diagonal=diagonal)
zero_bc_nodes=kwargs.get('zero_bc_nodes', True), diagonal=diagonal)
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
elif len(form.arguments()) == 2:
return TwoFormAssembler(form, *args, **kwargs)
else:
Expand Down Expand Up @@ -308,7 +308,7 @@ def __init__(self,
sub_mat_type=None,
options_prefix=None,
appctx=None,
zero_bc_nodes=False,
zero_bc_nodes=True,
diagonal=False,
weight=1.0,
allocation_integral_types=None):
Expand Down Expand Up @@ -406,7 +406,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params)
elif rank == 1 or (rank == 2 and self._diagonal):
assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal)
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
elif rank == 2:
assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
Expand Down Expand Up @@ -1149,14 +1149,15 @@ class OneFormAssembler(ParloopFormAssembler):

@classmethod
def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=False, diagonal=False, weight=1.0):
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
bcs = solving._extract_bcs(bcs)
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal

@FormAssembler._skip_if_initialised
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=False, diagonal=False, weight=1.0):
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
self._weight = weight
self._diagonal = diagonal
self._zero_bc_nodes = zero_bc_nodes
if self._diagonal and any(isinstance(bc, EquationBCSplit) for bc in self._bcs):
Expand Down Expand Up @@ -1185,23 +1186,21 @@ def _apply_bc(self, tensor, bc):
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if not self._zero_bc_nodes:
tensor_func = tensor.riesz_representation(riesz_map="l2")
if self._diagonal:
bc.set(tensor_func, 1)
else:
bc.apply(tensor_func)
tensor.assign(tensor_func.riesz_representation(riesz_map="l2"))
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# We cannot set primal data on a dual Cofunction, this will throw an error
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space():
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
Expand Down
8 changes: 6 additions & 2 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,12 @@ def assign(self, expr, subset=None, expr_from_assemble=False):
return self.assign(
assembled_expr, subset=subset,
expr_from_assemble=True)

raise ValueError('Cannot assign %s' % expr)
elif expr == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do the same for ufl.classes.Zero on line 199 so this should probably be combined. In general I am a bit concerned about this code because we do not tape Assigner. @Ig-dolci knows more I believe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I copied this from Function.assign, and just noticed that Function.assign is decorated with @FunctionMixin._ad_annotate_assign. This is required for bc.set(Cofunction, float) if we want to set diagonal entries to 1.

self.dat.zero(subset=subset)
else:
from firedrake.assign import Assigner
Assigner(self, expr, subset).assign()
return self

def riesz_representation(self, riesz_map='L2', **solver_options):
"""Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map.
Expand Down
4 changes: 4 additions & 0 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ufl
import finat.ufl

from ufl.duals import is_dual, is_primal
from pyop2 import op2, mpi
from pyop2.utils import as_tuple

Expand Down Expand Up @@ -296,6 +297,9 @@ def restore_work_function(self, function):
cache[function] = False

def __eq__(self, other):
if is_primal(self) != is_primal(other) or \
is_dual(self) != is_dual(other):
return False
try:
return self.topological == other.topological and \
self.mesh() is other.mesh()
Expand Down
5 changes: 5 additions & 0 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def solve(self, x, b):
if not isinstance(b, (function.Function, cofunction.Cofunction)):
raise TypeError("Provided RHS is a '%s', not a Function or Cofunction" % type(b).__name__)

if x.function_space() != self.trial_space or b.function_space() != self.test_space.dual():
# When solving `Ax = b`, with A: V x U -> R, or equivalently A: V -> U*,
# we need to make sure that x and b belong to V and U*, respectively.
raise ValueError("Mismatching function spaces.")
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved

if len(self.trial_space) > 1 and self.nullspace is not None:
self.nullspace._apply(self.trial_space.dof_dset.field_ises)
if len(self.test_space) > 1 and self.transpose_nullspace is not None:
Expand Down
2 changes: 1 addition & 1 deletion firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(self, a, row_bcs=[], col_bcs=[],
self._assemble_action = get_assembler(self.action,
bcs=self.bcs_action,
form_compiler_parameters=self.fc_params,
zero_bc_nodes=True).assemble
).assemble

# For assembling action(adjoint(f), self._y)
# Sorted list of equation bcs
Expand Down
2 changes: 1 addition & 1 deletion firedrake/slate/static_condensation/scpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def initialize(self, pc):
r_expr = reduced_sys.rhs

# Construct the condensed right-hand side
self._assemble_Srhs = get_assembler(r_expr, bcs=bcs, zero_bc_nodes=True, form_compiler_parameters=self.cxt.fc_params).assemble
self._assemble_Srhs = get_assembler(r_expr, bcs=bcs, form_compiler_parameters=self.cxt.fc_params).assemble

# Allocate and set the condensed operator
form_assembler = get_assembler(S_expr, bcs=bcs, form_compiler_parameters=self.cxt.fc_params, mat_type=mat_type, options_prefix=prefix, appctx=self.get_appctx(pc))
Expand Down
2 changes: 1 addition & 1 deletion firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=True).assemble
).assemble

self._jacobian_assembled = False
self._splits = {}
Expand Down
8 changes: 6 additions & 2 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PETSc, OptionsManager, flatten_parameters, DEFAULT_KSP_PARAMETERS,
DEFAULT_SNES_PARAMETERS
)
from firedrake.function import Function
from firedrake.function import Function, Cofunction
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space
from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin
Expand Down Expand Up @@ -92,7 +92,11 @@ def __init__(self, F, u, bcs=None, J=None,
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
F_arg, = F.arguments()
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
replace_F = {F_arg: v_res, self.u: self.u_restrict}
for c in F.coefficients():
if c.function_space() == V.dual():
replace_F[c] = Cofunction(V_res.dual()).interpolate(c)
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
self.F = replace(F, replace_F)
v_arg, u_arg = self.J.arguments()
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
if self.Jp:
Expand Down
7 changes: 3 additions & 4 deletions tests/firedrake/multigrid/test_poisson_gmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,11 @@ def test_baseform_coarsening(solver_type, mixed):
a_terms.append(inner(grad(u), grad(v)) * dx)
a = sum(a_terms)

assemble_bcs = lambda L: assemble(L, bcs=bcs, zero_bc_nodes=True)
# These are equivalent right-hand sides
sources = [sum(forms), # purely symbolic linear form
assemble_bcs(sum(forms)), # purely numerical cofunction
sum(assemble_bcs(form) for form in forms), # symbolic combination of numerical cofunctions
forms[0] + assemble_bcs(sum(forms[1:])), # symbolic plus numerical
assemble(sum(forms), bcs=bcs), # purely numerical cofunction
sum(assemble(form, bcs=bcs) for form in forms), # symbolic combination of numerical cofunctions
forms[0] + assemble(sum(forms[1:]), bcs=bcs), # symbolic plus numerical
]
solutions = []
for L in sources:
Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/regression/test_assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_one_form_assembler_cache(mesh):
assert len(L._cache[_FORM_CACHE_KEY]) == 3

# changing zero_bc_nodes should increase the cache size
assemble(L, zero_bc_nodes=True)
assemble(L, zero_bc_nodes=False)
assert len(L._cache[_FORM_CACHE_KEY]) == 4


Expand Down
1 change: 1 addition & 0 deletions tests/firedrake/regression/test_assemble_baseform.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def test_zero_form(M, f, one):
assert abs(zero_form - 0.5 * np.prod(f.ufl_shape)) < 1.0e-12


@pytest.mark.xfail(reason="action(M, M) raises primal-dual TypeError")
pbrubeck marked this conversation as resolved.
Show resolved Hide resolved
def test_preprocess_form(M, a, f):
from ufl.algorithms import expand_indices, expand_derivatives

Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/regression/test_bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_bcs_rhs_assemble(a, V):
b1 = assemble(a)
b1_func = b1.riesz_representation(riesz_map="l2")
for bc in bcs:
bc.apply(b1_func)
bc.zero(b1_func)
b1.assign(b1_func.riesz_representation(riesz_map="l2"))
b2 = assemble(a, bcs=bcs)
assert np.allclose(b1.dat.data, b2.dat.data)
Expand Down
6 changes: 2 additions & 4 deletions tests/firedrake/regression/test_netgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def poisson(h, degree=2):

# Assembling matrix
A = assemble(a, bcs=bc)
b = assemble(l)
bc.apply(b)
b = assemble(l, bcs=bc)

# Solving the problem
solve(A, u, b, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"})
Expand Down Expand Up @@ -95,8 +94,7 @@ def poisson3D(h, degree=2):

# Assembling matrix
A = assemble(a, bcs=bc)
b = assemble(l)
bc.apply(b)
b = assemble(l, bcs=bc)

# Solving the problem
solve(A, u, b, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"})
Expand Down
9 changes: 6 additions & 3 deletions tests/firedrake/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def test_poisson_inhomogeneous_bcs_2(j):


@pytest.mark.parallel(nprocs=3)
def test_poisson_inhomogeneous_bcs_high_level_interface():
@pytest.mark.parametrize("assembled_rhs", [False, True], ids=("Form", "Cofunction"))
def test_poisson_inhomogeneous_bcs_high_level_interface(assembled_rhs):
mesh = UnitSquareMesh(8, 8)
V = FunctionSpace(mesh, "CG", 2)
bc1 = DirichletBC(V, 0., 1)
Expand All @@ -155,9 +156,11 @@ def test_poisson_inhomogeneous_bcs_high_level_interface():
v = TestFunction(V)
a = inner(grad(u), grad(v)) * dx
u = Function(V)
L = inner(Constant(0), v) * dx
L = inner(Constant(-2), v) * dx
if assembled_rhs:
L = assemble(L)
solve(a == L, u, bcs=[bc1, bc2], restrict=True)
assert errornorm(SpatialCoordinate(mesh)[0], u) < 1.e-12
assert errornorm(SpatialCoordinate(mesh)[0]**2, u) < 1.e-12


@pytest.mark.parametrize("j", [1, 2, 5])
Expand Down
Loading