From 3d06fc56e58302c9ff179eb9890986fe9e4f50ee Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 11:26:48 -0600 Subject: [PATCH] ImplicitMatrixContext: handle empty action --- firedrake/assemble.py | 13 ++++++++----- firedrake/matrix_free/operators.py | 28 +++++++++++++++++----------- firedrake/slate/slate.py | 19 +++++++++++-------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 60c934b6c7..88d00c6db8 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -311,7 +311,8 @@ def __init__(self, zero_bc_nodes=False, diagonal=False, weight=1.0, - allocation_integral_types=None): + allocation_integral_types=None, + needs_zeroing=False): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) self._mat_type = mat_type self._sub_mat_type = sub_mat_type @@ -321,6 +322,7 @@ def __init__(self, self._diagonal = diagonal self._weight = weight self._allocation_integral_types = allocation_integral_types + assert not needs_zeroing def allocate(self): rank = len(self._form.arguments()) @@ -1127,7 +1129,8 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - pass + if not isinstance(tensor, op2.Global): + raise TypeError(f"Expecting a op2.Global, got {tensor!r}.") @staticmethod def _as_pyop2_type(tensor, indices=None): @@ -1143,7 +1146,7 @@ class OneFormAssembler(ParloopFormAssembler): Parameters ---------- - form : ufl.Form or slate.TensorBasehe + form : ufl.Form or slate.TensorBase 1-form. Notes @@ -1189,8 +1192,8 @@ def _apply_bc(self, tensor, bc): self._apply_dirichlet_bc(tensor, bc) elif isinstance(bc, EquationBCSplit): bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index 3ee448730e..b111cbde76 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -10,6 +10,9 @@ from firedrake.bcs import DirichletBC, EquationBCSplit from firedrake.petsc import PETSc from firedrake.utils import cached_property +from firedrake.function import Function +from firedrake.cofunction import Cofunction +from ufl.form import ZeroBaseForm __all__ = ("ImplicitMatrixContext", ) @@ -107,23 +110,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[], # create functions from test and trial space to help # with 1-form assembly - test_space, trial_space = [ - a.arguments()[i].function_space() for i in (0, 1) - ] - from firedrake import function, cofunction + test_space, trial_space = ( + arg.function_space() for arg in a.arguments() + ) # Need a cofunction since y receives the assembled result of Ax - self._ystar = cofunction.Cofunction(test_space.dual()) - self._y = function.Function(test_space) - self._x = function.Function(trial_space) - self._xstar = cofunction.Cofunction(trial_space.dual()) + self._ystar = Cofunction(test_space.dual()) + self._y = Function(test_space) + self._x = Function(trial_space) + self._xstar = Cofunction(trial_space.dual()) # These are temporary storage for holding the BC # values during matvec application. _xbc is for # the action and ._ybc is for transpose. if len(self.bcs) > 0: - self._xbc = cofunction.Cofunction(trial_space.dual()) + self._xbc = Cofunction(trial_space.dual()) if len(self.col_bcs) > 0: - self._ybc = cofunction.Cofunction(test_space.dual()) + self._ybc = Cofunction(test_space.dual()) # Get size information from template vecs on test and trial spaces trial_vec = trial_space.dof_dset.layout_vec @@ -135,6 +137,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[], self.action = action(self.a, self._x) self.actionT = action(self.aT, self._y) + # TODO prevent action from returning empty Forms + if self.action.empty(): + self.action = ZeroBaseForm(self.a.arguments()[:-1]) + if self.actionT.empty(): + self.actionT = ZeroBaseForm(self.aT.arguments()[:-1]) # For assembling action(f, self._x) self.bcs_action = [] @@ -170,7 +177,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[], @cached_property def _diagonal(self): - from firedrake import Cofunction assert self.on_diag return Cofunction(self._x.function_space().dual()) diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index fd9535c31a..1a8c792414 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -21,7 +21,10 @@ from ufl import Constant from ufl.coefficient import BaseCoefficient +from firedrake.formmanipulation import ExtractSubBlock from firedrake.function import Function, Cofunction +from firedrake.functionspace import FunctionSpace, MixedFunctionSpace +from firedrake.ufl_expr import Argument, TestFunction from firedrake.utils import cached_property, unique from itertools import chain, count @@ -35,8 +38,6 @@ from ufl.form import Form, ZeroBaseForm import hashlib -from firedrake.formmanipulation import ExtractSubBlock - from tsfc.ufl_utils import extract_firedrake_constants @@ -293,6 +294,10 @@ def solve(self, B, decomposition=None): """ return Solve(self, B, decomposition=decomposition) + def empty(self): + """Returns whether the form associated with the tensor is empty.""" + return False + @cached_property def blocks(self): """Returns an object containing the blocks of the tensor defined @@ -461,8 +466,6 @@ def arg_function_spaces(self): @cached_property def _argument(self): """Generates a 'test function' associated with this class.""" - from firedrake.ufl_expr import TestFunction - V, = self.arg_function_spaces return TestFunction(V) @@ -543,7 +546,6 @@ def arg_function_spaces(self): @cached_property def _argument(self): """Generates a tuple of 'test function' associated with this class.""" - from firedrake.ufl_expr import TestFunction return tuple(TestFunction(fs) for fs in self.arg_function_spaces) def arguments(self): @@ -668,9 +670,6 @@ def _split_arguments(self): """Splits the function space and stores the component spaces determined by the indices. """ - from firedrake.functionspace import FunctionSpace, MixedFunctionSpace - from firedrake.ufl_expr import Argument - tensor, = self.operands nargs = [] for i, arg in enumerate(tensor.arguments()): @@ -938,6 +937,10 @@ def subdomain_data(self): """ return self.form.subdomain_data() + def empty(self): + """Returns whether the form associated with the tensor is empty.""" + return self.form.empty() + def _output_string(self, prec=None): """Creates a string representation of the tensor.""" return ["S", "V", "M"][self.rank] + "_%d" % self.id