Skip to content

Commit

Permalink
aaorf - restore strong constraint 4dvar
Browse files Browse the repository at this point in the history
  • Loading branch information
JHopeCollins committed Dec 23, 2024
1 parent 6c4c5b7 commit c22531f
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 60 deletions.
103 changes: 59 additions & 44 deletions firedrake/adjoint/all_at_once_reduced_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \
stop_annotating, get_working_tape, set_working_tape
from pyadjoint.enlisting import Enlist
from firedrake.function import Function
from firedrake.ensemblefunction import EnsembleFunction, EnsembleCofunction

from functools import wraps, cached_property
from typing import Callable, Optional
from contextlib import contextmanager
Expand Down Expand Up @@ -93,17 +96,14 @@ def _intermediate_options(final_options):
class AllAtOnceReducedFunctional(ReducedFunctional):
"""ReducedFunctional for 4DVar data assimilation.
Creates either the strong constraint or weak constraint system incrementally
Creates either the strong constraint or weak constraint system
by logging observations through the initial forward model run.
Warning: Weak constraint 4DVar not implemented yet.
Parameters
----------
control
The :class:`EnsembleFunction` for the control x_{i} at the initial
condition and at the end of each observation stage.
The :class:`EnsembleFunction` for the control x_{i} at the initial condition and at the end of each observation stage.
background_iprod
The inner product to calculate the background error functional
Expand Down Expand Up @@ -150,17 +150,22 @@ def __init__(self, control: Control,
self.weak_constraint = weak_constraint
self.initial_observations = observation_err is not None

with stop_annotating():
if background:
self.background = background._ad_copy()
else:
self.background = control.control.subfunctions[0]._ad_copy()
_rename(self.background, "Background")

if self.weak_constraint:
self._annotate_accumulation = _annotate_accumulation
self._accumulation_started = False

if not isinstance(control.control, EnsembleFunction):
raise TypeError(
"Control for weak constraint 4DVar must be an EnsembleFunction"
)

with stop_annotating():
if background:
self.background = background._ad_copy()
else:
self.background = control.control.subfunctions[0]._ad_copy()
_rename(self.background, "Background")

ensemble = control.ensemble
self.ensemble = ensemble
self.trank = ensemble.ensemble_comm.rank if ensemble else 0
Expand Down Expand Up @@ -225,11 +230,21 @@ def __init__(self, control: Control,
self._annotate_accumulation = True
self._accumulation_started = False

if not isinstance(control.control, Function):
raise TypeError(
"Control for strong constraint 4DVar must be a Function"
)

with stop_annotating():
if background:
self.background = background._ad_copy()
else:
self.background = control.control._ad_copy()
_rename(self.background, "Background")

# initial conditions guess to be updated
self.controls = Enlist(control)

self.tape = get_working_tape() if tape is None else tape

# Strong constraint functional to be converted to ReducedFunctional later

# penalty for straying from prior
Expand All @@ -249,10 +264,10 @@ def strong_reduced_functional(self):
before all observations are recorded.
"""
if self.weak_constraint:
msg = "Strong constraint ReducedFunctional not instantiated for weak constraint 4DVar"
msg = "Strong constraint ReducedFunctional cannot be instantiated for weak constraint 4DVar"
raise AttributeError(msg)
self._strong_reduced_functional = ReducedFunctional(
self._total_functional, self.controls, tape=self.tape)
self._total_functional, self.controls.delist(), tape=self.tape)
return self._strong_reduced_functional

def __getattr__(self, attr):
Expand Down Expand Up @@ -381,14 +396,12 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}):
# create the derivative in the right primal or dual space
from ufl.duals import is_primal, is_dual
if is_primal(sderiv0[0]):
from firedrake.ensemblefunction import EnsembleFunction
derivatives = EnsembleFunction(
self.ensemble, self.control.local_function_spaces)
else:
if not is_dual(sderiv0[0]):
raise ValueError(
"Do not know how to handle stage derivative which is not primal or dual")
from firedrake.ensemblefunction import EnsembleCofunction
derivatives = EnsembleCofunction(
self.ensemble, [V.dual() for V in self.control.local_function_spaces])

Expand Down Expand Up @@ -505,7 +518,7 @@ def _accumulate_functional(self, val):
self._accumulation_started = True

@contextmanager
def recording_stages(self, sequential=True, **stage_kwargs):
def recording_stages(self, sequential=True, nstages=None, **stage_kwargs):
if not sequential:
raise ValueError("Recording stages concurrently not yet implemented")

Expand Down Expand Up @@ -566,7 +579,9 @@ def recording_stages(self, sequential=True, **stage_kwargs):
else: # strong constraint

yield ObservationStageSequence(
self.controls, self, stage_kwargs, sequential=True)
self.controls, self, global_index=-1,
observation_index=0 if self.initial_observations else -1,
stage_kwargs=stage_kwargs, nstages=nstages)


class ObservationStageSequence:
Expand All @@ -575,29 +590,34 @@ def __init__(self, controls: Control,
global_index: int,
observation_index: int,
stage_kwargs: dict = None,
sequential: bool = True):
nstages: Optional[int] = None):
self.controls = controls
self.nstages = len(controls) - 1
self.aaorf = aaorf
self.ctx = StageContext(**(stage_kwargs or {}))
self.weak_constraint = aaorf.weak_constraint
self.global_index = global_index
self.observation_index = observation_index
self.local_index = -1
self.nstages = (len(controls) - 1 if self.weak_constraint
else nstages)

def __iter__(self):
return self

def __next__(self):

# increment global indices
self.local_index += 1
self.global_index += 1
self.observation_index += 1

# stop after we've recorded all stages
if self.local_index >= self.nstages:
raise StopIteration

if self.weak_constraint:
stages = self.aaorf.stages

# increment global indices
self.local_index += 1
self.global_index += 1
self.observation_index += 1

# start of the next stage
next_control = self.controls[self.local_index]

Expand All @@ -607,10 +627,6 @@ def __next__(self):
with stop_annotating():
next_control.control.assign(state)

# stop after we've recorded all stages
if self.local_index >= self.nstages:
raise StopIteration

stage = WeakObservationStage(next_control,
local_index=self.local_index,
global_index=self.global_index,
Expand All @@ -619,21 +635,15 @@ def __next__(self):

else: # strong constraint

# increment stage indices
self.local_index += 1
self.global_index += 1
self.observation_index += 1

# stop after we've recorded all stages
if self.index >= self.nstages:
raise StopIteration
self.index += 1

# dummy control to "start" stage from
control = (self.aaorf.controls[0].control if self.index == 0
control = (self.aaorf.controls[0].control if self.local_index == 0
else self._prev_stage.state)

stage = StrongObservationStage(control, self.aaorf)
stage = StrongObservationStage(
control, self.aaorf,
index=self.local_index,
observation_index=self.observation_index)

self._prev_stage = stage

return stage, self.ctx
Expand All @@ -658,9 +668,13 @@ class StrongObservationStage:
"""

def __init__(self, control: OverloadedType,
aaorf: AllAtOnceReducedFunctional):
aaorf: AllAtOnceReducedFunctional,
index: Optional[int] = None,
observation_index: Optional[int] = None):
self.aaorf = aaorf
self.control = control
self.index = index
self.observation_index = observation_index

def set_observation(self, state: OverloadedType,
observation_err: Callable[[OverloadedType], OverloadedType],
Expand Down Expand Up @@ -691,6 +705,7 @@ def set_observation(self, state: OverloadedType,
" constraint ReducedFunctional instantiated")
self.aaorf._accumulate_functional(
observation_iprod(observation_err(state)))
# save the user's state to hand back for beginning of next stage
self.state = state


Expand Down
5 changes: 3 additions & 2 deletions firedrake/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,13 @@ def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, r
def sequential(self, **kwargs):
"""
Context manager for executing code on each ensemble
member in turn.
member consecutively by `ensemble_comm.rank`.
Any data in `kwargs` will be made available in the context
and will be communicated forward after each ensemble member
exits. Firedrake Functions/Cofunctions will be send with the
exits. Firedrake Functions/Cofunctions will be sent with the
corresponding Ensemble methods.
For example:
with ensemble.sequential(index=0) as ctx:
print(ensemble.ensemble_comm.rank, ctx.index)
Expand Down
Loading

0 comments on commit c22531f

Please sign in to comment.