Skip to content

Commit

Permalink
Fixing errors related to default L2 grad in adjoints (#3579)
Browse files Browse the repository at this point in the history
Also add test to verify riesz maps for grad adjoints
---------

Co-authored-by: Alberto Paganini <[email protected]>
  • Loading branch information
Ig-dolci and APaganini committed Jul 6, 2024
1 parent c5ddd06 commit 428e49f
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 59 deletions.
50 changes: 20 additions & 30 deletions firedrake/adjoint_utils/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,48 +57,38 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# Catch the case where adj_inputs[0] is just a float
return adj_inputs[0]
elif isconstant(block_variable.output):
R = block_variable.output._ad_function_space(
prepared.function_space().mesh()
adj_output = self._adj_assign_constant(
prepared, block_variable.output.function_space()
)
return self._adj_assign_constant(prepared, R)
else:
adj_output = firedrake.Function(
block_variable.output.function_space())
block_variable.output.function_space()
)
adj_output.assign(prepared)
adj_output = adj_output.riesz_representation(riesz_map="l2")
return adj_output
return adj_output.riesz_representation(riesz_map="l2")
else:
# Linear combination
expr, adj_input_func = prepared
adj_output = firedrake.Function(adj_input_func.function_space())
if not isconstant(block_variable.output):
diff_expr = ufl.algorithms.expand_derivatives(
ufl.derivative(
expr, block_variable.saved_output, adj_input_func
)
if isconstant(block_variable.output):
R = block_variable.output._ad_function_space(
adj_input_func.function_space().mesh()
)
# Firedrake does not support assignment of conjugate functions
adj_output.interpolate(ufl.conj(diff_expr))
adj_output = adj_output.riesz_representation(riesz_map="l2")
else:
mesh = adj_output.function_space().mesh()
diff_expr = ufl.algorithms.expand_derivatives(
ufl.derivative(
expr,
block_variable.saved_output,
firedrake.Constant(1., domain=mesh)
)
ufl.derivative(expr, block_variable.saved_output,
firedrake.Function(R, val=1.0))
)
adj_output.assign(diff_expr)
return adj_output.dat.inner(adj_input_func.dat)

if isconstant(block_variable.output):
R = block_variable.output._ad_function_space(
adj_output.function_space().mesh()
diff_expr_assembled = firedrake.Function(adj_input_func.function_space())
diff_expr_assembled.interpolate(ufl.conj(diff_expr))
adj_output = firedrake.Function(
R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func))
)
return self._adj_assign_constant(adj_output, R)
else:
return adj_output
adj_output = firedrake.Function(adj_input_func.function_space())
diff_expr = ufl.algorithms.expand_derivatives(
ufl.derivative(expr, block_variable.saved_output, adj_input_func)
)
adj_output.interpolate(ufl.conj(diff_expr))
return adj_output.riesz_representation(riesz_map="l2")

def _adj_assign_constant(self, adj_output, constant_fs):
r = firedrake.Function(constant_fs)
Expand Down
40 changes: 20 additions & 20 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import firedrake
from .checkpointing import disk_checkpointing, CheckpointFunction, \
CheckpointBase, checkpoint_init_data, DelegatedFunctionCheckpoint
from numbers import Number


class FunctionMixin(FloatingType):
Expand Down Expand Up @@ -224,26 +223,27 @@ def _ad_create_checkpoint(self):
def _ad_convert_riesz(self, value, options=None):
from firedrake import Function, Cofunction

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "l2")
solver_options = options.get("solver_options", {})
V = options.get("function_space", self.function_space())
if value == 0.:
# In adjoint-based differentiation, value == 0. arises only when
# the functional is independent on the control variable.
# In this case, we do not apply the Riesz map and return a zero
# Cofunction.
return Cofunction(V.dual())

if riesz_representation != "l2" and not isinstance(value, Cofunction):
raise TypeError("Expected a Cofunction")
elif not isinstance(value, (Number, Cofunction, Function)):
raise TypeError("Expected a Cofunction, Function or a float")
options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
solver_options = options.get("solver_options", {})
if not isinstance(value, (Cofunction, Function)):
raise TypeError("Expected a Cofunction or a Function")

if riesz_representation == "l2":
if isinstance(value, (Cofunction, Function)):
return Function(V, val=value.dat)
else:
f = Function(V)
with stop_annotating():
f.assign(value)
return f
return Function(V, val=value.dat)

elif riesz_representation in ("L2", "H1"):
if not isinstance(value, Cofunction):
raise TypeError("Expected a Cofunction")

ret = Function(V)
a = self._define_riesz_map_form(riesz_representation, V)
firedrake.solve(a == value, ret, **solver_options)
Expand All @@ -253,7 +253,7 @@ def _ad_convert_riesz(self, value, options=None):
return riesz_representation(value)

else:
raise NotImplementedError(
raise ValueError(
"Unknown Riesz representation %s" % riesz_representation)

def _define_riesz_map_form(self, riesz_representation, V):
Expand All @@ -276,9 +276,9 @@ def _define_riesz_map_form(self, riesz_representation, V):
@no_annotations
def _ad_convert_type(self, value, options=None):
# `_ad_convert_type` is not annotated, unlike `_ad_convert_riesz`
options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
if riesz_representation is None:
options = {} if options is None else options.copy()
options.setdefault("riesz_representation", "L2")
if options["riesz_representation"] is None:
return value
else:
return self._ad_convert_riesz(value, options=options)
Expand Down Expand Up @@ -333,7 +333,7 @@ def _ad_dot(self, other, options=None):
from firedrake import assemble

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "l2")
riesz_representation = options.get("riesz_representation", "L2")
if riesz_representation == "l2":
return self.dat.inner(other.dat)
elif riesz_representation == "L2":
Expand Down
2 changes: 1 addition & 1 deletion firedrake/ml/pytorch/fem_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def backward(ctx, grad_output):
adj_input = float(adj_input)

# Compute adjoint model of `F`: delegated to pyadjoint.ReducedFunctional
adj_output = F.derivative(adj_input=adj_input)
adj_output = F.derivative(adj_input=adj_input, options={"riesz_representation": "l2"})

# Tuplify adjoint output
adj_output = (adj_output,) if not isinstance(adj_output, collections.abc.Sequence) else adj_output
Expand Down
46 changes: 38 additions & 8 deletions tests/regression/test_adjoint_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,13 +897,43 @@ def test_cofunction_subfunctions_with_adjoint():


@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done
def test_none_riesz_representation_to_derivative():
def test_riesz_representation_for_adjoints():
# Check if the Riesz representation norms for adjoints are working as expected.
mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)
u = Function(space).interpolate(SpatialCoordinate(mesh)[0])
J = assemble((u ** 2) * dx)
rf = ReducedFunctional(J, Control(u))
assert isinstance(rf.derivative(), Function)
assert isinstance(rf.derivative(options={"riesz_representation": "H1"}), Function)
assert isinstance(rf.derivative(options={"riesz_representation": "L2"}), Function)
assert isinstance(rf.derivative(options={"riesz_representation": None}), Cofunction)
f = Function(space).interpolate(SpatialCoordinate(mesh)[0])
J = assemble((f ** 2) * dx)
rf = ReducedFunctional(J, Control(f))
with stop_annotating():
v = TestFunction(space)
u = TrialFunction(space)
dJdu_cofunction = assemble(derivative((f ** 2) * dx, f, v))

# Riesz representation with l2
dJdu_function_l2 = Function(space, val=dJdu_cofunction.dat)

# Riesz representation with H1
a = u * v * dx + inner(grad(u), grad(v)) * dx
dJdu_function_H1 = Function(space)
solve(a == dJdu_cofunction, dJdu_function_H1)

# Riesz representation with L2
a = u*v*dx
dJdu_function_L2 = Function(space)
solve(a == dJdu_cofunction, dJdu_function_L2)

dJdu_none = rf.derivative(options={"riesz_representation": None})
dJdu_l2 = rf.derivative(options={"riesz_representation": "l2"})
dJdu_H1 = rf.derivative(options={"riesz_representation": "H1"})
dJdu_L2 = rf.derivative(options={"riesz_representation": "L2"})
dJdu_default_L2 = rf.derivative()
assert (
isinstance(dJdu_none, Cofunction) and isinstance(dJdu_function_l2, Function)
and isinstance(dJdu_H1, Function) and isinstance(dJdu_default_L2, Function)
and isinstance(dJdu_L2, Function)
and np.allclose(dJdu_none.dat.data, dJdu_cofunction.dat.data)
and np.allclose(dJdu_l2.dat.data, dJdu_function_l2.dat.data)
and np.allclose(dJdu_H1.dat.data, dJdu_function_H1.dat.data)
and np.allclose(dJdu_default_L2.dat.data, dJdu_function_L2.dat.data)
and np.allclose(dJdu_L2.dat.data, dJdu_function_L2.dat.data)
)

0 comments on commit 428e49f

Please sign in to comment.