Skip to content

Commit

Permalink
assemble: introduce ExprAssembler
Browse files Browse the repository at this point in the history
_assemble_expr -> ExprAssembler().assemble
  • Loading branch information
ksagiyam committed Feb 28, 2024
1 parent 353ca0f commit 04f3534
Showing 1 changed file with 75 additions and 52 deletions.
127 changes: 75 additions & 52 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,87 @@ def get_form_assembler(form, tensor, *args, **kwargs):
raise ValueError('Expecting a 0-, 1-, or 2-form: got %s' % (form))
elif isinstance(form, ufl.core.expr.Expr) and not isinstance(form, ufl.core.base_form_operator.BaseFormOperator):
# BaseForm preprocessing can turn BaseForm into an Expr (cf. case (6) in `restructure_base_form`)
return functools.partial(_assemble_expr, form)
return functools.partial(ExprAssembler(form).assemble, tensor=tensor)
elif isinstance(form, ufl.form.BaseForm):
return functools.partial(BaseFormAssembler(form, *args, **kwargs).assemble, tensor=tensor)
else:
raise ValueError(f'Expecting a BaseForm, slate.TensorBase, or Expr object: got {form}')


class ExprAssembler:
"""Expression assembler.
Parameters
----------
expr : ufl.core.expr.Expr
Expression.
"""

def __init__(self, expr):
self._expr = expr

def assemble(self, tensor=None):
"""Assemble the pointwise expression.
Parameters
----------
tensor : firedrake.function.Function or firedrake.cofunction.Cofunction or matrix.MatrixBase
Output tensor.
Returns
-------
float or firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
Result of evaluation: `float` for 0-forms, `firedrake.cofunction.Cofunction` or `firedrake.function.Function` for 1-forms, and `matrix.MatrixBase` for 2-forms.
"""
from ufl.algorithms.analysis import extract_base_form_operators
from ufl.checks import is_scalar_constant_expression

assert tensor is None
expr = self._expr
# Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`)
base_form_operators = extract_base_form_operators(expr)

# -- Linear combination involving 2-form BaseFormOperators -- #
# Example: a * dNdu1(u1, u2; v1, v*) + b * dNdu2(u1, u2; v2, v*)
# with u1, u2 Functions, v1, v2, v* BaseArguments, dNdu1, dNdu2 BaseFormOperators, and a, b scalars.
if len(base_form_operators) and any(len(e.arguments()) > 1 for e in base_form_operators):
if isinstance(expr, ufl.algebra.Sum):
a, b = [assemble(e) for e in expr.ufl_operands]
# Only Expr resulting in a Matrix if assembled are BaseFormOperator
if not all(isinstance(op, matrix.AssembledMatrix) for op in (a, b)):
raise TypeError('Mismatching Sum shapes')
return get_form_assembler(ufl.FormSum((a, 1), (b, 1)), None)()
elif isinstance(expr, ufl.algebra.Product):
a, b = expr.ufl_operands
scalar = [e for e in expr.ufl_operands if is_scalar_constant_expression(e)]
if scalar:
base_form = a if a is scalar else b
assembled_mat = assemble(base_form)
return get_form_assembler(ufl.FormSum((assembled_mat, scalar[0])), None)()
a, b = [assemble(e) for e in (a, b)]
return get_form_assembler(ufl.action(a, b), None)()
# -- Linear combination of Functions and 1-form BaseFormOperators -- #
# Example: a * u1 + b * u2 + c * N(u1; v*) + d * N(u2; v*)
# with u1, u2 Functions, N a BaseFormOperator, and a, b, c, d scalars or 0-form BaseFormOperators.
else:
base_form_operators = extract_base_form_operators(expr)
assembled_bfops = [firedrake.assemble(e) for e in base_form_operators]
# Substitute base form operators with their output before examining the expression
# which avoids conflict when determining function space, for example:
# extract_coefficients(Interpolate(u, V2)) with u \in V1 will result in an output function space V1
# instead of V2.
if base_form_operators:
expr = ufl.replace(expr, dict(zip(base_form_operators, assembled_bfops)))
try:
coefficients = ufl.algorithms.extract_coefficients(expr)
V, = set(c.function_space() for c in coefficients) - {None}
except ValueError:
raise ValueError("Cannot deduce correct target space from pointwise expression")
return firedrake.Function(V).assign(expr)


class AbstractFormAssembler(abc.ABC):
"""Abstract assembler class for forms.
Expand Down Expand Up @@ -807,57 +881,6 @@ def allocate_matrix(expr, bcs=None, *, mat_type=None, sub_mat_type=None,
options_prefix=options_prefix)


def _assemble_expr(expr):
"""Assemble a pointwise expression.
:arg expr: The :class:`ufl.core.expr.Expr` to be evaluated.
:returns: A :class:`firedrake.Function` containing the result of this evaluation.
"""
from ufl.algorithms.analysis import extract_base_form_operators
from ufl.checks import is_scalar_constant_expression

# Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`)
base_form_operators = extract_base_form_operators(expr)

# -- Linear combination involving 2-form BaseFormOperators -- #
# Example: a * dNdu1(u1, u2; v1, v*) + b * dNdu2(u1, u2; v2, v*)
# with u1, u2 Functions, v1, v2, v* BaseArguments, dNdu1, dNdu2 BaseFormOperators, and a, b scalars.
if len(base_form_operators) and any(len(e.arguments()) > 1 for e in base_form_operators):
if isinstance(expr, ufl.algebra.Sum):
a, b = [assemble(e) for e in expr.ufl_operands]
# Only Expr resulting in a Matrix if assembled are BaseFormOperator
if not all(isinstance(op, matrix.AssembledMatrix) for op in (a, b)):
raise TypeError('Mismatching Sum shapes')
return get_form_assembler(ufl.FormSum((a, 1), (b, 1)), None)()
elif isinstance(expr, ufl.algebra.Product):
a, b = expr.ufl_operands
scalar = [e for e in expr.ufl_operands if is_scalar_constant_expression(e)]
if scalar:
base_form = a if a is scalar else b
assembled_mat = assemble(base_form)
return get_form_assembler(ufl.FormSum((assembled_mat, scalar[0])), None)()
a, b = [assemble(e) for e in (a, b)]
return get_form_assembler(ufl.action(a, b), None)()
# -- Linear combination of Functions and 1-form BaseFormOperators -- #
# Example: a * u1 + b * u2 + c * N(u1; v*) + d * N(u2; v*)
# with u1, u2 Functions, N a BaseFormOperator, and a, b, c, d scalars or 0-form BaseFormOperators.
else:
base_form_operators = extract_base_form_operators(expr)
assembled_bfops = [firedrake.assemble(e) for e in base_form_operators]
# Substitute base form operators with their output before examining the expression
# which avoids conflict when determining function space, for example:
# extract_coefficients(Interpolate(u, V2)) with u \in V1 will result in an output function space V1
# instead of V2.
if base_form_operators:
expr = ufl.replace(expr, dict(zip(base_form_operators, assembled_bfops)))
try:
coefficients = ufl.algorithms.extract_coefficients(expr)
V, = set(c.function_space() for c in coefficients) - {None}
except ValueError:
raise ValueError("Cannot deduce correct target space from pointwise expression")
return firedrake.Function(V).assign(expr)


class FormAssembler(AbstractFormAssembler):
"""Form assembler.
Expand Down

0 comments on commit 04f3534

Please sign in to comment.