Skip to content

Commit

Permalink
ImplicitMatrixContext: handle empty action
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 3, 2025
1 parent 7f40504 commit 3d06fc5
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
13 changes: 8 additions & 5 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def __init__(self,
zero_bc_nodes=False,
diagonal=False,
weight=1.0,
allocation_integral_types=None):
allocation_integral_types=None,
needs_zeroing=False):
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
self._mat_type = mat_type
self._sub_mat_type = sub_mat_type
Expand All @@ -321,6 +322,7 @@ def __init__(self,
self._diagonal = diagonal
self._weight = weight
self._allocation_integral_types = allocation_integral_types
assert not needs_zeroing

def allocate(self):
rank = len(self._form.arguments())
Expand Down Expand Up @@ -1127,7 +1129,8 @@ def _apply_bc(self, tensor, bc):
pass

def _check_tensor(self, tensor):
pass
if not isinstance(tensor, op2.Global):
raise TypeError(f"Expecting a op2.Global, got {tensor!r}.")

@staticmethod
def _as_pyop2_type(tensor, indices=None):
Expand All @@ -1143,7 +1146,7 @@ class OneFormAssembler(ParloopFormAssembler):
Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
1-form.
Notes
Expand Down Expand Up @@ -1189,8 +1192,8 @@ def _apply_bc(self, tensor, bc):
self._apply_dirichlet_bc(tensor, bc)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
else:
raise AssertionError

Expand Down
28 changes: 17 additions & 11 deletions firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from firedrake.bcs import DirichletBC, EquationBCSplit
from firedrake.petsc import PETSc
from firedrake.utils import cached_property
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from ufl.form import ZeroBaseForm


__all__ = ("ImplicitMatrixContext", )
Expand Down Expand Up @@ -107,23 +110,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

# create functions from test and trial space to help
# with 1-form assembly
test_space, trial_space = [
a.arguments()[i].function_space() for i in (0, 1)
]
from firedrake import function, cofunction
test_space, trial_space = (
arg.function_space() for arg in a.arguments()
)
# Need a cofunction since y receives the assembled result of Ax
self._ystar = cofunction.Cofunction(test_space.dual())
self._y = function.Function(test_space)
self._x = function.Function(trial_space)
self._xstar = cofunction.Cofunction(trial_space.dual())
self._ystar = Cofunction(test_space.dual())
self._y = Function(test_space)
self._x = Function(trial_space)
self._xstar = Cofunction(trial_space.dual())

# These are temporary storage for holding the BC
# values during matvec application. _xbc is for
# the action and ._ybc is for transpose.
if len(self.bcs) > 0:
self._xbc = cofunction.Cofunction(trial_space.dual())
self._xbc = Cofunction(trial_space.dual())
if len(self.col_bcs) > 0:
self._ybc = cofunction.Cofunction(test_space.dual())
self._ybc = Cofunction(test_space.dual())

# Get size information from template vecs on test and trial spaces
trial_vec = trial_space.dof_dset.layout_vec
Expand All @@ -135,6 +137,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

self.action = action(self.a, self._x)
self.actionT = action(self.aT, self._y)
# TODO prevent action from returning empty Forms
if self.action.empty():
self.action = ZeroBaseForm(self.a.arguments()[:-1])
if self.actionT.empty():
self.actionT = ZeroBaseForm(self.aT.arguments()[:-1])

# For assembling action(f, self._x)
self.bcs_action = []
Expand Down Expand Up @@ -170,7 +177,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

@cached_property
def _diagonal(self):
from firedrake import Cofunction
assert self.on_diag
return Cofunction(self._x.function_space().dual())

Expand Down
19 changes: 11 additions & 8 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from ufl import Constant
from ufl.coefficient import BaseCoefficient

from firedrake.formmanipulation import ExtractSubBlock
from firedrake.function import Function, Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace
from firedrake.ufl_expr import Argument, TestFunction
from firedrake.utils import cached_property, unique

from itertools import chain, count
Expand All @@ -35,8 +38,6 @@
from ufl.form import Form, ZeroBaseForm
import hashlib

from firedrake.formmanipulation import ExtractSubBlock

from tsfc.ufl_utils import extract_firedrake_constants


Expand Down Expand Up @@ -293,6 +294,10 @@ def solve(self, B, decomposition=None):
"""
return Solve(self, B, decomposition=decomposition)

def empty(self):
"""Returns whether the form associated with the tensor is empty."""
return False

@cached_property
def blocks(self):
"""Returns an object containing the blocks of the tensor defined
Expand Down Expand Up @@ -461,8 +466,6 @@ def arg_function_spaces(self):
@cached_property
def _argument(self):
"""Generates a 'test function' associated with this class."""
from firedrake.ufl_expr import TestFunction

V, = self.arg_function_spaces
return TestFunction(V)

Expand Down Expand Up @@ -543,7 +546,6 @@ def arg_function_spaces(self):
@cached_property
def _argument(self):
"""Generates a tuple of 'test function' associated with this class."""
from firedrake.ufl_expr import TestFunction
return tuple(TestFunction(fs) for fs in self.arg_function_spaces)

def arguments(self):
Expand Down Expand Up @@ -668,9 +670,6 @@ def _split_arguments(self):
"""Splits the function space and stores the component
spaces determined by the indices.
"""
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace
from firedrake.ufl_expr import Argument

tensor, = self.operands
nargs = []
for i, arg in enumerate(tensor.arguments()):
Expand Down Expand Up @@ -938,6 +937,10 @@ def subdomain_data(self):
"""
return self.form.subdomain_data()

def empty(self):
"""Returns whether the form associated with the tensor is empty."""
return self.form.empty()

def _output_string(self, prec=None):
"""Creates a string representation of the tensor."""
return ["S", "V", "M"][self.rank] + "_%d" % self.id
Expand Down

0 comments on commit 3d06fc5

Please sign in to comment.