Skip to content

Commit

Permalink
Zero-simplify slate Tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 3, 2025
1 parent af53302 commit 7f40504
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
1 change: 1 addition & 0 deletions firedrake/slate/slac/tsfc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None):
assert tensor.terminal, (
"Only terminal tensors have forms associated with them!"
)

# Sets a default name for the subkernel prefix.
mapper = RemoveRestrictions()
integrals = map(partial(map_integrand_dags, mapper),
Expand Down
45 changes: 35 additions & 10 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ufl.corealg.multifunction import MultiFunction
from ufl.classes import Zero
from ufl.domain import join_domains, sort_domains
from ufl.form import Form
from ufl.form import Form, ZeroBaseForm
import hashlib

from firedrake.formmanipulation import ExtractSubBlock
Expand Down Expand Up @@ -237,7 +237,7 @@ def coeff_map(self):
coeff_map[m].update(c.indices[0])
else:
m = self.coefficients().index(c)
split_map = tuple(range(len(c.subfunctions))) if isinstance(c, Function) or isinstance(c, Constant) or isinstance(c, Cofunction) else tuple(range(1))
split_map = tuple(range(len(c.subfunctions))) if isinstance(c, (Function, Constant, Cofunction)) else (0,)
coeff_map[m].update(split_map)
return tuple((k, tuple(sorted(v)))for k, v in coeff_map.items())

Expand Down Expand Up @@ -382,6 +382,10 @@ def __eq__(self, other):
"""Determines whether two TensorBase objects are equal using their
associated keys.
"""
if isinstance(other, (int, float)) and other == 0:
if isinstance(self, Tensor):
return isinstance(self.form, ZeroBaseForm) or self.form.empty()
return False
return self._key == other._key

def __ne__(self, other):
Expand Down Expand Up @@ -650,7 +654,7 @@ def __init__(self, tensor, indices):
"""Constructor for the Block class."""
super(Block, self).__init__()
self.operands = (tensor,)
self._blocks = dict(enumerate(indices))
self._blocks = dict(enumerate(map(as_tuple, indices)))
self._indices = indices

@cached_property
Expand All @@ -671,14 +675,12 @@ def _split_arguments(self):
nargs = []
for i, arg in enumerate(tensor.arguments()):
V = arg.function_space()
V_is = V.subfunctions
idx = as_tuple(self._blocks[i])
idx = self._blocks[i]
if len(idx) == 1:
fidx, = idx
W = V_is[fidx]
W = V[idx[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
else:
W = MixedFunctionSpace([V_is[fidx] for fidx in idx])
W = MixedFunctionSpace([V[fidx] for fidx in idx])

nargs.append(Argument(W, arg.number(), part=arg.part()))

Expand Down Expand Up @@ -880,7 +882,7 @@ class Tensor(TensorBase):

def __init__(self, form, diagonal=False):
"""Constructor for the Tensor class."""
if not isinstance(form, Form):
if not isinstance(form, (Form, ZeroBaseForm)):
if isinstance(form, Function):
raise TypeError("Use AssembledVector instead of Tensor.")
raise TypeError("Only UFL forms are acceptable inputs.")
Expand Down Expand Up @@ -1103,6 +1105,10 @@ def _output_string(self, prec=None):

class Transpose(UnaryOp):
"""An abstract Slate class representing the transpose of a tensor."""
def __new__(cls, A):
if A == 0:
return Tensor(ZeroBaseForm(A.form.arguments()[::-1]))
return BinaryOp.__new__(cls)

@cached_property
def arg_function_spaces(self):
Expand All @@ -1127,6 +1133,10 @@ def _output_string(self, prec=None):

class Negative(UnaryOp):
"""Abstract Slate class representing the negation of a tensor object."""
def __new__(cls, A):
if A == 0:
return A
return BinaryOp.__new__(cls)

@cached_property
def arg_function_spaces(self):
Expand Down Expand Up @@ -1197,6 +1207,12 @@ class Add(BinaryOp):
:arg A: a :class:`~.firedrake.slate.TensorBase` object.
:arg B: another :class:`~.firedrake.slate.TensorBase` object.
"""
def __new__(cls, A, B):
if A == 0:
return B
elif B == 0:
return A
return BinaryOp.__new__(cls)

def __init__(self, A, B):
"""Constructor for the Add class."""
Expand Down Expand Up @@ -1238,6 +1254,10 @@ class Mul(BinaryOp):
:arg A: a :class:`~.firedrake.slate.TensorBase` object.
:arg B: another :class:`~.firedrake.slate.TensorBase` object.
"""
def __new__(cls, A, B):
if A == 0 or B == 0:
return Tensor(ZeroBaseForm(A.arguments()[:-1] + B.arguments()[1:]))
return BinaryOp.__new__(cls)

def __init__(self, A, B):
"""Constructor for the Mul class."""
Expand Down Expand Up @@ -1295,7 +1315,7 @@ def __new__(cls, A, B, decomposition=None):
raise ValueError("Illegal op on a %s-tensor with a %s-tensor."
% (A.shape, B.shape))

fsA = A.arg_function_spaces[::-1][-1]
fsA = A.arg_function_spaces[0]
fsB = B.arg_function_spaces[0]

assert space_equivalence(fsA, fsB), (
Expand Down Expand Up @@ -1348,6 +1368,11 @@ class DiagonalTensor(UnaryOp):
"""
diagonal = True

def __new__(cls, A):
if A == 0:
return Tensor(ZeroBaseForm(A.arguments()[:1]))
return BinaryOp.__new__(cls)

def __init__(self, A):
"""Constructor for the Diagonal class."""
assert A.rank == 2, "The tensor must be rank 2."
Expand Down

0 comments on commit 7f40504

Please sign in to comment.