Skip to content

Commit

Permalink
To improve the aggregate_downsample performance by using a new defa…
Browse files Browse the repository at this point in the history
…ult `aggregate_func` (astropy#17574)

Add a nw default aggregate_func that can take an array an do `nanmean` on pieces
much more efficiently that looping over the pieces by replacing `nan` with `0`, using
`np.add.reduceat` and doing counts appropriately for calculating the mean.

---------

Co-authored-by: Clément Robert <[email protected]>
  • Loading branch information
prajwel and neutrinoceros authored Jan 7, 2025
1 parent 654e241 commit 954300c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
26 changes: 24 additions & 2 deletions astropy/timeseries/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,35 @@
__all__ = ["aggregate_downsample"]


def nanmean_reduceat(data, indices):
mask = np.isnan(data)

if mask.any(): # If there are NaNs
# Create a writeable copy and mask NaNs
data = data.copy()
data[mask] = 0
count_data = np.add.reduceat(~mask, indices)
# Avoid division by zero warnings
count_data = count_data.astype(data.dtype)
count_data[count_data == 0] = np.nan
else:
# Derive counts from indices
count_data = np.diff(indices, append=len(data))
count_data[count_data <= 0] = 1

sum_data = np.add.reduceat(data, indices)
return sum_data / count_data


def reduceat(array, indices, function):
"""
Manual reduceat functionality for cases where Numpy functions don't have a reduceat.
It will check if the input function has a reduceat and call that if it does.
"""
if len(indices) == 0:
return np.array([])
elif function is nanmean_reduceat:
return np.array(function(array, indices))
elif hasattr(function, "reduceat"):
return np.array(function.reduceat(array, indices))
else:
Expand Down Expand Up @@ -93,7 +115,7 @@ def aggregate_downsample(
parameter will be ignored.
aggregate_func : callable, optional
The function to use for combining points in the same bin. Defaults
to np.nanmean.
to an internal implementation of nanmean.
Returns
-------
Expand Down Expand Up @@ -175,7 +197,7 @@ def aggregate_downsample(
)

if aggregate_func is None:
aggregate_func = np.nanmean
aggregate_func = nanmean_reduceat

# Start and end times of the binned timeseries
bin_start = binned.time_bin_start
Expand Down
30 changes: 29 additions & 1 deletion astropy/timeseries/tests/test_downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from astropy import units as u
from astropy.time import Time
from astropy.timeseries.downsample import aggregate_downsample, reduceat
from astropy.timeseries.downsample import (
aggregate_downsample,
nanmean_reduceat,
reduceat,
)
from astropy.timeseries.sampled import TimeSeries
from astropy.utils.exceptions import AstropyUserWarning

Expand Down Expand Up @@ -39,6 +43,30 @@ def test_reduceat():
)


def test_nanmean_reduceat():
data = np.arange(8)
indices = [0, 4, 1, 5, 5, 2, 6, 6, 3, 7]

reduceat_output1 = reduceat(data, indices, np.nanmean)
nanmean_output1 = nanmean_reduceat(data, indices)
assert_equal(reduceat_output1, nanmean_output1)

data = data.astype("float")
data[::2] = np.nan
with np.testing.suppress_warnings() as sup:
sup.filter(RuntimeWarning, "Mean of empty slice")
reduceat_output2 = reduceat(data, indices, np.nanmean)
nanmean_output2 = nanmean_reduceat(data, indices)
assert_equal(reduceat_output2, nanmean_output2)

data[:] = np.nan
with np.testing.suppress_warnings() as sup:
sup.filter(RuntimeWarning, "Mean of empty slice")
reduceat_output3 = reduceat(data, indices, np.nanmean)
nanmean_output3 = nanmean_reduceat(data, indices)
assert_equal(reduceat_output3, nanmean_output3)


def test_timeseries_invalid():
with pytest.raises(TypeError, match="time_series should be a TimeSeries"):
aggregate_downsample(None)
Expand Down
1 change: 1 addition & 0 deletions docs/changes/timeseries/17574.perf.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved the ``aggregate_downsample`` performance using a new default ``aggregate_func``.

0 comments on commit 954300c

Please sign in to comment.