From c9f10ab18d34e72683e0646ebfc29fcfff3bbd68 Mon Sep 17 00:00:00 2001 From: Doug Davis Date: Fri, 27 Oct 2023 13:53:15 -0500 Subject: [PATCH] feat: add support for `Array` collection broadcasting with `Scalar` collection (#398) --- src/dask_awkward/layers/layers.py | 9 ++++++++- src/dask_awkward/lib/core.py | 5 ++++- tests/test_core.py | 21 +++++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/dask_awkward/layers/layers.py b/src/dask_awkward/layers/layers.py index 83e3058f..909dff55 100644 --- a/src/dask_awkward/layers/layers.py +++ b/src/dask_awkward/layers/layers.py @@ -261,12 +261,13 @@ def mock(self) -> MaterializedLayer: return self name = next(iter(mapping))[0] + npln = len(self.previous_layer_names) # one previous layer name # # this case is used for mocking repartition or slicing where # we maybe have multiple partitions that need to be included # in a task. - if len(self.previous_layer_names) == 1: + if npln == 1: prev_name: str = self.previous_layer_names[0] if (name, 0) in mapping: task = mapping[(name, 0)] @@ -284,6 +285,12 @@ def mock(self) -> MaterializedLayer: return MaterializedLayer({(name, 0): task}) return self + # zero previous layers; this is likely a known scalar. + # + # we just use the existing mapping + elif npln == 0: + return MaterializedLayer({(name, 0): mapping[(name, 0)]}) + # more than one previous_layer_names # # this case is needed for dak.concatenate on axis=0; we need diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 51b9e66d..26c5843a 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1562,7 +1562,7 @@ def partitionwise_layer( """ pairs: list[Any] = [] - numblocks: dict[Any, int | tuple[int, ...]] = {} + numblocks: dict[str, tuple[int, ...]] = {} for arg in args: if isinstance(arg, Array): pairs.extend([arg.name, "i"]) @@ -1575,6 +1575,9 @@ def partitionwise_layer( elif is_arraylike(arg) and is_dask_collection(arg) and arg.ndim == 1: pairs.extend([arg.name, "i"]) numblocks[arg.name] = arg.numblocks + elif isinstance(arg, Scalar): + pairs.extend([arg.name, "i"]) + numblocks[arg.name] = (1,) elif is_dask_collection(arg): raise DaskAwkwardNotImplemented( "Use of Array with other Dask collections is currently unsupported." diff --git a/tests/test_core.py b/tests/test_core.py index f77cdb78..49a8dd25 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -240,6 +240,27 @@ def test_scalar_unary_ops(op: Callable, daa: Array, caa: ak.Array) -> None: assert_eq(op(-a1), op(-a2)) +@pytest.mark.parametrize("op", [operator.add, operator.sub, operator.pow]) +def test_array_broadcast_scalar(op: Callable, daa: Array, caa: Array) -> None: + s1 = new_known_scalar(3) + s2 = 3 + r1 = op(daa.points.x, s1) + r2 = op(caa.points.x, s2) + assert_eq(r1, r2) + + s3 = dak.min(daa.points.x, axis=None) + s4 = ak.min(caa.points.x, axis=None) + r3 = op(daa.points.y, s3) + r4 = op(caa.points.y, s4) + assert_eq(r3, r4) + + s5 = dak.max(daa.points.y, axis=None) + s6 = ak.max(caa.points.y, axis=None) + r5 = op(daa.points.x, s5) + r6 = op(caa.points.x, s6) + assert_eq(r5, r6) + + @pytest.mark.parametrize( "where", [