Skip to content

Commit

Permalink
Replace empty Jacobians with ZeroBaseForm
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 2, 2025
1 parent 2286596 commit bb04bb0
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 48 deletions.
8 changes: 3 additions & 5 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
from firedrake.ufl_expr import derivative, adjoint
from ufl import replace


Expand All @@ -11,7 +12,6 @@ def _ad_annotate_init(init):
@no_annotations
@wraps(init)
def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u_restrict
Expand All @@ -20,10 +20,8 @@ def wrapper(self, *args, **kwargs):
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
dFdu = derivative(self.F, self.u_restrict)
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
except (TypeError, NotImplementedError):
self._ad_adj_F = None
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
Expand Down
11 changes: 8 additions & 3 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
@staticmethod
def update_tensor(assembled_base_form, tensor):
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
assembled_base_form.dat.copy(tensor.dat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.dat.zero()
else:
assembled_base_form.dat.copy(tensor.dat)
elif isinstance(tensor, matrix.MatrixBase):
# Uses the PETSc copy method.
assembled_base_form.petscmat.copy(tensor.petscmat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.petscmat.zero()
else:
assembled_base_form.petscmat.copy(tensor.petscmat)
else:
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))

Expand Down
67 changes: 40 additions & 27 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,29 @@
import numpy
import collections

from ufl import as_vector
from ufl import as_vector, split, ZeroBaseForm
from ufl.classes import Zero, FixedIndex, ListTensor
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.functionspace import MixedFunctionSpace, FunctionSpace


def subspace(V, indices):
try:
indices = tuple(indices)
except TypeError:
# Only one index provided.
indices = (indices, )
if len(indices) == 1:
W = V[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
else:
W = MixedFunctionSpace([V[i] for i in indices])
return W


class ExtractSubBlock(MultiFunction):
Expand All @@ -26,9 +42,11 @@ def indexed(self, o, child, multiindex):
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child.ufl_operands[indices[0]._value]
return child[indices[0]]
elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)):
return child
else:
return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices()))
return ListTensor(*(child[i] for i in indices))
return self.expr(o, child, multiindex)

index_inliner = IndexInliner()
Expand Down Expand Up @@ -57,6 +75,11 @@ def split(self, form, argument_indices):
assert (idx[0] == 0 for idx in self.blocks.values())
return form
f = map_integrand_dags(self, form)
f = expand_derivatives(f)
if f.empty():
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), indices),
arg.number(), part=arg.part())
for arg, indices in zip(form.arguments(), argument_indices)))
return f

expr = MultiFunction.reuse_if_untouched
Expand Down Expand Up @@ -85,8 +108,6 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds):

@PETSc.Log.EventDecorator()
def argument(self, o):
from ufl import split
from firedrake import MixedFunctionSpace, FunctionSpace
V = o.function_space()
if len(V) == 1:
# Not on a mixed space, just return ourselves.
Expand All @@ -95,36 +116,29 @@ def argument(self, o):
if o in self._arg_cache:
return self._arg_cache[o]

V_is = V.subfunctions
indices = self.blocks[o.number()]

try:
indices = tuple(indices)
nidx = len(indices)
except TypeError:
# Only one index provided.
indices = (indices, )
nidx = 1

if nidx == 1:
W = V_is[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
a = (Argument(W, o.number(), part=o.part()), )
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
a = split(Argument(W, o.number(), part=o.part()))
W = subspace(V, indices)
a = Argument(W, o.number(), part=o.part())
a = (a, ) if len(W) == 1 else split(a)

args = []
for i in range(len(V_is)):
for i in range(len(V)):
if i in indices:
c = indices.index(i)
a_ = a[c]
if len(a_.ufl_shape) == 0:
args += [a_]
args.append(a_)
else:
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape))
else:
args += [Zero()
for j in numpy.ndindex(V_is[i].value_shape)]
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))


Expand Down Expand Up @@ -168,11 +182,10 @@ def split_form(form, diagonal=False):
assert len(shape) == 2
for idx in numpy.ndindex(shape):
f = splitter.split(form, idx)
if len(f.integrals()) > 0:
if diagonal:
i, j = idx
if i != j:
continue
idx = (i, )
forms.append(SplitForm(indices=idx, form=f))
if diagonal:
i, j = idx
if i != j:
continue
idx = (i, )
forms.append(SplitForm(indices=idx, form=f))
return tuple(forms)
2 changes: 1 addition & 1 deletion firedrake/preconditioners/massinv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC):
context, keyed on ``"mu"``.
"""
def form(self, pc, test, trial):
_, bcs = super(MassInvPC, self).form(pc, test, trial)
_, bcs = super(MassInvPC, self).form(pc)

appctx = self.get_appctx(pc)
mu = appctx.get("mu", 1.0)
Expand Down
9 changes: 4 additions & 5 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@


def _make_reasons(reasons):
return dict([(getattr(reasons, r), r)
for r in dir(reasons) if not r.startswith('_')])
return {getattr(reasons, r): r
for r in dir(reasons) if not r.startswith('_')}


KSPReasons = _make_reasons(PETSc.KSP.ConvergedReason())
Expand Down Expand Up @@ -333,7 +333,7 @@ def split(self, fields):
# Split it apart to shove in the form.
subsplit = split(subu)
# Permutation from field indexing to indexing of pieces
field_renumbering = dict([f, i] for i, f in enumerate(field))
field_renumbering = {f: i for i, f in enumerate(field)}
vec = []
for i, u in enumerate(us):
if i in field:
Expand All @@ -344,8 +344,7 @@ def split(self, fields):
if u.ufl_shape == ():
vec.append(u)
else:
for idx in numpy.ndindex(u.ufl_shape):
vec.append(u[idx])
vec.extend(u[idx] for idx in numpy.ndindex(u.ufl_shape))

# So now we have a new representation for the solution
# vector in the old problem. For the fields we're going
Expand Down
4 changes: 2 additions & 2 deletions firedrake/tsfc_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import ufl
import finat.ufl
from ufl import Form, conj
from ufl import conj, Form, ZeroBaseForm
from .ufl_expr import TestFunction

from tsfc import compile_form as original_tsfc_compile_form
Expand Down Expand Up @@ -203,7 +203,7 @@ def compile_form(form, name, parameters=None, split=True, interface=None, diagon
iterable = ([(None, )*nargs, form], )
for idx, f in iterable:
f = _real_mangle(f)
if not f.integrals():
if isinstance(f, ZeroBaseForm) or f.empty():
# If we're assembling the R space component of a mixed argument,
# and that component doesn't actually appear in the form then we
# have an empty form, which we should not attempt to assemble.
Expand Down
17 changes: 12 additions & 5 deletions tests/firedrake/slate/test_assemble_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,13 @@ def test_matrix_subblocks(mesh):
refs = dict(split_form(A.form))
_A = A.blocks
for x, y in indices:
ref = assemble(refs[x, y]).M.values
block = _A[x, y]
assert np.allclose(assemble(block).M.values, ref, rtol=1e-14)
ref = refs[x, y]
if isinstance(ref, Form):
assert np.allclose(assemble(block).M.values,
assemble(ref).M.values, rtol=1e-14)
elif isinstance(ref, ZeroBaseForm):
assert block.form == ref

# Mixed blocks
A0101 = _A[:2, :2]
Expand Down Expand Up @@ -280,9 +284,12 @@ def test_matrix_subblocks(mesh):
(A1212_10, refs[(2, 1)])]

# Test assembly of blocks of mixed blocks
for tensor, form in items:
ref = assemble(form).M.values
assert np.allclose(assemble(tensor).M.values, ref, rtol=1e-14)
for block, ref in items:
if isinstance(ref, Form):
assert np.allclose(assemble(block).M.values,
assemble(ref).M.values, rtol=1e-14)
elif isinstance(ref, ZeroBaseForm):
assert block.form == ref


def test_diagonal(mass, matrix_mixed_nofacet):
Expand Down

0 comments on commit bb04bb0

Please sign in to comment.