Skip to content

Commit

Permalink
Merge branch 'master' into JDBetteridge/merge_pyop2_tsfc
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Nov 9, 2024
2 parents 94aa126 + cdf04f6 commit bb2c7b7
Show file tree
Hide file tree
Showing 15 changed files with 356 additions and 92 deletions.
26 changes: 18 additions & 8 deletions demos/full_waveform_inversion/full_waveform_inversion.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,29 @@ The source number is defined with the ``Ensemble.ensemble_comm`` rank::
source_number = my_ensemble.ensemble_comm.rank

In this example, we consider a two-dimensional square domain with a side length of 1.0 km. The mesh is
built over the ``my_ensemble.comm`` (spatial) communicator::
Lx, Lz = 1.0, 1.0
mesh = UnitSquareMesh(80, 80, comm=my_ensemble.comm)
built over the ``my_ensemble.comm`` (spatial) communicator.

::

import os
if os.getenv("FIREDRAKE_CI_TESTS") == "1":
# Setup for a faster test execution.
dt = 0.03 # time step in seconds
final_time = 0.6 # final time in seconds
nx, ny = 15, 15
else:
dt = 0.002 # time step in seconds
final_time = 1.0 # final time in seconds
nx, ny = 80, 80

The basic input for the FWI problem are defined as follows::
mesh = UnitSquareMesh(nx, ny, comm=my_ensemble.comm)

The frequency of the Ricker wavelet, the source and receiver locations are defined as follows::

import numpy as np
frequency_peak = 7.0 # The dominant frequency of the Ricker wavelet in Hz.
source_locations = np.linspace((0.3, 0.1), (0.7, 0.1), num_sources)
receiver_locations = np.linspace((0.2, 0.9), (0.8, 0.9), 20)
dt = 0.002 # time step in seconds
final_time = 1.0 # final time in seconds
frequency_peak = 7.0 # The dominant frequency of the Ricker wavelet in Hz.

Sources and receivers locations are illustrated in the following figure:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/advanced_tut.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ element systems.
A pressure-convection-diffusion preconditioner for the Navier-Stokes equations.</demos/navier_stokes.py>
Rayleigh-Benard convection.<demos/rayleigh-benard.py>
Netgen support.<demos/netgen_mesh.py>
Full-waveform inversion: Full-waveform inversion: spatial and wave sources parallelism.<demos/full_waveform_inversion.py>
Full-waveform inversion: spatial and wave sources parallelism.<demos/full_waveform_inversion.py>
138 changes: 92 additions & 46 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import ufl
from ufl import replace
from ufl.formatting.ufl2unicode import ufl2unicode
from enum import Enum

from pyadjoint import Block, stop_annotating
from pyadjoint import Block, stop_annotating, get_working_tape
from pyadjoint.enlisting import Enlist
import firedrake
from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint
Expand All @@ -24,6 +25,12 @@ def extract_subfunction(u, V):
return u


class Solver(Enum):
"""Enum for solver types."""
FORWARD = 0
ADJOINT = 1


class GenericSolveBlock(Block):
pop_kwargs_keys = ["adj_cb", "adj_bdy_cb", "adj2_cb", "adj2_bdy_cb",
"forward_args", "forward_kwargs", "adj_args",
Expand Down Expand Up @@ -206,15 +213,17 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):

adj_sol_bdy = None
if compute_bdy:
adj_sol_bdy = firedrake.Function(
self.function_space.dual(),
dJdu_copy.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)
).dat
)
adj_sol_bdy = self._compute_adj_bdy(
adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy)

return adj_sol, adj_sol_bdy

def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
adj_sol_bdy = firedrake.Function(
self.function_space.dual(), dJdu.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)).dat)
return adj_sol_bdy

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
if not self.linear and self.func == block_variable.output:
Expand Down Expand Up @@ -604,12 +613,11 @@ def _init_solver_parameters(self, args, kwargs):


class NonlinearVariationalSolveBlock(GenericSolveBlock):
def __init__(self, equation, func, bcs, adj_F, adj_cache, problem_J,
def __init__(self, equation, func, bcs, adj_cache, problem_J,
solver_params, solver_kwargs, **kwargs):
lhs = equation.lhs
rhs = equation.rhs

self.adj_F = adj_F
self._adj_cache = adj_cache
self._dFdm_cache = adj_cache.setdefault("dFdm_cache", {})
self.problem_J = problem_J
Expand All @@ -626,15 +634,62 @@ def _init_solver_parameters(self, args, kwargs):
super()._init_solver_parameters(args, kwargs)
solve_init_params(self, args, kwargs, varform=True)

def recompute_component(self, inputs, block_variable, idx, prepared):
tape = get_working_tape()
if self._ad_solvers["recompute_count"] == tape.recompute_count - 1:
# Update how many times the block has been recomputed.
self._ad_solvers["recompute_count"] = tape.recompute_count
if self._ad_solvers["forward_nlvs"]._problem._constant_jacobian:
self._ad_solvers["forward_nlvs"].invalidate_jacobian()
self._ad_solvers["update_adjoint"] = True
return super().recompute_component(inputs, block_variable, idx, prepared)

def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
self._ad_nlvs_replace_forms()
self._ad_nlvs.parameters.update(self.solver_params)
self._ad_nlvs.solve()
func.assign(self._ad_nlvs._problem.u)
self._ad_solver_replace_forms()
self._ad_solvers["forward_nlvs"].parameters.update(self.solver_params)
self._ad_solvers["forward_nlvs"].solve()
func.assign(self._ad_solvers["forward_nlvs"]._problem.u)
return func

def _ad_assign_map(self, form):
count_map = self._ad_nlvs._problem._ad_count_map
def _adjoint_solve(self, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
for bc in bcs:
bc.apply(dJdu)

if (
self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
and self._ad_solvers["update_adjoint"]
):
# Update left hand side of the adjoint equation.
self._ad_solver_replace_forms(Solver.ADJOINT)
self._ad_solvers["adjoint_lvs"].invalidate_jacobian()
self._ad_solvers["update_adjoint"] = False
elif not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian:
# Update left hand side of the adjoint equation.
self._ad_solver_replace_forms(Solver.ADJOINT)

# Update the right hand side of the adjoint equation.
# problem.F._component[1] is the right hand side of the adjoint.
self._ad_solvers["adjoint_lvs"]._problem.F._components[1].assign(dJdu)

# Solve the adjoint linear variational solver.
self._ad_solvers["adjoint_lvs"].solve()
u_sol = self._ad_solvers["adjoint_lvs"]._problem.u

adj_sol_bdy = None
if compute_bdy:
jac_adj = self._ad_solvers["adjoint_lvs"]._problem.J
adj_sol_bdy = self._compute_adj_bdy(
u_sol, adj_sol_bdy, jac_adj, dJdu_copy)
return u_sol, adj_sol_bdy

def _ad_assign_map(self, form, solver):
if solver == Solver.FORWARD:
count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map
else:
count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map
assign_map = {}
form_ad_count_map = dict((count_map[coeff], coeff)
for coeff in form.coefficients())
Expand All @@ -647,55 +702,46 @@ def _ad_assign_map(self, form):
if coeff_count in form_ad_count_map:
assign_map[form_ad_count_map[coeff_count]] = \
block_variable.saved_output

if (
solver == Solver.ADJOINT
and not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
):
block_variable = self.get_outputs()[0]
coeff_count = block_variable.output.count()
if coeff_count in form_ad_count_map:
assign_map[form_ad_count_map[coeff_count]] = \
block_variable.saved_output
return assign_map

def _ad_assign_coefficients(self, form):
assign_map = self._ad_assign_map(form)
def _ad_assign_coefficients(self, form, solver):
assign_map = self._ad_assign_map(form, solver)
for coeff, value in assign_map.items():
coeff.assign(value)

def _ad_nlvs_replace_forms(self):
problem = self._ad_nlvs._problem
self._ad_assign_coefficients(problem.F)
self._ad_assign_coefficients(problem.J)

def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):
if "dFdu_adj" in self._adj_cache:
dFdu = self._adj_cache["dFdu_adj"]
def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
if solver == Solver.FORWARD:
problem = self._ad_solvers["forward_nlvs"]._problem
self._ad_assign_coefficients(problem.F, solver)
self._ad_assign_coefficients(problem.J, solver)
else:
dFdu = super()._assemble_dFdu_adj(dFdu_adj_form, **kwargs)
if self._ad_nlvs._problem._constant_jacobian:
self._adj_cache["dFdu_adj"] = dFdu
return dFdu
self._ad_assign_coefficients(
self._ad_solvers["adjoint_lvs"]._problem.J, solver)

def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
dJdu = adj_inputs[0]

F_form = self._create_F_form()

dFdu_form = self.adj_F
dJdu = dJdu.copy()

# Replace the form coefficients with checkpointed values.
replace_map = self._replace_map(dFdu_form)
replace_map[self.func] = self.get_outputs()[0].saved_output
dFdu_form = replace(dFdu_form, replace_map)

compute_bdy = self._should_compute_boundary_adjoint(
relevant_dependencies
)
adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq(
dFdu_form, dJdu, compute_bdy
)
adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy)
self.adj_state = adj_sol
if self.adj_cb is not None:
self.adj_cb(adj_sol)
if self.adj_bdy_cb is not None and compute_bdy:
self.adj_bdy_cb(adj_sol_bdy)

r = {}
r["form"] = F_form
r["adj_sol"] = adj_sol
r["form"] = self._create_F_form()
r["adj_sol"] = self.adj_state
r["adj_sol_bdy"] = adj_sol_bdy
return r

Expand Down
93 changes: 69 additions & 24 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def wrapper(self, problem, *args, **kwargs):
self._ad_problem = problem
self._ad_args = args
self._ad_kwargs = kwargs
self._ad_nlvs = None
self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None,
"recompute_count": 0}
self._ad_adj_cache = {}

return wrapper
Expand All @@ -58,7 +59,7 @@ def wrapper(self, **kwargs):
Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic
for the purposes of the adjoint computation (such as projecting fields to other function spaces
for the purposes of visualisation)."""

from firedrake import LinearVariationalSolver
annotate = annotate_tape(kwargs)
if annotate:
tape = get_working_tape()
Expand All @@ -69,20 +70,31 @@ def wrapper(self, **kwargs):
block = NonlinearVariationalSolveBlock(problem._ad_F == 0,
problem._ad_u,
problem._ad_bcs,
problem._ad_adj_F,
adj_cache=self._ad_adj_cache,
problem_J=problem._ad_J,
solver_params=self.parameters,
solver_kwargs=self._ad_kwargs,
ad_block_tag=self.ad_block_tag,
**sb_kwargs)
if not self._ad_nlvs:
self._ad_nlvs = type(self)(

# Forward variational solver.
if not self._ad_solvers["forward_nlvs"]:
self._ad_solvers["forward_nlvs"] = type(self)(
self._ad_problem_clone(self._ad_problem, block.get_dependencies()),
**self._ad_kwargs
)

block._ad_nlvs = self._ad_nlvs
# Adjoint variational solver.
if not self._ad_solvers["adjoint_lvs"]:
with stop_annotating():
self._ad_solvers["adjoint_lvs"] = LinearVariationalSolver(
self._ad_adj_lvs_problem(block, problem._ad_adj_F),
*block.adj_args, **block.adj_kwargs)
if self._ad_problem._constant_jacobian:
self._ad_solvers["update_adjoint"] = False

block._ad_solvers = self._ad_solvers

tape.add_block(block)

with stop_annotating():
Expand All @@ -103,22 +115,62 @@ def _ad_problem_clone(self, problem, dependencies):
affect the user-defined self._ad_problem.F, self._ad_problem.J and self._ad_problem.u
expressions, we'll instead create clones of them.
"""
from firedrake import Function, NonlinearVariationalProblem
from firedrake import NonlinearVariationalProblem
_ad_count_map, J_replace_map, F_replace_map = self._build_count_map(
problem.J, dependencies, F=problem.F)
nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map),
F_replace_map[problem.u_restrict],
bcs=problem.bcs,
J=replace(problem.J, J_replace_map))
nlvp.is_linear = problem.is_linear
nlvp._constant_jacobian = problem._constant_jacobian
nlvp._ad_count_map_update(_ad_count_map)
return nlvp

@no_annotations
def _ad_adj_lvs_problem(self, block, adj_F):
"""Create the adjoint variational problem."""
from firedrake import Function, Cofunction, LinearVariationalProblem
# Homogeneous boundary conditions for the adjoint problem
# when Dirichlet boundary conditions are applied.
bcs = block._homogenize_bcs()
adj_sol = Function(block.function_space)
right_hand_side = Cofunction(block.function_space.dual())
tmp_problem = LinearVariationalProblem(
adj_F, right_hand_side, adj_sol, bcs=bcs,
constant_jacobian=self._ad_problem._constant_jacobian)
# The `block.adj_F` coefficients hold the output references.
# We do not want to modify the user-defined values. Hence, the adjoint
# linear variational problem is created with a deep copy of the
# `block.adj_F` coefficients.
_ad_count_map, J_replace_map, _ = self._build_count_map(
adj_F, block._dependencies)
lvp = LinearVariationalProblem(
replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol,
bcs=tmp_problem.bcs,
constant_jacobian=self._ad_problem._constant_jacobian)
lvp._ad_count_map_update(_ad_count_map)
return lvp

def _build_count_map(self, J, dependencies, F=None):
from firedrake import Function

F_replace_map = {}
J_replace_map = {}

F_coefficients = problem.F.coefficients()
J_coefficients = problem.J.coefficients()
if F:
F_coefficients = F.coefficients()
J_coefficients = J.coefficients()

_ad_count_map = {}
for block_variable in dependencies:
coeff = block_variable.output
if coeff in F_coefficients and coeff not in F_replace_map:
if isinstance(coeff, Function) and coeff.ufl_element().family() == "Real":
F_replace_map[coeff] = copy.deepcopy(coeff)
else:
F_replace_map[coeff] = coeff.copy(deepcopy=True)
_ad_count_map[F_replace_map[coeff]] = coeff.count()
if F:
if coeff in F_coefficients and coeff not in F_replace_map:
if isinstance(coeff, Function) and coeff.ufl_element().family() == "Real":
F_replace_map[coeff] = copy.deepcopy(coeff)
else:
F_replace_map[coeff] = coeff.copy(deepcopy=True)
_ad_count_map[F_replace_map[coeff]] = coeff.count()

if coeff in J_coefficients and coeff not in J_replace_map:
if coeff in F_replace_map:
Expand All @@ -128,11 +180,4 @@ def _ad_problem_clone(self, problem, dependencies):
else:
J_replace_map[coeff] = coeff.copy()
_ad_count_map[J_replace_map[coeff]] = coeff.count()

nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map),
F_replace_map[problem.u_restrict],
bcs=problem.bcs,
J=replace(problem.J, J_replace_map))
nlvp._constant_jacobian = problem._constant_jacobian
nlvp._ad_count_map_update(_ad_count_map)
return nlvp
return _ad_count_map, J_replace_map, F_replace_map
Loading

0 comments on commit bb2c7b7

Please sign in to comment.