diff --git a/HISTORY.rst b/HISTORY.rst index 9c0dc908f..6ae266cad 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History 0.2.11 (YYYY-MM-DD) ------------------- +* Interpolate missing nominal values during Averaging (:pr:`246`) * Baseline-Dependent Time-and-Channel Averaging (:pr:`173`, :pr:`243`) 0.2.10 (2021-02-09) diff --git a/africanus/averaging/bda_mapping.py b/africanus/averaging/bda_mapping.py index 34360afa9..5d80839d8 100644 --- a/africanus/averaging/bda_mapping.py +++ b/africanus/averaging/bda_mapping.py @@ -168,8 +168,6 @@ def start_bin(self, row, time, interval, flag_row): self.rs = row self.re = row self.bin_count = 1 - self.time_sum = time[row] - self.interval_sum = interval[row] self.bin_flag_count = (1 if flag_row is not None and flag_row[row] != 0 else 0) @@ -196,8 +194,6 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row): self.re = row self.bin_half_Δψ = self.decorrelation self.bin_count += 1 - self.time_sum += time[row] - self.interval_sum += interval[row] if flag_row is not None and flag_row[row] != 0: self.bin_flag_count += 1 @@ -233,8 +229,6 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row): self.re = row self.bin_half_Δψ = half_𝞓𝞇 self.bin_count += 1 - self.time_sum += time[row] - self.interval_sum += interval[row] if flag_row is not None and flag_row[row] != 0: self.bin_flag_count += 1 @@ -245,7 +239,9 @@ def add_row(self, row, auto_corr, time, interval, uvw, flag_row): def empty(self): return self.bin_count == 0 - def finalise_bin(self, auto_corr, uvw, nchan_factors, + def finalise_bin(self, auto_corr, + time, interval, + uvw, nchan_factors, chan_width, chan_freq): """ Finalise the contents of this bin """ if self.bin_count == 0: @@ -301,10 +297,20 @@ def finalise_bin(self, auto_corr, uvw, nchan_factors, s = np.searchsorted(nchan_factors, nchan, side='left') nchan = nchan_factors[min(nchan_factors.shape[0] - 1, s)] + if rs == re: + # single value in the bin, re-use time and interval + bin_time = time[rs] + bin_interval = interval[rs] + else: + # take the midpoint + dt = time[re] - time[rs] + bin_time = 0.5*(time[re] + time[rs]) + bin_interval = 0.5*interval[re] + 0.5*interval[rs] + dt + # Finalise bin values for return out = FinaliseOutput(self.tbin, - self.time_sum / self.bin_count, - self.interval_sum, + bin_time, + bin_interval, nchan, self.bin_count == self.bin_flag_count) @@ -487,8 +493,8 @@ def update_lookups(finalised, bl): elif not binner.add_row(r, auto_corr, time, interval, uvw, flag_row): - f = binner.finalise_bin(auto_corr, uvw, - nchan_factors, + f = binner.finalise_bin(auto_corr, time, interval, + uvw, nchan_factors, chan_width, chan_freq) update_lookups(f, bl) # Post-finalisation, the bin is empty, start a new bin @@ -499,9 +505,8 @@ def update_lookups(finalised, bl): # Finalise any remaining data in the bin if not binner.empty: - f = binner.finalise_bin(auto_corr, uvw, - nchan_factors, - chan_width, chan_freq) + f = binner.finalise_bin(auto_corr, time, interval, uvw, + nchan_factors, chan_width, chan_freq) update_lookups(f, bl) nr_of_time_bins += binner.tbin diff --git a/africanus/averaging/tests/test_mapping.py b/africanus/averaging/tests/test_mapping.py index 5a8273e31..02c76ebaf 100644 --- a/africanus/averaging/tests/test_mapping.py +++ b/africanus/averaging/tests/test_mapping.py @@ -16,9 +16,8 @@ def time(): @pytest.fixture -def interval(): - data = np.asarray([1.9, 2.0, 2.1, 1.85, 1.95, 2.0, 2.05, 2.1, 2.05, 1.9]) - return data*0.1 +def interval(time): + return np.full_like(time, 1.0) @pytest.fixture @@ -108,6 +107,68 @@ def test_row_mapper(time, interval, ant1, ant2, assert_array_almost_equal(new_exp, new_exp2) +@pytest.mark.parametrize("time_bin_secs", [3.0]) +@pytest.mark.parametrize("keep", [ + [0, 1, 3, 4, 5, 7, 8, 9], + [0, 1, 2, 3, 4, 5, 7, 8, 9], +]) +def test_interpolation(time_bin_secs, keep): + time = np.linspace(1.0, 10.0, 10) + interval = np.full_like(time, 1.0, time.dtype) + + ant1 = np.full_like(time, 0, np.int32) + ant2 = np.full_like(time, 1, np.int32) + flag_row = np.full_like(time, 0, np.uint8) + + full = row_mapper(time, interval, ant1, ant2, flag_row, time_bin_secs) + + holes = row_mapper(time[keep], interval[keep], + ant1[keep], ant2[keep], + flag_row[keep], time_bin_secs) + + assert_array_almost_equal(full.time, holes.time) + assert_array_almost_equal(full.interval, holes.interval) + + +@pytest.mark.parametrize("time_bin_secs", [3.0]) +def test_interpolation_edge(time_bin_secs): + time = np.linspace(1.0, 10.0, 10) + interval = np.full_like(time, 1.0, time.dtype) + + ant1 = np.full_like(time, 0, np.int32) + ant2 = np.full_like(time, 1, np.int32) + flag_row = np.full_like(time, 0, np.uint8) + + # First and last time centroids removed + keep = [1, 2, 3, 4, 5, 6, 7, 8], + holes = row_mapper(time[keep], interval[keep], + ant1[keep], ant2[keep], + flag_row[keep], time_bin_secs) + + assert_array_almost_equal(holes.time, [3, 6, 8.5]) + assert_array_almost_equal(holes.interval, [3, 3, 2]) + + # First and last time centroids removed as well + # as an interval value + keep = [1, 2, 3, 4, 5, 6, 8], + holes = row_mapper(time[keep], interval[keep], + ant1[keep], ant2[keep], + flag_row[keep], time_bin_secs) + + assert_array_almost_equal(holes.time, [3, 6, 9]) + assert_array_almost_equal(holes.interval, [3, 3, 1]) + + # First and last time centroids removed as well + # as an internal value + keep = [1, 3, 4, 5, 6, 7, 8], + holes = row_mapper(time[keep], interval[keep], + ant1[keep], ant2[keep], + flag_row[keep], time_bin_secs) + + assert_array_almost_equal(holes.time, [3, 6, 8.5]) + assert_array_almost_equal(holes.interval, [3, 3, 2]) + + def test_channel_mapper(): chan_map, out_chans = channel_mapper(64, 17) @@ -122,3 +183,33 @@ def test_channel_mapper(): assert_array_equal(counts, [17, 17, 17, 13]) assert out_chans == 4 + + +@pytest.mark.parametrize("time_bin_secs", [3]) +def test_row_mapper2(time_bin_secs): + time = np.linspace(1.0, 10.0, 10) + interval = np.full_like(time, 1.0) + + min_time_i = time.argmin() + max_time_i = time.argmax() + + time_min = time[min_time_i] - interval[min_time_i] / 2 + time_max = time[max_time_i] + interval[max_time_i] / 2 + grid = [time_min] + next = time_min + time_bin_secs + + while next < time_max: + grid.append(next) + next += time_bin_secs + + grid.append(time_max) + grid = np.asarray(grid) + print(grid, np.diff(grid)) + + for j, (t, i) in enumerate(zip(time, interval)): + half_i = i / 2 + l = np.searchsorted(grid, t - half_i, side="left") # noqa + u = np.searchsorted(grid, t + half_i, side="left") + vals = ([((t - half_i, t + half_i), (l, u), (grid[l], grid[u]))] + + [time[k] for k in range(l, u)]) + print(*vals, sep=", ", end="\n") diff --git a/africanus/averaging/tests/test_time_and_channel_averaging.py b/africanus/averaging/tests/test_time_and_channel_averaging.py index 9b821232f..428748213 100644 --- a/africanus/averaging/tests/test_time_and_channel_averaging.py +++ b/africanus/averaging/tests/test_time_and_channel_averaging.py @@ -46,9 +46,8 @@ def uvw(): @pytest.fixture -def interval(): - data = np.asarray([1.9, 2.0, 2.1, 1.85, 1.95, 2.0, 2.05, 2.1, 2.05, 1.9]) - return 0.1 * data +def interval(time): + return np.full_like(time, 1.0) @pytest.fixture @@ -140,6 +139,37 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, # data # 2. Nominal row bin, which includes both flagged and unflagged rows + def _can_add_row(high, low): + dt = ((time[high] + 0.5*interval[high]) - + (time[low] - 0.5*interval[low])) + + if dt > time_bin_secs: + return False + + return True + + def _time_avg(nominal_rows): + if len(nominal_rows) == 0: + raise ValueError("nominal_rows == 0") + elif len(nominal_rows) == 1: + return time[nominal_rows[0]] + else: + low = nominal_rows[0] + high = nominal_rows[-1] + return 0.5*(time[high] + time[low]) + + def _int_sum(nominal_rows): + if len(nominal_rows) == 0: + raise ValueError("nominal_rows == 0") + elif len(nominal_rows) == 1: + return interval[nominal_rows[0]] + else: + low = nominal_rows[0] + high = nominal_rows[-1] + + return (0.5*interval[high] + 0.5*interval[low] + + (time[high] - time[low])) + for bl, (a1, a2) in enumerate(ubl): bl_row_idx = bl_time_lookup[bl, :] @@ -153,13 +183,13 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, if ri == -1: continue - half_int = 0.5 * interval[ri] - # We're starting a new bin if len(nominal_map) == 0: - bin_low = time[ri] - half_int + rs = ri + effective_map = [] + nominal_map = [] # Reached passed the endpoint of the bin, start a new one - elif time[ri] + half_int - bin_low > time_bin_secs: + elif not _can_add_row(ri, rs): if len(effective_map) > 0: effective_bin_map.append(effective_map) nominal_bin_map.append(nominal_map) @@ -170,6 +200,7 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, else: raise ValueError("Zero-filled bin") + rs = ri effective_map = [] nominal_map = [] @@ -190,13 +221,15 @@ def _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, effective_bin_map.append(nominal_map) nominal_bin_map.append(nominal_map) - # Produce a (avg_time, bl, effective_rows, nominal_rows) tuple - time_bl_row_map.extend((time[nrows].mean(), (a1, a2), erows, nrows) + # Produce a tuple of the form + # (avg_time, bl, interval, effective_rows, nominal_rows) + time_bl_row_map.extend((_time_avg(nrows), (a1, a2), + _int_sum(nrows), erows, nrows) for erows, nrows in zip(effective_bin_map, nominal_bin_map)) - # Sort lookup sorted on averaged times - return sorted(time_bl_row_map, key=lambda tup: tup[0]) + # Sort lookup sorted on averaged times and baselines + return sorted(time_bl_row_map, key=lambda tup: tup[:2]) def _calc_sigma(sigma, weight, idx): @@ -239,19 +272,22 @@ def test_averager(time, ant1, ant2, flagged_rows, row_meta = row_mapper(time, interval, ant1, ant2, flag_row, time_bin_secs) chan_map, chan_bins = channel_mapper(nchan, chan_bin_size) - time_bl_row_map = _gen_testing_lookup(time_centroid, exposure, ant1, ant2, + time_bl_row_map = _gen_testing_lookup(time, interval, ant1, ant2, flag_row, time_bin_secs, row_meta) # Effective and Nominal rows associated with each output row - eff_idx, nom_idx = zip(*[(nrows, erows) for _, _, nrows, erows + eff_idx, nom_idx = zip(*[(nrows, erows) for _, _, _, nrows, erows in time_bl_row_map]) eff_idx = [ei for ei in eff_idx if len(ei) > 0] - # Check that the averaged times from the test and accelerated lookup match - assert_array_equal([t for t, _, _, _ in time_bl_row_map], + # Check that the times and intervals from the test lookup + # match those of the accelerated lookup + assert_array_equal([t for t, _, _, _, _ in time_bl_row_map], row_meta.time) + assert_array_equal([i for _, _, i, _, _ in time_bl_row_map], + row_meta.interval) avg = time_and_channel(time, interval, ant1, ant2, flag_row=flag_row, @@ -266,20 +302,16 @@ def test_averager(time, ant1, ant2, flagged_rows, # Take mean time, but first ant1 and ant2 expected_time_centroids = [time_centroid[i].mean(axis=0) for i in eff_idx] - expected_times = [time[i].mean(axis=0) for i in nom_idx] expected_ant1 = [ant1[i[0]] for i in nom_idx] expected_ant2 = [ant2[i[0]] for i in nom_idx] expected_flag_row = [flag_row[i].any(axis=0) for i in eff_idx] # Take mean average, but sum of interval and exposure expected_uvw = [uvw[i].mean(axis=0) for i in eff_idx] - expected_interval = [interval[i].sum(axis=0) for i in nom_idx] expected_exposure = [exposure[i].sum(axis=0) for i in eff_idx] expected_weight = [weight[i].sum(axis=0) for i in eff_idx] expected_sigma = [_calc_sigma(sigma, weight, i) for i in eff_idx] - assert_array_equal(row_meta.time, expected_times) - assert_array_equal(row_meta.interval, expected_interval) assert_array_equal(row_meta.flag_row, expected_flag_row) assert_array_equal(avg.antenna1, expected_ant1) assert_array_equal(avg.antenna2, expected_ant2) diff --git a/africanus/averaging/time_and_channel_mapping.py b/africanus/averaging/time_and_channel_mapping.py index a3f2984a5..d0bdf8306 100644 --- a/africanus/averaging/time_and_channel_mapping.py +++ b/africanus/averaging/time_and_channel_mapping.py @@ -2,54 +2,238 @@ from collections import namedtuple +from numbers import Number import numpy as np import numba +from numba.experimental import jitclass from africanus.averaging.support import unique_time, unique_baselines -from africanus.util.numba import is_numba_type_none, generated_jit, njit, jit +from africanus.util.numba import is_numba_type_none, generated_jit, njit class RowMapperError(Exception): pass -def is_flagged_factory(have_flag_row): - if have_flag_row: - def impl(flag_row, r): - return flag_row[r] != 0 +def _numba_type(obj): + if isinstance(obj, np.ndarray): + return numba.typeof(obj.dtype).dtype + elif isinstance(obj, numba.types.npytypes.Array): + return obj.dtype + elif isinstance(obj, (np.dtype, numba.types.Type)): + return numba.typeof(obj).dtype + elif isinstance(obj, Number): + return numba.typeof(obj) else: - def impl(flag_row, r): - return False - - return njit(nogil=True, cache=True)(impl) + raise TypeError(f"Unhandled type {type(obj)}") -def output_factory(have_flag_row): - if have_flag_row: - def impl(rows, flag_row): - return np.zeros(rows, dtype=flag_row.dtype) +def binner_factory(time, interval, antenna1, antenna2, + flag_row, time_bin_secs): + if flag_row is None: + have_flag_row = False else: - def impl(rows, flag_row): - return None + have_flag_row = not is_numba_type_none(flag_row) + + class Binner: + def __init__(self, time, interval, antenna1, antenna2, + flag_row, time_bin_secs): + ubl, _, bl_inv, _ = unique_baselines(antenna1, antenna2) + utime, _, time_inv, _ = unique_time(time) + + ntime = utime.shape[0] + nbl = ubl.shape[0] + self.bl_inv = bl_inv + self.time_inv = time_inv + self.out_rows = 0 + row_lookup = np.full((nbl, ntime), -1, dtype=np.intp) + + # Create a mapping from the full bl x time resolution back + # to the original input rows + for r, (t, bl) in enumerate(zip(time_inv, bl_inv)): + if row_lookup[bl, t] == -1: + row_lookup[bl, t] = r + else: + raise ValueError("Duplicate (TIME, ANTENNA1, ANTENNA2) " + "combinations were discovered in the " + "input data. This is usually caused by " + "not partitioning your data sufficiently " + "by indexing columns, DATA_DESC_ID " + "and SCAN_NUMBER in particular.") + + sentinel = np.finfo(time.dtype).max + + self.row_lookup = row_lookup + self.time_bin_secs = time_bin_secs + self.time_lookup = np.full((nbl, ntime), sentinel, time.dtype) + self.interval_lookup = np.zeros((nbl, ntime), interval.dtype) + self.bin_flagged = np.full((nbl, ntime), False) + self.bin_lookup = np.full((nbl, ntime), -1) + + self.time = time + self.interval = interval + + if have_flag_row: + self.flag_row = flag_row + + def start_baseline(self): + self.tbin = 0 + self.bin_count = 0 + self.bin_flag_count = 0 + + def finalise_baseline(self): + self.out_rows += self.tbin + + @property + def bin_empty(self): + return self.bin_count == 0 + + def start_bin(self, bl, row): + # Establish the starting time of bin + # + # (1) Preferably use the edge of the last bin, as this + # is more accurate + # (2) Use the first discovered row in the current bin, this is less + # accurate than (1) but is best effort when data is missing + if self.tbin > 0: + last_tbin = self.tbin - 1 + self.time_start = (self.time_lookup[bl, last_tbin] + + self.interval_lookup[bl, last_tbin]*0.5) + else: + self.time_start = self.time[row] - 0.5*self.interval[row] - return njit(nogil=True, cache=True)(impl) + self.rc = row + self.bin_count = 1 + self.bin_flag_count = int(have_flag_row and + self.flag_row[row] != 0) + def add_row(self, row): + if self.rc == row: + raise ValueError("start_bin should be called " + "to start a bin before add_row " + "is called.") -def set_flag_row_factory(have_flag_row): - if have_flag_row: - def impl(flag_row, in_row, out_flag_row, out_row, flagged): - if flag_row[in_row] == 0 and flagged: - raise RowMapperError("Unflagged input row contributing " - "to flagged output row. " - "This should never happen!") + dt = ((self.time[row] + 0.5*self.interval[row]) - self.time_start) - out_flag_row[out_row] = (1 if flagged else 0) - else: - def impl(flag_row, in_row, out_flag_row, out_row, flagged): - pass + if dt > self.time_bin_secs: + return False - return njit(nogil=True, cache=True)(impl) + self.rc = row + self.bin_count += 1 + flagged = have_flag_row and self.flag_row[row] != 0 + self.bin_flag_count += int(flagged) + + return True + + def finalise_bin(self, bl, next_row): + rc = self.rc + + # No interpolation required + if self.bin_count == 1: + bin_time = self.time[rc] + bin_interval = self.interval[rc] + else: + # Interpolate between bin start and end times. + + # 1. We use the first row of the next bin to establish this + # where possible as these points determine + # the full bin extent. + # 2. Otherwise we must use the last row + # of the bin. This is not as accurate as (1), + # but is best effort in the case of missing edge data + + # Find bin ending time + if next_row != rc: + # Use the time and interval of the next row outside + # the bin to establish the end time of the bin + time_end = (self.time[next_row] - + 0.5*self.interval[next_row]) + + # But we cannot exceed the prescribed interval + if time_end - self.time_start > self.time_bin_secs: + time_end = self.time_start + self.time_bin_secs + else: + # Use the time and interval of the ending row + # to establish the end time of the bin + time_end = self.time[rc] + 0.5*self.interval[rc] + + # Establish the midpoint + bin_time = 0.5*(self.time_start + time_end) + bin_interval = time_end - self.time_start + + self.time_lookup[bl, self.tbin] = bin_time + self.interval_lookup[bl, self.tbin] = bin_interval + flagged = self.bin_count == self.bin_flag_count + self.bin_flagged[bl, self.tbin] = flagged + + self.tbin += 1 + + def execute(self): + row_lookup = self.row_lookup + bin_lookup = self.bin_lookup + + # Average times over each baseline and construct the + # bin_lookup and time_lookup arrays + for bl in range(row_lookup.shape[0]): + self.start_baseline() + + for t in range(row_lookup.shape[1]): + r = row_lookup[bl, t] + + if r == -1: + continue + + if self.bin_empty: + self.start_bin(bl, r) + elif not self.add_row(r): + # Can't add a new row to this bin, close it + # and start a new one + self.finalise_bin(bl, r) + self.start_bin(bl, r) + + # Register the output time bin for this row + bin_lookup[bl, t] = self.tbin + + # Close any open bins + if not self.bin_empty: + self.finalise_bin(bl, r) + + self.finalise_baseline() + + time = _numba_type(time) + interval = _numba_type(interval) + antenna1 = _numba_type(antenna1) + antenna2 = _numba_type(antenna2) + time_bin_secs = _numba_type(time_bin_secs) + + spec = [ + ('out_rows', numba.uintp), + ('rc', numba.intp), + ('tbin', numba.intp), + ('time_start', time), + ('bin_count', numba.uintp), + ('bin_flag_count', numba.uintp), + ('bl_inv', numba.uintp[:]), + ('time_inv', numba.uintp[:])] + + spec.extend([ + ('time_lookup', time[:, :]), + ('interval_lookup', interval[:, :]), + ('row_lookup', numba.intp[:, :]), + ('bin_lookup', numba.intp[:, :]), + ('bin_flagged', numba.bool_[:, :]), + ('time_bin_secs', time_bin_secs)]) + + spec.extend([ + ('time', time[:]), + ('interval', interval[:])]) + + if have_flag_row: + flag_row = _numba_type(flag_row) + spec.append(('flag_row', flag_row[:])) + + return jitclass(spec)(Binner) RowMapOutput = namedtuple("RowMapOutput", @@ -179,133 +363,47 @@ def row_mapper(time, interval, antenna1, antenna2, """ have_flag_row = not is_numba_type_none(flag_row) - is_flagged_fn = is_flagged_factory(have_flag_row) + have_time_bin_secs = not is_numba_type_none(time_bin_secs) + time_bin_secs_type = time_bin_secs if have_time_bin_secs else time.dtype - output_flag_row = output_factory(have_flag_row) - set_flag_row = set_flag_row_factory(have_flag_row) + JitBinner = binner_factory(time, interval, antenna1, antenna2, + flag_row, time_bin_secs_type) def impl(time, interval, antenna1, antenna2, flag_row=None, time_bin_secs=1): - ubl, _, bl_inv, _ = unique_baselines(antenna1, antenna2) - utime, _, time_inv, _ = unique_time(time) - - nbl = ubl.shape[0] - ntime = utime.shape[0] - - sentinel = np.finfo(time.dtype).max - out_rows = numba.uint32(0) - - scratch = np.full(3*nbl*ntime, -1, dtype=np.int32) - row_lookup = scratch[:nbl*ntime].reshape(nbl, ntime) - bin_lookup = scratch[nbl*ntime:2*nbl*ntime].reshape(nbl, ntime) - inv_argsort = scratch[2*nbl*ntime:] - time_lookup = np.zeros((nbl, ntime), dtype=time.dtype) - interval_lookup = np.zeros((nbl, ntime), dtype=interval.dtype) + # If we don't have time_bin_secs + # set it to the maximum floating point value, + # effectively ignoring this limit + if not have_time_bin_secs: + time_bin_secs = np.finfo(time.dtype).max - # Is the entire bin flagged? - bin_flagged = np.zeros((nbl, ntime), dtype=np.bool_) - - # Create a mapping from the full bl x time resolution back - # to the original input rows - for r in range(time.shape[0]): - bl = bl_inv[r] - t = time_inv[r] - - if row_lookup[bl, t] == -1: - row_lookup[bl, t] = r - else: - raise ValueError("Duplicate (TIME, ANTENNA1, ANTENNA2) " - "combinations were discovered in the input " - "data. This is usually caused by not " - "partitioning your data sufficiently " - "by indexing columns, DATA_DESC_ID " - "and SCAN_NUMBER in particular.") - - # Average times over each baseline and construct the - # bin_lookup and time_lookup arrays - for bl in range(ubl.shape[0]): - tbin = numba.int32(0) - bin_count = numba.int32(0) - bin_flag_count = numba.int32(0) - bin_low = time.dtype.type(0) - - for t in range(utime.shape[0]): - # Lookup input row - r = row_lookup[bl, t] - - # Ignore if not present - if r == -1: - continue - - # At this point, we decide whether to contribute to - # the current bin, or create a new one. We don't add - # the current sample to the current bin if - # high - low >= time_bin_secs - half_int = interval[r] * 0.5 - - # We're starting a new bin anyway, - # just set the lower bin value - if bin_count == 0: - bin_low = time[r] - half_int - # If we exceed the seconds in the bin, - # normalise the time and start a new bin - elif time[r] + half_int - bin_low > time_bin_secs: - # Normalise and flag the bin - # if total counts match flagged counts - if bin_count > 0: - time_lookup[bl, tbin] /= bin_count - bin_flagged[bl, tbin] = bin_count == bin_flag_count - # There was nothing in the bin - else: - time_lookup[bl, tbin] = sentinel - bin_flagged[bl, tbin] = False - - tbin += 1 - bin_count = 0 - bin_low = time[r] - half_int - bin_flag_count = 0 - - # Record the output bin associated with the row - bin_lookup[bl, t] = tbin - - # Time + Interval take unflagged + unflagged - # samples into account (nominal value) - time_lookup[bl, tbin] += time[r] - interval_lookup[bl, tbin] += interval[r] - bin_count += 1 - - # Record flags - if is_flagged_fn(flag_row, r): - bin_flag_count += 1 - - # Normalise the last bin if it has entries in it - if bin_count > 0: - time_lookup[bl, tbin] /= bin_count - bin_flagged[bl, tbin] = bin_count == bin_flag_count - tbin += 1 - - # Add this baseline's number of bins to the output rows - out_rows += tbin - - # Set any remaining bins to sentinel value and unflagged - for b in range(tbin, ntime): - time_lookup[bl, b] = sentinel - bin_flagged[bl, b] = False + binner = JitBinner(time, interval, antenna1, antenna2, + flag_row, time_bin_secs) + binner.execute() # Flatten the time lookup and argsort it - flat_time = time_lookup.ravel() - flat_int = interval_lookup.ravel() + flat_time = binner.time_lookup.ravel() + flat_int = binner.interval_lookup.ravel() argsort = np.argsort(flat_time, kind='mergesort') + inv_argsort = np.empty_like(argsort) # Generate lookup from flattened (bl, time) to output row for i, a in enumerate(argsort): inv_argsort[a] = i # Construct the final row map - row_map = np.empty((time.shape[0]), dtype=np.uint32) + row_map = np.empty(time.shape[0], dtype=np.uint32) + + nbl, ntime = binner.row_lookup.shape + out_rows = binner.out_rows + bin_lookup = binner.bin_lookup + bin_flagged = binner.bin_flagged + bl_inv = binner.bl_inv + time_inv = binner.time_inv # Construct output flag row, if necessary - out_flag_row = output_flag_row(out_rows, flag_row) + out_flag_row = (np.zeros(out_rows, dtype=flag_row.dtype) + if have_flag_row else None) # foreach input row for in_row in range(time.shape[0]): @@ -321,10 +419,14 @@ def impl(time, interval, antenna1, antenna2, if out_row >= out_rows: raise RowMapperError("out_row >= out_rows") - # Handle output row flagging - set_flag_row(flag_row, in_row, - out_flag_row, out_row, - bin_flagged[bl, tbin]) + if have_flag_row: + flagged = bin_flagged[bl, tbin] + if flag_row[in_row] == 0 and flagged: + raise RowMapperError("Unflagged input row contributing " + "to flagged output row. " + "This should never happen!") + + out_flag_row[out_row] = 1 if flagged else 0 row_map[in_row] = out_row @@ -336,7 +438,7 @@ def impl(time, interval, antenna1, antenna2, return impl -@jit(nopython=True, nogil=True, cache=True) +@njit(nogil=True, cache=True) def channel_mapper(nchan, chan_bin_size=1): chan_map = np.empty(nchan, dtype=np.uint32)