Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the FormSum memory leak #3897

Merged
merged 12 commits into from
Dec 6, 2024
5 changes: 3 additions & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
return sum(weight * arg for weight, arg in zip(expr.weights(), args))
elif all(isinstance(op, firedrake.Cofunction) for op in args):
V, = set(a.function_space() for a in args)
res = sum([w*op.dat for (op, w) in zip(args, expr.weights())])
return firedrake.Cofunction(V, res)
result = firedrake.Cofunction(V)
result.dat.maxpy(expr.weights(), [a.dat for a in args])
return result
elif all(isinstance(op, ufl.Matrix) for op in args):
res = tensor.petscmat if tensor else PETSc.Mat()
is_set = False
Expand Down
50 changes: 50 additions & 0 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,39 @@ def norm(self):
from math import sqrt
return sqrt(self.inner(self).real)

def maxpy(self, scalar: list, x: list) -> None:
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute a sequence of axpy operations.

This is equivalent to calling :meth:`axpy` for each pair of
scalars and :class:`Dat` in the input sequences.

:arg scalar: A sequence of scalars.
:arg x: A sequence of :class:`Dat`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

See also :meth:`axpy`.

Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""
if len(scalar) != len(x):
raise ValueError("scalar and x must have the same length")
for alpha_i, x_i in zip(scalar, x):
self.axpy(alpha_i, x_i)

def axpy(self, alpha: float, other: 'Dat') -> None:
connorjward marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.

On this case, `self` is `y` and `other` is `x`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

"""
self._check_shape(other)
if isinstance(other._data, np.ndarray):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
if not np.isscalar(alpha):
raise TypeError("alpha must be a scalar")
np.add(
alpha * other.data_ro, self.data_ro,
out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")

def __pos__(self):
pos = Dat(self)
return pos
Expand Down Expand Up @@ -1022,6 +1055,23 @@ def inner(self, other):
ret += s.inner(o)
return ret

def axpy(self, alpha: float, other: 'MixedDat') -> None:
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.

On this case, `self` is `y` and `other` is `x`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

"""
self._check_shape(other)
for dat_result, dat_other in zip(self, other):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(dat_result._data, np.ndarray):
if not np.isscalar(alpha):
raise TypeError("alpha must be a scalar")
np.add(
alpha * dat_other.data_ro, dat_result.data_ro,
out=dat_result.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")

def _op(self, other, op):
ret = []
if np.isscalar(other):
Expand Down
30 changes: 30 additions & 0 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,36 @@ def inner(self, other):
assert issubclass(type(other), type(self))
return np.dot(self.data_ro, np.conj(other.data_ro))

def maxpy(self, scalar: list, x: list) -> None:
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute a sequence of axpy operations.

This is equivalent to calling :meth:`axpy` for each pair of
scalars and :class:`Dat` in the input sequences.

:arg scalar: A sequence of scalars.
:arg x: A sequence of :class:`Dat`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

See also :meth:`axpy`.

Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""
if len(scalar) != len(x):
raise ValueError("scalar and x must have the same length")
for alpha_i, x_i in zip(scalar, x):
self.axpy(alpha_i, x_i)

def axpy(self, alpha: float, other: 'Global') -> None:
"""Compute the operation :math:`y = \\alpha x + y`.

On this case, `self` is `y` and `other` is `x`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

"""
JHopeCollins marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self._data, np.ndarray):
if not np.isscalar(alpha):
raise ValueError("alpha must be a scalar")
np.add(alpha * other.data_ro, self.data_ro, out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")


# must have comm, can be modified in parloop (implies a reduction)
class Global(SetFreeDataCarrier, VecAccessMixin):
Expand Down
16 changes: 16 additions & 0 deletions tests/pyop2/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,22 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1):
d1.data_with_halos
assert d1.dat_version == 1

def test_axpy(self, d1):
d2 = op2.Dat(d1.dataset)
d1.data[:] = 0
d2.data[:] = 2
d1.axpy(3, d2)
assert (d1.data_ro == 3 * 2).all()

def test_maxpy(self, d1):
d2 = op2.Dat(d1.dataset)
d3 = op2.Dat(d1.dataset)
d1.data[:] = 0
d2.data[:] = 2
d3.data[:] = 3
d1.maxpy((2, 3), (d2, d3))
assert (d1.data_ro == 2 * 2 + 3 * 3).all()


class TestDatView():

Expand Down
Loading