diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 82d4ca80db..75ed1f85af 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -29,6 +29,7 @@ from firedrake.utils import ScalarType, assert_empty, tuplify from pyop2 import op2 from pyop2.exceptions import MapValueError, SparsityFormatError +from pyop2.types.mat import _GlobalMatPayload, _DatMatPayload from pyop2.utils import cached_property @@ -965,22 +966,24 @@ def assemble(self, tensor=None): Result of assembly: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms. """ - self._check_tensor(tensor) - if tensor is None: - tensor = self.allocate() - needs_zeroing = False - else: - needs_zeroing = self._needs_zeroing if annotate_tape(): raise NotImplementedError( "Taping with explicit FormAssembler objects is not supported yet. " "Use assemble instead." ) - if needs_zeroing: - type(self)._as_pyop2_type(tensor).zero() + + if tensor is None: + tensor = self.allocate() + else: + self._check_tensor(tensor) + if self._needs_zeroing: + self._as_pyop2_type(tensor).zero() + self.execute_parloops(tensor) + for bc in self._bcs: self._apply_bc(tensor, bc) + return self.result(tensor) @abc.abstractmethod @@ -992,9 +995,9 @@ def _check_tensor(self, tensor): """Check input tensor.""" @staticmethod - def _as_pyop2_type(tensor): - """Return tensor as pyop2 type.""" - raise NotImplementedError + @abc.abstractmethod + def _as_pyop2_type(tensor, indices=None): + """Cast a Firedrake tensor into a PyOP2 data structure, optionally indexing it.""" def execute_parloops(self, tensor): for parloop in self.parloops(tensor): @@ -1003,20 +1006,14 @@ def execute_parloops(self, tensor): def parloops(self, tensor): if hasattr(self, "_parloops"): for (lknl, _), parloop in zip(self.local_kernels, self._parloops): - data = _FormHandler.index_tensor(tensor, self._form, lknl.indices, self.diagonal) + data = self._as_pyop2_type(tensor, lknl.indices) parloop.arguments[0].data = data + else: # Make parloops for one concrete output tensor and cache them. - # TODO: Make parloops only with some symbolic information of the output tensor. - self._parloops = tuple(parloop_builder.build(tensor) for parloop_builder in self.parloop_builders) - return self._parloops - - @cached_property - def parloop_builders(self): - out = [] - for local_kernel, subdomain_id in self.local_kernels: - out.append( - ParloopBuilder( + parloops_ = [] + for local_kernel, subdomain_id in self.local_kernels: + parloop_builder = ParloopBuilder( self._form, self._bcs, local_kernel, @@ -1024,8 +1021,12 @@ def parloop_builders(self): self.all_integer_subdomain_ids, diagonal=self.diagonal, ) - ) - return tuple(out) + pyop2_tensor = self._as_pyop2_type(tensor, local_kernel.indices) + parloop = parloop_builder.build(pyop2_tensor) + parloops_.append(parloop) + self._parloops = tuple(parloops_) + + return self._parloops @cached_property def local_kernels(self): @@ -1120,10 +1121,11 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - assert tensor is None + pass @staticmethod - def _as_pyop2_type(tensor): + def _as_pyop2_type(tensor, indices=None): + assert not indices return tensor def result(self, tensor): @@ -1198,15 +1200,16 @@ def _apply_dirichlet_bc(self, tensor, bc): bc.zero(tensor) def _check_tensor(self, tensor): - rank = len(self._form.arguments()) - if rank == 1: - test, = self._form.arguments() - if tensor is not None and test.function_space() != tensor.function_space(): - raise ValueError("Form's argument does not match provided result tensor") + if tensor.function_space() != self._form.arguments()[0].function_space(): + raise ValueError("Form's argument does not match provided result tensor") @staticmethod - def _as_pyop2_type(tensor): - return tensor.dat + def _as_pyop2_type(tensor, indices=None): + if indices is not None and any(index is not None for index in indices): + i, = indices + return tensor.dat[i] + else: + return tensor.dat def execute_parloops(self, tensor): # We are repeatedly incrementing into the same Dat so intermediate halo exchanges @@ -1454,12 +1457,26 @@ def _apply_bcs_mat_real_block(op2tensor, i, j, component, node_set): dat.zero(subset=node_set) def _check_tensor(self, tensor): - if tensor is not None and tensor.a.arguments() != self._form.arguments(): + if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") @staticmethod - def _as_pyop2_type(tensor): - return tensor.M + def _as_pyop2_type(tensor, indices=None): + if indices is not None and any(index is not None for index in indices): + i, j = indices + mat = tensor.M[i, j] + else: + mat = tensor.M + + if mat.handle.getType() == "python": + mat_context = mat.handle.getPythonContext() + if isinstance(mat_context, _GlobalMatPayload): + mat = mat_context.global_ + else: + assert isinstance(mat_context, _DatMatPayload) + mat = mat_context.dat + + return mat def result(self, tensor): tensor.M.assemble() @@ -1471,7 +1488,7 @@ class MatrixFreeAssembler(FormAssembler): Parameters ---------- - form : ufl.Form or slate.TensorBasehe + form : ufl.Form or slate.TensorBase 2-form. Notes @@ -1498,14 +1515,15 @@ def allocate(self): appctx=self._appctx or {}) def assemble(self, tensor=None): - self._check_tensor(tensor) if tensor is None: tensor = self.allocate() + else: + self._check_tensor(tensor) tensor.assemble() return tensor def _check_tensor(self, tensor): - if tensor is not None and tensor.a.arguments() != self._form.arguments(): + if tensor.a.arguments() != self._form.arguments(): raise ValueError("Form's arguments do not match provided result tensor") @@ -1820,12 +1838,12 @@ def __init__(self, form, bcs, local_knl, subdomain_id, self._active_coefficients = _FormHandler.iter_active_coefficients(form, local_knl.kinfo) self._constants = _FormHandler.iter_constants(form, local_knl.kinfo) - def build(self, tensor): + def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop: """Construct the parloop. Parameters ---------- - tensor : op2.Global or firedrake.cofunction.Cofunction or matrix.MatrixBase + tensor : The output tensor. """ @@ -1909,17 +1927,28 @@ def collect_lgmaps(self): :param local_knl: A :class:`tsfc_interface.SplitKernel`. :param bcs: Iterable of boundary conditions. """ + if len(self._form.arguments()) == 2 and not self._diagonal: if not self._bcs: return None - lgmaps = [] - for i, j in self.get_indicess(): + + if any(i is not None for i in self._local_knl.indices): + i, j = self._local_knl.indices row_bcs, col_bcs = self._filter_bcs(i, j) - rlgmap, clgmap = self._tensor.M[i, j].local_to_global_maps + # the tensor is already indexed + rlgmap, clgmap = self._tensor.local_to_global_maps rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap) clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap) - lgmaps.append((rlgmap, clgmap)) - return tuple(lgmaps) + return ((rlgmap, clgmap),) + else: + lgmaps = [] + for i, j in self.get_indicess(): + row_bcs, col_bcs = self._filter_bcs(i, j) + rlgmap, clgmap = self._tensor[i, j].local_to_global_maps + rlgmap = self.test_function_space[i].local_to_global_map(row_bcs, rlgmap) + clgmap = self.trial_function_space[j].local_to_global_map(col_bcs, clgmap) + lgmaps.append((rlgmap, clgmap)) + return tuple(lgmaps) else: return None @@ -1939,10 +1968,6 @@ def _integral_type(self): def _indexed_function_spaces(self): return _FormHandler.index_function_spaces(self._form, self._indices) - @property - def _indexed_tensor(self): - return _FormHandler.index_tensor(self._tensor, self._form, self._indices, self._diagonal) - @cached_property def _mesh(self): return tuple(self._form.ufl_domains())[self._kinfo.domain_number] @@ -1990,28 +2015,27 @@ def _as_parloop_arg(tsfc_arg, self): @_as_parloop_arg.register(kernel_args.OutputKernelArg) def _as_parloop_arg_output(_, self): rank = len(self._form.arguments()) - tensor = self._indexed_tensor Vs = self._indexed_function_spaces if rank == 0: - return op2.GlobalParloopArg(tensor) + return op2.GlobalParloopArg(self._tensor) elif rank == 1 or rank == 2 and self._diagonal: V, = Vs if V.ufl_element().family() == "Real": - return op2.GlobalParloopArg(tensor) + return op2.GlobalParloopArg(self._tensor) else: - return op2.DatParloopArg(tensor, self._get_map(V)) + return op2.DatParloopArg(self._tensor, self._get_map(V)) elif rank == 2: rmap, cmap = [self._get_map(V) for V in Vs] if all(V.ufl_element().family() == "Real" for V in Vs): assert rmap is None and cmap is None - return op2.GlobalParloopArg(tensor.handle.getPythonContext().global_) + return op2.GlobalParloopArg(self._tensor) elif any(V.ufl_element().family() == "Real" for V in Vs): m = rmap or cmap - return op2.DatParloopArg(tensor.handle.getPythonContext().dat, m) + return op2.DatParloopArg(self._tensor, m) else: - return op2.MatParloopArg(tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) + return op2.MatParloopArg(self._tensor, (rmap, cmap), lgmaps=self.collect_lgmaps()) else: raise AssertionError @@ -2122,22 +2146,3 @@ def index_function_spaces(form, indices): return tuple(a.ufl_function_space()[i] for i, a in zip(indices, form.arguments())) else: raise AssertionError - - @staticmethod - def index_tensor(tensor, form, indices, diagonal): - """Return the PyOP2 data structure tied to ``tensor``, indexed - if necessary. - """ - rank = len(form.arguments()) - is_indexed = any(i is not None for i in indices) - - if rank == 0: - return tensor - elif rank == 1 or rank == 2 and diagonal: - i, = indices - return tensor.dat[i] if is_indexed else tensor.dat - elif rank == 2: - i, j = indices - return tensor.M[i, j] if is_indexed else tensor.M - else: - raise AssertionError diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 074606c124..53f69e92ce 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -843,8 +843,7 @@ def local_to_global_map(self, bcs, lgmap=None): return PETSc.LGMap().create(indices, bsize=bsize, comm=lgmap.comm) def collapse(self): - from firedrake import FunctionSpace - return FunctionSpace(self.mesh(), self.ufl_element()) + return type(self)(self.mesh(), self.ufl_element()) class RestrictedFunctionSpace(FunctionSpace): @@ -1161,8 +1160,7 @@ def _ises(self): return self.dof_dset.field_ises def collapse(self): - from firedrake import MixedFunctionSpace - return MixedFunctionSpace([V_ for V_ in self]) + return type(self)([V_ for V_ in self], self.mesh()) class ProxyFunctionSpace(FunctionSpace): diff --git a/tests/conftest.py b/tests/conftest.py index 0874f5b01e..9401b565b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ """Global test configuration.""" import pytest -from firedrake.petsc import get_external_packages +from firedrake.petsc import PETSc, get_external_packages def pytest_configure(config): @@ -122,3 +122,34 @@ def fin(): assert len(tape.get_blocks()) == 0 request.addfinalizer(fin) + + +class _petsc_raises: + """Context manager for catching PETSc-raised exceptions. + + The usual `pytest.raises` exception handler is not suitable for errors + raised inside a callback to PETSc because the error is wrapped inside a + `PETSc.Error` object and so this context manager unpacks this to access + the actual internal error. + + Parameters + ---------- + exc_type : + The exception type that is expected to be raised inside a PETSc callback. + + """ + def __init__(self, exc_type): + self.exc_type = exc_type + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, traceback): + if exc_type is PETSc.Error and isinstance(exc_val.__cause__, self.exc_type): + return True + + +@pytest.fixture +def petsc_raises(): + # This function is needed because pytest does not support classes as fixtures. + return _petsc_raises diff --git a/tests/macro/test_macro_multigrid.py b/tests/macro/test_macro_multigrid.py index 0c91119383..c45b085555 100644 --- a/tests/macro/test_macro_multigrid.py +++ b/tests/macro/test_macro_multigrid.py @@ -140,7 +140,7 @@ def test_macro_grid_transfer(hierarchy, space, degrees, variant, transfer_type): @pytest.mark.parametrize("degree", (1,)) -def test_macro_multigrid_poisson(hierarchy, degree, variant): +def test_macro_multigrid_poisson(hierarchy, degree, variant, petsc_raises): mesh = hierarchy[-1] V = FunctionSpace(mesh, "CG", degree, variant=variant) u = TrialFunction(V) @@ -153,7 +153,7 @@ def test_macro_multigrid_poisson(hierarchy, degree, variant): problem = LinearVariationalProblem(a, L, uh, bcs=bcs) solver = LinearVariationalSolver(problem, solver_parameters=mg_params) if complex_mode and variant == "alfeld": - with pytest.raises(NotImplementedError): + with petsc_raises(NotImplementedError): solver.solve() else: solver.solve() @@ -172,7 +172,7 @@ def square_hierarchy(): @pytest.mark.parametrize("family", ("HCT-red", "HCT")) -def test_macro_multigrid_biharmonic(square_hierarchy, family): +def test_macro_multigrid_biharmonic(square_hierarchy, family, petsc_raises): mesh = square_hierarchy[-1] V = FunctionSpace(mesh, family, 3) u = TrialFunction(V) @@ -185,7 +185,7 @@ def test_macro_multigrid_biharmonic(square_hierarchy, family): problem = LinearVariationalProblem(a, L, uh, bcs=bcs) solver = LinearVariationalSolver(problem, solver_parameters=mg_params) if complex_mode: - with pytest.raises(NotImplementedError): + with petsc_raises(NotImplementedError): solver.solve() else: solver.solve() diff --git a/tests/regression/test_assemble.py b/tests/regression/test_assemble.py index a80b46d5f0..9ee0e1d9e7 100644 --- a/tests/regression/test_assemble.py +++ b/tests/regression/test_assemble.py @@ -1,6 +1,7 @@ import pytest import numpy as np from firedrake import * +from firedrake.assemble import TwoFormAssembler from firedrake.utils import ScalarType, IntType @@ -125,6 +126,23 @@ def test_assemble_mat_with_tensor(mesh): assert np.allclose(M.M.values, 2*assemble(a).M.values, rtol=1e-14) +@pytest.mark.skipcomplex +def test_mat_nest_real_block_assembler_correctly_reuses_tensor(mesh): + V = FunctionSpace(mesh, "CG", 1) + R = FunctionSpace(mesh, "R", 0) + W = V * R + + u = TrialFunction(W) + v = TestFunction(W) + a = inner(v, u) * dx + + assembler = TwoFormAssembler(a, mat_type="nest") + A1 = assembler.assemble() + A2 = assembler.assemble(tensor=A1) + + assert A2.M is A1.M + + def test_assemble_diagonal(mesh): V = FunctionSpace(mesh, "P", 3) u = TrialFunction(V) diff --git a/tests/slate/test_slate_hybridization.py b/tests/slate/test_slate_hybridization.py index 267c904867..ab7d4d4415 100644 --- a/tests/slate/test_slate_hybridization.py +++ b/tests/slate/test_slate_hybridization.py @@ -130,7 +130,7 @@ def test_slate_hybridization(degree, hdiv_family, quadrilateral): assert u_err < 1e-11 -def test_slate_hybridization_wrong_option(setup_poisson): +def test_slate_hybridization_wrong_option(setup_poisson, petsc_raises): a, L, W = setup_poisson w = Function(W) @@ -145,18 +145,9 @@ def test_slate_hybridization_wrong_option(setup_poisson): 'pc_fieldsplit_type': 'frog'}}} problem = LinearVariationalProblem(a, L, w) solver = LinearVariationalSolver(problem, solver_parameters=params) - with pytest.raises(ValueError): - # HybridizationPC isn't called directly from the Python interpreter, - # it's a callback that PETSc calls. This means that the call stack from pytest - # down to HybridizationPC goes via PETSc C code, which interferes with the exception - # before it is observed outside. Hence removing PETSc's error handler - # makes the problem go away, because PETSc stops interfering. - # We need to repush the error handler because popErrorHandler globally changes - # the system state for all future tests. - from firedrake.petsc import PETSc - PETSc.Sys.pushErrorHandler("ignore") + + with petsc_raises(ValueError): solver.solve() - PETSc.Sys.popErrorHandler("ignore") def test_slate_hybridization_nested_schur(setup_poisson):