diff --git a/demos/full_waveform_inversion/full_waveform_inversion.py.rst b/demos/full_waveform_inversion/full_waveform_inversion.py.rst index 8d0b415888..45d200a05d 100644 --- a/demos/full_waveform_inversion/full_waveform_inversion.py.rst +++ b/demos/full_waveform_inversion/full_waveform_inversion.py.rst @@ -99,19 +99,29 @@ The source number is defined with the ``Ensemble.ensemble_comm`` rank:: source_number = my_ensemble.ensemble_comm.rank In this example, we consider a two-dimensional square domain with a side length of 1.0 km. The mesh is -built over the ``my_ensemble.comm`` (spatial) communicator:: - - Lx, Lz = 1.0, 1.0 - mesh = UnitSquareMesh(80, 80, comm=my_ensemble.comm) +built over the ``my_ensemble.comm`` (spatial) communicator. + +:: + + import os + if os.getenv("FIREDRAKE_CI_TESTS") == "1": + # Setup for a faster test execution. + dt = 0.03 # time step in seconds + final_time = 0.6 # final time in seconds + nx, ny = 15, 15 + else: + dt = 0.002 # time step in seconds + final_time = 1.0 # final time in seconds + nx, ny = 80, 80 -The basic input for the FWI problem are defined as follows:: + mesh = UnitSquareMesh(nx, ny, comm=my_ensemble.comm) + +The frequency of the Ricker wavelet, the source and receiver locations are defined as follows:: import numpy as np + frequency_peak = 7.0 # The dominant frequency of the Ricker wavelet in Hz. source_locations = np.linspace((0.3, 0.1), (0.7, 0.1), num_sources) receiver_locations = np.linspace((0.2, 0.9), (0.8, 0.9), 20) - dt = 0.002 # time step in seconds - final_time = 1.0 # final time in seconds - frequency_peak = 7.0 # The dominant frequency of the Ricker wavelet in Hz. Sources and receivers locations are illustrated in the following figure: diff --git a/docs/source/advanced_tut.rst b/docs/source/advanced_tut.rst index a6f4642aa1..d9fbf252ba 100644 --- a/docs/source/advanced_tut.rst +++ b/docs/source/advanced_tut.rst @@ -23,4 +23,4 @@ element systems. A pressure-convection-diffusion preconditioner for the Navier-Stokes equations. Rayleigh-Benard convection. Netgen support. - Full-waveform inversion: Full-waveform inversion: spatial and wave sources parallelism. + Full-waveform inversion: spatial and wave sources parallelism. diff --git a/firedrake/adjoint_utils/blocks/solving.py b/firedrake/adjoint_utils/blocks/solving.py index a96889b7ba..e4664665b0 100644 --- a/firedrake/adjoint_utils/blocks/solving.py +++ b/firedrake/adjoint_utils/blocks/solving.py @@ -2,8 +2,9 @@ import ufl from ufl import replace from ufl.formatting.ufl2unicode import ufl2unicode +from enum import Enum -from pyadjoint import Block, stop_annotating +from pyadjoint import Block, stop_annotating, get_working_tape from pyadjoint.enlisting import Enlist import firedrake from firedrake.adjoint_utils.checkpointing import maybe_disk_checkpoint @@ -24,6 +25,12 @@ def extract_subfunction(u, V): return u +class Solver(Enum): + """Enum for solver types.""" + FORWARD = 0 + ADJOINT = 1 + + class GenericSolveBlock(Block): pop_kwargs_keys = ["adj_cb", "adj_bdy_cb", "adj2_cb", "adj2_bdy_cb", "forward_args", "forward_kwargs", "adj_args", @@ -206,15 +213,17 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy): adj_sol_bdy = None if compute_bdy: - adj_sol_bdy = firedrake.Function( - self.function_space.dual(), - dJdu_copy.dat - firedrake.assemble( - firedrake.action(dFdu_adj_form, adj_sol) - ).dat - ) + adj_sol_bdy = self._compute_adj_bdy( + adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu_copy) return adj_sol, adj_sol_bdy + def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu): + adj_sol_bdy = firedrake.Function( + self.function_space.dual(), dJdu.dat - firedrake.assemble( + firedrake.action(dFdu_adj_form, adj_sol)).dat) + return adj_sol_bdy + def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if not self.linear and self.func == block_variable.output: @@ -604,12 +613,11 @@ def _init_solver_parameters(self, args, kwargs): class NonlinearVariationalSolveBlock(GenericSolveBlock): - def __init__(self, equation, func, bcs, adj_F, adj_cache, problem_J, + def __init__(self, equation, func, bcs, adj_cache, problem_J, solver_params, solver_kwargs, **kwargs): lhs = equation.lhs rhs = equation.rhs - self.adj_F = adj_F self._adj_cache = adj_cache self._dFdm_cache = adj_cache.setdefault("dFdm_cache", {}) self.problem_J = problem_J @@ -626,15 +634,62 @@ def _init_solver_parameters(self, args, kwargs): super()._init_solver_parameters(args, kwargs) solve_init_params(self, args, kwargs, varform=True) + def recompute_component(self, inputs, block_variable, idx, prepared): + tape = get_working_tape() + if self._ad_solvers["recompute_count"] == tape.recompute_count - 1: + # Update how many times the block has been recomputed. + self._ad_solvers["recompute_count"] = tape.recompute_count + if self._ad_solvers["forward_nlvs"]._problem._constant_jacobian: + self._ad_solvers["forward_nlvs"].invalidate_jacobian() + self._ad_solvers["update_adjoint"] = True + return super().recompute_component(inputs, block_variable, idx, prepared) + def _forward_solve(self, lhs, rhs, func, bcs, **kwargs): - self._ad_nlvs_replace_forms() - self._ad_nlvs.parameters.update(self.solver_params) - self._ad_nlvs.solve() - func.assign(self._ad_nlvs._problem.u) + self._ad_solver_replace_forms() + self._ad_solvers["forward_nlvs"].parameters.update(self.solver_params) + self._ad_solvers["forward_nlvs"].solve() + func.assign(self._ad_solvers["forward_nlvs"]._problem.u) return func - def _ad_assign_map(self, form): - count_map = self._ad_nlvs._problem._ad_count_map + def _adjoint_solve(self, dJdu, compute_bdy): + dJdu_copy = dJdu.copy() + # Homogenize and apply boundary conditions on adj_dFdu and dJdu. + bcs = self._homogenize_bcs() + for bc in bcs: + bc.apply(dJdu) + + if ( + self._ad_solvers["forward_nlvs"]._problem._constant_jacobian + and self._ad_solvers["update_adjoint"] + ): + # Update left hand side of the adjoint equation. + self._ad_solver_replace_forms(Solver.ADJOINT) + self._ad_solvers["adjoint_lvs"].invalidate_jacobian() + self._ad_solvers["update_adjoint"] = False + elif not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian: + # Update left hand side of the adjoint equation. + self._ad_solver_replace_forms(Solver.ADJOINT) + + # Update the right hand side of the adjoint equation. + # problem.F._component[1] is the right hand side of the adjoint. + self._ad_solvers["adjoint_lvs"]._problem.F._components[1].assign(dJdu) + + # Solve the adjoint linear variational solver. + self._ad_solvers["adjoint_lvs"].solve() + u_sol = self._ad_solvers["adjoint_lvs"]._problem.u + + adj_sol_bdy = None + if compute_bdy: + jac_adj = self._ad_solvers["adjoint_lvs"]._problem.J + adj_sol_bdy = self._compute_adj_bdy( + u_sol, adj_sol_bdy, jac_adj, dJdu_copy) + return u_sol, adj_sol_bdy + + def _ad_assign_map(self, form, solver): + if solver == Solver.FORWARD: + count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map + else: + count_map = self._ad_solvers["adjoint_lvs"]._problem._ad_count_map assign_map = {} form_ad_count_map = dict((count_map[coeff], coeff) for coeff in form.coefficients()) @@ -647,46 +702,37 @@ def _ad_assign_map(self, form): if coeff_count in form_ad_count_map: assign_map[form_ad_count_map[coeff_count]] = \ block_variable.saved_output + + if ( + solver == Solver.ADJOINT + and not self._ad_solvers["forward_nlvs"]._problem._constant_jacobian + ): + block_variable = self.get_outputs()[0] + coeff_count = block_variable.output.count() + if coeff_count in form_ad_count_map: + assign_map[form_ad_count_map[coeff_count]] = \ + block_variable.saved_output return assign_map - def _ad_assign_coefficients(self, form): - assign_map = self._ad_assign_map(form) + def _ad_assign_coefficients(self, form, solver): + assign_map = self._ad_assign_map(form, solver) for coeff, value in assign_map.items(): coeff.assign(value) - def _ad_nlvs_replace_forms(self): - problem = self._ad_nlvs._problem - self._ad_assign_coefficients(problem.F) - self._ad_assign_coefficients(problem.J) - - def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs): - if "dFdu_adj" in self._adj_cache: - dFdu = self._adj_cache["dFdu_adj"] + def _ad_solver_replace_forms(self, solver=Solver.FORWARD): + if solver == Solver.FORWARD: + problem = self._ad_solvers["forward_nlvs"]._problem + self._ad_assign_coefficients(problem.F, solver) + self._ad_assign_coefficients(problem.J, solver) else: - dFdu = super()._assemble_dFdu_adj(dFdu_adj_form, **kwargs) - if self._ad_nlvs._problem._constant_jacobian: - self._adj_cache["dFdu_adj"] = dFdu - return dFdu + self._ad_assign_coefficients( + self._ad_solvers["adjoint_lvs"]._problem.J, solver) def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): - dJdu = adj_inputs[0] - - F_form = self._create_F_form() - - dFdu_form = self.adj_F - dJdu = dJdu.copy() - - # Replace the form coefficients with checkpointed values. - replace_map = self._replace_map(dFdu_form) - replace_map[self.func] = self.get_outputs()[0].saved_output - dFdu_form = replace(dFdu_form, replace_map) - compute_bdy = self._should_compute_boundary_adjoint( relevant_dependencies ) - adj_sol, adj_sol_bdy = self._assemble_and_solve_adj_eq( - dFdu_form, dJdu, compute_bdy - ) + adj_sol, adj_sol_bdy = self._adjoint_solve(adj_inputs[0], compute_bdy) self.adj_state = adj_sol if self.adj_cb is not None: self.adj_cb(adj_sol) @@ -694,8 +740,8 @@ def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): self.adj_bdy_cb(adj_sol_bdy) r = {} - r["form"] = F_form - r["adj_sol"] = adj_sol + r["form"] = self._create_F_form() + r["adj_sol"] = self.adj_state r["adj_sol_bdy"] = adj_sol_bdy return r diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index a6811002ac..c90d2668e0 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -45,7 +45,8 @@ def wrapper(self, problem, *args, **kwargs): self._ad_problem = problem self._ad_args = args self._ad_kwargs = kwargs - self._ad_nlvs = None + self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None, + "recompute_count": 0} self._ad_adj_cache = {} return wrapper @@ -58,7 +59,7 @@ def wrapper(self, **kwargs): Firedrake solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint computation (such as projecting fields to other function spaces for the purposes of visualisation).""" - + from firedrake import LinearVariationalSolver annotate = annotate_tape(kwargs) if annotate: tape = get_working_tape() @@ -69,20 +70,31 @@ def wrapper(self, **kwargs): block = NonlinearVariationalSolveBlock(problem._ad_F == 0, problem._ad_u, problem._ad_bcs, - problem._ad_adj_F, adj_cache=self._ad_adj_cache, problem_J=problem._ad_J, solver_params=self.parameters, solver_kwargs=self._ad_kwargs, ad_block_tag=self.ad_block_tag, **sb_kwargs) - if not self._ad_nlvs: - self._ad_nlvs = type(self)( + + # Forward variational solver. + if not self._ad_solvers["forward_nlvs"]: + self._ad_solvers["forward_nlvs"] = type(self)( self._ad_problem_clone(self._ad_problem, block.get_dependencies()), **self._ad_kwargs ) - block._ad_nlvs = self._ad_nlvs + # Adjoint variational solver. + if not self._ad_solvers["adjoint_lvs"]: + with stop_annotating(): + self._ad_solvers["adjoint_lvs"] = LinearVariationalSolver( + self._ad_adj_lvs_problem(block, problem._ad_adj_F), + *block.adj_args, **block.adj_kwargs) + if self._ad_problem._constant_jacobian: + self._ad_solvers["update_adjoint"] = False + + block._ad_solvers = self._ad_solvers + tape.add_block(block) with stop_annotating(): @@ -103,22 +115,62 @@ def _ad_problem_clone(self, problem, dependencies): affect the user-defined self._ad_problem.F, self._ad_problem.J and self._ad_problem.u expressions, we'll instead create clones of them. """ - from firedrake import Function, NonlinearVariationalProblem + from firedrake import NonlinearVariationalProblem + _ad_count_map, J_replace_map, F_replace_map = self._build_count_map( + problem.J, dependencies, F=problem.F) + nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map), + F_replace_map[problem.u_restrict], + bcs=problem.bcs, + J=replace(problem.J, J_replace_map)) + nlvp.is_linear = problem.is_linear + nlvp._constant_jacobian = problem._constant_jacobian + nlvp._ad_count_map_update(_ad_count_map) + return nlvp + + @no_annotations + def _ad_adj_lvs_problem(self, block, adj_F): + """Create the adjoint variational problem.""" + from firedrake import Function, Cofunction, LinearVariationalProblem + # Homogeneous boundary conditions for the adjoint problem + # when Dirichlet boundary conditions are applied. + bcs = block._homogenize_bcs() + adj_sol = Function(block.function_space) + right_hand_side = Cofunction(block.function_space.dual()) + tmp_problem = LinearVariationalProblem( + adj_F, right_hand_side, adj_sol, bcs=bcs, + constant_jacobian=self._ad_problem._constant_jacobian) + # The `block.adj_F` coefficients hold the output references. + # We do not want to modify the user-defined values. Hence, the adjoint + # linear variational problem is created with a deep copy of the + # `block.adj_F` coefficients. + _ad_count_map, J_replace_map, _ = self._build_count_map( + adj_F, block._dependencies) + lvp = LinearVariationalProblem( + replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol, + bcs=tmp_problem.bcs, + constant_jacobian=self._ad_problem._constant_jacobian) + lvp._ad_count_map_update(_ad_count_map) + return lvp + + def _build_count_map(self, J, dependencies, F=None): + from firedrake import Function + F_replace_map = {} J_replace_map = {} - - F_coefficients = problem.F.coefficients() - J_coefficients = problem.J.coefficients() + if F: + F_coefficients = F.coefficients() + J_coefficients = J.coefficients() _ad_count_map = {} for block_variable in dependencies: coeff = block_variable.output - if coeff in F_coefficients and coeff not in F_replace_map: - if isinstance(coeff, Function) and coeff.ufl_element().family() == "Real": - F_replace_map[coeff] = copy.deepcopy(coeff) - else: - F_replace_map[coeff] = coeff.copy(deepcopy=True) - _ad_count_map[F_replace_map[coeff]] = coeff.count() + if F: + if coeff in F_coefficients and coeff not in F_replace_map: + if isinstance(coeff, Function) and coeff.ufl_element().family() == "Real": + F_replace_map[coeff] = copy.deepcopy(coeff) + else: + F_replace_map[coeff] = coeff.copy(deepcopy=True) + _ad_count_map[F_replace_map[coeff]] = coeff.count() if coeff in J_coefficients and coeff not in J_replace_map: if coeff in F_replace_map: @@ -128,11 +180,4 @@ def _ad_problem_clone(self, problem, dependencies): else: J_replace_map[coeff] = coeff.copy() _ad_count_map[J_replace_map[coeff]] = coeff.count() - - nlvp = NonlinearVariationalProblem(replace(problem.F, F_replace_map), - F_replace_map[problem.u_restrict], - bcs=problem.bcs, - J=replace(problem.J, J_replace_map)) - nlvp._constant_jacobian = problem._constant_jacobian - nlvp._ad_count_map_update(_ad_count_map) - return nlvp + return _ad_count_map, J_replace_map, F_replace_map diff --git a/firedrake/assemble.py b/firedrake/assemble.py index b72f99ba2c..181909161e 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1759,6 +1759,16 @@ def _as_global_kernel_arg_interior_facet(_, self): return op2.DatKernelArg((2,)) +@_as_global_kernel_arg.register(kernel_args.ExteriorFacetOrientationKernelArg) +def _as_global_kernel_arg_exterior_facet_orientation(_, self): + return op2.DatKernelArg((1,)) + + +@_as_global_kernel_arg.register(kernel_args.InteriorFacetOrientationKernelArg) +def _as_global_kernel_arg_interior_facet_orientation(_, self): + return op2.DatKernelArg((2,)) + + @_as_global_kernel_arg.register(CellFacetKernelArg) def _as_global_kernel_arg_cell_facet(_, self): if self._mesh.extruded: @@ -2053,6 +2063,16 @@ def _as_parloop_arg_interior_facet(_, self): return op2.DatParloopArg(self._mesh.interior_facets.local_facet_dat) +@_as_parloop_arg.register(kernel_args.ExteriorFacetOrientationKernelArg) +def _as_parloop_arg_exterior_facet_orientation(_, self): + return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_orientation_dat) + + +@_as_parloop_arg.register(kernel_args.InteriorFacetOrientationKernelArg) +def _as_parloop_arg_interior_facet_orientation(_, self): + return op2.DatParloopArg(self._mesh.interior_facets.local_facet_orientation_dat) + + @_as_parloop_arg.register(CellFacetKernelArg) def _as_parloop_arg_cell_facet(_, self): return op2.DatParloopArg(self._mesh.cell_to_facets) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index ae20ab0e2c..ee38b524e4 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -459,6 +459,9 @@ def extract_form(self, form_type): # DirichletBC is directly used in assembly. return self + def _as_nonlinear_variational_problem_arg(self): + return self + class EquationBC(object): r'''Construct and store EquationBCSplit objects (for `F`, `J`, and `Jp`). @@ -549,12 +552,15 @@ def extract_form(self, form_type): return getattr(self, f"_{form_type}") @PETSc.Log.EventDecorator() - def reconstruct(self, V, subu, u, field): + def reconstruct(self, V, subu, u, field, is_linear): _F = self._F.reconstruct(field=field, V=V, subu=subu, u=u) _J = self._J.reconstruct(field=field, V=V, subu=subu, u=u) _Jp = self._Jp.reconstruct(field=field, V=V, subu=subu, u=u) if all([_F is not None, _J is not None, _Jp is not None]): - return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=self.is_linear) + return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear) + + def _as_nonlinear_variational_problem_arg(self): + return self class EquationBCSplit(BCBase): @@ -645,6 +651,20 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col ebc.add(bc_temp) return ebc + def _as_nonlinear_variational_problem_arg(self): + # NonlinearVariationalProblem expects EquationBC, not EquationBCSplit. + # -- This method is required when NonlinearVariationalProblem is constructed inside PC. + if len(self.f.arguments()) != 2: + raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)") + J = self.f + Vcol = J.arguments()[-1].function_space() + u = firedrake.Function(Vcol) + F = ufl_expr.action(J, u) + Vrow = self._function_space + sub_domain = self.sub_domain + bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs) + return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow) + @PETSc.Log.EventDecorator() def homogenize(bc): diff --git a/firedrake/mesh.py b/firedrake/mesh.py index a844889274..05d7a4a77f 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -36,6 +36,7 @@ ) from firedrake.adjoint_utils import MeshGeometryMixin from pyadjoint import stop_annotating +import gem try: import netgen @@ -44,6 +45,7 @@ ngsPETSc = None # Only for docstring import mpi4py # noqa: F401 +from tsfc.finatinterface import as_fiat_cell __all__ = [ @@ -317,6 +319,44 @@ def facet_cell_map(self): return op2.Map(self.set, self.mesh.cell_set, self._rank, self.facet_cell, "facet_to_cell_map") + @utils.cached_property + def local_facet_orientation_dat(self): + """Dat for the local facet orientations.""" + dtype = gem.uint_type + # Make a map from cell to facet orientations. + fiat_cell = as_fiat_cell(self.mesh.ufl_cell()) + topo = fiat_cell.topology + num_entities = [0] + for d in range(len(topo)): + num_entities.append(len(topo[d])) + offsets = np.cumsum(num_entities) + local_facet_start = offsets[-3] + local_facet_end = offsets[-2] + map_from_cell_to_facet_orientations = self.mesh.entity_orientations[:, local_facet_start:local_facet_end] + # Make output data; + # this is a map from an exterior/interior facet to the corresponding local facet orientation/orientations. + # Halo data are required by design, but not actually used. + # -- Reshape as (-1, self._rank) to uniformly handle exterior and interior facets. + data = np.empty_like(self.local_facet_dat.data_ro_with_halos).reshape((-1, self._rank)) + data.fill(np.iinfo(dtype).max) + # Set local facet orientations on the block corresponding to the owned facets; i.e., data[:shape[0], :] below. + local_facets = self.local_facet_dat.data_ro # do not need halos. + # -- Reshape as (-1, self._rank) to uniformly handle exterior and interior facets. + local_facets = local_facets.reshape((-1, self._rank)) + shape = local_facets.shape + map_from_owned_facet_to_cells = self.facet_cell[:shape[0], :] + data[:shape[0], :] = np.take_along_axis( + map_from_cell_to_facet_orientations[map_from_owned_facet_to_cells], + local_facets.reshape(shape + (1, )), # reshape as required by take_along_axis. + axis=2, + ).reshape(shape) + return op2.Dat( + self.local_facet_dat.dataset, + data, + dtype, + f"{self.mesh.name}_{self.kind}_local_facet_orientation" + ) + @PETSc.Log.EventDecorator() def _from_gmsh(filename, comm=None): diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index 0eda241218..2c82092f51 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -143,6 +143,7 @@ def compile_element(expression, dual_space=None, parameters=None, config = dict(interface=builder, ufl_cell=cell, + integral_type="cell", point_indices=(), point_expr=point, argument_multiindices=argument_multiindices, @@ -537,6 +538,7 @@ def dg_injection_kernel(Vf, Vc, ncell): integration_dim, entity_ids = lower_integral_type(Vfe.cell, "cell") macro_cfg = dict(interface=macro_builder, ufl_cell=Vf.ufl_cell(), + integral_type="cell", integration_dim=integration_dim, entity_ids=entity_ids, index_cache=index_cache, @@ -573,6 +575,7 @@ def dg_injection_kernel(Vf, Vc, ncell): coarse_cfg = dict(interface=coarse_builder, ufl_cell=Vc.ufl_cell(), + integral_type="cell", integration_dim=integration_dim, entity_ids=entity_ids, index_cache=index_cache, diff --git a/firedrake/pointeval_utils.py b/firedrake/pointeval_utils.py index bf858f9c72..f8a985273d 100644 --- a/firedrake/pointeval_utils.py +++ b/firedrake/pointeval_utils.py @@ -71,6 +71,7 @@ def compile_element(expression, coordinates, parameters=None): config = dict(interface=builder, ufl_cell=extract_unique_domain(coordinates).ufl_cell(), + integral_type="cell", point_indices=(), point_expr=point, scalar_type=utils.ScalarType) diff --git a/firedrake/pointquery_utils.py b/firedrake/pointquery_utils.py index 03db28fa4b..82059def13 100644 --- a/firedrake/pointquery_utils.py +++ b/firedrake/pointquery_utils.py @@ -160,6 +160,7 @@ def to_reference_coords_newton_step(ufl_coordinate_element, parameters, x0_dtype context = tsfc.fem.GemPointContext( interface=builder, ufl_cell=cell, + integral_type="cell", point_indices=(), point_expr=point, scalar_type=parameters["scalar_type"] diff --git a/firedrake/preconditioners/base.py b/firedrake/preconditioners/base.py index e7b809024e..0bdfc97a37 100644 --- a/firedrake/preconditioners/base.py +++ b/firedrake/preconditioners/base.py @@ -95,11 +95,11 @@ def form(self, obj, *args): if P.getType() == "python": ctx = P.getPythonContext() a = ctx.a - bcs = tuple(ctx.row_bcs) + bcs = tuple(ctx.bcs) else: ctx = get_appctx(pc.getDM()) a = ctx.Jp or ctx.J - bcs = tuple(ctx._problem.bcs) + bcs = ctx.bcs_Jp if len(args): a = a(*args) return a, bcs @@ -121,6 +121,8 @@ def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None): old_appctx = get_appctx(dm).appctx u = Function(op.arguments()[-1].function_space()) F = action(op, u) + if bcs: + bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in bcs) nprob = NonlinearVariationalProblem(F, u, bcs=bcs, J=op, diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index a9dda71038..9e843016b5 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -369,7 +369,7 @@ def split(self, fields): if isinstance(bc, DirichletBC): bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain) elif isinstance(bc, EquationBC): - bc_temp = bc.reconstruct(field, V, subu, u) + bc_temp = bc.reconstruct(V, subu, u, field, False) if bc_temp is not None: bcs.append(bc_temp) new_problem = NLVP(F, subu, bcs=bcs, J=J, Jp=Jp, diff --git a/tests/firedrake/equation_bcs/test_equation_bcs.py b/tests/firedrake/equation_bcs/test_equation_bcs.py index 0037537ae4..087b07aa36 100644 --- a/tests/firedrake/equation_bcs/test_equation_bcs.py +++ b/tests/firedrake/equation_bcs/test_equation_bcs.py @@ -354,3 +354,46 @@ def test_EquationBC_mixedpoisson_matfree_fieldsplit(): err.append(nonlinear_poisson_mixed(solver_parameters, mesh_num, porder)) assert abs(math.log2(err[0][0]) - math.log2(err[1][0]) - (porder+1)) < 0.05 + + +def test_equation_bcs_pc(): + mesh = UnitSquareMesh(2**6, 2**6) + CG = FunctionSpace(mesh, "CG", 3) + R = FunctionSpace(mesh, "R", 0) + V = CG * R + f = Function(V) + u, l = split(f) + v, w = split(TestFunction(V)) + x, y = SpatialCoordinate(mesh) + exact = cos(2 * pi * x) * cos(2 * pi * y) + g = Function(CG).interpolate(8 * pi**2 * exact) + F = inner(grad(u), grad(v)) * dx + inner(l, w) * dx - inner(g, v) * dx + bc = EquationBC(inner((u - exact), v) * ds == 0, f, (1, 2, 3, 4), V=V.sub(0)) + params = { + "mat_type": "matfree", + "ksp_type": "fgmres", + "pc_type": "fieldsplit", + "pc_fieldsplit_type": "schur", + "pc_fieldsplit_schur_fact_type": "full", + "pc_fieldsplit_0_fields": "0", + "pc_fieldsplit_1_fields": "1", + "fieldsplit_0": { + "ksp_type": "preonly", + "pc_type": "python", + "pc_python_type": "firedrake.AssembledPC", + "assembled": { + "ksp_type": "cg", + "pc_type": "asm", + }, + }, + "fieldsplit_1": { + "ksp_type": "gmres", + "max_it": 1, + "convergence_test": "skip", + } + } + problem = NonlinearVariationalProblem(F, f, bcs=[bc]) + solver = NonlinearVariationalSolver(problem, solver_parameters=params) + solver.solve() + error = assemble(inner(u - exact, u - exact) * dx)**0.5 + assert error < 1.e-7 diff --git a/tests/firedrake/regression/test_adjoint_operators.py b/tests/firedrake/regression/test_adjoint_operators.py index f31a9b4acb..a8f5149657 100644 --- a/tests/firedrake/regression/test_adjoint_operators.py +++ b/tests/firedrake/regression/test_adjoint_operators.py @@ -840,7 +840,7 @@ def test_assign_cofunction(solve_type): solver.solve() J += assemble(((sol + Constant(1.0)) ** 2) * dx) rf = ReducedFunctional(J, Control(k)) - assert rf(k) == J + assert np.isclose(rf(k), J, rtol=1e-10) assert taylor_test(rf, k, Function(V).assign(0.1)) > 1.9 @@ -969,17 +969,15 @@ def test_lvs_constant_jacobian(constant_jacobian): solver.solve() J = assemble(v * v * dx) - assert "dFdu_adj" not in solver._ad_adj_cache + J_hat = ReducedFunctional(J, Control(u)) - dJ = compute_gradient(J, Control(u), options={"riesz_representation": "l2"}) - - cached_dFdu_adj = solver._ad_adj_cache.get("dFdu_adj", None) - assert (cached_dFdu_adj is None) == (not constant_jacobian) + dJ = J_hat.derivative(options={"riesz_representation": None}) assert np.allclose(dJ.dat.data_ro, 2 * assemble(inner(u_ref, test) * dx).dat.data_ro) - dJ = compute_gradient(J, Control(u), options={"riesz_representation": "l2"}) + u_ref = Function(space, name="u").interpolate(X[0] - 0.1) + J_hat(u_ref) - assert cached_dFdu_adj is solver._ad_adj_cache.get("dFdu_adj", None) + dJ = J_hat.derivative(options={"riesz_representation": None}) assert np.allclose(dJ.dat.data_ro, 2 * assemble(inner(u_ref, test) * dx).dat.data_ro) diff --git a/tests/firedrake/regression/test_integral_hex.py b/tests/firedrake/regression/test_integral_hex.py index 525dff9af8..4d9a68e538 100644 --- a/tests/firedrake/regression/test_integral_hex.py +++ b/tests/firedrake/regression/test_integral_hex.py @@ -18,3 +18,38 @@ def test_integral_hex_exterior_facet(mesh_from_file, family): x, y, z = SpatialCoordinate(mesh) f = Function(V).interpolate(2 * x + 3 * y * y + 4 * z * z * z) assert abs(assemble(f * ds) - (2 + 4 + 2 + 5 + 2 + 6)) < 1.e-10 + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize('mesh_from_file', [False, True]) +@pytest.mark.parametrize('family', ["Q", "DQ"]) +def test_integral_hex_interior_facet(mesh_from_file, family): + if mesh_from_file: + mesh = Mesh(join(cwd, "..", "meshes", "cube_hex.msh")) + else: + mesh = UnitCubeMesh(2, 3, 5, hexahedral=True) + V = FunctionSpace(mesh, family, 3) + x, y, z = SpatialCoordinate(mesh) + f = Function(V).interpolate(2 * x + 3 * y * y + 4 * z * z * z) + assert assemble((f('+') - f('-'))**2 * dS)**0.5 < 1.e-14 + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize('mesh_from_file', [False, True]) +def test_integral_hex_interior_facet_solve(mesh_from_file): + if mesh_from_file: + mesh = Mesh(join(cwd, "..", "meshes", "cube_hex.msh")) + else: + mesh = UnitCubeMesh(2, 3, 5, hexahedral=True) + V = FunctionSpace(mesh, "Q", 1) + x, y, z = SpatialCoordinate(mesh) + f = Function(V).interpolate(2 * x + 3 * y + 5 * z) + u = TrialFunction(V) + v = TestFunction(V) + a = inner(u('+'), v('+')) * dS(degree=3) + L = inner(f('+'), v('-')) * dS(degree=3) + bc = DirichletBC(V, f, "on_boundary") + sol = Function(V) + solve(a == L, sol, bcs=[bc]) + err = assemble((sol - f)**2 * dx)**0.5 + assert err < 1.e-14