Skip to content

Commit

Permalink
feat: add support for Array collection broadcasting with Scalar c…
Browse files Browse the repository at this point in the history
…ollection (#398)
  • Loading branch information
douglasdavis authored Oct 27, 2023
1 parent 816cb30 commit c9f10ab
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/dask_awkward/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,13 @@ def mock(self) -> MaterializedLayer:
return self
name = next(iter(mapping))[0]

npln = len(self.previous_layer_names)
# one previous layer name
#
# this case is used for mocking repartition or slicing where
# we maybe have multiple partitions that need to be included
# in a task.
if len(self.previous_layer_names) == 1:
if npln == 1:
prev_name: str = self.previous_layer_names[0]
if (name, 0) in mapping:
task = mapping[(name, 0)]
Expand All @@ -284,6 +285,12 @@ def mock(self) -> MaterializedLayer:
return MaterializedLayer({(name, 0): task})
return self

# zero previous layers; this is likely a known scalar.
#
# we just use the existing mapping
elif npln == 0:
return MaterializedLayer({(name, 0): mapping[(name, 0)]})

# more than one previous_layer_names
#
# this case is needed for dak.concatenate on axis=0; we need
Expand Down
5 changes: 4 additions & 1 deletion src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ def partitionwise_layer(
"""
pairs: list[Any] = []
numblocks: dict[Any, int | tuple[int, ...]] = {}
numblocks: dict[str, tuple[int, ...]] = {}
for arg in args:
if isinstance(arg, Array):
pairs.extend([arg.name, "i"])
Expand All @@ -1575,6 +1575,9 @@ def partitionwise_layer(
elif is_arraylike(arg) and is_dask_collection(arg) and arg.ndim == 1:
pairs.extend([arg.name, "i"])
numblocks[arg.name] = arg.numblocks
elif isinstance(arg, Scalar):
pairs.extend([arg.name, "i"])
numblocks[arg.name] = (1,)
elif is_dask_collection(arg):
raise DaskAwkwardNotImplemented(
"Use of Array with other Dask collections is currently unsupported."
Expand Down
21 changes: 21 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,27 @@ def test_scalar_unary_ops(op: Callable, daa: Array, caa: ak.Array) -> None:
assert_eq(op(-a1), op(-a2))


@pytest.mark.parametrize("op", [operator.add, operator.sub, operator.pow])
def test_array_broadcast_scalar(op: Callable, daa: Array, caa: Array) -> None:
s1 = new_known_scalar(3)
s2 = 3
r1 = op(daa.points.x, s1)
r2 = op(caa.points.x, s2)
assert_eq(r1, r2)

s3 = dak.min(daa.points.x, axis=None)
s4 = ak.min(caa.points.x, axis=None)
r3 = op(daa.points.y, s3)
r4 = op(caa.points.y, s4)
assert_eq(r3, r4)

s5 = dak.max(daa.points.y, axis=None)
s6 = ak.max(caa.points.y, axis=None)
r5 = op(daa.points.x, s5)
r6 = op(caa.points.x, s6)
assert_eq(r5, r6)


@pytest.mark.parametrize(
"where",
[
Expand Down

0 comments on commit c9f10ab

Please sign in to comment.