diff --git a/africanus/experimental/rime/fused/terms/cube_dde.py b/africanus/experimental/rime/fused/terms/cube_dde.py index aaa87e8c..8bc86d33 100644 --- a/africanus/experimental/rime/fused/terms/cube_dde.py +++ b/africanus/experimental/rime/fused/terms/cube_dde.py @@ -174,6 +174,9 @@ def beam( (nsrc, ntime, nfeed, nantenna, nchan, ncorr), beam.dtype ) + corr_sum = np.zeros(ncorr, beam.dtype) + absc_sum = np.zeros(ncorr, beam.real.dtype) + for s in range(nsrc): l = lm[s, 0] # noqa m = lm[s, 1] @@ -231,114 +234,68 @@ def beam( ld = vl - gl0 md = vm - gm0 - corr_sum = zero_vis(beam.dtype.type(0)) - absc_sum = zero_vis(beam.real.dtype.type(0)) + # Zero the accumulators + for co in range(ncorr): + absc_sum[co] = 0 + corr_sum[co] = 0 # Lower cube weight = (1.0 - ld) * (1.0 - md) * nud for co in range(ncorr): value = beam[gl0, gm0, gc0, co] - absc_sum = tuple_setitem( - absc_sum, - co, - weight * np.abs(value) + absc_sum[co], - ) - corr_sum = tuple_setitem( - corr_sum, co, weight * value + corr_sum[co] - ) + absc_sum[co] += weight * np.abs(value) + corr_sum[co] += weight * value weight = ld * (1.0 - md) * nud for co in range(ncorr): value = beam[gl1, gm0, gc0, co] - absc_sum = tuple_setitem( - absc_sum, - co, - weight * np.abs(value) + absc_sum[co], - ) - corr_sum = tuple_setitem( - corr_sum, co, weight * value + corr_sum[co] - ) + absc_sum[co] += weight * np.abs(value) + corr_sum[co] += weight * value weight = (1.0 - ld) * md * nud for co in range(ncorr): value = beam[gl0, gm1, gc0, co] - absc_sum = tuple_setitem( - absc_sum, - co, - weight * np.abs(value) + absc_sum[co], - ) - corr_sum = tuple_setitem( - corr_sum, co, weight * value + corr_sum[co] - ) + absc_sum[co] += weight * np.abs(value) + corr_sum[co] += weight * value weight = ld * md * nud for co in range(ncorr): value = beam[gl1, gm1, gc0, co] - absc_sum = tuple_setitem( - absc_sum, - co, - weight * np.abs(value) + absc_sum[co], - ) - corr_sum = tuple_setitem( - corr_sum, co, weight * value + corr_sum[co] - ) + absc_sum[co] += weight * np.abs(value) + corr_sum[co] += weight * value # Upper cube weight = (1.0 - ld) * (1.0 - md) * inv_nud for co in range(ncorr): value = beam[gl0, gm0, gc1, co] - absc_sum = tuple_setitem( - absc_sum, - co, - weight * np.abs(value) + absc_sum[co], - ) - corr_sum = tuple_setitem( - corr_sum, co, weight * value + corr_sum[co] - ) + absc_sum[co] += weight * np.abs(value) + corr_sum[co] += weight * value weight = ld * (1.0 - md) * inv_nud for co in range(ncorr): value = beam[gl1, gm0, gc1, co] - absc_sum = tuple_setitem( - absc_sum, - co, - weight * np.abs(value) + absc_sum[co], - ) - corr_sum = tuple_setitem( - corr_sum, co, weight * value + corr_sum[co] - ) + absc_sum[co] += weight * np.abs(value) + corr_sum[co] += weight * value weight = (1.0 - ld) * md * inv_nud for co in range(ncorr): value = beam[gl0, gm1, gc1, co] - absc_sum = tuple_setitem( - absc_sum, - co, - weight * np.abs(value) + absc_sum[co], - ) - corr_sum = tuple_setitem( - corr_sum, co, weight * value + corr_sum[co] - ) + absc_sum[co] += weight * np.abs(value) + corr_sum[co] += weight * value weight = ld * md * inv_nud for co in range(ncorr): value = beam[gl1, gm1, gc1, co] - absc_sum = tuple_setitem( - absc_sum, - co, - weight * np.abs(value) + absc_sum[co], - ) - corr_sum = tuple_setitem( - corr_sum, co, weight * value + corr_sum[co] - ) + absc_sum[co] += weight * np.abs(value) + corr_sum[co] += weight * value for co in range(ncorr): div = np.abs(corr_sum[co]) @@ -347,10 +304,7 @@ def beam( if div != 0.0: value /= div - corr_sum = tuple_setitem(corr_sum, co, value) - - for co in range(ncorr): - sampled_beam[s, t, f, a, c, co] = corr_sum[co] + sampled_beam[s, t, f, a, c, co] = value return sampled_beam