diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 607661831f..c93594ce52 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -128,13 +128,87 @@ def get_form_assembler(form, tensor, *args, **kwargs): raise ValueError('Expecting a 0-, 1-, or 2-form: got %s' % (form)) elif isinstance(form, ufl.core.expr.Expr) and not isinstance(form, ufl.core.base_form_operator.BaseFormOperator): # BaseForm preprocessing can turn BaseForm into an Expr (cf. case (6) in `restructure_base_form`) - return functools.partial(_assemble_expr, form) + return functools.partial(ExprAssembler(form).assemble, tensor=tensor) elif isinstance(form, ufl.form.BaseForm): return functools.partial(BaseFormAssembler(form, *args, **kwargs).assemble, tensor=tensor) else: raise ValueError(f'Expecting a BaseForm, slate.TensorBase, or Expr object: got {form}') +class ExprAssembler: + """Expression assembler. + + Parameters + ---------- + expr : ufl.core.expr.Expr + Expression. + + """ + + def __init__(self, expr): + self._expr = expr + + def assemble(self, tensor=None): + """Assemble the pointwise expression. + + Parameters + ---------- + tensor : firedrake.function.Function or firedrake.cofunction.Cofunction or matrix.MatrixBase + Output tensor. + + Returns + ------- + float or firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase + Result of evaluation: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms. + + """ + from ufl.algorithms.analysis import extract_base_form_operators + from ufl.checks import is_scalar_constant_expression + + assert tensor is None + expr = self._expr + # Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`) + base_form_operators = extract_base_form_operators(expr) + + # -- Linear combination involving 2-form BaseFormOperators -- # + # Example: a * dNdu1(u1, u2; v1, v*) + b * dNdu2(u1, u2; v2, v*) + # with u1, u2 Functions, v1, v2, v* BaseArguments, dNdu1, dNdu2 BaseFormOperators, and a, b scalars. + if len(base_form_operators) and any(len(e.arguments()) > 1 for e in base_form_operators): + if isinstance(expr, ufl.algebra.Sum): + a, b = [assemble(e) for e in expr.ufl_operands] + # Only Expr resulting in a Matrix if assembled are BaseFormOperator + if not all(isinstance(op, matrix.AssembledMatrix) for op in (a, b)): + raise TypeError('Mismatching Sum shapes') + return get_form_assembler(ufl.FormSum((a, 1), (b, 1)), None)() + elif isinstance(expr, ufl.algebra.Product): + a, b = expr.ufl_operands + scalar = [e for e in expr.ufl_operands if is_scalar_constant_expression(e)] + if scalar: + base_form = a if a is scalar else b + assembled_mat = assemble(base_form) + return get_form_assembler(ufl.FormSum((assembled_mat, scalar[0])), None)() + a, b = [assemble(e) for e in (a, b)] + return get_form_assembler(ufl.action(a, b), None)() + # -- Linear combination of Functions and 1-form BaseFormOperators -- # + # Example: a * u1 + b * u2 + c * N(u1; v*) + d * N(u2; v*) + # with u1, u2 Functions, N a BaseFormOperator, and a, b, c, d scalars or 0-form BaseFormOperators. + else: + base_form_operators = extract_base_form_operators(expr) + assembled_bfops = [firedrake.assemble(e) for e in base_form_operators] + # Substitute base form operators with their output before examining the expression + # which avoids conflict when determining function space, for example: + # extract_coefficients(Interpolate(u, V2)) with u \in V1 will result in an output function space V1 + # instead of V2. + if base_form_operators: + expr = ufl.replace(expr, dict(zip(base_form_operators, assembled_bfops))) + try: + coefficients = ufl.algorithms.extract_coefficients(expr) + V, = set(c.function_space() for c in coefficients) - {None} + except ValueError: + raise ValueError("Cannot deduce correct target space from pointwise expression") + return firedrake.Function(V).assign(expr) + + class AbstractFormAssembler(abc.ABC): """Abstract assembler class for forms. @@ -807,57 +881,6 @@ def allocate_matrix(expr, bcs=None, *, mat_type=None, sub_mat_type=None, options_prefix=options_prefix) -def _assemble_expr(expr): - """Assemble a pointwise expression. - - :arg expr: The :class:`ufl.core.expr.Expr` to be evaluated. - :returns: A :class:`firedrake.Function` containing the result of this evaluation. - """ - from ufl.algorithms.analysis import extract_base_form_operators - from ufl.checks import is_scalar_constant_expression - - # Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`) - base_form_operators = extract_base_form_operators(expr) - - # -- Linear combination involving 2-form BaseFormOperators -- # - # Example: a * dNdu1(u1, u2; v1, v*) + b * dNdu2(u1, u2; v2, v*) - # with u1, u2 Functions, v1, v2, v* BaseArguments, dNdu1, dNdu2 BaseFormOperators, and a, b scalars. - if len(base_form_operators) and any(len(e.arguments()) > 1 for e in base_form_operators): - if isinstance(expr, ufl.algebra.Sum): - a, b = [assemble(e) for e in expr.ufl_operands] - # Only Expr resulting in a Matrix if assembled are BaseFormOperator - if not all(isinstance(op, matrix.AssembledMatrix) for op in (a, b)): - raise TypeError('Mismatching Sum shapes') - return get_form_assembler(ufl.FormSum((a, 1), (b, 1)), None)() - elif isinstance(expr, ufl.algebra.Product): - a, b = expr.ufl_operands - scalar = [e for e in expr.ufl_operands if is_scalar_constant_expression(e)] - if scalar: - base_form = a if a is scalar else b - assembled_mat = assemble(base_form) - return get_form_assembler(ufl.FormSum((assembled_mat, scalar[0])), None)() - a, b = [assemble(e) for e in (a, b)] - return get_form_assembler(ufl.action(a, b), None)() - # -- Linear combination of Functions and 1-form BaseFormOperators -- # - # Example: a * u1 + b * u2 + c * N(u1; v*) + d * N(u2; v*) - # with u1, u2 Functions, N a BaseFormOperator, and a, b, c, d scalars or 0-form BaseFormOperators. - else: - base_form_operators = extract_base_form_operators(expr) - assembled_bfops = [firedrake.assemble(e) for e in base_form_operators] - # Substitute base form operators with their output before examining the expression - # which avoids conflict when determining function space, for example: - # extract_coefficients(Interpolate(u, V2)) with u \in V1 will result in an output function space V1 - # instead of V2. - if base_form_operators: - expr = ufl.replace(expr, dict(zip(base_form_operators, assembled_bfops))) - try: - coefficients = ufl.algorithms.extract_coefficients(expr) - V, = set(c.function_space() for c in coefficients) - {None} - except ValueError: - raise ValueError("Cannot deduce correct target space from pointwise expression") - return firedrake.Function(V).assign(expr) - - class FormAssembler(AbstractFormAssembler): """Form assembler.