From 353ca0f94e5fe58b334075b5f24e5cf54dbea4aa Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Mon, 26 Feb 2024 22:11:08 +0000 Subject: [PATCH] assemble: introduce BaseFormAssembler assemble_base_form -> BaseFormAssembler().assemble --- firedrake/adjoint_utils/assembly.py | 4 +- firedrake/assemble.py | 996 ++++++++++----------- tests/regression/test_assemble_baseform.py | 4 +- 3 files changed, 498 insertions(+), 506 deletions(-) diff --git a/firedrake/adjoint_utils/assembly.py b/firedrake/adjoint_utils/assembly.py index f7bdb287df..cac94728ae 100644 --- a/firedrake/adjoint_utils/assembly.py +++ b/firedrake/adjoint_utils/assembly.py @@ -16,7 +16,7 @@ def wrapper(form, *args, **kwargs): ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) with stop_annotating(): - from firedrake.assemble import preprocess_base_form + from firedrake.assemble import BaseFormAssembler from firedrake.slate import slate if not isinstance(form, slate.TensorBase): # Preprocess the form at the annotation stage so that the `AssembleBlock` @@ -25,7 +25,7 @@ def wrapper(form, *args, **kwargs): # -> `interp = Action(Interpolate(v1, v0), f)` with `v1` and `v0` being respectively `Argument` # and `Coargument`. Differentiating `interp` is not currently supported as the action's left slot # is a 2-form. However, after preprocessing, we obtain `Interpolate(f, v0)`, which can be differentiated. - form = preprocess_base_form(form) + form = BaseFormAssembler.preprocess_base_form(form) kwargs['is_base_form_preprocessed'] = True output = assemble(form, *args, **kwargs) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 9b12bbc83c..607661831f 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -114,8 +114,8 @@ def get_form_assembler(form, tensor, *args, **kwargs): mat_type = kwargs.get('mat_type', None) # Preprocess the DAG and restructure the DAG # Only pre-process `form` once beforehand to avoid pre-processing for each assembly call - form = preprocess_base_form(form, mat_type=mat_type, form_compiler_parameters=fc_params) - if isinstance(form, (ufl.form.Form, slate.TensorBase)) and not base_form_operands(form): + form = BaseFormAssembler.preprocess_base_form(form, mat_type=mat_type, form_compiler_parameters=fc_params) + if isinstance(form, (ufl.form.Form, slate.TensorBase)) and not BaseFormAssembler.base_form_operands(form): diagonal = kwargs.pop('diagonal', False) if len(form.arguments()) == 0: return functools.partial(ZeroFormAssembler(form, form_compiler_parameters=fc_params).assemble, tensor=tensor) @@ -130,7 +130,7 @@ def get_form_assembler(form, tensor, *args, **kwargs): # BaseForm preprocessing can turn BaseForm into an Expr (cf. case (6) in `restructure_base_form`) return functools.partial(_assemble_expr, form) elif isinstance(form, ufl.form.BaseForm): - return functools.partial(assemble_base_form, form, *args, tensor=tensor, **kwargs) + return functools.partial(BaseFormAssembler(form, *args, **kwargs).assemble, tensor=tensor) else: raise ValueError(f'Expecting a BaseForm, slate.TensorBase, or Expr object: got {form}') @@ -177,561 +177,553 @@ def assemble(self, tensor=None): """ -def base_form_postorder_traversal(expr, visitor, visited={}): - if expr in visited: - return visited[expr] +class BaseFormAssembler(AbstractFormAssembler): + """Base form assembler. - stack = [expr] - while stack: - e = stack.pop() - unvisited_children = [] - operands = base_form_operands(e) - for arg in operands: - if arg not in visited: - unvisited_children.append(arg) - - if unvisited_children: - stack.append(e) - stack.extend(unvisited_children) - else: - visited[e] = visitor(e, *(visited[arg] for arg in operands)) + Parameters + ---------- + form : ufl.form.BaseForm + `ufl.form.BaseForm` to assemble. - return visited[expr] + Notes + ----- + See `AbstractFormAssembler` and `assemble` for descriptions of the other parameters. + """ -def base_form_preorder_traversal(expr, visitor, visited={}): - if expr in visited: - return visited[expr] + def __init__(self, + form, + bcs=None, + form_compiler_parameters=None, + mat_type=None, + sub_mat_type=None, + options_prefix=None, + appctx=None, + zero_bc_nodes=False, + diagonal=False, + weight=1.0, + allocation_integral_types=None): + super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) + self._mat_type = mat_type + self._sub_mat_type = sub_mat_type + self._options_prefix = options_prefix + self._appctx = appctx + self._zero_bc_nodes = zero_bc_nodes + self._diagonal = diagonal + self._weight = weight + self._allocation_integral_types = allocation_integral_types - stack = [expr] - while stack: - e = stack.pop() - unvisited_children = [] - operands = base_form_operands(e) - for arg in operands: - if arg not in visited: - unvisited_children.append(arg) + def allocate(self): + pass - if unvisited_children: - stack.extend(unvisited_children) + def assemble(self, tensor=None): + """Assemble the form. - visited[e] = visitor(e) + Parameters + ---------- + tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase + Output tensor to contain the result of assembly. - return visited[expr] + Returns + ------- + float or firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase + Result of assembly: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms. + Notes + ----- + This function assembles a `ufl.form.BaseForm` object by traversing the corresponding DAG + in a post-order fashion and evaluating the nodes on the fly. -def reconstruct_node_from_operands(expr, operands): - if isinstance(expr, (ufl.Adjoint, ufl.Action)): - return expr._ufl_expr_reconstruct_(*operands) - elif isinstance(expr, ufl.FormSum): - return ufl.FormSum(*[(op, w) for op, w in zip(operands, expr.weights())]) - return expr + """ + def visitor(e, *operands): + t = tensor if e is self._form else None + return self.base_form_assembly_visitor(e, t, *operands) + # DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly. + visited = {} + result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited) -def base_form_operands(expr): - if isinstance(expr, (ufl.FormSum, ufl.Adjoint, ufl.Action)): - return expr.ufl_operands - if isinstance(expr, ufl.Form): - # Use reversed to treat base form operators - # in the order in which they have been made. - return list(reversed(expr.base_form_operators())) - if isinstance(expr, ufl.core.base_form_operator.BaseFormOperator): - # Conserve order - children = dict.fromkeys(e for e in (expr.argument_slots() + expr.ufl_operands) - if isinstance(e, ufl.form.BaseForm)) - return list(children) - return [] + if tensor: + BaseFormAssembler.update_tensor(result, tensor) + return tensor + else: + return result + def base_form_assembly_visitor(self, expr, tensor, *args): + r"""Assemble a :class:`~ufl.classes.BaseForm` object given its assembled operands. -def restructure_base_form_postorder(expression, visited=None): - visited = visited or {} + This functions contains the assembly handlers corresponding to the different nodes that + can arise in a `~ufl.classes.BaseForm` object. It is called by :func:`assemble_base_form` + in a post-order fashion. + """ + if isinstance(expr, (ufl.form.Form, slate.TensorBase)): + if args and self._mat_type != "matfree": + # Retrieve the Form's children + base_form_operators = BaseFormAssembler.base_form_operands(expr) + # Substitute the base form operators by their output + expr = ufl.replace(expr, dict(zip(base_form_operators, args))) + form = expr + rank = len(form.arguments()) + if rank == 0: + assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params) + elif rank == 1 or (rank == 2 and self._diagonal): + assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal) + elif rank == 2: + assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, + mat_type=self._mat_type, sub_mat_type=self._sub_mat_type, + options_prefix=self._options_prefix, appctx=self._appctx, weight=self._weight) + else: + raise AssertionError + return assembler.assemble(tensor=tensor) + elif isinstance(expr, ufl.Adjoint): + if len(args) != 1: + raise TypeError("Not enough operands for Adjoint") + mat, = args + res = tensor.petscmat if tensor else PETSc.Mat() + petsc_mat = mat.petscmat + # Out-of-place Hermitian transpose + petsc_mat.hermitianTranspose(out=res) + (row, col) = mat.arguments() + return matrix.AssembledMatrix((col, row), self._bcs, res, + appctx=self._appctx, + options_prefix=self._options_prefix) + elif isinstance(expr, ufl.Action): + if (len(args) != 2): + raise TypeError("Not enough operands for Action") + lhs, rhs = args + if isinstance(lhs, matrix.MatrixBase): + if isinstance(rhs, (firedrake.Cofunction, firedrake.Function)): + petsc_mat = lhs.petscmat + (row, col) = lhs.arguments() + # The matrix-vector product lives in the dual of the test space. + res = firedrake.Function(row.function_space().dual()) + with rhs.dat.vec_ro as v_vec: + with res.dat.vec as res_vec: + petsc_mat.mult(v_vec, res_vec) + return res + elif isinstance(rhs, matrix.MatrixBase): + petsc_mat = lhs.petscmat + (row, col) = lhs.arguments() + res = petsc_mat.matMult(rhs.petscmat) + return matrix.AssembledMatrix(expr, self._bcs, res, + appctx=self._appctx, + options_prefix=self._options_prefix) + else: + raise TypeError("Incompatible RHS for Action.") + elif isinstance(lhs, (firedrake.Cofunction, firedrake.Function)): + if isinstance(rhs, (firedrake.Cofunction, firedrake.Function)): + # Return scalar value + with lhs.dat.vec_ro as x, rhs.dat.vec_ro as y: + res = x.dot(y) + return res + else: + raise TypeError("Incompatible RHS for Action.") + else: + raise TypeError("Incompatible LHS for Action.") + elif isinstance(expr, ufl.FormSum): + if len(args) != len(expr.weights()): + raise TypeError("Mismatching weights and operands in FormSum") + if len(args) == 0: + raise TypeError("Empty FormSum") + if all(isinstance(op, float) for op in args): + return sum(args) + elif all(isinstance(op, firedrake.Cofunction) for op in args): + V, = set(a.function_space() for a in args) + res = sum([w*op.dat for (op, w) in zip(args, expr.weights())]) + return firedrake.Cofunction(V, res) + elif all(isinstance(op, ufl.Matrix) for op in args): + res = tensor.petscmat if tensor else PETSc.Mat() + is_set = False + for (op, w) in zip(args, expr.weights()): + # Make a copy to avoid in-place scaling + petsc_mat = op.petscmat.copy() + petsc_mat.scale(w) + if is_set: + # Modify output tensor in-place + res += petsc_mat + else: + # Copy to output tensor + petsc_mat.copy(result=res) + is_set = True + return matrix.AssembledMatrix(expr, self._bcs, res, + appctx=self._appctx, + options_prefix=self._options_prefix) + else: + raise TypeError("Mismatching FormSum shapes") + elif isinstance(expr, ufl.ExternalOperator): + opts = {'form_compiler_parameters': self._form_compiler_params, + 'mat_type': self._mat_type, 'sub_mat_type': self._sub_mat_type, + 'appctx': self._appctx, 'options_prefix': self._options_prefix, + 'diagonal': self._diagonal} + # External operators might not have any children that needs to be assembled + # -> e.g. N(u; v0, w) with v0 a ufl.Argument and w a ufl.Coefficient + if args: + # Replace base forms in the operands and argument slots of the external operator by their result + v, *assembled_children = args + if assembled_children: + _, *children = BaseFormAssembler.base_form_operands(expr) + # Replace assembled children by their results + expr = ufl.replace(expr, dict(zip(children, assembled_children))) + # Always reconstruct the dual argument (0-slot argument) since it is a BaseForm + # It is also convenient when we have a Form in that slot since Forms don't play well with `ufl.replace` + expr = expr._ufl_expr_reconstruct_(*expr.ufl_operands, argument_slots=(v,) + expr.argument_slots()[1:]) + # Call the external operator assembly + return expr.assemble(assembly_opts=opts) + elif isinstance(expr, ufl.Interpolate): + # Replace assembled children + _, expression = expr.argument_slots() + v, *assembled_expression = args + if assembled_expression: + # Occur in situations such as Interpolate composition + expression = assembled_expression[0] + expr = expr._ufl_expr_reconstruct_(expression, v) + + # Different assembly procedures: + # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix) + # 2) Interpolate(Coefficient(...), Argument(V2.dual(), 0)) -> Operator (or Jacobian action) + # 3) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Jacobian adjoint + # 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint + # This can be generalized to the case where the first slot is an arbitray expression. + rank = len(expr.arguments()) + # If argument numbers have been swapped => Adjoint. + arg_expression = ufl.algorithms.extract_arguments(expression) + is_adjoint = (arg_expression and arg_expression[0].number() == 0) + # Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument. + if not is_adjoint and rank != 1: + _, v1 = expr.arguments() + expression = ufl.replace(expression, {v1: firedrake.Argument(v1.function_space(), number=0, part=v1.part())}) + # Get the interpolator + interp_data = expr.interp_data + default_missing_val = interp_data.pop('default_missing_val', None) + interpolator = firedrake.Interpolator(expression, expr.function_space(), **interp_data) + # Assembly + if rank == 1: + # Assembling the action of the Jacobian adjoint. + if is_adjoint: + output = tensor or firedrake.Cofunction(arg_expression[0].function_space().dual()) + return interpolator._interpolate(v, output=output, transpose=True, default_missing_val=default_missing_val) + # Assembling the Jacobian action. + if interpolator.nargs: + return interpolator._interpolate(expression, output=tensor, default_missing_val=default_missing_val) + # Assembling the operator + if tensor is None: + return interpolator._interpolate(default_missing_val=default_missing_val) + return firedrake.Interpolator(expression, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val) + elif rank == 2: + res = tensor.petscmat if tensor else PETSc.Mat() + # Get the interpolation matrix + op2_mat = interpolator.callable() + petsc_mat = op2_mat.handle + if is_adjoint: + # Out-of-place Hermitian transpose + petsc_mat.hermitianTranspose(out=res) + else: + # Copy the interpolation matrix into the output tensor + petsc_mat.copy(result=res) + return matrix.AssembledMatrix(expr.arguments(), self._bcs, res, + appctx=self._appctx, + options_prefix=self._options_prefix) + else: + # The case rank == 0 is handled via the DAG restructuring + raise ValueError("Incompatible number of arguments.") + elif isinstance(expr, (ufl.Cofunction, ufl.Coargument, ufl.Argument, ufl.Matrix, ufl.ZeroBaseForm)): + return expr + elif isinstance(expr, ufl.Coefficient): + return expr + else: + raise TypeError(f"Unrecognised BaseForm instance: {expr}") - def visitor(expr, *operands): - # Need to reconstruct the expression with its visited operands! - expr = reconstruct_node_from_operands(expr, operands) - # Perform the DAG restructuring when needed - return restructure_base_form(expr, visited) + @staticmethod + def update_tensor(assembled_base_form, tensor): + if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): + assembled_base_form.dat.copy(tensor.dat) + elif isinstance(tensor, matrix.MatrixBase): + # Uses the PETSc copy method. + assembled_base_form.petscmat.copy(tensor.petscmat) + else: + raise NotImplementedError("Cannot update tensor of type %s" % type(tensor)) - return base_form_postorder_traversal(expression, visitor, visited) + @staticmethod + def base_form_postorder_traversal(expr, visitor, visited={}): + if expr in visited: + return visited[expr] + + stack = [expr] + while stack: + e = stack.pop() + unvisited_children = [] + operands = BaseFormAssembler.base_form_operands(e) + for arg in operands: + if arg not in visited: + unvisited_children.append(arg) + + if unvisited_children: + stack.append(e) + stack.extend(unvisited_children) + else: + visited[e] = visitor(e, *(visited[arg] for arg in operands)) + return visited[expr] -def restructure_base_form_preorder(expression, visited=None): - visited = visited or {} + @staticmethod + def base_form_preorder_traversal(expr, visitor, visited={}): + if expr in visited: + return visited[expr] - def visitor(expr): - # Perform the DAG restructuring when needed - return restructure_base_form(expr, visited) + stack = [expr] + while stack: + e = stack.pop() + unvisited_children = [] + operands = BaseFormAssembler.base_form_operands(e) + for arg in operands: + if arg not in visited: + unvisited_children.append(arg) - expression = base_form_preorder_traversal(expression, visitor, visited) - # Need to reconstruct the expression at the end when all its operands have been visited! - operands = [visited.get(args, args) for args in base_form_operands(expression)] - return reconstruct_node_from_operands(expression, operands) + if unvisited_children: + stack.extend(unvisited_children) + visited[e] = visitor(e) -def restructure_base_form(expr, visited=None): - r"""Perform a preorder traversal to simplify and optimize the DAG. - Example: Let's consider F(u, N(u; v*); v) with N(u; v*) a base form operator. + return visited[expr] - We have: dFdu = \frac{\partial F}{\partial u} + Action(dFdN, dNdu) - Now taking the action on a rank-1 object w (e.g. Coefficient/Cofunction) results in: + @staticmethod + def reconstruct_node_from_operands(expr, operands): + if isinstance(expr, (ufl.Adjoint, ufl.Action)): + return expr._ufl_expr_reconstruct_(*operands) + elif isinstance(expr, ufl.FormSum): + return ufl.FormSum(*[(op, w) for op, w in zip(operands, expr.weights())]) + return expr - (1) Action(Action(dFdN, dNdu), w) + @staticmethod + def base_form_operands(expr): + if isinstance(expr, (ufl.FormSum, ufl.Adjoint, ufl.Action)): + return expr.ufl_operands + if isinstance(expr, ufl.Form): + # Use reversed to treat base form operators + # in the order in which they have been made. + return list(reversed(expr.base_form_operators())) + if isinstance(expr, ufl.core.base_form_operator.BaseFormOperator): + # Conserve order + children = dict.fromkeys(e for e in (expr.argument_slots() + expr.ufl_operands) + if isinstance(e, ufl.form.BaseForm)) + return list(children) + return [] - Action Action - / \ / \ - Action w -----> dFdN Action - / \ / \ - dFdN dNdu dNdu w + @staticmethod + def restructure_base_form_postorder(expression, visited=None): + visited = visited or {} - This situations does not only arise for BaseFormOperator but also when we have a 2-form instead of dNdu! + def visitor(expr, *operands): + # Need to reconstruct the expression with its visited operands! + expr = BaseFormAssembler.reconstruct_node_from_operands(expr, operands) + # Perform the DAG restructuring when needed + return BaseFormAssembler.restructure_base_form(expr, visited) - (2) Action(dNdu, w) + return BaseFormAssembler.base_form_postorder_traversal(expression, visitor, visited) - Action - / \ - / w -----> dNdu(u; w, v*) - / - dNdu(u; uhat, v*) + @staticmethod + def restructure_base_form_preorder(expression, visited=None): + visited = visited or {} - (3) Action(F, N) + def visitor(expr): + # Perform the DAG restructuring when needed + return BaseFormAssembler.restructure_base_form(expr, visited) - Action F - / \ -----> F(..., N)[v] = | - F[v] N N + expression = BaseFormAssembler.base_form_preorder_traversal(expression, visitor, visited) + # Need to reconstruct the expression at the end when all its operands have been visited! + operands = [visited.get(args, args) for args in BaseFormAssembler.base_form_operands(expression)] + return BaseFormAssembler.reconstruct_node_from_operands(expression, operands) - (4) Adjoint(dNdu) + @staticmethod + def restructure_base_form(expr, visited=None): + r"""Perform a preorder traversal to simplify and optimize the DAG. + Example: Let's consider F(u, N(u; v*); v) with N(u; v*) a base form operator. - Adjoint - | -----> dNdu(u; v*, uhat) - dNdu(u; uhat, v*) + We have: dFdu = \frac{\partial F}{\partial u} + Action(dFdN, dNdu) + Now taking the action on a rank-1 object w (e.g. Coefficient/Cofunction) results in: - (5) N(u; w) (scalar valued) + (1) Action(Action(dFdN, dNdu), w) - Action - N(u; w) ----> / \ = Action(N, w) - N(u; v*) w + Action Action + / \ / \ + Action w -----> dFdN Action + / \ / \ + dFdN dNdu dNdu w - So from Action(Action(dFdN, dNdu(u; v*)), w) we get: + This situations does not only arise for BaseFormOperator but also when we have a 2-form instead of dNdu! - Action Action Action - / \ (1) / \ (2) / \ (4) dFdN - Action w ----> dFdN Action ----> dFdN dNdu(u; w, v*) ----> dFdN(..., dNdu(u; w, v*)) = | - / \ / \ dNdu(u; w, v*) - dFdN dNdu dNdu w + (2) Action(dNdu, w) - (6) ufl.FormSum(dN1du(u; w, v*), dN2du(u; w, v*)) -> ufl.Sum(dN1du(u; w, v*), dN2du(u; w, v*)) + Action + / \ + / w -----> dNdu(u; w, v*) + / + dNdu(u; uhat, v*) - Let's consider `Action(dN1du, w) + Action(dN2du, w)`, we have: + (3) Action(F, N) - FormSum (2) FormSum (6) Sum - / \ ----> / \ ----> / \ - / \ / \ / \ - Action(dN1du, w) Action(dN2du, w) dN1du(u; w, v*) dN2du(u; w, v*) dN1du(u; w, v*) dN2du(u; w, v*) + Action F + / \ -----> F(..., N)[v] = | + F[v] N N - This case arises as a consequence of (2) which turns sum of `Action`s (i.e. ufl.FormSum since Action is a BaseForm) - into sum of `BaseFormOperator`s (i.e. ufl.Sum since BaseFormOperator is an Expr as well). + (4) Adjoint(dNdu) - (7) Action(w*, dNdu) + Adjoint + | -----> dNdu(u; v*, uhat) + dNdu(u; uhat, v*) - Action - / \ - w* \ -----> dNdu(u; v0, w*) - \ - dNdu(u; v1, v0*) + (5) N(u; w) (scalar valued) - It uses a recursive approach to reconstruct the DAG as we traverse it, enabling to take into account - various dag rotations/manipulations in expr. - """ - if isinstance(expr, ufl.Action): - left, right = expr.ufl_operands - is_rank_1 = lambda x: isinstance(x, (firedrake.Cofunction, firedrake.Function, firedrake.Argument)) or len(x.arguments()) == 1 - is_rank_2 = lambda x: len(x.arguments()) == 2 - - # -- Case (1) -- # - # If left is Action and has a rank 2, then it is an action of a 2-form on a 2-form - if isinstance(left, ufl.Action) and is_rank_2(left): - return ufl.action(left.left(), ufl.action(left.right(), right)) - # -- Case (2) (except if left has only 1 argument, i.e. we have done case (5)) -- # - if isinstance(left, ufl.core.base_form_operator.BaseFormOperator) and is_rank_1(right) and len(left.arguments()) != 1: - # Retrieve the highest numbered argument - arg = max(left.arguments(), key=lambda v: v.number()) - return ufl.replace(left, {arg: right}) - # -- Case (3) -- # - if isinstance(left, ufl.Form) and is_rank_1(right): - # 1) Replace the highest-numbered argument of left by right when needed - # -> e.g. if right is a BaseFormOperator with 1 argument. - # Or - # 2) Let expr as it is by returning `ufl.Action(left, right)`. - return ufl.action(left, right) - # -- Case (7) -- # - if is_rank_1(left) and isinstance(right, ufl.core.base_form_operator.BaseFormOperator) and len(right.arguments()) != 1: - # Action(w*, dNdu(u; v1, v*)) -> dNdu(u; v0, w*) - # Get lowest numbered argument - arg = min(right.arguments(), key=lambda v: v.number()) - # Need to replace lowest numbered argument of right by left - replace_map = {arg: left} - # Decrease number for all the other arguments since the lowest numbered argument will be replaced. - other_args = [a for a in right.arguments() if a is not arg] - new_args = [firedrake.Argument(a.function_space(), number=a.number()-1, part=a.part()) for a in other_args] - replace_map.update(dict(zip(other_args, new_args))) - # Replace arguments - return ufl.replace(right, replace_map) - - # -- Case (4) -- # - if isinstance(expr, ufl.Adjoint) and isinstance(expr.form(), ufl.core.base_form_operator.BaseFormOperator): - B = expr.form() - u, v = B.arguments() - # Let V1 and V2 be primal spaces, B: V1 -> V2 and B*: V2* -> V1*: - # Adjoint(B(Argument(V1, 1), Argument(V2.dual(), 0))) = B(Argument(V1, 0), Argument(V2.dual(), 1)) - reordered_arguments = (firedrake.Argument(u.function_space(), number=v.number(), part=v.part()), - firedrake.Argument(v.function_space(), number=u.number(), part=u.part())) - # Replace arguments in argument slots - return ufl.replace(B, dict(zip((u, v), reordered_arguments))) - - # -- Case (5) -- # - if isinstance(expr, ufl.core.base_form_operator.BaseFormOperator) and not expr.arguments(): - # We are assembling a BaseFormOperator of rank 0 (no arguments). - # B(f, u*) be a BaseFormOperator with u* a Cofunction and f a Coefficient, then: - # B(f, u*) <=> Action(B(f, v*), f) where v* is a Coargument - ustar, *_ = expr.argument_slots() - vstar = firedrake.Argument(ustar.function_space(), 0) - expr = ufl.replace(expr, {ustar: vstar}) - return ufl.action(expr, ustar) - - # -- Case (6) -- # - if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()): - # Return ufl.Sum - return sum([c for c in expr.components()]) - return expr - - -def assemble_base_form(expression, tensor=None, bcs=None, - diagonal=False, - mat_type=None, - sub_mat_type=None, - form_compiler_parameters=None, - appctx=None, - options_prefix=None, - zero_bc_nodes=False, - weight=1.0, - visited=None): - r"""Evaluate expression. - - :arg expression: a :class:`~ufl.classes.BaseForm` - :kwarg tensor: Existing tensor object to place the result in. - :kwarg bcs: Iterable of boundary conditions to apply. - :kwarg diagonal: If assembling a matrix is it diagonal? - :kwarg mat_type: String indicating how a 2-form (matrix) should be - assembled -- either as a monolithic matrix (``"aij"`` or ``"baij"``), - a block matrix (``"nest"``), or left as a :class:`.ImplicitMatrix` giving - matrix-free actions (``'matfree'``). If not supplied, the default value in - ``parameters["default_matrix_type"]`` is used. BAIJ differs - from AIJ in that only the block sparsity rather than the dof - sparsity is constructed. This can result in some memory - savings, but does not work with all PETSc preconditioners. - BAIJ matrices only make sense for non-mixed matrices. - :kwarg sub_mat_type: String indicating the matrix type to - use *inside* a nested block matrix. Only makes sense if - ``mat_type`` is ``nest``. May be one of ``"aij"`` or ``"baij"``. If - not supplied, defaults to ``parameters["default_sub_matrix_type"]``. - :kwarg form_compiler_parameters: Dictionary of parameters to pass to - the form compiler. Ignored if not assembling a :class:`~ufl.classes.Form`. - Any parameters provided here will be overridden by parameters set on the - :class:`~ufl.classes.Measure` in the form. For example, if a - ``quadrature_degree`` of 4 is specified in this argument, but a degree of - 3 is requested in the measure, the latter will be used. - :kwarg appctx: Additional information to hang on the assembled - matrix if an implicit matrix is requested (mat_type ``"matfree"``). - :kwarg options_prefix: PETSc options prefix to apply to matrices. - :kwarg zero_bc_nodes: If ``True``, set the boundary condition nodes in the - output tensor to zero rather than to the values prescribed by the - boundary condition. Default is ``False``. - :kwarg weight: weight of the boundary condition, i.e. the scalar in front of the - identity matrix corresponding to the boundary nodes. - To discretise eigenvalue problems set the weight equal to 0.0. + Action + N(u; w) ----> / \ = Action(N, w) + N(u; v*) w - :returns: a :class:`float` for 0-forms, a :class:`.Cofunction` or a :class:`.Function` for 1-forms, - and a :class:`.MatrixBase` for 2-forms. + So from Action(Action(dFdN, dNdu(u; v*)), w) we get: - This function assembles a :class:`~ufl.classes.BaseForm` object by traversing the corresponding DAG - in a post-order fashion and evaluating the nodes on the fly. - """ + Action Action Action + / \ (1) / \ (2) / \ (4) dFdN + Action w ----> dFdN Action ----> dFdN dNdu(u; w, v*) ----> dFdN(..., dNdu(u; w, v*)) = | + / \ / \ dNdu(u; w, v*) + dFdN dNdu dNdu w - expr = expression + (6) ufl.FormSum(dN1du(u; w, v*), dN2du(u; w, v*)) -> ufl.Sum(dN1du(u; w, v*), dN2du(u; w, v*)) - # Define assembly DAG visitor - assembly_visitor = functools.partial(base_form_assembly_visitor, bcs=bcs, diagonal=diagonal, - form_compiler_parameters=form_compiler_parameters, - mat_type=mat_type, sub_mat_type=sub_mat_type, - appctx=appctx, options_prefix=options_prefix, - zero_bc_nodes=zero_bc_nodes, weight=weight) + Let's consider `Action(dN1du, w) + Action(dN2du, w)`, we have: - def visitor(e, *operands): - t = tensor if e is expr else None - return assembly_visitor(e, t, *operands) + FormSum (2) FormSum (6) Sum + / \ ----> / \ ----> / \ + / \ / \ / \ + Action(dN1du, w) Action(dN2du, w) dN1du(u; w, v*) dN2du(u; w, v*) dN1du(u; w, v*) dN2du(u; w, v*) - # DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly. - visited = visited or {} - result = base_form_postorder_traversal(expr, visitor, visited) + This case arises as a consequence of (2) which turns sum of `Action`s (i.e. ufl.FormSum since Action is a BaseForm) + into sum of `BaseFormOperator`s (i.e. ufl.Sum since BaseFormOperator is an Expr as well). - if tensor: - update_tensor(result, tensor) - return tensor - else: - return result - - -def preprocess_base_form(expr, mat_type=None, form_compiler_parameters=None): - """Preprocess ufl.BaseForm objects""" - original_expr = expr - if mat_type != "matfree": - # For "matfree", Form evaluation is delayed - expr = expand_derivatives_form(expr, form_compiler_parameters) - if not isinstance(expr, (ufl.form.Form, slate.TensorBase)): - # => No restructuring needed for Form and slate.TensorBase - expr = restructure_base_form_preorder(expr) - expr = restructure_base_form_postorder(expr) - # Preprocessing the form makes a new object -> current form caching mechanism - # will populate `expr`'s cache which is now different than `original_expr`'s cache so we need - # to transmit the cache. All of this only holds when both are `ufl.Form` objects. - if isinstance(original_expr, ufl.form.Form) and isinstance(expr, ufl.form.Form): - expr._cache = original_expr._cache - return expr - - -def update_tensor(assembled_base_form, tensor): - if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): - assembled_base_form.dat.copy(tensor.dat) - elif isinstance(tensor, matrix.MatrixBase): - # Uses the PETSc copy method. - assembled_base_form.petscmat.copy(tensor.petscmat) - else: - raise NotImplementedError("Cannot update tensor of type %s" % type(tensor)) + (7) Action(w*, dNdu) + Action + / \ + w* \ -----> dNdu(u; v0, w*) + \ + dNdu(u; v1, v0*) -def expand_derivatives_form(form, fc_params): - """Expand derivatives of ufl.BaseForm objects - :arg form: a :class:`~ufl.classes.BaseForm` - :arg fc_params:: Dictionary of parameters to pass to the form compiler. + It uses a recursive approach to reconstruct the DAG as we traverse it, enabling to take into account + various dag rotations/manipulations in expr. + """ + if isinstance(expr, ufl.Action): + left, right = expr.ufl_operands + is_rank_1 = lambda x: isinstance(x, (firedrake.Cofunction, firedrake.Function, firedrake.Argument)) or len(x.arguments()) == 1 + is_rank_2 = lambda x: len(x.arguments()) == 2 + + # -- Case (1) -- # + # If left is Action and has a rank 2, then it is an action of a 2-form on a 2-form + if isinstance(left, ufl.Action) and is_rank_2(left): + return ufl.action(left.left(), ufl.action(left.right(), right)) + # -- Case (2) (except if left has only 1 argument, i.e. we have done case (5)) -- # + if isinstance(left, ufl.core.base_form_operator.BaseFormOperator) and is_rank_1(right) and len(left.arguments()) != 1: + # Retrieve the highest numbered argument + arg = max(left.arguments(), key=lambda v: v.number()) + return ufl.replace(left, {arg: right}) + # -- Case (3) -- # + if isinstance(left, ufl.Form) and is_rank_1(right): + # 1) Replace the highest-numbered argument of left by right when needed + # -> e.g. if right is a BaseFormOperator with 1 argument. + # Or + # 2) Let expr as it is by returning `ufl.Action(left, right)`. + return ufl.action(left, right) + # -- Case (7) -- # + if is_rank_1(left) and isinstance(right, ufl.core.base_form_operator.BaseFormOperator) and len(right.arguments()) != 1: + # Action(w*, dNdu(u; v1, v*)) -> dNdu(u; v0, w*) + # Get lowest numbered argument + arg = min(right.arguments(), key=lambda v: v.number()) + # Need to replace lowest numbered argument of right by left + replace_map = {arg: left} + # Decrease number for all the other arguments since the lowest numbered argument will be replaced. + other_args = [a for a in right.arguments() if a is not arg] + new_args = [firedrake.Argument(a.function_space(), number=a.number()-1, part=a.part()) for a in other_args] + replace_map.update(dict(zip(other_args, new_args))) + # Replace arguments + return ufl.replace(right, replace_map) + + # -- Case (4) -- # + if isinstance(expr, ufl.Adjoint) and isinstance(expr.form(), ufl.core.base_form_operator.BaseFormOperator): + B = expr.form() + u, v = B.arguments() + # Let V1 and V2 be primal spaces, B: V1 -> V2 and B*: V2* -> V1*: + # Adjoint(B(Argument(V1, 1), Argument(V2.dual(), 0))) = B(Argument(V1, 0), Argument(V2.dual(), 1)) + reordered_arguments = (firedrake.Argument(u.function_space(), number=v.number(), part=v.part()), + firedrake.Argument(v.function_space(), number=u.number(), part=u.part())) + # Replace arguments in argument slots + return ufl.replace(B, dict(zip((u, v), reordered_arguments))) + + # -- Case (5) -- # + if isinstance(expr, ufl.core.base_form_operator.BaseFormOperator) and not expr.arguments(): + # We are assembling a BaseFormOperator of rank 0 (no arguments). + # B(f, u*) be a BaseFormOperator with u* a Cofunction and f a Coefficient, then: + # B(f, u*) <=> Action(B(f, v*), f) where v* is a Coargument + ustar, *_ = expr.argument_slots() + vstar = firedrake.Argument(ustar.function_space(), 0) + expr = ufl.replace(expr, {ustar: vstar}) + return ufl.action(expr, ustar) + + # -- Case (6) -- # + if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()): + # Return ufl.Sum + return sum([c for c in expr.components()]) + return expr - :returns: The resulting preprocessed :class:`~ufl.classes.BaseForm`. - This function preprocess the form, mainly by expanding the derivatives, in order to determine - if we are dealing with a :class:`~ufl.classes.Form` or another :class:`~ufl.classes.BaseForm` object. - This function is called in :func:`base_form_assembly_visitor`. Depending on the type of the resulting tensor, - we may call :func:`assemble_form` or traverse the sub-DAG via :func:`assemble_base_form`. - """ - if isinstance(form, ufl.form.Form): - from firedrake.parameters import parameters as default_parameters - from tsfc.parameters import is_complex + @staticmethod + def preprocess_base_form(expr, mat_type=None, form_compiler_parameters=None): + """Preprocess ufl.BaseForm objects""" + original_expr = expr + if mat_type != "matfree": + # Don't expand derivatives if `mat_type` is 'matfree' + # For "matfree", Form evaluation is delayed + expr = BaseFormAssembler.expand_derivatives_form(expr, form_compiler_parameters) + if not isinstance(expr, (ufl.form.Form, slate.TensorBase)): + # => No restructuring needed for Form and slate.TensorBase + expr = BaseFormAssembler.restructure_base_form_preorder(expr) + expr = BaseFormAssembler.restructure_base_form_postorder(expr) + # Preprocessing the form makes a new object -> current form caching mechanism + # will populate `expr`'s cache which is now different than `original_expr`'s cache so we need + # to transmit the cache. All of this only holds when both are `ufl.Form` objects. + if isinstance(original_expr, ufl.form.Form) and isinstance(expr, ufl.form.Form): + expr._cache = original_expr._cache + return expr - if fc_params is None: - fc_params = default_parameters["form_compiler"].copy() - else: - # Override defaults with user-specified values - _ = fc_params - fc_params = default_parameters["form_compiler"].copy() - fc_params.update(_) - - complex_mode = fc_params and is_complex(fc_params.get("scalar_type")) - - return ufl.algorithms.preprocess_form(form, complex_mode) - # We also need to expand derivatives for `ufl.BaseForm` objects that are not `ufl.Form` - # Example: `Action(A, derivative(B, f))`, where `A` is a `ufl.BaseForm` and `B` can - # be `ufl.BaseForm`, or even an appropriate `ufl.Expr`, since assembly of expressions - # containing derivatives is not supported anymore but might be needed if the expression - # in question is within a `ufl.BaseForm` object. - return ufl.algorithms.ad.expand_derivatives(form) - - -def base_form_assembly_visitor(expr, tensor, *args, bcs, diagonal, - form_compiler_parameters, - mat_type, sub_mat_type, - appctx, options_prefix, - zero_bc_nodes, weight): - r"""Assemble a :class:`~ufl.classes.BaseForm` object given its assembled operands. - - This functions contains the assembly handlers corresponding to the different nodes that - can arise in a `~ufl.classes.BaseForm` object. It is called by :func:`assemble_base_form` - in a post-order fashion. - """ + @staticmethod + def expand_derivatives_form(form, fc_params): + """Expand derivatives of ufl.BaseForm objects + :arg form: a :class:`~ufl.classes.BaseForm` + :arg fc_params:: Dictionary of parameters to pass to the form compiler. + + :returns: The resulting preprocessed :class:`~ufl.classes.BaseForm`. + This function preprocess the form, mainly by expanding the derivatives, in order to determine + if we are dealing with a :class:`~ufl.classes.Form` or another :class:`~ufl.classes.BaseForm` object. + This function is called in :func:`base_form_assembly_visitor`. Depending on the type of the resulting tensor, + we may call :func:`assemble_form` or traverse the sub-DAG via :func:`assemble_base_form`. + """ + if isinstance(form, ufl.form.Form): + from firedrake.parameters import parameters as default_parameters + from tsfc.parameters import is_complex - if isinstance(expr, (ufl.form.Form, slate.TensorBase)): - if args and mat_type != "matfree": - # Retrieve the Form's children - base_form_operators = base_form_operands(expr) - # Substitute the base form operators by their output - expr = ufl.replace(expr, dict(zip(base_form_operators, args))) - form = expr - rank = len(form.arguments()) - if rank == 0: - assembler = ZeroFormAssembler(form, form_compiler_parameters=form_compiler_parameters) - elif rank == 1 or (rank == 2 and diagonal): - assembler = OneFormAssembler(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=True, - zero_bc_nodes=zero_bc_nodes, diagonal=diagonal) - elif rank == 2: - assembler = TwoFormAssembler(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=True, - mat_type=mat_type, sub_mat_type=sub_mat_type, options_prefix=options_prefix, appctx=appctx, weight=weight) - else: - raise AssertionError - return assembler.assemble(tensor=tensor) - elif isinstance(expr, ufl.Adjoint): - if len(args) != 1: - raise TypeError("Not enough operands for Adjoint") - mat, = args - res = tensor.petscmat if tensor else PETSc.Mat() - petsc_mat = mat.petscmat - # Out-of-place Hermitian transpose - petsc_mat.hermitianTranspose(out=res) - (row, col) = mat.arguments() - return matrix.AssembledMatrix((col, row), bcs, res, - appctx=appctx, - options_prefix=options_prefix) - elif isinstance(expr, ufl.Action): - if (len(args) != 2): - raise TypeError("Not enough operands for Action") - lhs, rhs = args - if isinstance(lhs, matrix.MatrixBase): - if isinstance(rhs, (firedrake.Cofunction, firedrake.Function)): - petsc_mat = lhs.petscmat - (row, col) = lhs.arguments() - # The matrix-vector product lives in the dual of the test space. - res = firedrake.Function(row.function_space().dual()) - - with rhs.dat.vec_ro as v_vec: - with res.dat.vec as res_vec: - petsc_mat.mult(v_vec, res_vec) - return res - elif isinstance(rhs, matrix.MatrixBase): - petsc_mat = lhs.petscmat - (row, col) = lhs.arguments() - res = petsc_mat.matMult(rhs.petscmat) - return matrix.AssembledMatrix(expr, bcs, res, - appctx=appctx, - options_prefix=options_prefix) - else: - raise TypeError("Incompatible RHS for Action.") - elif isinstance(lhs, (firedrake.Cofunction, firedrake.Function)): - if isinstance(rhs, (firedrake.Cofunction, firedrake.Function)): - # Return scalar value - with lhs.dat.vec_ro as x, rhs.dat.vec_ro as y: - res = x.dot(y) - return res + if fc_params is None: + fc_params = default_parameters["form_compiler"].copy() else: - raise TypeError("Incompatible RHS for Action.") - else: - raise TypeError("Incompatible LHS for Action.") - elif isinstance(expr, ufl.FormSum): - if len(args) != len(expr.weights()): - raise TypeError("Mismatching weights and operands in FormSum") - if len(args) == 0: - raise TypeError("Empty FormSum") - if all(isinstance(op, float) for op in args): - return sum(args) - elif all(isinstance(op, firedrake.Cofunction) for op in args): - V, = set(a.function_space() for a in args) - res = sum([w*op.dat for (op, w) in zip(args, expr.weights())]) - return firedrake.Cofunction(V, res) - elif all(isinstance(op, ufl.Matrix) for op in args): - res = tensor.petscmat if tensor else PETSc.Mat() - is_set = False - for (op, w) in zip(args, expr.weights()): - # Make a copy to avoid in-place scaling - petsc_mat = op.petscmat.copy() - petsc_mat.scale(w) - if is_set: - # Modify output tensor in-place - res += petsc_mat - else: - # Copy to output tensor - petsc_mat.copy(result=res) - is_set = True - return matrix.AssembledMatrix(expr, bcs, res, - appctx=appctx, - options_prefix=options_prefix) - else: - raise TypeError("Mismatching FormSum shapes") - elif isinstance(expr, ufl.ExternalOperator): - opts = {'form_compiler_parameters': form_compiler_parameters, - 'mat_type': mat_type, 'sub_mat_type': sub_mat_type, - 'appctx': appctx, 'options_prefix': options_prefix, - 'diagonal': diagonal} - # External operators might not have any children that needs to be assembled - # -> e.g. N(u; v0, w) with v0 a ufl.Argument and w a ufl.Coefficient - if args: - # Replace base forms in the operands and argument slots of the external operator by their result - v, *assembled_children = args - if assembled_children: - _, *children = base_form_operands(expr) - # Replace assembled children by their results - expr = ufl.replace(expr, dict(zip(children, assembled_children))) - # Always reconstruct the dual argument (0-slot argument) since it is a BaseForm - # It is also convenient when we have a Form in that slot since Forms don't play well with `ufl.replace` - expr = expr._ufl_expr_reconstruct_(*expr.ufl_operands, argument_slots=(v,) + expr.argument_slots()[1:]) - # Call the external operator assembly - return expr.assemble(assembly_opts=opts) - elif isinstance(expr, ufl.Interpolate): - # Replace assembled children - _, expression = expr.argument_slots() - v, *assembled_expression = args - if assembled_expression: - # Occur in situations such as Interpolate composition - expression = assembled_expression[0] - expr = expr._ufl_expr_reconstruct_(expression, v) - - # Different assembly procedures: - # 1) Interpolate(Argument(V1, 1), Argument(V2.dual(), 0)) -> Jacobian (Interpolate matrix) - # 2) Interpolate(Coefficient(...), Argument(V2.dual(), 0)) -> Operator (or Jacobian action) - # 3) Interpolate(Argument(V1, 0), Argument(V2.dual(), 1)) -> Jacobian adjoint - # 4) Interpolate(Argument(V1, 0), Cofunction(...)) -> Action of the Jacobian adjoint - # This can be generalized to the case where the first slot is an arbitray expression. - rank = len(expr.arguments()) - # If argument numbers have been swapped => Adjoint. - arg_expression = ufl.algorithms.extract_arguments(expression) - is_adjoint = (arg_expression and arg_expression[0].number() == 0) - # Workaround: Renumber argument when needed since Interpolator assumes it takes a zero-numbered argument. - if not is_adjoint and rank != 1: - _, v1 = expr.arguments() - expression = ufl.replace(expression, {v1: firedrake.Argument(v1.function_space(), number=0, part=v1.part())}) - # Get the interpolator - interp_data = expr.interp_data - default_missing_val = interp_data.pop('default_missing_val', None) - interpolator = firedrake.Interpolator(expression, expr.function_space(), **interp_data) - # Assembly - if rank == 1: - # Assembling the action of the Jacobian adjoint. - if is_adjoint: - output = tensor or firedrake.Cofunction(arg_expression[0].function_space().dual()) - return interpolator._interpolate(v, output=output, transpose=True, default_missing_val=default_missing_val) - # Assembling the Jacobian action. - if interpolator.nargs: - return interpolator._interpolate(expression, output=tensor, default_missing_val=default_missing_val) - # Assembling the operator - if tensor is None: - return interpolator._interpolate(default_missing_val=default_missing_val) - return firedrake.Interpolator(expression, tensor, **interp_data)._interpolate(default_missing_val=default_missing_val) - elif rank == 2: - res = tensor.petscmat if tensor else PETSc.Mat() - # Get the interpolation matrix - op2_mat = interpolator.callable() - petsc_mat = op2_mat.handle - if is_adjoint: - # Out-of-place Hermitian transpose - petsc_mat.hermitianTranspose(out=res) - else: - # Copy the interpolation matrix into the output tensor - petsc_mat.copy(result=res) - return matrix.AssembledMatrix(expr.arguments(), bcs, res, - appctx=appctx, - options_prefix=options_prefix) - else: - # The case rank == 0 is handled via the DAG restructuring - raise ValueError("Incompatible number of arguments.") - elif isinstance(expr, (ufl.Cofunction, ufl.Coargument, ufl.Argument, ufl.Matrix, ufl.ZeroBaseForm)): - return expr - elif isinstance(expr, ufl.Coefficient): - return expr - else: - raise TypeError(f"Unrecognised BaseForm instance: {expr}") + # Override defaults with user-specified values + _ = fc_params + fc_params = default_parameters["form_compiler"].copy() + fc_params.update(_) + + complex_mode = fc_params and is_complex(fc_params.get("scalar_type")) + + return ufl.algorithms.preprocess_form(form, complex_mode) + # We also need to expand derivatives for `ufl.BaseForm` objects that are not `ufl.Form` + # Example: `Action(A, derivative(B, f))`, where `A` is a `ufl.BaseForm` and `B` can + # be `ufl.BaseForm`, or even an appropriate `ufl.Expr`, since assembly of expressions + # containing derivatives is not supported anymore but might be needed if the expression + # in question is within a `ufl.BaseForm` object. + return ufl.algorithms.ad.expand_derivatives(form) @PETSc.Log.EventDecorator() diff --git a/tests/regression/test_assemble_baseform.py b/tests/regression/test_assemble_baseform.py index 41a2965944..a0241a845f 100644 --- a/tests/regression/test_assemble_baseform.py +++ b/tests/regression/test_assemble_baseform.py @@ -1,7 +1,7 @@ import pytest import numpy as np from firedrake import * -from firedrake.assemble import preprocess_base_form, allocate_matrix +from firedrake.assemble import BaseFormAssembler, allocate_matrix from firedrake.utils import ScalarType import ufl @@ -159,7 +159,7 @@ def test_preprocess_form(M, a, f): from ufl.algorithms import expand_indices, expand_derivatives expr = action(action(M, M), f) - A = preprocess_base_form(expr) + A = BaseFormAssembler.preprocess_base_form(expr) B = action(expand_derivatives(M), action(M, f)) assert isinstance(A, ufl.Action)