From 954300c11198c46c5f364b1947534294bca88ca3 Mon Sep 17 00:00:00 2001 From: Prajwel Joseph Date: Tue, 7 Jan 2025 21:25:15 +0530 Subject: [PATCH] To improve the `aggregate_downsample` performance by using a new default `aggregate_func` (#17574) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a nw default aggregate_func that can take an array an do `nanmean` on pieces much more efficiently that looping over the pieces by replacing `nan` with `0`, using `np.add.reduceat` and doing counts appropriately for calculating the mean. --------- Co-authored-by: Clément Robert --- astropy/timeseries/downsample.py | 26 ++++++++++++++++-- astropy/timeseries/tests/test_downsample.py | 30 ++++++++++++++++++++- docs/changes/timeseries/17574.perf.rst | 1 + 3 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 docs/changes/timeseries/17574.perf.rst diff --git a/astropy/timeseries/downsample.py b/astropy/timeseries/downsample.py index 4b089f1f31f6..47f21153aa2a 100644 --- a/astropy/timeseries/downsample.py +++ b/astropy/timeseries/downsample.py @@ -13,6 +13,26 @@ __all__ = ["aggregate_downsample"] +def nanmean_reduceat(data, indices): + mask = np.isnan(data) + + if mask.any(): # If there are NaNs + # Create a writeable copy and mask NaNs + data = data.copy() + data[mask] = 0 + count_data = np.add.reduceat(~mask, indices) + # Avoid division by zero warnings + count_data = count_data.astype(data.dtype) + count_data[count_data == 0] = np.nan + else: + # Derive counts from indices + count_data = np.diff(indices, append=len(data)) + count_data[count_data <= 0] = 1 + + sum_data = np.add.reduceat(data, indices) + return sum_data / count_data + + def reduceat(array, indices, function): """ Manual reduceat functionality for cases where Numpy functions don't have a reduceat. @@ -20,6 +40,8 @@ def reduceat(array, indices, function): """ if len(indices) == 0: return np.array([]) + elif function is nanmean_reduceat: + return np.array(function(array, indices)) elif hasattr(function, "reduceat"): return np.array(function.reduceat(array, indices)) else: @@ -93,7 +115,7 @@ def aggregate_downsample( parameter will be ignored. aggregate_func : callable, optional The function to use for combining points in the same bin. Defaults - to np.nanmean. + to an internal implementation of nanmean. Returns ------- @@ -175,7 +197,7 @@ def aggregate_downsample( ) if aggregate_func is None: - aggregate_func = np.nanmean + aggregate_func = nanmean_reduceat # Start and end times of the binned timeseries bin_start = binned.time_bin_start diff --git a/astropy/timeseries/tests/test_downsample.py b/astropy/timeseries/tests/test_downsample.py index 63f7d96ee30e..2912ce8114ee 100644 --- a/astropy/timeseries/tests/test_downsample.py +++ b/astropy/timeseries/tests/test_downsample.py @@ -8,7 +8,11 @@ from astropy import units as u from astropy.time import Time -from astropy.timeseries.downsample import aggregate_downsample, reduceat +from astropy.timeseries.downsample import ( + aggregate_downsample, + nanmean_reduceat, + reduceat, +) from astropy.timeseries.sampled import TimeSeries from astropy.utils.exceptions import AstropyUserWarning @@ -39,6 +43,30 @@ def test_reduceat(): ) +def test_nanmean_reduceat(): + data = np.arange(8) + indices = [0, 4, 1, 5, 5, 2, 6, 6, 3, 7] + + reduceat_output1 = reduceat(data, indices, np.nanmean) + nanmean_output1 = nanmean_reduceat(data, indices) + assert_equal(reduceat_output1, nanmean_output1) + + data = data.astype("float") + data[::2] = np.nan + with np.testing.suppress_warnings() as sup: + sup.filter(RuntimeWarning, "Mean of empty slice") + reduceat_output2 = reduceat(data, indices, np.nanmean) + nanmean_output2 = nanmean_reduceat(data, indices) + assert_equal(reduceat_output2, nanmean_output2) + + data[:] = np.nan + with np.testing.suppress_warnings() as sup: + sup.filter(RuntimeWarning, "Mean of empty slice") + reduceat_output3 = reduceat(data, indices, np.nanmean) + nanmean_output3 = nanmean_reduceat(data, indices) + assert_equal(reduceat_output3, nanmean_output3) + + def test_timeseries_invalid(): with pytest.raises(TypeError, match="time_series should be a TimeSeries"): aggregate_downsample(None) diff --git a/docs/changes/timeseries/17574.perf.rst b/docs/changes/timeseries/17574.perf.rst new file mode 100644 index 000000000000..051746703972 --- /dev/null +++ b/docs/changes/timeseries/17574.perf.rst @@ -0,0 +1 @@ +Improved the ``aggregate_downsample`` performance using a new default ``aggregate_func``.