From bb04bb00860b9f697df226b1ec1ce35f35e4c2bb Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 31 Dec 2024 20:01:24 -0600 Subject: [PATCH] Replace empty Jacobians with ZeroBaseForm --- firedrake/adjoint_utils/variational_solver.py | 8 +-- firedrake/assemble.py | 11 ++- firedrake/formmanipulation.py | 67 +++++++++++-------- firedrake/preconditioners/massinv.py | 2 +- firedrake/solving_utils.py | 9 ++- firedrake/tsfc_interface.py | 4 +- .../firedrake/slate/test_assemble_tensors.py | 17 +++-- 7 files changed, 70 insertions(+), 48 deletions(-) diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index c90d2668e0..79eb09096e 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -2,6 +2,7 @@ from functools import wraps from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock +from firedrake.ufl_expr import derivative, adjoint from ufl import replace @@ -11,7 +12,6 @@ def _ad_annotate_init(init): @no_annotations @wraps(init) def wrapper(self, *args, **kwargs): - from firedrake import derivative, adjoint, TrialFunction init(self, *args, **kwargs) self._ad_F = self.F self._ad_u = self.u_restrict @@ -20,10 +20,8 @@ def wrapper(self, *args, **kwargs): try: # Some forms (e.g. SLATE tensors) are not currently # differentiable. - dFdu = derivative(self.F, - self.u_restrict, - TrialFunction(self.u_restrict.function_space())) - self._ad_adj_F = adjoint(dFdu) + dFdu = derivative(self.F, self.u_restrict) + self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) except (TypeError, NotImplementedError): self._ad_adj_F = None self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f3049ae01c..60c934b6c7 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args): @staticmethod def update_tensor(assembled_base_form, tensor): if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): - assembled_base_form.dat.copy(tensor.dat) + if isinstance(assembled_base_form, ufl.ZeroBaseForm): + tensor.dat.zero() + else: + assembled_base_form.dat.copy(tensor.dat) elif isinstance(tensor, matrix.MatrixBase): - # Uses the PETSc copy method. - assembled_base_form.petscmat.copy(tensor.petscmat) + if isinstance(assembled_base_form, ufl.ZeroBaseForm): + tensor.petscmat.zero() + else: + assembled_base_form.petscmat.copy(tensor.petscmat) else: raise NotImplementedError("Cannot update tensor of type %s" % type(tensor)) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 35a6789107..3179961df8 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -2,13 +2,29 @@ import numpy import collections -from ufl import as_vector +from ufl import as_vector, split, ZeroBaseForm from ufl.classes import Zero, FixedIndex, ListTensor from ufl.algorithms.map_integrands import map_integrand_dags +from ufl.algorithms import expand_derivatives from ufl.corealg.map_dag import MultiFunction, map_expr_dags from firedrake.petsc import PETSc from firedrake.ufl_expr import Argument +from firedrake.functionspace import MixedFunctionSpace, FunctionSpace + + +def subspace(V, indices): + try: + indices = tuple(indices) + except TypeError: + # Only one index provided. + indices = (indices, ) + if len(indices) == 1: + W = V[indices[0]] + W = FunctionSpace(W.mesh(), W.ufl_element()) + else: + W = MixedFunctionSpace([V[i] for i in indices]) + return W class ExtractSubBlock(MultiFunction): @@ -26,9 +42,11 @@ def indexed(self, o, child, multiindex): indices = multiindex.indices() if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices): if len(indices) == 1: - return child.ufl_operands[indices[0]._value] + return child[indices[0]] + elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)): + return child else: - return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices())) + return ListTensor(*(child[i] for i in indices)) return self.expr(o, child, multiindex) index_inliner = IndexInliner() @@ -57,6 +75,11 @@ def split(self, form, argument_indices): assert (idx[0] == 0 for idx in self.blocks.values()) return form f = map_integrand_dags(self, form) + f = expand_derivatives(f) + if f.empty(): + f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), indices), + arg.number(), part=arg.part()) + for arg, indices in zip(form.arguments(), argument_indices))) return f expr = MultiFunction.reuse_if_untouched @@ -85,8 +108,6 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds): @PETSc.Log.EventDecorator() def argument(self, o): - from ufl import split - from firedrake import MixedFunctionSpace, FunctionSpace V = o.function_space() if len(V) == 1: # Not on a mixed space, just return ourselves. @@ -95,36 +116,29 @@ def argument(self, o): if o in self._arg_cache: return self._arg_cache[o] - V_is = V.subfunctions indices = self.blocks[o.number()] try: indices = tuple(indices) - nidx = len(indices) except TypeError: # Only one index provided. indices = (indices, ) - nidx = 1 - if nidx == 1: - W = V_is[indices[0]] - W = FunctionSpace(W.mesh(), W.ufl_element()) - a = (Argument(W, o.number(), part=o.part()), ) - else: - W = MixedFunctionSpace([V_is[i] for i in indices]) - a = split(Argument(W, o.number(), part=o.part())) + W = subspace(V, indices) + a = Argument(W, o.number(), part=o.part()) + a = (a, ) if len(W) == 1 else split(a) + args = [] - for i in range(len(V_is)): + for i in range(len(V)): if i in indices: c = indices.index(i) a_ = a[c] if len(a_.ufl_shape) == 0: - args += [a_] + args.append(a_) else: - args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)] + args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape)) else: - args += [Zero() - for j in numpy.ndindex(V_is[i].value_shape)] + args.extend(Zero() for j in numpy.ndindex(V[i].value_shape)) return self._arg_cache.setdefault(o, as_vector(args)) @@ -168,11 +182,10 @@ def split_form(form, diagonal=False): assert len(shape) == 2 for idx in numpy.ndindex(shape): f = splitter.split(form, idx) - if len(f.integrals()) > 0: - if diagonal: - i, j = idx - if i != j: - continue - idx = (i, ) - forms.append(SplitForm(indices=idx, form=f)) + if diagonal: + i, j = idx + if i != j: + continue + idx = (i, ) + forms.append(SplitForm(indices=idx, form=f)) return tuple(forms) diff --git a/firedrake/preconditioners/massinv.py b/firedrake/preconditioners/massinv.py index 92f286c708..d29c704e8b 100644 --- a/firedrake/preconditioners/massinv.py +++ b/firedrake/preconditioners/massinv.py @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC): context, keyed on ``"mu"``. """ def form(self, pc, test, trial): - _, bcs = super(MassInvPC, self).form(pc, test, trial) + _, bcs = super(MassInvPC, self).form(pc) appctx = self.get_appctx(pc) mu = appctx.get("mu", 1.0) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index 9e843016b5..789a6f1880 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -12,8 +12,8 @@ def _make_reasons(reasons): - return dict([(getattr(reasons, r), r) - for r in dir(reasons) if not r.startswith('_')]) + return {getattr(reasons, r): r + for r in dir(reasons) if not r.startswith('_')} KSPReasons = _make_reasons(PETSc.KSP.ConvergedReason()) @@ -333,7 +333,7 @@ def split(self, fields): # Split it apart to shove in the form. subsplit = split(subu) # Permutation from field indexing to indexing of pieces - field_renumbering = dict([f, i] for i, f in enumerate(field)) + field_renumbering = {f: i for i, f in enumerate(field)} vec = [] for i, u in enumerate(us): if i in field: @@ -344,8 +344,7 @@ def split(self, fields): if u.ufl_shape == (): vec.append(u) else: - for idx in numpy.ndindex(u.ufl_shape): - vec.append(u[idx]) + vec.extend(u[idx] for idx in numpy.ndindex(u.ufl_shape)) # So now we have a new representation for the solution # vector in the old problem. For the fields we're going diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index ba10d79507..1117f54bd4 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -11,7 +11,7 @@ import ufl import finat.ufl -from ufl import Form, conj +from ufl import conj, Form, ZeroBaseForm from .ufl_expr import TestFunction from tsfc import compile_form as original_tsfc_compile_form @@ -203,7 +203,7 @@ def compile_form(form, name, parameters=None, split=True, interface=None, diagon iterable = ([(None, )*nargs, form], ) for idx, f in iterable: f = _real_mangle(f) - if not f.integrals(): + if isinstance(f, ZeroBaseForm) or f.empty(): # If we're assembling the R space component of a mixed argument, # and that component doesn't actually appear in the form then we # have an empty form, which we should not attempt to assemble. diff --git a/tests/firedrake/slate/test_assemble_tensors.py b/tests/firedrake/slate/test_assemble_tensors.py index 5aff159b9b..c35d43e27e 100644 --- a/tests/firedrake/slate/test_assemble_tensors.py +++ b/tests/firedrake/slate/test_assemble_tensors.py @@ -249,9 +249,13 @@ def test_matrix_subblocks(mesh): refs = dict(split_form(A.form)) _A = A.blocks for x, y in indices: - ref = assemble(refs[x, y]).M.values block = _A[x, y] - assert np.allclose(assemble(block).M.values, ref, rtol=1e-14) + ref = refs[x, y] + if isinstance(ref, Form): + assert np.allclose(assemble(block).M.values, + assemble(ref).M.values, rtol=1e-14) + elif isinstance(ref, ZeroBaseForm): + assert block.form == ref # Mixed blocks A0101 = _A[:2, :2] @@ -280,9 +284,12 @@ def test_matrix_subblocks(mesh): (A1212_10, refs[(2, 1)])] # Test assembly of blocks of mixed blocks - for tensor, form in items: - ref = assemble(form).M.values - assert np.allclose(assemble(tensor).M.values, ref, rtol=1e-14) + for block, ref in items: + if isinstance(ref, Form): + assert np.allclose(assemble(block).M.values, + assemble(ref).M.values, rtol=1e-14) + elif isinstance(ref, ZeroBaseForm): + assert block.form == ref def test_diagonal(mass, matrix_mixed_nofacet):