diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0eb616c24d..f6beb41c4d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,6 +84,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch pyadjoint JHopeCollins/mark_evaluate_tlm \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 3c89e61429..0162671b5a 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -103,6 +103,7 @@ from firedrake.vector import * from firedrake.version import __version__ as ver, __version_info__, check # noqa: F401 from firedrake.ensemble import * +from firedrake.ensemblefunction import * from firedrake.randomfunctiongen import * from firedrake.external_operators import * from firedrake.progress_bar import ProgressBar # noqa: F401 diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index c48b990420..9155a93c37 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -38,6 +38,7 @@ from firedrake.adjoint.ufl_constraints import UFLInequalityConstraint, \ UFLEqualityConstraint # noqa F401 from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401 +from firedrake.adjoint.fourdvar_reduced_functional import FourDVarReducedFunctional # noqa F401 import numpy_adjoint # noqa F401 import firedrake.ufl_expr import types diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py deleted file mode 100644 index 429000b2a2..0000000000 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ /dev/null @@ -1,592 +0,0 @@ -from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ - stop_annotating, no_annotations, get_working_tape, set_working_tape -from pyadjoint.enlisting import Enlist -from functools import wraps, cached_property -from typing import Callable, Optional - -__all__ = ['AllAtOnceReducedFunctional'] - - -def sc_passthrough(func): - """ - Wraps standard ReducedFunctional methods to differentiate strong or - weak constraint implementations. - - If using strong constraint, makes sure strong_reduced_functional - is instantiated then passes args/kwargs through to the - corresponding strong_reduced_functional method. - - If using weak constraint, returns the AllAtOnceReducedFunctional - method definition. - """ - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.weak_constraint: - return func(self, *args, **kwargs) - else: - sc_func = getattr(self.strong_reduced_functional, func.__name__) - return sc_func(*args, **kwargs) - return wrapper - - -def _rename(obj, name): - if hasattr(obj, "rename"): - obj.rename(name) - - -def _ad_sub(left, right): - result = right._ad_copy() - result._ad_imul(-1) - result._ad_iadd(left) - return result - - -class AllAtOnceReducedFunctional(ReducedFunctional): - """ReducedFunctional for 4DVar data assimilation. - - Creates either the strong constraint or weak constraint system incrementally - by logging observations through the initial forward model run. - - Warning: Weak constraint 4DVar not implemented yet. - - Parameters - ---------- - - control - The initial condition :math:`x_{0}`. Starting value is used as the - background (prior) data :math:`x_{b}`. - - background_iprod - The inner product to calculate the background error functional - from the background error :math:`x_{0} - x_{b}`. Can include the - error covariance matrix. - - observation_err - Given a state :math:`x`, returns the observations error - :math:`y_{0} - \\mathcal{H}_{0}(x)` where :math:`y_{0}` are the - observations at the initial time and :math:`\\mathcal{H}_{0}` is - the observation operator for the initial time. Optional. - - observation_iprod - The inner product to calculate the observation error functional - from the observation error :math:`y_{0} - \\mathcal{H}_{0}(x)`. - Can include the error covariance matrix. Must be provided if - observation_err is provided. - - weak_constraint - Whether to use the weak or strong constraint 4DVar formulation. - - tape - The tape to record on. - - See Also - -------- - :class:`pyadjoint.ReducedFunctional`. - """ - - def __init__(self, control: Control, - background_iprod: Callable[[OverloadedType], AdjFloat], - observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None, - observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None, - weak_constraint: bool = True, - tape: Optional[Tape] = None, - _annotate_accumulation: bool = False): - - self.tape = get_working_tape() if tape is None else tape - - self.weak_constraint = weak_constraint - self.initial_observations = observation_err is not None - - # We need a copy for the prior, but this shouldn't be part of the tape - with stop_annotating(): - self.background = control.copy_data() - - if self.weak_constraint: - self._annotate_accumulation = _annotate_accumulation - - # new tape for background error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - control_copy = control.copy_data() - _rename(control_copy, "Control_0_bkg_copy") - - # vector of x_0 - x_b - bkg_err_vec = _ad_sub(control_copy, self.background) - _rename(bkg_err_vec, "bkg_err_vec") - - # RF to recover x_0 - x_b - self.background_error = ReducedFunctional( - bkg_err_vec, Control(control_copy), tape=tape) - - # new tape for background error reduction - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - bkg_err_vec_copy = bkg_err_vec._ad_copy() - _rename(bkg_err_vec_copy, "bkg_err_vec_copy") - - # inner product |x_0 - x_b|_B - bkg_err = background_iprod(bkg_err_vec_copy) - - # RF to recover |x_0 - x_b|_B - self.background_rf = ReducedFunctional( - bkg_err, Control(bkg_err_vec_copy), tape=tape) - - self.controls = [control] # The solution at the beginning of each time-chunk - self.states = [] # The model propogation at the end of each time-chunk - self.forward_model_stages = [] # ReducedFunctional for each model propogation (returns state) - self.forward_model_errors = [] # Inner product for model errors (possibly including error covariance) - self.forward_model_rfs = [] # Inner product for model errors (possibly including error covariance) - self.observation_errors = [] # ReducedFunctional for each observation set (returns observation error) - self.observation_rfs = [] # Inner product for observation errors (possibly including error covariance) - - if self.initial_observations: - - # new tape for observation error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - control_copy = control.copy_data() - _rename(control_copy, "Control_0_obs_copy") - - # vector of H(x_0) - y_0 - obs_err_vec = observation_err(control_copy) - _rename(obs_err_vec, "obs_err_vec_0") - - # RF to recover H(x_0) - y_0 - self.observation_errors.append(ReducedFunctional( - obs_err_vec, Control(control_copy), tape=tape) - ) - - # new tape for observation error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - obs_err_vec_copy = obs_err_vec._ad_copy() - _rename(obs_err_vec_copy, "obs_err_vec_0_copy") - - # inner product |H(x_0) - y_0|_R - obs_err = observation_iprod(obs_err_vec_copy) - - # RF to recover |H(x_0) - y_0|_R - self.observation_rfs.append(ReducedFunctional( - obs_err, Control(obs_err_vec_copy), tape=tape) - ) - - # new tape for the next stage - set_working_tape() - self._stage_tape = get_working_tape() - - else: - self._annotate_accumulation = True - - # initial conditions guess to be updated - self.controls = Enlist(control) - - # Strong constraint functional to be converted to ReducedFunctional later - - # penalty for straying from prior - self._accumulate_functional( - background_iprod(control.control - self.background)) - - # penalty for not hitting observations at initial time - if self.initial_observations: - self._accumulate_functional( - observation_iprod(observation_err(control.control))) - - def set_observation(self, state: OverloadedType, - observation_err: Callable[[OverloadedType], OverloadedType], - observation_iprod: Callable[[OverloadedType], AdjFloat], - forward_model_iprod: Optional[Callable[[OverloadedType], AdjFloat]]): - """ - Record an observation at the time of `state`. - - Parameters - ---------- - - state - The state at the current observation time. - - observation_err - Given a state :math:`x`, returns the observations error - :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are - the observations at the current observation time and - :math:`\\mathcal{H}_{i}` is the observation operator for the - current observation time. - - observation_iprod - The inner product to calculate the observation error functional - from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. - Can include the error covariance matrix. - - forward_model_iprod - The inner product to calculate the model error functional from - the model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})`. Can - include the error covariance matrix. Ignored if using the strong - constraint formulation. - """ - if self.weak_constraint: - - stage_index = len(self.controls) - - # Cut the tape into seperate time-chunks. - # State is output from previous control i.e. forward model - # propogation over previous time-chunk. - - # get the tape used for this stage and make sure its the right one - prev_stage_tape = get_working_tape() - if prev_stage_tape is not self._stage_tape: - raise ValueError( - "Working tape at the end of the observation stage" - " differs from the tape at the stage beginning." - ) - - # # record forward propogation - with set_working_tape(prev_stage_tape.copy()) as tape: - prev_control = self.controls[-1] - self.forward_model_stages.append(ReducedFunctional( - state._ad_copy(), controls=prev_control, tape=tape) - ) - - # Beginning of next time-chunk is the control for this observation - # and the state at the end of the next time-chunk. - with stop_annotating(): - # smuggle initial guess at this time into the control without the tape seeing - next_control_state = state._ad_copy() - _rename(next_control_state, f"Control_{len(self.controls)}") - next_control = Control(next_control_state) - self.controls.append(next_control) - - # model error links time-chunks by depending on both the - # previous and current controls - - # new tape for model error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - state_copy = state._ad_copy() - _rename(state_copy, f"state_{stage_index}_copy") - next_control_copy = next_control.copy_data() - _rename(next_control_copy, f"Control_{stage_index}_model_copy") - - # vector of M_i - x_i - model_err_vec = _ad_sub(state_copy, next_control_copy) - _rename(model_err_vec, f"model_err_vec_{stage_index}") - - # RF to recover M_i - x_i - fmcontrols = [Control(state_copy), Control(next_control_copy)] - self.forward_model_errors.append(ReducedFunctional( - model_err_vec, fmcontrols, tape=tape) - ) - - # new tape for model error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - model_err_vec_copy = model_err_vec._ad_copy() - _rename(model_err_vec_copy, f"model_err_vec_{stage_index}_copy") - - # inner product |M_i - x_i|_Q - model_err = forward_model_iprod(model_err_vec_copy) - - # RF to recover |M_i - x_i|_Q - self.forward_model_rfs.append(ReducedFunctional( - model_err, Control(model_err_vec_copy), tape=tape) - ) - - # Observations after tape cut because this is now a control, not a state - - # new tape for observation error vector - with set_working_tape() as tape: - # start from a control independent of any other tapes - with stop_annotating(): - next_control_copy = next_control.copy_data() - _rename(next_control_copy, f"Control_{stage_index}_obs_copy") - - # vector of H(x_i) - y_i - obs_err_vec = observation_err(next_control_copy) - _rename(obs_err_vec, f"obs_err_vec_{stage_index}") - - # RF to recover H(x_i) - y_i - self.observation_errors.append(ReducedFunctional( - obs_err_vec, Control(next_control_copy), tape=tape) - ) - - # new tape for observation error reduction - with set_working_tape() as tape: - # start from a control independent of any othe tapes - with stop_annotating(): - obs_err_vec_copy = obs_err_vec._ad_copy() - _rename(obs_err_vec_copy, f"obs_err_vec_{stage_index}_copy") - - # inner product |H(x_i) - y_i|_R - obs_err = observation_iprod(obs_err_vec_copy) - - # RF to recover |H(x_i) - y_i|_R - self.observation_rfs.append(ReducedFunctional( - obs_err, Control(obs_err_vec_copy), tape=tape) - ) - - # new tape for the next stage - - set_working_tape() - self._stage_tape = get_working_tape() - - # Look we're starting this time-chunk from an "unrelated" value... really! - state.assign(next_control.control) - - else: - - if hasattr(self, "_strong_reduced_functional"): - msg = "Cannot add observations once strong constraint ReducedFunctional instantiated" - raise ValueError(msg) - self._accumulate_functional( - observation_iprod(observation_err(state))) - - @cached_property - def strong_reduced_functional(self): - """A :class:`pyadjoint.ReducedFunctional` for the strong constraint 4DVar system. - - Only instantiated if using the strong constraint formulation, and cannot be used - before all observations are recorded. - """ - if self.weak_constraint: - msg = "Strong constraint ReducedFunctional not instantiated for weak constraint 4DVar" - raise AttributeError(msg) - self._strong_reduced_functional = ReducedFunctional( - self._total_functional, self.controls, tape=self.tape) - return self._strong_reduced_functional - - def __getattr__(self, attr): - """ - If using strong constraint then grab attributes from self.strong_reduced_functional. - """ - if self.weak_constraint: - raise AttributeError(f"'{type(self)}' object has no attribute '{attr}'") - else: - return getattr(self.strong_reduced_functional, attr) - - @sc_passthrough - @no_annotations - def __call__(self, values: OverloadedType): - """Computes the reduced functional with supplied control value. - - Parameters - ---------- - - values - If you have multiple controls this should be a list of new values - for each control in the order you listed the controls to the constructor. - If you have a single control it can either be a list or a single object. - Each new value should have the same type as the corresponding control. - - Returns - ------- - pyadjoint.OverloadedType - The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. - - """ - # controls are updated by the sub ReducedFunctionals - # so we don't need to do it ourselves - - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. - # Model i propogates from i-1 to i. - # Observation i is at i. - - for c, v in zip(self.controls, values): - c.control.assign(v) - - model_stages = [None, *self.forward_model_stages] - model_errors = [None, *self.forward_model_errors] - model_rfs = [None, *self.forward_model_rfs] - - observation_errors = (self.observation_errors if self.initial_observations - else [None, *self.observation_errors]) - - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) - - # Initial condition functionals - bkg_err_vec = self.background_error(values[0]) - J = self.background_rf(bkg_err_vec) - - # observations at time 0 - if self.initial_observations: - obs_err_vec = observation_errors[0](values[0]) - J += observation_rfs[0](obs_err_vec) - - for i in range(1, len(observation_rfs)): - prev_control = values[i-1] - this_control = values[i] - - # observation error - do we match the 'real world'? - obs_err_vec = observation_errors[i](this_control) - J += observation_rfs[i](obs_err_vec) - - # Model error - does propogation from last control match this control? - Mi = model_stages[i](prev_control) - model_err_vec = model_errors[i]([Mi, this_control]) - J += model_rfs[i](model_err_vec) - - return J - - @sc_passthrough - @no_annotations - def derivative(self, adj_input: float = 1.0, options: dict = {}): - """Returns the derivative of the functional w.r.t. the control. - Using the adjoint method, the derivative of the functional with - respect to the control, around the last supplied value of the - control, is computed and returned. - - Parameters - ---------- - adj_input - The adjoint input. - - options - Additional options for the derivative computation. - - Returns - ------- - pyadjoint.OverloadedType - The derivative with respect to the control. - Should be an instance of the same type as the control. - """ - # create a list of overloaded types to put derivative into - derivatives = [] - - # chaining ReducedFunctionals means we need to pass Cofunctions not Functions - intermediate_options = { - 'riesz_representation': None, - **{k: v for k, v in options.items() - if (k != 'riesz_representation')} - } - - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. Model i propogates from i-1 to i. - model_stages = [None, *self.forward_model_stages] - model_errors = [None, *self.forward_model_errors] - model_rfs = [None, *self.forward_model_rfs] - - observation_errors = (self.observation_errors if self.initial_observations - else [None, *self.observation_errors]) - - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) - - # initial condition derivatives - bkg_deriv = self.background_rf.derivative(adj_input=adj_input, - options=intermediate_options) - derivatives.append(self.background_error.derivative(adj_input=bkg_deriv, - options=options)) - - # observations at time 0 - if self.initial_observations: - obs_deriv = observation_rfs[0].derivative(adj_input=adj_input, - options=intermediate_options) - derivatives[0] += observation_errors[0].derivative(adj_input=obs_deriv, - options=options) - - for i in range(1, len(observation_rfs)): - obs_deriv = observation_rfs[i].derivative(adj_input=adj_input, - options=intermediate_options) - derivatives.append(observation_errors[i].derivative(adj_input=obs_deriv, - options=options)) - - # derivative of model error will split: - # wrt x_i through error vector - # wrt x_i-1 through stage propogation - model_deriv = model_rfs[i].derivative(adj_input=adj_input, - options=intermediate_options) - model_err_derivs = model_errors[i].derivative(adj_input=model_deriv, - options=intermediate_options) - model_stage_deriv = model_stages[i].derivative(adj_input=model_err_derivs[0], - options=options) - - derivatives[i-1] += model_stage_deriv - derivatives[i] += model_err_derivs[1].riesz_representation() - - return derivatives - - @sc_passthrough - @no_annotations - def hessian(self, m_dot: OverloadedType, options: dict = {}): - """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. - - Using the second-order adjoint method, the action of the Hessian of the - functional with respect to the control, around the last supplied value - of the control, is computed and returned. - - Parameters - ---------- - - m_dot - The direction in which to compute the action of the Hessian. - - options - A dictionary of options. To find a list of available options - have a look at the specific control type. - - Returns - ------- - pyadjoint.OverloadedType - The action of the Hessian in the direction m_dot. - Should be an instance of the same type as the control. - """ - # create a list of overloaded types to put hessian into - hessians = [] - - kwargs = {'options': options} - - # Shift lists so indexing matches standard nomenclature: - # index 0 is initial condition. Model i propogates from i-1 to i. - model_rfs = [None, *self.forward_model_rfs] - - observation_rfs = (self.observation_rfs if self.initial_observations - else [None, *self.observation_rfs]) - - # initial condition hessians - hessians.append( - self.background_rf.hessian(m_dot[0], **kwargs)) - - if self.initial_observations: - hessians[0] += observation_rfs[0].hessian(m_dot[0], **kwargs) - - for i in range(1, len(model_rfs)): - hessians.append(observation_rfs[i].hessian(m_dot[i], **kwargs)) - - mhess = model_rfs[i].hessian(m_dot[i-1:i+1], **kwargs) - - hessians[i-1] += mhess[0] - hessians[i] += mhess[1] - - return hessians - - @no_annotations - def hessian_matrix(self): - # Other reduced functionals don't have this. - if not self.weak_constraint: - raise AttributeError("Strong constraint 4DVar does not form a Hessian matrix") - raise NotImplementedError - - @sc_passthrough - @no_annotations - def optimize_tape(self): - for rf in (self.background_error, - self.background_rf, - *self.observation_errors, - *self.observation_rfs, - *self.forward_model_stages, - *self.forward_model_errors, - *self.forward_model_rfs): - rf.optimize_tape() - - def _accumulate_functional(self, val): - if not self._annotate_accumulation: - return - if hasattr(self, '_total_functional'): - self._total_functional += val - else: - self._total_functional = val diff --git a/firedrake/adjoint/composite_reduced_functional.py b/firedrake/adjoint/composite_reduced_functional.py new file mode 100644 index 0000000000..d09dfebe9f --- /dev/null +++ b/firedrake/adjoint/composite_reduced_functional.py @@ -0,0 +1,332 @@ +from pyadjoint import stop_annotating, get_working_tape, OverloadedType, Control, Tape, ReducedFunctional +from pyadjoint.enlisting import Enlist +from typing import Optional + + +def intermediate_options(options: dict): + """ + Options set for the intermediate stages of a chain of ReducedFunctionals + + Takes all elements of the options except riesz_representation, which + is set to None to prevent returning derivatives to the primal space. + + Parameters + ---------- + options + The dictionary of options provided by the user + + Returns + ------- + dict + The options for ReducedFunctionals at intermediate stages + + """ + return { + **{k: v for k, v in (options or {}).items() + if k != 'riesz_representation'}, + 'riesz_representation': None + } + + +def compute_tlm(J: OverloadedType, + m: Control, + m_dot: OverloadedType, + options: Optional[dict] = None, + tape: Optional[Tape] = None): + """ + Compute the tangent linear model of J in a direction m_dot at the current value of m + + Parameters + ---------- + + J + The objective functional. + m + The (list of) :class:`pyadjoint.Control` for the functional. + m_dot + The direction in which to compute the Hessian. + Must be a (list of) :class:`pyadjoint.OverloadedType`. + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + tape + The tape to use. Default is the current tape. + + Returns + ------- + pyadjoint.OverloadedType + The tangent linear with respect to the control in direction m_dot. + Should be an instance of the same type as the control. + + """ + tape = tape or get_working_tape() + + # reset tlm values + tape.reset_tlm_values() + + m = Enlist(m) + m_dot = Enlist(m_dot) + + # set initial tlm values + for mi, mdi in zip(m, m_dot): + mi.tlm_value = mdi + + # evaluate tlm + with stop_annotating(): + with tape.marked_nodes(m): + tape.evaluate_tlm(markings=True) + + # return functional's tlm + return J._ad_convert_type(J.block_variable.tlm_value, + options=options or {}) + + +def compute_hessian(J: OverloadedType, + m: Control, + options: Optional[dict] = None, + tape: Optional[Tape] = None, + hessian_value: Optional[OverloadedType] = 0.): + """ + Compute the Hessian of J at the current value of m with the current tlm values on the tape. + + Parameters + ---------- + J + The objective functional. + m + The (list of) :class:`pyadjoint.Control` for the functional. + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + tape + The tape to use. Default is the current tape. + hessian_value + The initial hessian_value to start accumulating from. + + Returns + ------- + pyadjoint.OverloadedType + The second derivative with respect to the control in direction m_dot. + Should be an instance of the same type as the control. + + """ + tape = tape or get_working_tape() + + # reset hessian values + tape.reset_hessian_values() + + m = Enlist(m) + + # set initial hessian_value + J.block_variable.hessian_value = J._ad_convert_type( + hessian_value, options=intermediate_options(options)) + + # evaluate hessian + with stop_annotating(): + with tape.marked_nodes(m): + tape.evaluate_hessian(markings=True) + + # return controls' hessian values + return m.delist([v.get_hessian(options=options or {}) for v in m]) + + +def tlm(rf: ReducedFunctional, + m_dot: OverloadedType, + options: Optional[dict] = None): + """Returns the action of the tangent linear model of the functional w.r.t. the control on a vector m_dot. + + Parameters + ---------- + rf + The :class:`pyadjoint.ReducedFunctional` to evaluate the tlm of. + m_dot + The direction in which to compute the action of the tangent linear model. + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + Returns + ------- + pyadjoint.OverloadedType + The action of the tangent linear model in the direction m_dot. + Should be an instance of the same type as the control. + + """ + return compute_tlm(rf.functional, rf.controls, m_dot, + tape=rf.tape, options=options) + + +def hessian(rf: ReducedFunctional, + options: Optional[dict] = None, + hessian_value: Optional[OverloadedType] = 0.): + """Returns the action of the Hessian of the functional w.r.t. the control. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control and the last tlm values, is computed and returned. + + Parameters + ---------- + rf + The :class:`pyadjoint.ReducedFunctional` to evaluate the tlm of. + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + hessian_value + The initial hessian_value to start accumulating from. + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian. Should be an instance of the same type as the control. + + """ + return rf.controls.delist( + compute_hessian(rf.functional, rf.controls, + tape=rf.tape, options=options, + hessian_value=hessian_value)) + + +class CompositeReducedFunctional: + """Class representing the composition of two reduced functionals. + + For two reduced functionals J1: X->Y and J2: Y->Z, this is a convenience + class representing the composition J12: X->Z = J2(J1(x)) and providing + methods for the evaluation, derivative, tlm, and hessian action of J12. + + Parameters + ---------- + rf1 + The first :class:`pyadjoint.ReducedFunctional` in the composition. + rf2 + The second :class:`pyadjoint.ReducedFunctional` in the composition. + The control for rf2 must have the same type as the functional of rf1. + + """ + def __init__(self, rf1, rf2): + self.rf1 = rf1 + self.rf2 = rf2 + + def __call__(self, values: OverloadedType): + """Computes the reduced functional with supplied control value. + + Parameters + ---------- + + values + If you have multiple controls this should be a list of new values + for each control in the order you listed the controls to the constructor. + If you have a single control it can either be a list or a single object. + Each new value should have the same type as the corresponding control. + + Returns + ------- + pyadjoint.OverloadedType + The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. + + """ + return self.rf2(self.rf1(values)) + + def derivative(self, adj_input: Optional[float] = 1.0, options: Optional[dict] = None): + """Returns the derivative of the functional w.r.t. the control. + Using the adjoint method, the derivative of the functional with + respect to the control, around the last supplied value of the + control, is computed and returned. + + Parameters + ---------- + adj_input + The adjoint input. + + options + Additional options for the derivative computation. + + Returns + ------- + pyadjoint.OverloadedType + The derivative with respect to the control. + Should be an instance of the same type as the control. + + """ + deriv2 = self.rf2.derivative( + adj_input=adj_input, options=intermediate_options(options)) + deriv1 = self.rf1.derivative( + adj_input=deriv2, options=options or {}) + return deriv1 + + def tlm(self, m_dot: OverloadedType, options: Optional[dict] = None): + """Returns the action of the tangent linear model of the functional w.r.t. the control on a vector m_dot. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + + """ + tlm1 = self._eval_tlm( + self.rf1, m_dot, intermediate_options(options)), + tlm2 = self._eval_tlm( + self.rf2, tlm1, options) + return tlm2 + + def hessian(self, m_dot: OverloadedType, + options: Optional[dict] = None, + evaluate_tlm: Optional[bool] = True): + """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control, is computed and returned. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + evaluate_tlm + If True, the tlm values on the tape will be reset and evaluated before + the Hessian action is evaluated. If False, the existing tlm values on + the tape will be used. + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + + """ + if evaluate_tlm: + self.tlm(m_dot, options=intermediate_options(options)) + hess2 = self._eval_hessian( + self.rf2, 0., intermediate_options(options)) + hess1 = self._eval_hessian( + self.rf1, hess2, options or {}) + return hess1 + + def _eval_tlm(self, rf, m_dot, options): + if isinstance(rf, CompositeReducedFunctional): + return rf.tlm(m_dot, options=options) + else: + return tlm(rf, m_dot=m_dot, options=options) + + def _eval_hessian(self, rf, hessian_value, options): + if isinstance(rf, CompositeReducedFunctional): + return rf.hessian(None, options, evaluate_tlm=False) + else: + return hessian(rf, hessian_value=hessian_value, options=options) diff --git a/firedrake/adjoint/fourdvar_reduced_functional.py b/firedrake/adjoint/fourdvar_reduced_functional.py new file mode 100644 index 0000000000..82effbf514 --- /dev/null +++ b/firedrake/adjoint/fourdvar_reduced_functional.py @@ -0,0 +1,1108 @@ +from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ + stop_annotating, get_working_tape, set_working_tape +from pyadjoint.enlisting import Enlist +from firedrake.function import Function +from firedrake.ensemblefunction import EnsembleFunction, EnsembleCofunction +from firedrake.adjoint.composite_reduced_functional import ( + CompositeReducedFunctional, tlm, hessian, intermediate_options) + +from functools import wraps, cached_property +from typing import Callable, Optional +from types import SimpleNamespace +from contextlib import contextmanager +from mpi4py import MPI + +__all__ = ['FourDVarReducedFunctional'] + + +# @set_working_tape() # ends up using old_tape = None because evaluates when imported - need separate decorator +def isolated_rf(operation, control, + functional_name=None, + control_name=None): + """ + Return a ReducedFunctional where the functional is `operation` applied + to a copy of `control`, and the tape contains only `operation`. + """ + with stop_annotating(): + controls = Enlist(control) + control_copies = [control._ad_copy() for control in controls] + + if control_name: + for control, name in zip(control_copies, Enlist(control_name)): + _rename(control, name) + + with set_working_tape(): + functional = operation(controls.delist(control_copies)) + + if functional_name: + _rename(functional, functional_name) + + control = controls.delist([Control(control_copy) + for control_copy in control_copies]) + + return ReducedFunctional( + functional, control) + + +def sc_passthrough(func): + """ + Wraps standard ReducedFunctional methods to differentiate strong or + weak constraint implementations. + + If using strong constraint, makes sure strong_reduced_functional + is instantiated then passes args/kwargs through to the + corresponding strong_reduced_functional method. + + If using weak constraint, returns the FourDVarReducedFunctional + method definition. + """ + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.weak_constraint: + return func(self, *args, **kwargs) + else: + sc_func = getattr(self.strong_reduced_functional, func.__name__) + return sc_func(*args, **kwargs) + return wrapper + + +def _rename(obj, name): + if hasattr(obj, "rename"): + obj.rename(name) + + +def _ad_sub(left, right): + result = right._ad_copy() + result._ad_imul(-1) + result._ad_iadd(left) + return result + + +class FourDVarReducedFunctional(ReducedFunctional): + """ReducedFunctional for 4DVar data assimilation. + + Creates either the strong constraint or weak constraint system + by logging observations through the initial time propagator run. + + Parameters + ---------- + + control + The :class:`.EnsembleFunction` for the control x_{i} at the initial condition + and at the end of each observation stage. + + background_iprod + The inner product to calculate the background error functional + from the background error :math:`x_{0} - x_{b}`. Can include the + error covariance matrix. Only used on ensemble rank 0. + + background + The background (prior) data for the initial condition :math:`x_{b}`. + If not provided, the value of the first subfunction on the first ensemble + member of the control :class:`.EnsembleFunction` will be used. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{0} - \\mathcal{H}_{0}(x)` where :math:`y_{0}` are the + observations at the initial time and :math:`\\mathcal{H}_{0}` is + the observation operator for the initial time. Only used on + ensemble rank 0. Optional. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{0} - \\mathcal{H}_{0}(x)`. + Can include the error covariance matrix. Must be provided if + observation_err is provided. Only used on ensemble rank 0 + + weak_constraint + Whether to use the weak or strong constraint 4DVar formulation. + + See Also + -------- + :class:`pyadjoint.ReducedFunctional`. + """ + + def __init__(self, control: Control, + background_iprod: Optional[Callable[[OverloadedType], AdjFloat]], + background: Optional[OverloadedType] = None, + observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None, + observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None, + weak_constraint: bool = True, + tape: Optional[Tape] = None, + _annotate_accumulation: bool = False): + + self.tape = get_working_tape() if tape is None else tape + + self.weak_constraint = weak_constraint + self.initial_observations = observation_err is not None + + if self.weak_constraint: + self._annotate_accumulation = _annotate_accumulation + self._accumulation_started = False + + if not isinstance(control.control, EnsembleFunction): + raise TypeError( + "Control for weak constraint 4DVar must be an EnsembleFunction" + ) + + with stop_annotating(): + if background: + self.background = background._ad_copy() + else: + self.background = control.control.subfunctions[0]._ad_copy() + _rename(self.background, "Background") + + ensemble = control.ensemble + self.ensemble = ensemble + self.trank = ensemble.ensemble_comm.rank if ensemble else 0 + self.nchunks = ensemble.ensemble_comm.size if ensemble else 1 + + # because we need to manually evaluate the different bits + # of the functional, we need an internal set of controls + # to use for the stage ReducedFunctionals + self._cbuf = control.copy_data() + _x = self._cbuf.subfunctions + self._x = _x + self._controls = tuple(Control(xi) for xi in _x) + + self.control = control + self.controls = [control] + + # first control on rank 0 is initial conditions, not end of observation stage + self.nlocal_stages = len(_x) - (1 if self.trank == 0 else 0) + + self.stages = [] # The record of each observation stage + + # first rank sets up functionals for background initial observations + if self.trank == 0: + + # RF to recalculate error vector (x_0 - x_b) + self.background_error = isolated_rf( + operation=lambda x0: _ad_sub(x0, self.background), + control=_x[0], + functional_name="bkg_err_vec", + control_name="Control_0_bkg_copy") + + # RF to recalculate inner product |x_0 - x_b|_B + self.background_norm = isolated_rf( + operation=background_iprod, + control=self.background_error.functional, + control_name="bkg_err_vec_copy") + + # compose background reduced functionals to evaluate both together + self.background_rf = CompositeReducedFunctional( + self.background_error, self.background_norm) + + if self.initial_observations: + + # RF to recalculate error vector (H(x_0) - y_0) + self.initial_observation_error = isolated_rf( + operation=observation_err, + control=_x[0], + functional_name="obs_err_vec_0", + control_name="Control_0_obs_copy") + + # RF to recalculate inner product |H(x_0) - y_0|_R + self.initial_observation_norm = isolated_rf( + operation=observation_iprod, + control=self.initial_observation_error.functional, + functional_name="obs_err_vec_0_copy") + + # compose initial observation reduced functionals to evaluate both together + self.initial_observation_rf = CompositeReducedFunctional( + self.initial_observation_error, self.initial_observation_norm) + + # create halo for previous state + if self.ensemble and self.trank != 0: + with stop_annotating(): + self.xprev = _x[0]._ad_copy() + self._control_prev = Control(self.xprev) + + # halo for the derivative from the next chunk + if self.ensemble and self.trank != self.nchunks - 1: + with stop_annotating(): + self.xnext = _x[0]._ad_copy() + + else: + self._annotate_accumulation = True + self._accumulation_started = False + + if not isinstance(control.control, Function): + raise TypeError( + "Control for strong constraint 4DVar must be a Function" + ) + + with stop_annotating(): + if background: + self.background = background._ad_copy() + else: + self.background = control.control._ad_copy() + _rename(self.background, "Background") + + # initial conditions guess to be updated + self.controls = Enlist(control) + + # Strong constraint functional to be converted to ReducedFunctional later + + # penalty for straying from prior + self._accumulate_functional( + background_iprod(control.control - self.background)) + + # penalty for not hitting observations at initial time + if self.initial_observations: + self._accumulate_functional( + observation_iprod(observation_err(control.control))) + + @cached_property + def strong_reduced_functional(self): + """A :class:`pyadjoint.ReducedFunctional` for the strong constraint 4DVar system. + + Only instantiated if using the strong constraint formulation, and cannot be used + before all observations are recorded. + """ + if self.weak_constraint: + msg = "Strong constraint ReducedFunctional cannot be instantiated for weak constraint 4DVar" + raise AttributeError(msg) + self._strong_reduced_functional = ReducedFunctional( + self._total_functional, self.controls.delist(), tape=self.tape) + return self._strong_reduced_functional + + def __getattr__(self, attr): + """ + If using strong constraint then grab attributes from self.strong_reduced_functional. + """ + # hasattr calls getattr, so check self.__dir__ directly here to avoid recursion + if self.weak_constraint or "_strong_reduced_functional" not in dir(self): + raise AttributeError(f"'{type(self)}' object has no attribute '{attr}'") + return getattr(self.strong_reduced_functional, attr) + + @sc_passthrough + @stop_annotating() + def __call__(self, values: OverloadedType): + """Computes the reduced functional with supplied control value. + + Parameters + ---------- + + values + If you have multiple controls this should be a list of new values + for each control in the order you listed the controls to the constructor. + If you have a single control it can either be a list or a single object. + Each new value should have the same type as the corresponding control. + + Returns + ------- + pyadjoint.OverloadedType + The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. + + """ + value = values[0] if isinstance(values, list) else values + + if not isinstance(value, type(self.control.control)): + raise ValueError(f"Value must be of type {type(self.control.control)} not type {type(value)}") + + self.control.update(value) + # put the new value into our internal set of controls to pass to each stage + self._cbuf.assign(value) + + trank = self.trank + + # first "control" for later ranks is the halo + if self.ensemble and trank != 0: + x = [self.xprev, *self._x] + else: + x = [*self._x] + + # post messages for control of time propagator on next chunk + if self.ensemble: + src = trank - 1 + dst = trank + 1 + + if trank != self.nchunks - 1: + self.ensemble.isend( + x[-1], dest=dst, tag=dst) + + if trank != 0: + recv_reqs = self.ensemble.irecv( + self.xprev, source=src, tag=trank) + + # Initial condition functionals + if trank == 0: + Jlocal = self.background_rf(x[0]) + + # observations at time 0 + if self.initial_observations: + Jlocal += self.initial_observation_rf(x[0]) + else: + Jlocal = 0. + + # evaluate all stages on chunk except first + for i in range(1, len(self.stages)): + Jlocal += self.stages[i](x[i:i+2]) + + # wait for halo swap to finish + if trank != 0: + MPI.Request.Waitall(recv_reqs) + + # evaluate first stage model on chunk now we have data + Jlocal += self.stages[0](x[0:2]) + + # sum all stages + if self.ensemble: + J = self.ensemble.ensemble_comm.allreduce(Jlocal) + else: + J = Jlocal + + return J + + @sc_passthrough + @stop_annotating() + def derivative(self, adj_input: float = 1.0, options: dict = {}): + """Returns the derivative of the functional w.r.t. the control. + Using the adjoint method, the derivative of the functional with + respect to the control, around the last supplied value of the + control, is computed and returned. + + Parameters + ---------- + adj_input + The adjoint input. + + options + Additional options for the derivative computation. + + Returns + ------- + pyadjoint.OverloadedType + The derivative with respect to the control. + Should be an instance of the same type as the control. + """ + trank = self.trank + + # chaining ReducedFunctionals means we need to pass Cofunctions not Functions + options = options or {} + + # evaluate first time propagator, which contributes to previous chunk + sderiv0 = self.stages[0].derivative( + adj_input=adj_input, options=options) + + # create the derivative in the right primal or dual space + from ufl.duals import is_primal, is_dual + if is_primal(sderiv0[0]): + derivatives = EnsembleFunction( + self.ensemble, self.control.local_function_spaces) + else: + if not is_dual(sderiv0[0]): + raise ValueError( + "Do not know how to handle stage derivative which is not primal or dual") + derivatives = EnsembleCofunction( + self.ensemble, [V.dual() for V in self.control.local_function_spaces]) + + derivatives.zero() + + if self.ensemble: + with stop_annotating(): + xprev = derivatives.subfunctions[0]._ad_copy() + xnext = derivatives.subfunctions[0]._ad_copy() + xprev.zero() + xnext.zero() + if trank != 0: + derivs = [xprev, *derivatives.subfunctions] + else: + derivs = [*derivatives.subfunctions] + + # start accumulating the complete derivative + derivs[0] += sderiv0[0] + derivs[1] += sderiv0[1] + + # post the derivative halo exchange + if self.ensemble: + # halos sent backward in time + src = trank + 1 + dst = trank - 1 + + if trank != 0: + self.ensemble.isend( + derivs[0], dest=dst, tag=dst) + + if trank != self.nchunks - 1: + recv_reqs = self.ensemble.irecv( + xnext, source=src, tag=trank) + + # initial condition derivatives + if trank == 0: + derivs[0] += self.background_rf.derivative( + adj_input=adj_input, options=options) + + # observations at time 0 + if self.initial_observations: + derivs[0] += self.initial_observation_rf.derivative( + adj_input=adj_input, options=options) + + # # evaluate all time stages on chunk except first while halo in flight + for i in range(1, len(self.stages)): + sderiv = self.stages[i].derivative( + adj_input=adj_input, options=options) + + derivs[i] += sderiv[0] + derivs[i+1] += sderiv[1] + + # finish the derivative halo exchange + if self.ensemble: + if trank != self.nchunks - 1: + MPI.Request.Waitall(recv_reqs) + derivs[-1] += xnext + + return derivatives + + @sc_passthrough + @stop_annotating() + def hessian(self, m_dot: OverloadedType, options: dict = {}): + """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control, is computed and returned. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + """ + trank = self.trank + + hess = self.control.copy_data() + hess.zero() + + # set up arrays including halos + if trank == 0: + hs = [*hess.subfunctions] + mdot = [*m_dot[0].subfunctions] + else: + hprev = hess.subfunctions[0].copy(deepcopy=True) + mprev = m_dot[0].subfunctions[0].copy(deepcopy=True) + hs = [hprev, *hess.subfunctions] + mdot = [mprev, *m_dot[0].subfunctions] + + if trank != self.nchunks - 1: + hnext = hess.subfunctions[0].copy(deepcopy=True) + + # send m_dot halo forward + if self.ensemble: + src = trank - 1 + dst = trank + 1 + + if trank != self.nchunks - 1: + self.ensemble.isend( + mdot[-1], dest=dst, tag=dst) + + if trank != 0: + recv_reqs = self.ensemble.irecv( + mdot[0], source=src, tag=trank) + + # hessian actions at the initial condition + if trank == 0: + hs[0] += self.background_rf.hessian( + mdot[0], options=options) + + if self.initial_observations: + hs[0] += self.initial_observation_rf.hessian( + mdot[0], options=options) + + # evaluate all stages on chunk except first + for i in range(1, len(self.stages)): + hms = self.stages[i].hessian( + mdot[i:i+2], options=options) + + hs[i] += hms[0] + hs[i+1] += hms[1] + + # wait for halo swap to finish + if trank != 0: + MPI.Request.Waitall(recv_reqs) + + # evaluate first stage on chunk now we have the halo + hms = self.stages[0].hessian( + mdot[:2], options=options) + + hs[0] += hms[0] + hs[1] += hms[1] + + # send result halo backward + if self.ensemble: + src = trank + 1 + dst = trank - 1 + + if trank != 0: + self.ensemble.isend( + hs[0], dest=dst, tag=dst) + + if trank != self.nchunks - 1: + recv_reqs = self.ensemble.irecv( + hnext, source=src, tag=trank) + + # finish the result halo + if trank != self.nchunks - 1: + MPI.Request.Waitall(recv_reqs) + hs[-1] += hnext + + return hess + + @stop_annotating() + def hessian_matrix(self): + # Other reduced functionals don't have this. + if not self.weak_constraint: + raise AttributeError("Strong constraint 4DVar does not form a Hessian matrix") + raise NotImplementedError + + def _accumulate_functional(self, val): + if not self._annotate_accumulation: + return + if self._accumulation_started: + self._total_functional += val + else: + self._total_functional = val + self._accumulation_started = True + + @contextmanager + def recording_stages(self, sequential=True, nstages=None, **stage_kwargs): + if not sequential: + raise ValueError("Recording stages concurrently not yet implemented") + + # record over ensemble + if self.weak_constraint: + + trank = self.trank + + # index of "previous" stage and observation in global series + global_index = -1 + observation_index = 0 if self.initial_observations else -1 + with stop_annotating(): + xhalo = self._x[0]._ad_copy() + + # add our data onto the user's context data + ekwargs = {k: v for k, v in stage_kwargs.items()} + ekwargs['global_index'] = global_index + ekwargs['observation_index'] = observation_index + + ekwargs['xhalo'] = xhalo + + # proceed one ensemble rank at a time + with self.ensemble.sequential(**ekwargs) as ectx: + + # later ranks start from halo + if trank == 0: + controls = self._controls + else: + controls = [self._control_prev, *self._controls] + with stop_annotating(): + controls[0].assign(ectx.xhalo) + + # grab the user's data from the ensemble context + local_stage_kwargs = { + k: getattr(ectx, k) for k in stage_kwargs.keys() + } + + # initialise iterator for local stages + stage_sequence = ObservationStageSequence( + controls, self, ectx.global_index, + ectx.observation_index, + local_stage_kwargs, sequential) + + # let the user record the local stages + yield stage_sequence + + # send the state forward + with stop_annotating(): + state = self.stages[-1].controls[1].control + ectx.xhalo.assign(state) + # grab the user's information to send forward + for k in local_stage_kwargs.keys(): + setattr(ectx, k, getattr(stage_sequence.ctx, k)) + # increment the global indices for the last local stage + ectx.global_index = self.stages[-1].global_index + ectx.observation_index = self.stages[-1].observation_index + + # make sure that self.control now holds the + # values of the initial timeseris + self.control.assign(self._cbuf) + + else: # strong constraint + + yield ObservationStageSequence( + self.controls, self, global_index=-1, + observation_index=0 if self.initial_observations else -1, + stage_kwargs=stage_kwargs, nstages=nstages) + + +class ObservationStageSequence: + def __init__(self, controls: Control, + aaorf: FourDVarReducedFunctional, + global_index: int, + observation_index: int, + stage_kwargs: dict = None, + nstages: Optional[int] = None): + self.controls = controls + self.aaorf = aaorf + self.ctx = SimpleNamespace(**(stage_kwargs or {})) + self.weak_constraint = aaorf.weak_constraint + self.global_index = global_index + self.observation_index = observation_index + self.local_index = -1 + self.nstages = (len(controls) - 1 if self.weak_constraint + else nstages) + + def __iter__(self): + return self + + def __next__(self): + + # increment global indices. + self.local_index += 1 + self.global_index += 1 + self.observation_index += 1 + + if self.weak_constraint: + stages = self.aaorf.stages + + # control for the start of the next stage. + next_control = self.controls[self.local_index] + + # smuggle state forward into aaorf's next control. + if self.local_index > 0: + state = stages[-1].controls[1].control + with stop_annotating(): + next_control.control.assign(state) + + # now we know that the aaorf's controls have + # been updated from the previous stage's controls, + # we can check if we need to exit. + if self.local_index >= self.nstages: + raise StopIteration + + stage = WeakObservationStage(next_control, + local_index=self.local_index, + global_index=self.global_index, + observation_index=self.observation_index) + stages.append(stage) + + else: # strong constraint + + # stop after we've recorded all stages + if self.local_index >= self.nstages: + raise StopIteration + + # dummy control to "start" stage from + control = (self.aaorf.controls[0].control if self.local_index == 0 + else self._prev_stage.state) + + stage = StrongObservationStage( + control, self.aaorf, + index=self.local_index, + observation_index=self.observation_index) + + self._prev_stage = stage + + return stage, self.ctx + + +class StrongObservationStage: + """ + Record an observation for strong constraint 4DVar at the time of `state`. + + Parameters + ---------- + + aaorf + The strong constraint FourDVarReducedFunctional. + + """ + + def __init__(self, control: OverloadedType, + aaorf: FourDVarReducedFunctional, + index: Optional[int] = None, + observation_index: Optional[int] = None): + self.aaorf = aaorf + self.control = control + self.index = index + self.observation_index = observation_index + + def set_observation(self, state: OverloadedType, + observation_err: Callable[[OverloadedType], OverloadedType], + observation_iprod: Callable[[OverloadedType], AdjFloat]): + """ + Record an observation at the time of `state`. + + Parameters + ---------- + + state + The state at the current observation time. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are + the observations at the current observation time and + :math:`\\mathcal{H}_{i}` is the observation operator for the + current observation time. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. + Can include the error covariance matrix. + """ + if hasattr(self.aaorf, "_strong_reduced_functional"): + raise ValueError("Cannot add observations once strong" + " constraint ReducedFunctional instantiated") + self.aaorf._accumulate_functional( + observation_iprod(observation_err(state))) + # save the user's state to hand back for beginning of next stage + self.state = state + + +class WeakObservationStage: + """ + A single stage for weak constraint 4DVar at the time of `state`. + + Records the time propagator from the control at the beginning + of the stage, and the model and observation errors at the end of the stage. + + Parameters + ---------- + + control + The control x_{i-1} at the beginning of the stage + + local_index + The index of this stage in the timeseries on the + local ensemble member. + + global_index + The index of this stage in the global timeseries. + + observation_index + The index of the observation at the end of this stage in + the global timeseries. May be different from global_index if + an observation is taken at the initial time. + + """ + def __init__(self, control: Control, + local_index: Optional[int] = None, + global_index: Optional[int] = None, + observation_index: Optional[int] = None): + # "control" to use as initial condition. + # Not actual `Control` for consistency with strong constraint + self.control = control.control + + self.controls = Enlist(control) + self.local_index = local_index + self.global_index = global_index + self.observation_index = observation_index + set_working_tape() + self._stage_tape = get_working_tape() + + def set_observation(self, state: OverloadedType, + observation_err: Callable[[OverloadedType], OverloadedType], + observation_iprod: Callable[[OverloadedType], AdjFloat], + forward_model_iprod: Callable[[OverloadedType], AdjFloat]): + """ + Record an observation at the time of `state`. + + Parameters + ---------- + + state + The state at the current observation time. + + observation_err + Given a state :math:`x`, returns the observations error + :math:`y_{i} - \\mathcal{H}_{i}(x)` where :math:`y_{i}` are + the observations at the current observation time and + :math:`\\mathcal{H}_{i}` is the observation operator for the + current observation time. + + observation_iprod + The inner product to calculate the observation error functional + from the observation error :math:`y_{i} - \\mathcal{H}_{i}(x)`. + Can include the error covariance matrix. + + forward_model_iprod + The inner product to calculate the model error functional from + the model error :math:`x_{i} - \\mathcal{M}_{i}(x_{i-1})`. Can + include the error covariance matrix. + """ + # get the tape used for this stage and make sure its the right one + stage_tape = get_working_tape() + if stage_tape is not self._stage_tape: + raise ValueError( + "Working tape at the end of the observation stage" + " differs from the tape at the stage beginning." + ) + + # record forward propogation + with set_working_tape(stage_tape.copy()) as tape: + self.forward_model = ReducedFunctional( + state._ad_copy(), controls=self.controls[0], tape=tape) + + # Beginning of next time-chunk is the control for this observation + # and the state at the end of the next time-chunk. + with stop_annotating(): + # smuggle initial guess at this time into the control without the tape seeing + self.controls.append(Control(state._ad_copy())) + if self.global_index: + _rename(self.controls[-1].control, f"Control_{self.global_index}") + + # model error links time-chunks by depending on both the + # previous and current controls + + # RF to recalculate error vector (M_i - x_i) + names = { + 'functional_name': f"model_err_vec_{self.global_index}", + 'control_name': [f"state_{self.global_index}_copy", + f"Control_{self.global_index}_model_copy"] + } if self.global_index else {} + + self.model_error = isolated_rf( + operation=lambda controls: _ad_sub(*controls), + control=[state, self.controls[-1].control], + **names) + + # RF to recalculate inner product |M_i - x_i|_Q + names = { + 'control_name': f"model_err_vec_{self.global_index}_copy" + } if self.global_index else {} + + self.model_norm = isolated_rf( + operation=forward_model_iprod, + control=self.model_error.functional, + **names) + + # compose model error reduced functionals to evaluate both together + self.model_error_rf = CompositeReducedFunctional( + self.model_error, self.model_norm) + + # Observations after tape cut because this is now a control, not a state + + # RF to recalculate error vector (H(x_i) - y_i) + names = { + 'functional_name': f"obs_err_vec_{self.global_index}", + 'control_name': f"Control_{self.global_index}_obs_copy" + } if self.global_index else {} + + self.observation_error = isolated_rf( + operation=observation_err, + control=self.controls[-1], + **names) + + # RF to recalculate inner product |H(x_i) - y_i|_R + names = { + 'functional_name': "obs_err_vec_{self.global_index}_copy" + } if self.global_index else {} + self.observation_norm = isolated_rf( + operation=observation_iprod, + control=self.observation_error.functional, + **names) + + # compose observation reduced functionals to evaluate both together + self.observation_rf = CompositeReducedFunctional( + self.observation_error, self.observation_norm) + + # remove the stage initial condition "control" now we've finished recording + delattr(self, "control") + + # stop the stage tape recording anything else + set_working_tape() + + @stop_annotating() + def __call__(self, values: OverloadedType, + rftype: Optional[str] = None): + """Computes the reduced functional with supplied control value. + + Parameters + ---------- + + values + If you have multiple controls this should be a list of new values + for each control in the order you listed the controls to the constructor. + If you have a single control it can either be a list or a single object. + Each new value should have the same type as the corresponding control. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. + + """ + J = 0.0 + + # evaluate model error + if rftype in (None, 'model'): + J += self.model_error_rf( + [self.forward_model(values[0]), values[1]]) + + # evaluate observation errors + if rftype in (None, 'obs'): + J += self.observation_rf(values[1]) + + return J + + @stop_annotating() + def derivative(self, adj_input: float = 1.0, options: dict = {}, + rftype: Optional[str] = None): + """Returns the derivative of the functional w.r.t. the control. + Using the adjoint method, the derivative of the functional with + respect to the control, around the last supplied value of the + control, is computed and returned. + + Parameters + ---------- + adj_input + The adjoint input. + + options + Additional options for the derivative computation. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The derivative with respect to the control. + Should be an instance of the same type as the control. + """ + # create a list of overloaded types to put derivative into + derivatives = [] + + # chaining ReducedFunctionals means we need to pass Cofunctions not Functions + options = options or {} + ioptions = intermediate_options(options) + + if rftype in (None, 'model'): + # derivative of reduction and difference + model_err_derivs = self.model_error_rf.derivative( + adj_input=adj_input, options=ioptions) + + # derivative through the time propagator wrt to xprev + model_forward_deriv = self.forward_model.derivative( + adj_input=model_err_derivs[0], options=options) + + derivatives.append(model_forward_deriv) + + # model_err_derivs is still in the dual space, so we need to convert it to the + # type that the user has requested - this will be the type of model_forward_deriv. + derivatives.append( + model_forward_deriv._ad_convert_type( + model_err_derivs[1], options)) + + if rftype in (None, 'obs'): + obs_deriv = self.observation_rf.derivative( + adj_input=adj_input, options=options) + + if len(derivatives) == 0: + derivatives.append(None) + derivatives.append(obs_deriv) + else: + derivatives[1] += obs_deriv + + return derivatives + + @stop_annotating() + def hessian(self, m_dot: OverloadedType, options: dict = {}, + rftype: Optional[str] = None): + """Returns the action of the Hessian of the functional w.r.t. the control on a vector m_dot. + + Using the second-order adjoint method, the action of the Hessian of the + functional with respect to the control, around the last supplied value + of the control, is computed and returned. + + Parameters + ---------- + + m_dot + The direction in which to compute the action of the Hessian. + + options + A dictionary of options. To find a list of available options + have a look at the specific control type. + + rtype: + Whether to evaluate: + - the model error ('model'), + - the observation error ('obs'), + - both model and observation errors (None). + + Returns + ------- + pyadjoint.OverloadedType + The action of the Hessian in the direction m_dot. + Should be an instance of the same type as the control. + """ + hessian_value = [] + + if rftype in (None, 'model'): + hessian_value.extend(self._model_hessian( + m_dot, options=options)) + + if rftype in (None, 'obs'): + obs_hessian = self.observation_rf.hessian( + m_dot[1], options=options) + if len(hessian_value) == 0: + hessian_value.append(None) + hessian_value.append(obs_hessian) + else: + hessian_value[1] += obs_hessian + + return hessian_value + + def _model_hessian(self, m_dot, options): + iopts = intermediate_options(options) + + # TLM for model from mdot[0] + forward_tlm = tlm(self.forward_model, m_dot[0], + options=iopts) + + # combine model TLM and mdot[1] + mdot_error = [forward_tlm, m_dot[1]] + + # Hessian (dual) for error + error_hessian = self.model_error_rf.hessian( + mdot_error, options=iopts, evaluate_tlm=True) + + # Hessian for model + model_hessian = hessian( + self.forward_model, options=options, + hessian_value=error_hessian[0]) + + # combine model Hessian and converted error Hessian + return [ + model_hessian, + model_hessian._ad_convert_type(error_hessian[1], + options=options) + ] diff --git a/firedrake/adjoint_utils/__init__.py b/firedrake/adjoint_utils/__init__.py index 3b3426a850..c71da61ade 100644 --- a/firedrake/adjoint_utils/__init__.py +++ b/firedrake/adjoint_utils/__init__.py @@ -12,3 +12,4 @@ from firedrake.adjoint_utils.solving import * # noqa: F401 from firedrake.adjoint_utils.mesh import * # noqa: F401 from firedrake.adjoint_utils.checkpointing import * # noqa: F401 +from firedrake.adjoint_utils.ensemblefunction import * # noqa: F401 diff --git a/firedrake/adjoint_utils/blocks/function.py b/firedrake/adjoint_utils/blocks/function.py index e31a0c4567..dcb02da108 100644 --- a/firedrake/adjoint_utils/blocks/function.py +++ b/firedrake/adjoint_utils/blocks/function.py @@ -242,11 +242,13 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, else: return adj_inputs[0] - def evaluate_tlm(self): + def evaluate_tlm(self, markings=False): tlm_input = self.get_dependencies()[0].tlm_value if tlm_input is None: return output = self.get_outputs()[0] + if markings and not output.marked_in_path: + return fs = output.output.function_space() f = type(output.output)(fs) output.add_tlm_output( diff --git a/firedrake/adjoint_utils/ensemblefunction.py b/firedrake/adjoint_utils/ensemblefunction.py new file mode 100644 index 0000000000..ba6882f054 --- /dev/null +++ b/firedrake/adjoint_utils/ensemblefunction.py @@ -0,0 +1,101 @@ +from pyadjoint.overloaded_type import OverloadedType +from firedrake.petsc import PETSc +from .checkpointing import disk_checkpointing + +from functools import wraps + + +class EnsembleFunctionMixin(OverloadedType): + + @staticmethod + def _ad_annotate_init(init): + @wraps(init) + def wrapper(self, *args, **kwargs): + OverloadedType.__init__(self) + init(self, *args, **kwargs) + return wrapper + + @staticmethod + def _ad_to_list(m): + with m.vec_ro() as gvec: + lcomm = PETSc.COMM_SELF + gsize = gvec.size + lvec = PETSc.Vec().createSeq(gsize, comm=lcomm) + is_ = PETSc.IS().createStride(gsize, 0, 1, comm=lcomm) + + mode = PETSc.InsertMode.INSERT_VALUES + scatter = PETSc.Scatter().create(gvec, is_, lvec, None) + scatter.scatterBegin(gvec, lvec, addv=mode) + scatter.scatterEnd(gvec, lvec, addv=mode) + + return lvec.array_r.tolist() + + @staticmethod + def _ad_assign_numpy(dst, src, offset): + with dst.vec_wo() as vec: + begin, end = vec.owner_range + src_array = src[offset + begin: offset + end] + vec.array[:] = src_array + offset += vec.size + return dst, offset + + def _ad_dot(self, other, options=None): + # local dot product + ldot = sum( + uself._ad_dot(uother, options=options) + for uself, uother in zip(self.subfunctions, + other.subfunctions)) + # global dot product + gdot = self.ensemble.ensemble_comm.allreduce(ldot) + return gdot + + def _ad_add(self, other): + new = self.copy() + new += other + return new + + def _ad_mul(self, other): + new = self.copy() + # `self` can be a Cofunction in which case only left multiplication with a scalar is allowed. + other = other._fbuf if type(other) is type(self) else other + new._fbuf.assign(other*new._fbuf) + return new + + def _ad_iadd(self, other): + self += other + return self + + def _ad_imul(self, other): + self *= other + return self + + def _ad_copy(self): + return self.copy() + + def _ad_convert_riesz(self, value, options=None): + raise ValueError("NotImplementedYet") + + def _ad_create_checkpoint(self): + if disk_checkpointing(): + raise NotImplementedError( + "Disk checkpointing not implemented for EnsembleFunctions") + else: + return self.copy() + + def _ad_restore_at_checkpoint(self, checkpoint): + if isinstance(checkpoint, type(self)): + return checkpoint + raise NotImplementedError( + "Checkpointing not implemented for EnsembleFunctions") + + def _ad_from_petsc(self, vec): + with self.vec_wo as self_v: + vec.copy(result=self_v) + + def _ad_to_petsc(self, vec=None): + with self.vec_ro as self_v: + if vec: + self_v.copy(result=vec) + else: + vec = self_v.copy() + return vec diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index 5e87751d36..e69e5045e9 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -233,14 +233,14 @@ def _ad_convert_riesz(self, value, options=None): return Function(V) if not isinstance(value, (Cofunction, Function)): - raise TypeError("Expected a Cofunction or a Function") + raise TypeError(f"Expected a Cofunction or a Function not a {type(value)}") if riesz_representation == "l2": return Function(V, val=value.dat) elif riesz_representation in ("L2", "H1"): if not isinstance(value, Cofunction): - raise TypeError("Expected a Cofunction") + raise TypeError(f"Expected a Cofunction not a {type(value)}") ret = Function(V) a = self._define_riesz_map_form(riesz_representation, V) diff --git a/firedrake/ensemble.py b/firedrake/ensemble.py index f847be51bf..53049fe7d7 100644 --- a/firedrake/ensemble.py +++ b/firedrake/ensemble.py @@ -1,8 +1,12 @@ import weakref +from contextlib import contextmanager +from itertools import zip_longest +from types import SimpleNamespace from firedrake.petsc import PETSc +from firedrake.function import Function +from firedrake.cofunction import Cofunction from pyop2.mpi import MPI, internal_comm -from itertools import zip_longest __all__ = ("Ensemble", ) @@ -283,3 +287,60 @@ def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, r requests.extend([self._ensemble_comm.Irecv(dat.data, source=source, tag=recvtag) for dat in frecv.dat]) return requests + + @contextmanager + def sequential(self, synchronise=False, **kwargs): + """ + Context manager for executing code on each ensemble + member consecutively by `ensemble_comm.rank`. + + Any data in `kwargs` will be made available in the context + and will be communicated forward after each ensemble member + exits. Firedrake Functions/Cofunctions will be sent with the + corresponding Ensemble methods. + For example: + + with ensemble.sequential(index=0) as ctx: + print(ensemble.ensemble_comm.rank, ctx.index) + ctx.index += 2 + + Would print: + 0 0 + 1 2 + 2 4 + 3 6 + ... etc ... + + """ + rank = self.ensemble_comm.rank + first_rank = (rank == 0) + last_rank = (rank == self.ensemble_comm.size - 1) + + if not first_rank: + src = rank - 1 + for i, (k, v) in enumerate(kwargs.items()): + recv_kwargs = {'source': src, 'tag': rank+i*100} + if isinstance(v, (Function, Cofunction)): + self.recv(kwargs[k], **recv_kwargs) + else: + kwargs[k] = self.ensemble_comm.recv( + **recv_kwargs) + + ctx = SimpleNamespace(**kwargs) + + if synchronise: + self.global_comm.Barrier() + yield ctx + self.global_comm.Barrier() + else: + yield ctx + + if not last_rank: + dst = rank + 1 + for i, v in enumerate((getattr(ctx, k) + for k in kwargs.keys())): + send_kwargs = {'dest': dst, 'tag': dst+i*100} + if isinstance(v, (Function, Cofunction)): + self.send(v, **send_kwargs) + else: + self.ensemble_comm.send(v, **send_kwargs) diff --git a/firedrake/ensemblefunction.py b/firedrake/ensemblefunction.py new file mode 100644 index 0000000000..33ed7e7b2f --- /dev/null +++ b/firedrake/ensemblefunction.py @@ -0,0 +1,295 @@ +from firedrake.petsc import PETSc +from firedrake.adjoint_utils import EnsembleFunctionMixin +from firedrake.functionspace import MixedFunctionSpace +from firedrake.function import Function +from ufl.duals import is_primal, is_dual +from pyop2 import MixedDat + +from functools import cached_property +from contextlib import contextmanager + +__all__ = ("EnsembleFunction", "EnsembleCofunction") + + +class EnsembleFunctionBase(EnsembleFunctionMixin): + """ + A mixed finite element (co)function distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The sub(co)functions are distributed + over the different ensemble members. + + function_spaces + A list of function spaces for each (co)function on the + local ensemble member. + """ + + @PETSc.Log.EventDecorator() + @EnsembleFunctionMixin._ad_annotate_init + def __init__(self, ensemble, function_spaces): + self.ensemble = ensemble + self.local_function_spaces = function_spaces + self.local_size = len(function_spaces) + + # the local functions are stored as a big mixed space + self._function_space = MixedFunctionSpace(function_spaces) + self._fbuf = Function(self._function_space) + + # create a Vec containing the data for all functions on all + # ensemble members. Because we use the Vec of each local mixed + # function as the storage, if the data in the Function Vec + # is valid then the data in the EnsembleFunction Vec is valid. + + with self._fbuf.dat.vec as fvec: + local_size = self._function_space.node_set.size + sizes = (local_size, PETSc.DETERMINE) + self._vec = PETSc.Vec().createWithArray(fvec.array, + size=sizes, + comm=ensemble.global_comm) + self._vec.setFromOptions() + + @cached_property + def subfunctions(self): + """ + The (co)functions on the local ensemble member + """ + def local_function(i): + V = self.local_function_spaces[i] + usubs = self._subcomponents(i) + if len(usubs) == 1: + dat = usubs[0].dat + else: + dat = MixedDat((u.dat for u in usubs)) + return Function(V, val=dat) + + self._subfunctions = tuple(local_function(i) + for i in range(self.local_size)) + return self._subfunctions + + def _subcomponents(self, i): + """ + Return the subfunctions of the local mixed function storage + corresponding to the i-th local function. + """ + return tuple(self._fbuf.subfunctions[j] + for j in self._component_indices(i)) + + def _component_indices(self, i): + """ + Return the indices into the local mixed function storage + corresponding to the i-th local function. + """ + V = self.local_function_spaces[i] + offset = sum(len(V) for V in self.local_function_spaces[:i]) + return tuple(offset + i for i in range(len(V))) + + @PETSc.Log.EventDecorator() + def riesz_representation(self, riesz_map="L2", **kwargs): + """ + Return the Riesz representation of this :class:`EnsembleFunction` + with respect to the given Riesz map. + + Parameters + ---------- + + riesz_map + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. + + kwargs + other arguments to be passed to the firedrake.riesz_map. + """ + DualType = { + EnsembleFunction: EnsembleCofunction, + EnsembleCofunction: EnsembleFunction, + }[type(self)] + Vdual = [V.dual() for V in self.local_function_spaces] + riesz = DualType(self.ensemble, Vdual) + for u in riesz.subfunctions: + u.assign(u.riesz_representation(riesz_map=riesz_map, **kwargs)) + return riesz + + @PETSc.Log.EventDecorator() + def assign(self, other, subsets=None): + r"""Set the :class:`EnsembleFunction` to the value of another + :class:`EnsembleFunction` other. + + Parameters + ---------- + + other + The :class:`EnsembleFunction` to assign from. + + subsets + An iterable of :class:`pyop2.types.set.Subset`, one for each local :class:`Function`. + The values of each local function will then only + be assigned on the nodes on the corresponding subset. + """ + if type(other) is not type(self): + raise ValueError( + f"Cannot assign {type(self)} from {type(other)}") + if subsets: + for i in range(self.local_size): + self.subfunctions[i].assign( + other.subfunctions[i], subset=subsets[i]) + else: + for i in range(self.local_size): + self.subfunctions[i].assign(other.subfunctions[i]) + return self + + @PETSc.Log.EventDecorator() + def copy(self): + """ + Return a deep copy of the :class:`EnsembleFunction`. + """ + new = type(self)(self.ensemble, self.local_function_spaces) + new.assign(self) + return new + + @PETSc.Log.EventDecorator() + def zero(self, subsets=None): + """ + Set values to zero. + + Parameters + ---------- + + subsets + An iterable of :class:`pyop2.types.set.Subset`, one for each local :class:`Function`. + The values of each local function will then only + be assigned on the nodes on the corresponding subset. + """ + if subsets: + for i in range(self.local_size): + self.subfunctions[i].zero(subsets[i]) + else: + for u in self.subfunctions: + u.zero() + return self + + @PETSc.Log.EventDecorator() + def __iadd__(self, other): + for us, uo in zip(self.subfunctions, other.subfunctions): + us.assign(us + uo) + return self + + @PETSc.Log.EventDecorator() + def __imul__(self, other): + if type(other) is type(self): + for us, uo in zip(self.subfunctions, other.subfunctions): + us.assign(us*uo) + else: + for us in self.subfunctions: + us *= other + return self + + @PETSc.Log.EventDecorator() + def __add__(self, other): + new = self.copy() + for i in range(self.local_size): + new.subfunctions[i] += other.subfunctions[i] + return new + + @PETSc.Log.EventDecorator() + def __mul__(self, other): + new = self.copy() + if type(other) is type(self): + for i in range(self.local_size): + self.subfunctions[i].assign(other.subfunctions[i]*self.subfunctions[i]) + else: + for i in range(self.local_size): + self.subfunctions[i].assign(other*self.subfunctions[i]) + return new + + @PETSc.Log.EventDecorator() + def __rmul__(self, other): + return self.__mul__(other) + + @contextmanager + def vec(self): + """ + Context manager for the global PETSc Vec with read/write access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager so that the data gets copied to/from + # the Function.dat storage and _vec. + # However, this copy is done without _vec knowing, so we have + # to manually increment the state. + with self._fbuf.dat.vec: + self._vec.stateIncrease() + yield self._vec + + @contextmanager + def vec_ro(self): + """ + Context manager for the global PETSc Vec with read only access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager to make sure that the data gets copied + # to the Function.dat storage and _vec. + with self._fbuf.dat.vec_ro: + self._vec.stateIncrease() + yield self._vec + + @contextmanager + def vec_wo(self): + """ + Context manager for the global PETSc Vec with write only access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager to make sure that the data gets copied + # from the Function.dat storage and _vec. + with self._fbuf.dat.vec_wo: + yield self._vec + + +class EnsembleFunction(EnsembleFunctionBase): + """ + A mixed finite element Function distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The subfunctions are distributed + over the different ensemble members. + + function_spaces + A list of function spaces for each function on the + local ensemble member. + """ + def __init__(self, ensemble, function_spaces): + if not all(is_primal(V) for V in function_spaces): + raise TypeError( + "EnsembleFunction must be created using primal FunctionSpaces") + super().__init__(ensemble, function_spaces) + + +class EnsembleCofunction(EnsembleFunctionBase): + """ + A mixed finite element Cofunction distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The subcofunctions are distributed + over the different ensemble members. + + function_spaces + A list of dual function spaces for each cofunction on the + local ensemble member. + """ + def __init__(self, ensemble, function_spaces): + if not all(is_dual(V) for V in function_spaces): + raise TypeError( + "EnsembleCofunction must be created using dual FunctionSpaces") + super().__init__(ensemble, function_spaces) diff --git a/tests/firedrake/regression/test_4dvar_reduced_functional.py b/tests/firedrake/regression/test_4dvar_reduced_functional.py new file mode 100644 index 0000000000..67fde266b7 --- /dev/null +++ b/tests/firedrake/regression/test_4dvar_reduced_functional.py @@ -0,0 +1,459 @@ +import pytest +import firedrake as fd +from firedrake.__future__ import interpolate +from firedrake.adjoint import ( + continue_annotation, pause_annotation, stop_annotating, + set_working_tape, get_working_tape, Control, taylor_test, taylor_to_dict, + ReducedFunctional, FourDVarReducedFunctional) +from numpy import mean + + +@pytest.fixture(autouse=True) +def clear_tape_teardown(): + yield + get_working_tape().clear_tape() + + +def function_space(comm): + """DG0 periodic advection""" + mesh = fd.PeriodicUnitIntervalMesh(nx, comm=comm) + return fd.FunctionSpace(mesh, "DG", 0) + + +def timestepper(V): + """Implicit midpoint timestepper for the advection equation""" + qn = fd.Function(V, name="qn") + qn1 = fd.Function(V, name="qn1") + + def mass(q, phi): + return fd.inner(q, phi)*fd.dx + + def tendency(q, phi): + u = fd.as_vector([vconst]) + n = fd.FacetNormal(V.mesh()) + un = fd.Constant(0.5)*(fd.dot(u, n) + abs(fd.dot(u, n))) + return (- q*fd.div(phi*u)*fd.dx + + fd.jump(phi)*fd.jump(un*q)*fd.dS) + + # midpoint rule + q = fd.TrialFunction(V) + phi = fd.TestFunction(V) + + qh = fd.Constant(0.5)*(q + qn) + eqn = mass(q - qn, phi) + fd.Constant(dt)*tendency(qh, phi) + + stepper = fd.LinearVariationalSolver( + fd.LinearVariationalProblem( + fd.lhs(eqn), fd.rhs(eqn), qn1, + constant_jacobian=True)) + + return qn, qn1, stepper + + +def prod2(w): + """generate weighted inner products to pass to FourDVarReducedFunctional""" + def n2(x): + return fd.assemble(fd.inner(x, fd.Constant(w)*x)*fd.dx)**2 + return n2 + + +prodB = prod2(0.1) # background error +prodR = prod2(10.) # observation error +prodQ = prod2(1.0) # model error + + +"""Advecting velocity""" +velocity = 1 +vconst = fd.Constant(velocity) + +"""Number of cells""" +nx = 16 + +"""Timestep size""" +cfl = 2.3523 +dx = 1.0/nx +dt = cfl*dx/velocity + +"""How many times / how often we take observations +(one extra at initial time)""" +observation_frequency = 5 +observation_n = 6 +observation_times = [i*observation_frequency*dt + for i in range(observation_n+1)] + + +def nlocal_observations(ensemble): + """How many observations on the current ensemble member""" + esize = ensemble.ensemble_comm.size + erank = ensemble.ensemble_comm.rank + if esize == 1: + return observation_n + 1 + assert (observation_n % esize == 0), "Must be able to split observations across ensemble" # noqa: E501 + return observation_n//esize + (1 if erank == 0 else 0) + + +def analytic_solution(V, t, mag=1.0, phase=0.0): + """Exact advection of sin wave after time t""" + x, = fd.SpatialCoordinate(V.mesh()) + return fd.Function(V).interpolate( + mag*fd.sin(2*fd.pi*((x + phase) - vconst*t))) + + +def analytic_series(V, tshift=0.0, mag=1.0, phase=0.0, ensemble=None): + """Timeseries of the analytic solution""" + series = [analytic_solution(V, t+tshift, mag=mag, phase=phase) + for t in observation_times] + + if ensemble is None: + return series + else: + nlocal_obs = nlocal_observations(ensemble) + rank = ensemble.ensemble_comm.rank + offset = (0 if rank == 0 else rank*nlocal_obs + 1) + + efunc = fd.EnsembleFunction( + ensemble, [V for _ in range(nlocal_obs)]) + + for e, s in zip(efunc.subfunctions, + series[offset:offset+nlocal_obs]): + e.assign(s) + return efunc + + +def observation_errors(V): + """List of functions to evaluate the observation error + at each observation time""" + + observation_locations = [ + [x] for x in [0.13, 0.18, 0.34, 0.36, 0.49, 0.61, 0.72, 0.99] + ] + + observation_mesh = fd.VertexOnlyMesh(V.mesh(), observation_locations) + Vobs = fd.FunctionSpace(observation_mesh, "DG", 0) + + # observation operator + def H(x): + return fd.assemble(interpolate(x, Vobs)) + + # ground truth + targets = analytic_series(V) + + # take observations + y = [H(x) for x in targets] + + # generate function to evaluate observation error at observation time i + def observation_error(i): + def obs_err(x): + return fd.Function(Vobs).assign(H(x) - y[i]) + return obs_err + + return observation_error + + +def background(V): + """Prior for initial condition""" + return analytic_solution(V, t=0, mag=0.9, phase=0.1) + + +def m(V, ensemble=None): + """The expansion points for the Taylor test""" + return analytic_series(V, tshift=0.1, mag=1.1, phase=-0.2, + ensemble=ensemble) + + +def h(V, ensemble=None): + """The perturbation direction for the Taylor test""" + return analytic_series(V, tshift=0.3, mag=0.1, phase=0.3, + ensemble=ensemble) + + +def strong_fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the strong constraint 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # prior data + bkg = background(V) + control = bkg.copy(deepcopy=True) + + # generate ground truths + obs_errors = observation_errors(V) + + continue_annotation() + set_working_tape() + + # background functional + J = prodB(control - bkg) + + # initial observation functional + J += prodR(obs_errors(0)(control)) + + qn.assign(control) + + # record observation stages + for i in range(1, len(observation_times)): + + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # observation functional + J += prodR(obs_errors(i)(qn)) + + pause_annotation() + + Jhat = ReducedFunctional(J, Control(control)) + + return Jhat + + +def strong_fdvar_firedrake(V): + """Build an FourDVarReducedFunctional for the strong constraint 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # prior data + bkg = background(V) + control = bkg.copy(deepcopy=True) + + # generate ground truths + obs_errors = observation_errors(V) + + continue_annotation() + set_working_tape() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = FourDVarReducedFunctional( + Control(control), + background_iprod=prodB, + observation_iprod=prodR, + observation_err=obs_errors(0), + weak_constraint=False) + + # record observation stages + with Jhat.recording_stages(nstages=len(observation_times)-1) as stages: + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # take observation + obs_index = stage.observation_index + stage.set_observation(qn, obs_errors(obs_index), + observation_iprod=prodR) + + pause_annotation() + return Jhat + + +def weak_fdvar_pyadjoint(V): + """Build a pyadjoint ReducedFunctional for the weak constraint 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # One control for each observation time + controls = [fd.Function(V) + for _ in range(len(observation_times))] + + # Prior + bkg = background(V) + + controls[0].assign(bkg) + + # generate ground truths + obs_errors = observation_errors(V) + + # start building the 4DVar system + continue_annotation() + set_working_tape() + + # background error + J = prodB(controls[0] - bkg) + + # initial observation error + J += prodR(obs_errors(0)(controls[0])) + + # record observation stages + for i in range(1, len(controls)): + qn.assign(controls[i-1]) + + # forward model propogation + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # we need to smuggle the state over to next + # control without the tape seeing so that we + # can continue the timeseries through the next + # stage but with the tape thinking that the + # forward model in each stage is independent. + with stop_annotating(): + controls[i].assign(qn) + + # model error for this stage + J += prodQ(qn - controls[i]) + + # observation error + J += prodR(obs_errors(i)(controls[i])) + + pause_annotation() + + Jhat = ReducedFunctional( + J, [Control(c) for c in controls]) + + return Jhat + + +def weak_fdvar_firedrake(V, ensemble): + """Build an FourDVarReducedFunctional for the weak constraint 4DVar system""" + qn, qn1, stepper = timestepper(V) + + # One control for each observation time + + nlocal_obs = nlocal_observations(ensemble) + + control = fd.EnsembleFunction( + ensemble, [V for _ in range(nlocal_obs)]) + + # Prior + bkg = background(V) + + if ensemble.ensemble_comm.rank == 0: + control.subfunctions[0].assign(bkg) + + # generate ground truths + obs_errors = observation_errors(V) + + # start building the 4DVar system + continue_annotation() + set_working_tape() + + # create 4DVar reduced functional and record + # background and initial observation functionals + + Jhat = FourDVarReducedFunctional( + Control(control), + background_iprod=prodB, + observation_iprod=prodR, + observation_err=obs_errors(0), + weak_constraint=True) + + # record observation stages + with Jhat.recording_stages() as stages: + + # loop over stages + for stage, ctx in stages: + # start forward model + qn.assign(stage.control) + + # propogate + for _ in range(observation_frequency): + qn1.assign(qn) + stepper.solve() + qn.assign(qn1) + + # take observation + obs_err = obs_errors(stage.observation_index) + stage.set_observation(qn, obs_err, + observation_iprod=prodR, + forward_model_iprod=prodQ) + + pause_annotation() + + return Jhat + + +def main_test_strong_4dvar_advection(): + V = function_space(fd.COMM_WORLD) + + # setup the reference pyadjoint rf + Jhat_pyadj = strong_fdvar_pyadjoint(V) + mp = m(V)[0] + hp = h(V)[0] + + # make sure we've set up the reference rf correctly + assert taylor_test(Jhat_pyadj, mp, hp) > 1.99 + + Jhat_aaorf = strong_fdvar_firedrake(V) + + ma = m(V)[0] + ha = h(V)[0] + + eps = 1e-12 + + # Does evaluating the functional match the reference rf? + assert abs(Jhat_pyadj(mp) - Jhat_aaorf(ma)) < eps + assert abs(Jhat_pyadj(hp) - Jhat_aaorf(ha)) < eps + + # If we match the functional, then passing the taylor tests + # should mean that we match the derivative too. + taylor = taylor_to_dict(Jhat_aaorf, ma, ha) + assert mean(taylor['R0']['Rate']) > 0.9 + assert mean(taylor['R1']['Rate']) > 1.9 + assert mean(taylor['R2']['Rate']) > 2.9 + + +def main_test_weak_4dvar_advection(): + global_comm = fd.COMM_WORLD + if global_comm.size in (1, 2): # time serial + nspace = global_comm.size + elif global_comm.size == 3: # time parallel + nspace = 1 + elif global_comm.size == 4: # space-time parallel + nspace = 2 + + ensemble = fd.Ensemble(global_comm, nspace) + V = function_space(ensemble.comm) + + erank = ensemble.ensemble_comm.rank + + # only setup the reference pyadjoint rf on the first ensemble member + if erank == 0: + Jhat_pyadj = weak_fdvar_pyadjoint(V) + mp = m(V) + hp = h(V) + # make sure we've set up the reference rf correctly + assert taylor_test(Jhat_pyadj, mp, hp) > 1.99 + + Jpm = ensemble.ensemble_comm.bcast(Jhat_pyadj(mp) if erank == 0 else None) + Jph = ensemble.ensemble_comm.bcast(Jhat_pyadj(hp) if erank == 0 else None) + + Jhat_aaorf = weak_fdvar_firedrake(V, ensemble) + + ma = m(V, ensemble) + ha = h(V, ensemble) + + eps = 1e-10 + # Does evaluating the functional match the reference rf? + assert abs(Jpm - Jhat_aaorf(ma)) < eps + assert abs(Jph - Jhat_aaorf(ha)) < eps + + # If we match the functional, then passing the taylor tests + # should mean that we match the derivative too. + taylor = taylor_to_dict(Jhat_aaorf, ma, ha) + assert mean(taylor['R0']['Rate']) > 0.9 + assert mean(taylor['R1']['Rate']) > 1.9 + assert mean(taylor['R2']['Rate']) > 2.9 + + +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.parallel(nprocs=[1, 2]) +def test_strong_4dvar_advection(): + main_test_strong_4dvar_advection() + + +@pytest.mark.skipcomplex # Taping for complex-valued 0-forms not yet done +@pytest.mark.parallel(nprocs=[1, 2, 3, 4]) +def test_weak_4dvar_advection(): + main_test_weak_4dvar_advection() + + +if __name__ == '__main__': + main_test_strong_4dvar_advection()