diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 0dd2ccafe8..db7d23642a 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -1,6 +1,9 @@ from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ stop_annotating, get_working_tape, set_working_tape from pyadjoint.enlisting import Enlist +from firedrake.function import Function +from firedrake.ensemblefunction import EnsembleFunction, EnsembleCofunction + from functools import wraps, cached_property from typing import Callable, Optional from contextlib import contextmanager @@ -93,17 +96,14 @@ def _intermediate_options(final_options): class AllAtOnceReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. - Creates either the strong constraint or weak constraint system incrementally + Creates either the strong constraint or weak constraint system by logging observations through the initial forward model run. - Warning: Weak constraint 4DVar not implemented yet. - Parameters ---------- control - The :class:`EnsembleFunction` for the control x_{i} at the initial - condition and at the end of each observation stage. + The :class:`EnsembleFunction` for the control x_{i} at the initial condition and at the end of each observation stage. background_iprod The inner product to calculate the background error functional @@ -150,17 +150,22 @@ def __init__(self, control: Control, self.weak_constraint = weak_constraint self.initial_observations = observation_err is not None - with stop_annotating(): - if background: - self.background = background._ad_copy() - else: - self.background = control.control.subfunctions[0]._ad_copy() - _rename(self.background, "Background") - if self.weak_constraint: self._annotate_accumulation = _annotate_accumulation self._accumulation_started = False + if not isinstance(control.control, EnsembleFunction): + raise TypeError( + "Control for weak constraint 4DVar must be an EnsembleFunction" + ) + + with stop_annotating(): + if background: + self.background = background._ad_copy() + else: + self.background = control.control.subfunctions[0]._ad_copy() + _rename(self.background, "Background") + ensemble = control.ensemble self.ensemble = ensemble self.trank = ensemble.ensemble_comm.rank if ensemble else 0 @@ -225,11 +230,21 @@ def __init__(self, control: Control, self._annotate_accumulation = True self._accumulation_started = False + if not isinstance(control.control, Function): + raise TypeError( + "Control for strong constraint 4DVar must be a Function" + ) + + with stop_annotating(): + if background: + self.background = background._ad_copy() + else: + self.background = control.control._ad_copy() + _rename(self.background, "Background") + # initial conditions guess to be updated self.controls = Enlist(control) - self.tape = get_working_tape() if tape is None else tape - # Strong constraint functional to be converted to ReducedFunctional later # penalty for straying from prior @@ -249,10 +264,10 @@ def strong_reduced_functional(self): before all observations are recorded. """ if self.weak_constraint: - msg = "Strong constraint ReducedFunctional not instantiated for weak constraint 4DVar" + msg = "Strong constraint ReducedFunctional cannot be instantiated for weak constraint 4DVar" raise AttributeError(msg) self._strong_reduced_functional = ReducedFunctional( - self._total_functional, self.controls, tape=self.tape) + self._total_functional, self.controls.delist(), tape=self.tape) return self._strong_reduced_functional def __getattr__(self, attr): @@ -381,14 +396,12 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # create the derivative in the right primal or dual space from ufl.duals import is_primal, is_dual if is_primal(sderiv0[0]): - from firedrake.ensemblefunction import EnsembleFunction derivatives = EnsembleFunction( self.ensemble, self.control.local_function_spaces) else: if not is_dual(sderiv0[0]): raise ValueError( "Do not know how to handle stage derivative which is not primal or dual") - from firedrake.ensemblefunction import EnsembleCofunction derivatives = EnsembleCofunction( self.ensemble, [V.dual() for V in self.control.local_function_spaces]) @@ -505,7 +518,7 @@ def _accumulate_functional(self, val): self._accumulation_started = True @contextmanager - def recording_stages(self, sequential=True, **stage_kwargs): + def recording_stages(self, sequential=True, nstages=None, **stage_kwargs): if not sequential: raise ValueError("Recording stages concurrently not yet implemented") @@ -566,7 +579,9 @@ def recording_stages(self, sequential=True, **stage_kwargs): else: # strong constraint yield ObservationStageSequence( - self.controls, self, stage_kwargs, sequential=True) + self.controls, self, global_index=-1, + observation_index=0 if self.initial_observations else -1, + stage_kwargs=stage_kwargs, nstages=nstages) class ObservationStageSequence: @@ -575,29 +590,34 @@ def __init__(self, controls: Control, global_index: int, observation_index: int, stage_kwargs: dict = None, - sequential: bool = True): + nstages: Optional[int] = None): self.controls = controls - self.nstages = len(controls) - 1 self.aaorf = aaorf self.ctx = StageContext(**(stage_kwargs or {})) self.weak_constraint = aaorf.weak_constraint self.global_index = global_index self.observation_index = observation_index self.local_index = -1 + self.nstages = (len(controls) - 1 if self.weak_constraint + else nstages) def __iter__(self): return self def __next__(self): + # increment global indices + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 + + # stop after we've recorded all stages + if self.local_index >= self.nstages: + raise StopIteration + if self.weak_constraint: stages = self.aaorf.stages - # increment global indices - self.local_index += 1 - self.global_index += 1 - self.observation_index += 1 - # start of the next stage next_control = self.controls[self.local_index] @@ -607,10 +627,6 @@ def __next__(self): with stop_annotating(): next_control.control.assign(state) - # stop after we've recorded all stages - if self.local_index >= self.nstages: - raise StopIteration - stage = WeakObservationStage(next_control, local_index=self.local_index, global_index=self.global_index, @@ -619,21 +635,15 @@ def __next__(self): else: # strong constraint - # increment stage indices - self.local_index += 1 - self.global_index += 1 - self.observation_index += 1 - - # stop after we've recorded all stages - if self.index >= self.nstages: - raise StopIteration - self.index += 1 - # dummy control to "start" stage from - control = (self.aaorf.controls[0].control if self.index == 0 + control = (self.aaorf.controls[0].control if self.local_index == 0 else self._prev_stage.state) - stage = StrongObservationStage(control, self.aaorf) + stage = StrongObservationStage( + control, self.aaorf, + index=self.local_index, + observation_index=self.observation_index) + self._prev_stage = stage return stage, self.ctx @@ -658,9 +668,13 @@ class StrongObservationStage: """ def __init__(self, control: OverloadedType, - aaorf: AllAtOnceReducedFunctional): + aaorf: AllAtOnceReducedFunctional, + index: Optional[int] = None, + observation_index: Optional[int] = None): self.aaorf = aaorf self.control = control + self.index = index + self.observation_index = observation_index def set_observation(self, state: OverloadedType, observation_err: Callable[[OverloadedType], OverloadedType], @@ -691,6 +705,7 @@ def set_observation(self, state: OverloadedType, " constraint ReducedFunctional instantiated") self.aaorf._accumulate_functional( observation_iprod(observation_err(state))) + # save the user's state to hand back for beginning of next stage self.state = state diff --git a/firedrake/ensemble.py b/firedrake/ensemble.py index 5bb5519044..265ad8ba2b 100644 --- a/firedrake/ensemble.py +++ b/firedrake/ensemble.py @@ -291,12 +291,13 @@ def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, r def sequential(self, **kwargs): """ Context manager for executing code on each ensemble - member in turn. + member consecutively by `ensemble_comm.rank`. Any data in `kwargs` will be made available in the context and will be communicated forward after each ensemble member - exits. Firedrake Functions/Cofunctions will be send with the + exits. Firedrake Functions/Cofunctions will be sent with the corresponding Ensemble methods. + For example: with ensemble.sequential(index=0) as ctx: print(ensemble.ensemble_comm.rank, ctx.index) diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py index bf0530293f..894a25849c 100644 --- a/tests/firedrake/regression/test_4dvar_reduced_functional.py +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -2,8 +2,15 @@ import firedrake as fd from firedrake.__future__ import interpolate from firedrake.adjoint import ( - continue_annotation, pause_annotation, stop_annotating, set_working_tape, - Control, taylor_test, ReducedFunctional, AllAtOnceReducedFunctional) + continue_annotation, pause_annotation, stop_annotating, + set_working_tape, get_working_tape, Control, taylor_test, + ReducedFunctional, AllAtOnceReducedFunctional) + + +@pytest.fixture(autouse=True) +def clear_tape_teardown(): + yield + get_working_tape().clear_tape() def function_space(comm): @@ -159,8 +166,95 @@ def h(V, ensemble=None): ensemble=ensemble) -def fdvar_pyadjoint(V): - """Build a pyadjoint ReducedFunctional for the 4DVar system""" +def strong_fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the strong constraint 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # prior data + bkg = background(V) + control = bkg.copy(deepcopy=True) + + # generate ground truths + obs_errors = observation_errors(V) + + continue_annotation() + set_working_tape() + + # background functional + J = prodB(control - bkg) + + # initial observation functional + J += prodR(obs_errors(0)(control)) + + qn.assign(control) + + # record observation stages + for i in range(1, len(observation_times)): + + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # observation functional + J += prodR(obs_errors(i)(qn)) + + pause_annotation() + + Jhat = ReducedFunctional(J, Control(control)) + + return Jhat + + +def strong_fdvar_firedrake(V): + """Build an AllAtOnceReducedFunctional for the strong constraint 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # prior data + bkg = background(V) + control = bkg.copy(deepcopy=True) + + # generate ground truths + obs_errors = observation_errors(V) + + continue_annotation() + set_working_tape() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = AllAtOnceReducedFunctional( + Control(control), + background_iprod=prodB, + observation_iprod=prodR, + observation_err=obs_errors(0), + weak_constraint=False) + + # record observation stages + with Jhat.recording_stages(nstages=len(observation_times)-1) as stages: + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + obs_index = stage.index + 1 + + # take observation + stage.set_observation(qn, obs_errors(obs_index), + observation_iprod=prodR) + + pause_annotation() + return Jhat + + +def weak_fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the weak constraint 4DVar system""" qn, qn1, stepper = timestepper(V) # One control for each observation time @@ -217,8 +311,8 @@ def fdvar_pyadjoint(V): return Jhat -def fdvar_firedrake(V, ensemble): - """Build an AllAtOnceReducedFunctional for the 4DVar system""" +def weak_fdvar_firedrake(V, ensemble): + """Build an AllAtOnceReducedFunctional for the weak constraint 4DVar system""" qn, qn1, stepper = timestepper(V) # One control for each observation time @@ -276,12 +370,34 @@ def fdvar_firedrake(V, ensemble): return Jhat -@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) -def test_advection(): - main_test_advection() +def main_test_strong_4dvar_advection(): + V = function_space(fd.COMM_WORLD) + + # setup the reference pyadjoint rf + Jhat_pyadj = strong_fdvar_pyadjoint(V) + mp = m(V)[0] + hp = h(V)[0] + # make sure we've set up the reference rf correctly + assert taylor_test(Jhat_pyadj, mp, hp) > 1.99 -def main_test_advection(): + Jhat_aaorf = strong_fdvar_firedrake(V) + + ma = m(V)[0] + ha = h(V)[0] + + eps = 1e-12 + + # Does evaluating the functional match the reference rf? + assert abs(Jhat_pyadj(mp) - Jhat_aaorf(ma)) < eps + assert abs(Jhat_pyadj(hp) - Jhat_aaorf(ha)) < eps + + # If we match the functional, then passing the taylor test + # should mean that we match the derivative too. + assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 + + +def main_test_weak_4dvar_advection(): global_comm = fd.COMM_WORLD if global_comm.size in (1, 2): # time serial nspace = global_comm.size @@ -297,7 +413,7 @@ def main_test_advection(): # only setup the reference pyadjoint rf on the first ensemble member if erank == 0: - Jhat_pyadj = fdvar_pyadjoint(V) + Jhat_pyadj = weak_fdvar_pyadjoint(V) mp = m(V) hp = h(V) # make sure we've set up the reference rf correctly @@ -306,7 +422,7 @@ def main_test_advection(): Jpm = ensemble.ensemble_comm.bcast(Jhat_pyadj(mp) if erank == 0 else None) Jph = ensemble.ensemble_comm.bcast(Jhat_pyadj(hp) if erank == 0 else None) - Jhat_aaorf = fdvar_firedrake(V, ensemble) + Jhat_aaorf = weak_fdvar_firedrake(V, ensemble) ma = m(V, ensemble) ha = h(V, ensemble) @@ -317,9 +433,19 @@ def main_test_advection(): assert abs(Jph - Jhat_aaorf(ha)) < eps # If we match the functional, then passing the taylor test - # should mean we match the derivative too. + # should mean that we match the derivative too. assert taylor_test(Jhat_aaorf, ma, ha) > 1.99 +@pytest.mark.parallel(nprocs=[1, 2]) +def test_strong_4dvar_advection(): + main_test_strong_4dvar_advection() + + +@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) +def test_weak_4dvar_advection(): + main_test_weak_4dvar_advection() + + if __name__ == '__main__': - main_test_advection() + main_test_strong_4dvar_advection()