diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index 1d14e82c3b..fc5be8486a 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -57,48 +57,38 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, # Catch the case where adj_inputs[0] is just a float return adj_inputs[0] elif isconstant(block_variable.output): - R = block_variable.output._ad_function_space( - prepared.function_space().mesh() + adj_output = self._adj_assign_constant( + prepared, block_variable.output.function_space() ) - return self._adj_assign_constant(prepared, R) else: adj_output = firedrake.Function( - block_variable.output.function_space()) + block_variable.output.function_space() + ) adj_output.assign(prepared) - adj_output = adj_output.riesz_representation(riesz_map="l2") - return adj_output + return adj_output.riesz_representation(riesz_map="l2") else: # Linear combination expr, adj_input_func = prepared - adj_output = firedrake.Function(adj_input_func.function_space()) - if not isconstant(block_variable.output): - diff_expr = ufl.algorithms.expand_derivatives( - ufl.derivative( - expr, block_variable.saved_output, adj_input_func - ) + if isconstant(block_variable.output): + R = block_variable.output._ad_function_space( + adj_input_func.function_space().mesh() ) - # Firedrake does not support assignment of conjugate functions - adj_output.interpolate(ufl.conj(diff_expr)) - adj_output = adj_output.riesz_representation(riesz_map="l2") - else: - mesh = adj_output.function_space().mesh() diff_expr = ufl.algorithms.expand_derivatives( - ufl.derivative( - expr, - block_variable.saved_output, - firedrake.Constant(1., domain=mesh) - ) + ufl.derivative(expr, block_variable.saved_output, + firedrake.Function(R, val=1.0)) ) - adj_output.assign(diff_expr) - return adj_output.dat.inner(adj_input_func.dat) - - if isconstant(block_variable.output): - R = block_variable.output._ad_function_space( - adj_output.function_space().mesh() + diff_expr_assembled = firedrake.Function(adj_input_func.function_space()) + diff_expr_assembled.interpolate(ufl.conj(diff_expr)) + adj_output = firedrake.Function( + R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func)) ) - return self._adj_assign_constant(adj_output, R) else: - return adj_output + adj_output = firedrake.Function(adj_input_func.function_space()) + diff_expr = ufl.algorithms.expand_derivatives( + ufl.derivative(expr, block_variable.saved_output, adj_input_func) + ) + adj_output.interpolate(ufl.conj(diff_expr)) + return adj_output.riesz_representation(riesz_map="l2") def _adj_assign_constant(self, adj_output, constant_fs): r = firedrake.Function(constant_fs) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 04df8eebd6..b490f34128 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -7,7 +7,6 @@ import firedrake from .checkpointing import disk_checkpointing, CheckpointFunction, \ CheckpointBase, checkpoint_init_data, DelegatedFunctionCheckpoint -from numbers import Number class FunctionMixin(FloatingType): @@ -224,26 +223,27 @@ def _ad_create_checkpoint(self): def _ad_convert_riesz(self, value, options=None): from firedrake import Function, Cofunction - options = {} if options is None else options - riesz_representation = options.get("riesz_representation", "l2") - solver_options = options.get("solver_options", {}) V = options.get("function_space", self.function_space()) + if value == 0.: + # In adjoint-based differentiation, value == 0. arises only when + # the functional is independent on the control variable. + # In this case, we do not apply the Riesz map and return a zero + # Cofunction. + return Cofunction(V.dual()) - if riesz_representation != "l2" and not isinstance(value, Cofunction): - raise TypeError("Expected a Cofunction") - elif not isinstance(value, (Number, Cofunction, Function)): - raise TypeError("Expected a Cofunction, Function or a float") + options = {} if options is None else options + riesz_representation = options.get("riesz_representation", "L2") + solver_options = options.get("solver_options", {}) + if not isinstance(value, (Cofunction, Function)): + raise TypeError("Expected a Cofunction or a Function") if riesz_representation == "l2": - if isinstance(value, (Cofunction, Function)): - return Function(V, val=value.dat) - else: - f = Function(V) - with stop_annotating(): - f.assign(value) - return f + return Function(V, val=value.dat) elif riesz_representation in ("L2", "H1"): + if not isinstance(value, Cofunction): + raise TypeError("Expected a Cofunction") + ret = Function(V) a = self._define_riesz_map_form(riesz_representation, V) firedrake.solve(a == value, ret, **solver_options) @@ -253,7 +253,7 @@ def _ad_convert_riesz(self, value, options=None): return riesz_representation(value) else: - raise NotImplementedError( + raise ValueError( "Unknown Riesz representation %s" % riesz_representation) def _define_riesz_map_form(self, riesz_representation, V): @@ -276,9 +276,9 @@ def _define_riesz_map_form(self, riesz_representation, V): @no_annotations def _ad_convert_type(self, value, options=None): # `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz` - options = {} if options is None else options - riesz_representation = options.get("riesz_representation", "L2") - if riesz_representation is None: + options = {} if options is None else options.copy() + options.setdefault("riesz_representation", "L2") + if options["riesz_representation"] is None: return value else: return self._ad_convert_riesz(value, options=options) @@ -317,7 +317,7 @@ def _ad_dot(self, other, options=None): from firedrake import assemble options = {} if options is None else options - riesz_representation = options.get("riesz_representation", "l2") + riesz_representation = options.get("riesz_representation", "L2") if riesz_representation == "l2": return self.dat.inner(other.dat) elif riesz_representation == "L2": diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index 2ddaf3b8e5..e58b858493 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -83,7 +83,7 @@ def backward(ctx, grad_output): adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = F.derivative(adj_input=adj_input) + adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "l2"}) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 4f8b046606..bb3a33e07c 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -889,13 +889,43 @@ def test_cofunction_subfunctions_with_adjoint(): @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -def test_none_riesz_representation_to_derivative(): +def test_riesz_representation_for_adjoints(): + # Check if the Riesz representation norms for adjoints are working as expected. mesh = UnitIntervalMesh(1) space = FunctionSpace(mesh, "Lagrange", 1) - u = Function(space).interpolate(SpatialCoordinate(mesh)[0]) - J = assemble((u ** 2) * dx) - rf = ReducedFunctional(J, Control(u)) - assert isinstance(rf.derivative(), Function) - assert isinstance(rf.derivative(options={"riesz_representation": "H1"}), Function) - assert isinstance(rf.derivative(options={"riesz_representation": "L2"}), Function) - assert isinstance(rf.derivative(options={"riesz_representation": None}), Cofunction) + f = Function(space).interpolate(SpatialCoordinate(mesh)[0]) + J = assemble((f ** 2) * dx) + rf = ReducedFunctional(J, Control(f)) + with stop_annotating(): + v = TestFunction(space) + u = TrialFunction(space) + dJdu_cofunction = assemble(derivative((f ** 2) * dx, f, v)) + + # Riesz representation with l2 + dJdu_function_l2 = Function(space, val=dJdu_cofunction.dat) + + # Riesz representation with H1 + a = u * v * dx + inner(grad(u), grad(v)) * dx + dJdu_function_H1 = Function(space) + solve(a == dJdu_cofunction, dJdu_function_H1) + + # Riesz representation with L2 + a = u*v*dx + dJdu_function_L2 = Function(space) + solve(a == dJdu_cofunction, dJdu_function_L2) + + dJdu_none = rf.derivative(options={"riesz_representation": None}) + dJdu_l2 = rf.derivative(options={"riesz_representation": "l2"}) + dJdu_H1 = rf.derivative(options={"riesz_representation": "H1"}) + dJdu_L2 = rf.derivative(options={"riesz_representation": "L2"}) + dJdu_default_L2 = rf.derivative() + assert ( + isinstance(dJdu_none, Cofunction) and isinstance(dJdu_function_l2, Function) + and isinstance(dJdu_H1, Function) and isinstance(dJdu_default_L2, Function) + and isinstance(dJdu_L2, Function) + and np.allclose(dJdu_none.dat.data, dJdu_cofunction.dat.data) + and np.allclose(dJdu_l2.dat.data, dJdu_function_l2.dat.data) + and np.allclose(dJdu_H1.dat.data, dJdu_function_H1.dat.data) + and np.allclose(dJdu_default_L2.dat.data, dJdu_function_L2.dat.data) + and np.allclose(dJdu_L2.dat.data, dJdu_function_L2.dat.data) + )