From 51a7f98852ab65136b68ce992152feb7573e42dd Mon Sep 17 00:00:00 2001 From: Alberto Paganini <20994366+APaganini@users.noreply.github.com> Date: Fri, 17 May 2024 17:14:32 +0100 Subject: [PATCH 01/16] changed l2 into L2 --- firedrake/adjoint_utils/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 04df8eebd6..85cc47c124 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -225,7 +225,7 @@ def _ad_convert_riesz(self, value, options=None): from firedrake import Function, Cofunction options = {} if options is None else options - riesz_representation = options.get("riesz_representation", "l2") + riesz_representation = options.get("riesz_representation", "L2") solver_options = options.get("solver_options", {}) V = options.get("function_space", self.function_space()) From a0097dde1ef7d3d455378a2df54889b716c1dea7 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Tue, 21 May 2024 15:26:02 +0100 Subject: [PATCH 02/16] Cofunction assign number for L2 and H1 riesz maps --- firedrake/adjoint_utils/function.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 85cc47c124..3f429d02a6 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -229,9 +229,7 @@ def _ad_convert_riesz(self, value, options=None): solver_options = options.get("solver_options", {}) V = options.get("function_space", self.function_space()) - if riesz_representation != "l2" and not isinstance(value, Cofunction): - raise TypeError("Expected a Cofunction") - elif not isinstance(value, (Number, Cofunction, Function)): + if not isinstance(value, (Number, Cofunction, Function)): raise TypeError("Expected a Cofunction, Function or a float") if riesz_representation == "l2": @@ -242,11 +240,15 @@ def _ad_convert_riesz(self, value, options=None): with stop_annotating(): f.assign(value) return f - elif riesz_representation in ("L2", "H1"): + if isinstance(value, Number): + b = Cofunction(V.dual()) + b.assign(value) + else: + b = value ret = Function(V) a = self._define_riesz_map_form(riesz_representation, V) - firedrake.solve(a == value, ret, **solver_options) + firedrake.solve(a == b, ret, **solver_options) return ret elif callable(riesz_representation): @@ -277,6 +279,8 @@ def _define_riesz_map_form(self, riesz_representation, V): def _ad_convert_type(self, value, options=None): # `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz` options = {} if options is None else options + if "riesz_representation" not in options: + options = {"riesz_representation": "L2"} riesz_representation = options.get("riesz_representation", "L2") if riesz_representation is None: return value From 37f5ba7d2363a3c139ade1416f99c5e331bf25cf Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Tue, 21 May 2024 17:02:25 +0100 Subject: [PATCH 03/16] L2 riesz maps also _ad_dot method --- firedrake/adjoint_utils/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 3f429d02a6..712663a7da 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -321,7 +321,7 @@ def _ad_dot(self, other, options=None): from firedrake import assemble options = {} if options is None else options - riesz_representation = options.get("riesz_representation", "l2") + riesz_representation = options.get("riesz_representation", "L2") if riesz_representation == "l2": return self.dat.inner(other.dat) elif riesz_representation == "L2": From 03d49c4e9e29e65204b085f91acee9c6a3c1b6b1 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Tue, 21 May 2024 17:46:12 +0100 Subject: [PATCH 04/16] add test to verify riesz maps for grad adjoints --- firedrake/adjoint_utils/function.py | 14 ++++--- tests/regression/test_adjoint_operators.py | 47 +++++++++++++++++----- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 712663a7da..7e0f6eb9a6 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -232,6 +232,13 @@ def _ad_convert_riesz(self, value, options=None): if not isinstance(value, (Number, Cofunction, Function)): raise TypeError("Expected a Cofunction, Function or a float") + if value == 0.: + # Default of a function datatype is zero + return Function(V) + + if riesz_representation != "l2" and not isinstance(value, Cofunction): + raise TypeError("Expected a Cofunction") + if riesz_representation == "l2": if isinstance(value, (Cofunction, Function)): return Function(V, val=value.dat) @@ -241,14 +248,9 @@ def _ad_convert_riesz(self, value, options=None): f.assign(value) return f elif riesz_representation in ("L2", "H1"): - if isinstance(value, Number): - b = Cofunction(V.dual()) - b.assign(value) - else: - b = value ret = Function(V) a = self._define_riesz_map_form(riesz_representation, V) - firedrake.solve(a == b, ret, **solver_options) + firedrake.solve(a == value, ret, **solver_options) return ret elif callable(riesz_representation): diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 4f8b046606..3d8e718578 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -744,7 +744,8 @@ def test_copy_function(): def test_consecutive_nonlinear_solves(): mesh = UnitSquareMesh(1, 1) V = FunctionSpace(mesh, "CG", 1) - uic = Constant(2.0, domain=mesh) + uic = Function(V) + uic.assign(2.0) u1 = Function(V).assign(uic) u0 = Function(u1) v = TestFunction(V) @@ -756,8 +757,7 @@ def test_consecutive_nonlinear_solves(): solver.solve() J = assemble(u1**16*dx) rf = ReducedFunctional(J, Control(uic)) - h = Constant(0.01, domain=mesh) - assert taylor_test(rf, uic, h) > 1.9 + assert taylor_test(rf, uic, Function(V).assign(0.1)) > 1.9 @pytest.mark.skipcomplex @@ -892,10 +892,37 @@ def test_cofunction_subfunctions_with_adjoint(): def test_none_riesz_representation_to_derivative(): mesh = UnitIntervalMesh(1) space = FunctionSpace(mesh, "Lagrange", 1) - u = Function(space).interpolate(SpatialCoordinate(mesh)[0]) - J = assemble((u ** 2) * dx) - rf = ReducedFunctional(J, Control(u)) - assert isinstance(rf.derivative(), Function) - assert isinstance(rf.derivative(options={"riesz_representation": "H1"}), Function) - assert isinstance(rf.derivative(options={"riesz_representation": "L2"}), Function) - assert isinstance(rf.derivative(options={"riesz_representation": None}), Cofunction) + f = Function(space).interpolate(SpatialCoordinate(mesh)[0]) + J = assemble((f ** 2) * dx) + rf = ReducedFunctional(J, Control(f)) + with stop_annotating(): + v = TestFunction(space) + u = TrialFunction(space) + dJdu_cofunction = assemble(2 * inner(f, v) * dx) + + # Riesz representation with l2 + dJdu_function_l2 = Function(space, val=dJdu_cofunction.dat) + + # Riesz representation with H1 + a = firedrake.inner(u, v)*firedrake.dx \ + + firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx + dJdu_function_H1 = Function(space) + solve(a == dJdu_cofunction, dJdu_function_H1) + + # Riesz representation with L2 + a = firedrake.inner(u, v)*firedrake.dx + dJdu_function_L2 = Function(space) + solve(a == dJdu_cofunction, dJdu_function_L2) + + dJdu_none = rf.derivative(options={"riesz_representation": None}) + dJdu_l2 = rf.derivative(options={"riesz_representation": "l2"}) + dJdu_H1 = rf.derivative(options={"riesz_representation": "H1"}) + dJdu_default_L2 = rf.derivative() + assert ( + isinstance(dJdu_none, Cofunction) and isinstance(dJdu_function_l2, Function) + and isinstance(dJdu_H1, Function) and isinstance(dJdu_default_L2, Function) + and np.allclose(dJdu_none.dat.data, dJdu_cofunction.dat.data) + and np.allclose(dJdu_l2.dat.data, dJdu_function_l2.dat.data) + and np.allclose(dJdu_H1.dat.data, dJdu_function_H1.dat.data) + and np.allclose(dJdu_default_L2.dat.data, dJdu_function_L2.dat.data) + ) From 94c8b90103606a0956a562cd0cab90240ca0a04b Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Wed, 22 May 2024 15:36:48 +0100 Subject: [PATCH 05/16] use l2 as default when value is a Number --- firedrake/adjoint_utils/function.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 7e0f6eb9a6..ad32ad6478 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -232,21 +232,15 @@ def _ad_convert_riesz(self, value, options=None): if not isinstance(value, (Number, Cofunction, Function)): raise TypeError("Expected a Cofunction, Function or a float") - if value == 0.: + if isinstance(value, Number): # Default of a function datatype is zero - return Function(V) - - if riesz_representation != "l2" and not isinstance(value, Cofunction): - raise TypeError("Expected a Cofunction") + # Only works for l2 riesz representation + f = Function(V) + f.assign(value) + return f if riesz_representation == "l2": - if isinstance(value, (Cofunction, Function)): - return Function(V, val=value.dat) - else: - f = Function(V) - with stop_annotating(): - f.assign(value) - return f + return Function(V, val=value.dat) elif riesz_representation in ("L2", "H1"): ret = Function(V) a = self._define_riesz_map_form(riesz_representation, V) From 7ca0885a30493e15e1eec47369e0a67dd12bd9c4 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Wed, 22 May 2024 17:57:29 +0100 Subject: [PATCH 06/16] get derivative instead gradient in fem_operotor --- firedrake/ml/pytorch/fem_operator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index 2ddaf3b8e5..cfcaa320bb 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -81,9 +81,8 @@ def backward(ctx, grad_output): if isinstance(adj_input, Constant) and adj_input.ufl_shape == (): # This will later on result in an `AdjFloat` adjoint input instead of a Constant adj_input = float(adj_input) - # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = F.derivative(adj_input=adj_input) + adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": None}) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output From edd46a5b8e2364af165944bef22466b1a18a5c6a Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Thu, 23 May 2024 11:18:20 +0100 Subject: [PATCH 07/16] Use l2 in fem operator --- firedrake/ml/pytorch/fem_operator.py | 2 +- tests/regression/test_adjoint_operators.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index cfcaa320bb..aaed11edb1 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -82,7 +82,7 @@ def backward(ctx, grad_output): # This will later on result in an `AdjFloat` adjoint input instead of a Constant adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": None}) + adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": 'l2'}) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 3d8e718578..e72c0b1796 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -744,8 +744,7 @@ def test_copy_function(): def test_consecutive_nonlinear_solves(): mesh = UnitSquareMesh(1, 1) V = FunctionSpace(mesh, "CG", 1) - uic = Function(V) - uic.assign(2.0) + uic = Constant(2.0, domain=mesh) u1 = Function(V).assign(uic) u0 = Function(u1) v = TestFunction(V) @@ -757,7 +756,8 @@ def test_consecutive_nonlinear_solves(): solver.solve() J = assemble(u1**16*dx) rf = ReducedFunctional(J, Control(uic)) - assert taylor_test(rf, uic, Function(V).assign(0.1)) > 1.9 + h = Constant(0.01, domain=mesh) + assert taylor_test(rf, uic, h) > 1.9 @pytest.mark.skipcomplex From b892395cc8018f093c7b8d65c08a182a650fcfd6 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Thu, 23 May 2024 13:52:29 +0100 Subject: [PATCH 08/16] Write Constant(val, mesh) as a function in R --- firedrake/adjoint_utils/blocks/function.py | 5 +++-- firedrake/adjoint_utils/function.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index 1d14e82c3b..38efaa5d1f 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -60,7 +60,8 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, R = block_variable.output._ad_function_space( prepared.function_space().mesh() ) - return self._adj_assign_constant(prepared, R) + adj_output = self._adj_assign_constant(prepared, R) + return adj_output.riesz_representation(riesz_map="l2") else: adj_output = firedrake.Function( block_variable.output.function_space()) @@ -86,7 +87,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, ufl.derivative( expr, block_variable.saved_output, - firedrake.Constant(1., domain=mesh) + firedrake.Function(firedrake.FunctionSpace(mesh, "R", 0), val=1.0) ) ) adj_output.assign(diff_expr) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index ad32ad6478..8d3114cbf8 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -233,14 +233,17 @@ def _ad_convert_riesz(self, value, options=None): raise TypeError("Expected a Cofunction, Function or a float") if isinstance(value, Number): - # Default of a function datatype is zero - # Only works for l2 riesz representation + # It is applied l2 representation for this case. f = Function(V) f.assign(value) return f if riesz_representation == "l2": return Function(V, val=value.dat) + + if riesz_representation != "l2" and not isinstance(value, Cofunction): + raise TypeError("Expected a Cofunction") + elif riesz_representation in ("L2", "H1"): ret = Function(V) a = self._define_riesz_map_form(riesz_representation, V) @@ -275,8 +278,6 @@ def _define_riesz_map_form(self, riesz_representation, V): def _ad_convert_type(self, value, options=None): # `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz` options = {} if options is None else options - if "riesz_representation" not in options: - options = {"riesz_representation": "L2"} riesz_representation = options.get("riesz_representation", "L2") if riesz_representation is None: return value From 84e3532b823d6a373a3ab09f0634001e7774d65f Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Thu, 23 May 2024 13:57:05 +0100 Subject: [PATCH 09/16] minor changes --- firedrake/ml/pytorch/fem_operator.py | 1 + tests/regression/test_adjoint_operators.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index aaed11edb1..a489f169cb 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -81,6 +81,7 @@ def backward(ctx, grad_output): if isinstance(adj_input, Constant) and adj_input.ufl_shape == (): # This will later on result in an `AdjFloat` adjoint input instead of a Constant adj_input = float(adj_input) + # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": 'l2'}) diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index e72c0b1796..1ca7c6e360 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -889,7 +889,8 @@ def test_cofunction_subfunctions_with_adjoint(): @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done -def test_none_riesz_representation_to_derivative(): +def test_riesz_representation_for_adjoints(): + # Test is the riesz representation norms for adjoints are working as expected. mesh = UnitIntervalMesh(1) space = FunctionSpace(mesh, "Lagrange", 1) f = Function(space).interpolate(SpatialCoordinate(mesh)[0]) From 1cda85d31af74c17a6ae397e8d4a16a54812f8c0 Mon Sep 17 00:00:00 2001 From: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Date: Thu, 23 May 2024 14:21:30 +0100 Subject: [PATCH 10/16] Update tests/regression/test_adjoint_operators.py --- tests/regression/test_adjoint_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 1ca7c6e360..5d7cee4134 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -890,7 +890,7 @@ def test_cofunction_subfunctions_with_adjoint(): @pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done def test_riesz_representation_for_adjoints(): - # Test is the riesz representation norms for adjoints are working as expected. + # Check if the Riesz representation norms for adjoints are working as expected. mesh = UnitIntervalMesh(1) space = FunctionSpace(mesh, "Lagrange", 1) f = Function(space).interpolate(SpatialCoordinate(mesh)[0]) From 0a00bb4db5a3ef3527c2704e5d485ca4bb39611e Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Tue, 28 May 2024 12:39:46 +0100 Subject: [PATCH 11/16] review test script; improve docs for l2 riesz maps when the value is a number. --- firedrake/adjoint_utils/function.py | 18 +++++++++++++----- tests/regression/test_adjoint_operators.py | 10 ++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 8d3114cbf8..cba113beee 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -229,14 +229,22 @@ def _ad_convert_riesz(self, value, options=None): solver_options = options.get("solver_options", {}) V = options.get("function_space", self.function_space()) - if not isinstance(value, (Number, Cofunction, Function)): + if not isinstance(value, (Number, Cofunction, Function, Number)): raise TypeError("Expected a Cofunction, Function or a float") if isinstance(value, Number): - # It is applied l2 representation for this case. - f = Function(V) - f.assign(value) - return f + if value == 0.: + # l2 Riesz map is directly applied when the value is a real number 0.. + # This is seen in adjoint-based derivative when the functional + # is independent of the control variable. + return Function(V) + elif self.ufl_element().family() == "Real": + # Apply the l2 Riesz map for the case where self is a function in Real space. + f = Function(V) + f.assign(value) + return f + else: + raise TypeError("Riesz map of a non-zero scalar is not supported for non-Real function spaces.") if riesz_representation == "l2": return Function(V, val=value.dat) diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 5d7cee4134..bb3a33e07c 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -899,31 +899,33 @@ def test_riesz_representation_for_adjoints(): with stop_annotating(): v = TestFunction(space) u = TrialFunction(space) - dJdu_cofunction = assemble(2 * inner(f, v) * dx) + dJdu_cofunction = assemble(derivative((f ** 2) * dx, f, v)) # Riesz representation with l2 dJdu_function_l2 = Function(space, val=dJdu_cofunction.dat) # Riesz representation with H1 - a = firedrake.inner(u, v)*firedrake.dx \ - + firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx + a = u * v * dx + inner(grad(u), grad(v)) * dx dJdu_function_H1 = Function(space) solve(a == dJdu_cofunction, dJdu_function_H1) # Riesz representation with L2 - a = firedrake.inner(u, v)*firedrake.dx + a = u*v*dx dJdu_function_L2 = Function(space) solve(a == dJdu_cofunction, dJdu_function_L2) dJdu_none = rf.derivative(options={"riesz_representation": None}) dJdu_l2 = rf.derivative(options={"riesz_representation": "l2"}) dJdu_H1 = rf.derivative(options={"riesz_representation": "H1"}) + dJdu_L2 = rf.derivative(options={"riesz_representation": "L2"}) dJdu_default_L2 = rf.derivative() assert ( isinstance(dJdu_none, Cofunction) and isinstance(dJdu_function_l2, Function) and isinstance(dJdu_H1, Function) and isinstance(dJdu_default_L2, Function) + and isinstance(dJdu_L2, Function) and np.allclose(dJdu_none.dat.data, dJdu_cofunction.dat.data) and np.allclose(dJdu_l2.dat.data, dJdu_function_l2.dat.data) and np.allclose(dJdu_H1.dat.data, dJdu_function_H1.dat.data) and np.allclose(dJdu_default_L2.dat.data, dJdu_function_L2.dat.data) + and np.allclose(dJdu_L2.dat.data, dJdu_function_L2.dat.data) ) From 85ad86b3f2fcae5003fb06935de1f347a91fea44 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Wed, 29 May 2024 11:25:50 +0100 Subject: [PATCH 12/16] use setdefault L2 --- firedrake/adjoint_utils/function.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index cba113beee..eb3cbef12a 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -285,9 +285,9 @@ def _define_riesz_map_form(self, riesz_representation, V): @no_annotations def _ad_convert_type(self, value, options=None): # `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz` - options = {} if options is None else options - riesz_representation = options.get("riesz_representation", "L2") - if riesz_representation is None: + options = {} if options is None else options.copy() + options.setdefault("riesz_representation", "L2") + if options["riesz_representation"] is None: return value else: return self._ad_convert_riesz(value, options=options) From c182a96b7780a0e1a80ec02b3a3b236ba4322884 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Sat, 1 Jun 2024 11:58:46 +0100 Subject: [PATCH 13/16] Perhaps a code enhancement for firedrake Constant (Real space) control variables --- firedrake/adjoint_utils/blocks/function.py | 31 ++++++++-------------- firedrake/adjoint_utils/function.py | 27 ++++++------------- 2 files changed, 19 insertions(+), 39 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index 38efaa5d1f..b893bec655 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -61,45 +61,36 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared.function_space().mesh() ) adj_output = self._adj_assign_constant(prepared, R) - return adj_output.riesz_representation(riesz_map="l2") else: adj_output = firedrake.Function( block_variable.output.function_space()) adj_output.assign(prepared) - adj_output = adj_output.riesz_representation(riesz_map="l2") - return adj_output + return adj_output.riesz_representation(riesz_map="l2") else: # Linear combination expr, adj_input_func = prepared adj_output = firedrake.Function(adj_input_func.function_space()) - if not isconstant(block_variable.output): + if isconstant(block_variable.output): + R = block_variable.output._ad_function_space(adj_output.function_space().mesh()) diff_expr = ufl.algorithms.expand_derivatives( ufl.derivative( - expr, block_variable.saved_output, adj_input_func + expr, block_variable.saved_output, + type(block_variable.output)(R, val=1.0) ) ) - # Firedrake does not support assignment of conjugate functions - adj_output.interpolate(ufl.conj(diff_expr)) - adj_output = adj_output.riesz_representation(riesz_map="l2") else: - mesh = adj_output.function_space().mesh() diff_expr = ufl.algorithms.expand_derivatives( ufl.derivative( - expr, - block_variable.saved_output, - firedrake.Function(firedrake.FunctionSpace(mesh, "R", 0), val=1.0) + expr, block_variable.saved_output, adj_input_func ) ) - adj_output.assign(diff_expr) - return adj_output.dat.inner(adj_input_func.dat) - + # Firedrake does not support assignment of conjugate functions + adj_output.interpolate(ufl.conj(diff_expr)) if isconstant(block_variable.output): - R = block_variable.output._ad_function_space( - adj_output.function_space().mesh() + adj_output = type(block_variable.output)( + R, val=firedrake.assemble(ufl.Action(adj_output, adj_input_func)) ) - return self._adj_assign_constant(adj_output, R) - else: - return adj_output + return adj_output.riesz_representation(riesz_map="l2") def _adj_assign_constant(self, adj_output, constant_fs): r = firedrake.Function(constant_fs) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index eb3cbef12a..62bf73c42b 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -7,7 +7,6 @@ import firedrake from .checkpointing import disk_checkpointing, CheckpointFunction, \ CheckpointBase, checkpoint_init_data, DelegatedFunctionCheckpoint -from numbers import Number class FunctionMixin(FloatingType): @@ -224,27 +223,17 @@ def _ad_create_checkpoint(self): def _ad_convert_riesz(self, value, options=None): from firedrake import Function, Cofunction + V = options.get("function_space", self.function_space()) + if value == 0.: + # This is seen in adjoint-based derivative when the functional + # is independent of the control variable. + return Cofunction(V.dual()) + options = {} if options is None else options riesz_representation = options.get("riesz_representation", "L2") solver_options = options.get("solver_options", {}) - V = options.get("function_space", self.function_space()) - - if not isinstance(value, (Number, Cofunction, Function, Number)): - raise TypeError("Expected a Cofunction, Function or a float") - - if isinstance(value, Number): - if value == 0.: - # l2 Riesz map is directly applied when the value is a real number 0.. - # This is seen in adjoint-based derivative when the functional - # is independent of the control variable. - return Function(V) - elif self.ufl_element().family() == "Real": - # Apply the l2 Riesz map for the case where self is a function in Real space. - f = Function(V) - f.assign(value) - return f - else: - raise TypeError("Riesz map of a non-zero scalar is not supported for non-Real function spaces.") + if not isinstance(value, (Cofunction, Function)): + raise TypeError("Expected a Cofunction or a Function") if riesz_representation == "l2": return Function(V, val=value.dat) From ade48156cbffeb2f201c4517205f614511e0454f Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Mon, 3 Jun 2024 09:37:03 +0100 Subject: [PATCH 14/16] wip --- firedrake/adjoint_utils/blocks/function.py | 36 ++++++++++------------ firedrake/adjoint_utils/function.py | 8 +++-- firedrake/ml/pytorch/fem_operator.py | 2 +- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index b893bec655..fc5be8486a 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -57,39 +57,37 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, # Catch the case where adj_inputs[0] is just a float return adj_inputs[0] elif isconstant(block_variable.output): - R = block_variable.output._ad_function_space( - prepared.function_space().mesh() + adj_output = self._adj_assign_constant( + prepared, block_variable.output.function_space() ) - adj_output = self._adj_assign_constant(prepared, R) else: adj_output = firedrake.Function( - block_variable.output.function_space()) + block_variable.output.function_space() + ) adj_output.assign(prepared) return adj_output.riesz_representation(riesz_map="l2") else: # Linear combination expr, adj_input_func = prepared - adj_output = firedrake.Function(adj_input_func.function_space()) if isconstant(block_variable.output): - R = block_variable.output._ad_function_space(adj_output.function_space().mesh()) + R = block_variable.output._ad_function_space( + adj_input_func.function_space().mesh() + ) diff_expr = ufl.algorithms.expand_derivatives( - ufl.derivative( - expr, block_variable.saved_output, - type(block_variable.output)(R, val=1.0) - ) + ufl.derivative(expr, block_variable.saved_output, + firedrake.Function(R, val=1.0)) + ) + diff_expr_assembled = firedrake.Function(adj_input_func.function_space()) + diff_expr_assembled.interpolate(ufl.conj(diff_expr)) + adj_output = firedrake.Function( + R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func)) ) else: + adj_output = firedrake.Function(adj_input_func.function_space()) diff_expr = ufl.algorithms.expand_derivatives( - ufl.derivative( - expr, block_variable.saved_output, adj_input_func - ) - ) - # Firedrake does not support assignment of conjugate functions - adj_output.interpolate(ufl.conj(diff_expr)) - if isconstant(block_variable.output): - adj_output = type(block_variable.output)( - R, val=firedrake.assemble(ufl.Action(adj_output, adj_input_func)) + ufl.derivative(expr, block_variable.saved_output, adj_input_func) ) + adj_output.interpolate(ufl.conj(diff_expr)) return adj_output.riesz_representation(riesz_map="l2") def _adj_assign_constant(self, adj_output, constant_fs): diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 62bf73c42b..70bceb8f91 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -225,8 +225,10 @@ def _ad_convert_riesz(self, value, options=None): V = options.get("function_space", self.function_space()) if value == 0.: - # This is seen in adjoint-based derivative when the functional - # is independent of the control variable. + # In adjoint-based differentiation, value == 0. arises only when + # the functional is independent on the control variable. + # In this case, we do not apply the Riesz map and return a zero + # Cofunction. return Cofunction(V.dual()) options = {} if options is None else options @@ -238,7 +240,7 @@ def _ad_convert_riesz(self, value, options=None): if riesz_representation == "l2": return Function(V, val=value.dat) - if riesz_representation != "l2" and not isinstance(value, Cofunction): + if not isinstance(value, Cofunction): raise TypeError("Expected a Cofunction") elif riesz_representation in ("L2", "H1"): diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index a489f169cb..2ddaf3b8e5 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -83,7 +83,7 @@ def backward(ctx, grad_output): adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": 'l2'}) + adj_output = F.derivative(adj_input=adj_input) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output From 68ce4f4432ff42d85607c4b55e55e9208e11eb13 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Mon, 3 Jun 2024 10:20:06 +0100 Subject: [PATCH 15/16] adj_output in primal space from L2 Riesz maps --- firedrake/ml/pytorch/fem_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index 2ddaf3b8e5..0021172495 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -83,7 +83,7 @@ def backward(ctx, grad_output): adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = F.derivative(adj_input=adj_input) + adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "L2"}) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output From 2aa0032db2d3b41c92e1697b4c51549fc88b9f30 Mon Sep 17 00:00:00 2001 From: Ig-dolci Date: Wed, 5 Jun 2024 12:20:45 +0100 Subject: [PATCH 16/16] l2 for adj_input; minor changes --- firedrake/adjoint_utils/function.py | 8 ++++---- firedrake/ml/pytorch/fem_operator.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 70bceb8f91..b490f34128 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -240,10 +240,10 @@ def _ad_convert_riesz(self, value, options=None): if riesz_representation == "l2": return Function(V, val=value.dat) - if not isinstance(value, Cofunction): - raise TypeError("Expected a Cofunction") - elif riesz_representation in ("L2", "H1"): + if not isinstance(value, Cofunction): + raise TypeError("Expected a Cofunction") + ret = Function(V) a = self._define_riesz_map_form(riesz_representation, V) firedrake.solve(a == value, ret, **solver_options) @@ -253,7 +253,7 @@ def _ad_convert_riesz(self, value, options=None): return riesz_representation(value) else: - raise NotImplementedError( + raise ValueError( "Unknown Riesz representation %s" % riesz_representation) def _define_riesz_map_form(self, riesz_representation, V): diff --git a/firedrake/ml/pytorch/fem_operator.py b/firedrake/ml/pytorch/fem_operator.py index 0021172495..e58b858493 100644 --- a/firedrake/ml/pytorch/fem_operator.py +++ b/firedrake/ml/pytorch/fem_operator.py @@ -83,7 +83,7 @@ def backward(ctx, grad_output): adj_input = float(adj_input) # Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional - adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "L2"}) + adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "l2"}) # Tuplify adjoint output adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output