Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic multi-partition GroupBy support to cuDF-Polars #17503

Open
wants to merge 16 commits into
base: branch-25.04
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Parallel GroupBy Logic."""

from __future__ import annotations

import operator
from functools import reduce
from typing import TYPE_CHECKING, Any

import pylibcudf as plc

from cudf_polars.dsl.expr import Agg, BinOp, Cast, Col, Len, NamedExpr
from cudf_polars.dsl.ir import GroupBy, Select
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.expr import Expr
from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.parallel import LowerIRTransformer


class GroupByTree(GroupBy):
"""Groupby tree-reduction operation."""
rjzamora marked this conversation as resolved.
Show resolved Hide resolved


_GB_AGG_SUPPORTED = ("sum", "count", "mean")


def _single_fallback(
ir: IR,
children: tuple[IR],
partition_info: MutableMapping[IR, PartitionInfo],
unsupported_agg: Expr | None = None,
):
if any(partition_info[child].count > 1 for child in children): # pragma: no cover
msg = f"Class {type(ir)} does not support multiple partitions."
if unsupported_agg:
msg = msg[:-1] + f" with {unsupported_agg} expression."
raise NotImplementedError(msg)

new_node = ir.reconstruct(children)
partition_info[new_node] = PartitionInfo(count=1)
return new_node, partition_info


@lower_ir_node.register(GroupBy)
def _(
ir: GroupBy, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Lower children
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
partition_info = reduce(operator.or_, _partition_info)

if partition_info[children[0]].count == 1:
# Single partition
return _single_fallback(ir, children, partition_info)

# Check that we are grouping on element-wise
# keys (is this already guaranteed?)
for ne in ir.keys:
if not isinstance(ne.value, Col): # pragma: no cover
return _single_fallback(ir, children, partition_info)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by elementwise keys? It's certainly not the case that we always group on columns. But I think it is the case that the group keys (if expressions) are trivially elementwise (e.g. a + b as a key is fine, but a.unique() or a.sort() is not)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I'm being extra cautious by requiring the keys to be Col. This comment is essentially asking: "can we drop this check altogether? ie. Will the keys always be element-wise?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened pola-rs/polars#20152 as well


name_map: MutableMapping[str, Any] = {}
agg_tree: Cast | Agg | None = None
agg_requests_pwise = [] # Partition-wise requests
agg_requests_tree = [] # Tree-node requests

for ne in ir.agg_requests:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to think about this (and possibly reorganise what we're doing in the single-partition case) to make this easier to handle.

For example, I think it is going to do the wrong thing for .agg(a.max() + b.min())

I think what you're trying to do here is turn a GroupBy(df, keys, aggs) into Reduce(LocalGroupBy(df, keys, agg_exprs), keys, transformed_aggs)

And what does this look like, I think once we've determined the "leaf" aggregations we're performing (e.g. col.max()) then we must concat and combine to get the full leaf aggregations, followed by evaluation of the column expressions that produce the final result.

So suppose we have determined what the leaf aggs are, and then what the post-aggregation expressions are, for a single-partition this is effectively Select(GroupBy(df, keys, leaf_aggs), keys, post_agg_exprs) where post_agg_exprs are all guaranteed elementwise (for now).

thought: Would it be easier for you here if the GroupBy IR nodes really only held aggregation expressions that are "leaf" aggregations (with the post-processing done in a Select)?

I think it would, because then the transform becomes something like:

Select(
   GroupByCombine(GroupBy(df, keys, leaf_aggs), keys, post_aggs),
   keys, post_agg_exprs
)

Where groupbycombine emits the tree-reduction tasks with the post aggregations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: Would it be easier for you here if the GroupBy IR nodes really only held aggregation expressions that are "leaf" aggregations (with the post-processing done in a Select)?

I'm pretty sure the answer is "yes" :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick follow-up: I totally agree that we probably want to revise the upstream GroupBy design to make the decomposition here a bit simpler. With that said, I don't think we are doing anything "wrong" here. Rather, the code would just need to become unnecessarily messy if we wanted to do much more than "simple" mean/count/min/max aggregations.

For example, I think it is going to do the wrong thing for .agg(a.max() + b.min())

We won't do the "wrong" thing here - We will just raise an error. E.g.:

polars.exceptions.ComputeError: NotImplementedError: GroupBy does not support multiple partitions for this expression:
BinOp(<pylibcudf.types.DataType object at 0x7f06ebcc63b0>, <binary_operator.ADD: 0>, Cast(<pylibcudf.types.DataType object at 0x7f06ebcc63b0>, Agg(<pylibcudf.types.DataType object at 0x7f06ebcc6370>, 'max', False, Col(<pylibcudf.types.DataType object at 0x7f06ebcc6370>, 'x'))), Agg(<pylibcudf.types.DataType object at 0x7f06ebcc63b0>, 'max', False, Col(<pylibcudf.types.DataType object at 0x7f06ebcc63b0>, 'z')))

name = ne.name
agg: Expr = ne.value
dtype = agg.dtype
agg = agg.children[0] if isinstance(agg, Cast) else agg
if isinstance(agg, Len):
agg_requests_pwise.append(ne)
agg_requests_tree.append(
NamedExpr(
name,
Cast(
dtype,
Agg(dtype, "sum", None, Col(dtype, name)),
),
)
)
elif isinstance(agg, Agg):
if agg.name not in _GB_AGG_SUPPORTED:
return _single_fallback(ir, children, partition_info, agg)

if agg.name in ("sum", "count"):
agg_requests_pwise.append(ne)
agg_requests_tree.append(
NamedExpr(
name,
Cast(
dtype,
Agg(dtype, "sum", agg.options, Col(dtype, name)),
),
)
)
elif agg.name == "mean":
name_map[name] = {agg.name: {}}
for sub in ["sum", "count"]:
# Partwise
tmp_name = f"{name}__{sub}"
name_map[name][agg.name][sub] = tmp_name
agg_pwise = Agg(dtype, sub, agg.options, *agg.children)
agg_requests_pwise.append(NamedExpr(tmp_name, agg_pwise))
# Tree
child = Col(dtype, tmp_name)
agg_tree = Agg(dtype, "sum", agg.options, child)
agg_requests_tree.append(NamedExpr(tmp_name, agg_tree))
else:
# Unsupported
return _single_fallback(
ir, children, partition_info, agg
) # pragma: no cover

gb_pwise = GroupBy(
ir.schema,
ir.keys,
agg_requests_pwise,
ir.maintain_order,
ir.options,
*children,
)
child_count = partition_info[children[0]].count
partition_info[gb_pwise] = PartitionInfo(count=child_count)

gb_tree = GroupByTree(
ir.schema,
ir.keys,
agg_requests_tree,
ir.maintain_order,
ir.options,
gb_pwise,
)
partition_info[gb_tree] = PartitionInfo(count=1)

schema = ir.schema
output_exprs = []
for name, dtype in schema.items():
agg_mapping = name_map.get(name, None)
if agg_mapping is None:
output_exprs.append(NamedExpr(name, Col(dtype, name)))
elif "mean" in agg_mapping:
mean_cols = agg_mapping["mean"]
output_exprs.append(
NamedExpr(
name,
BinOp(
dtype,
plc.binaryop.BinaryOperator.DIV,
Col(dtype, mean_cols["sum"]),
Col(dtype, mean_cols["count"]),
),
)
)
should_broadcast: bool = False
new_node = Select(
schema,
output_exprs,
should_broadcast,
gb_tree,
)
partition_info[new_node] = PartitionInfo(count=1)
return new_node, partition_info


def _tree_node(do_evaluate, batch, *args):
return do_evaluate(*args, _concat(batch))
rjzamora marked this conversation as resolved.
Show resolved Hide resolved


@generate_ir_tasks.register(GroupByTree)
def _(
ir: GroupByTree, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
child = ir.children[0]
child_count = partition_info[child].count
child_name = get_key_name(child)
name = get_key_name(ir)

# Simple tree reduction.
j = 0
graph: MutableMapping[Any, Any] = {}
split_every = 32
keys: list[Any] = [(child_name, i) for i in range(child_count)]
while len(keys) > split_every:
new_keys: list[Any] = []
for i, k in enumerate(range(0, len(keys), split_every)):
batch = keys[k : k + split_every]
graph[(name, j, i)] = (
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
_tree_node,
ir.do_evaluate,
batch,
*ir._non_child_args,
)
new_keys.append((name, j, i))
j += 1
keys = new_keys
graph[(name, 0)] = (
_tree_node,
ir.do_evaluate,
keys,
*ir._non_child_args,
)
return graph
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

import cudf_polars.experimental.groupby
import cudf_polars.experimental.io # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Projection, Union
from cudf_polars.dsl.ir import IR, Cache, GroupBy, Projection, Union
from cudf_polars.dsl.traversal import CachingVisitor, traversal
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import (
Expand Down Expand Up @@ -243,5 +244,6 @@ def _generate_ir_tasks_pwise(
}


generate_ir_tasks.register(GroupBy, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise)
53 changes: 53 additions & 0 deletions python/cudf_polars/tests/experimental/test_groupby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(scope="module")
def engine():
return pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"max_rows_per_partition": 4},
)


@pytest.fixture(scope="module")
def df():
return pl.LazyFrame(
{
"x": range(150),
"y": ["cat", "dog", "fish"] * 50,
"z": [1.0, 2.0, 3.0, 4.0, 5.0] * 30,
}
)


@pytest.mark.parametrize("op", ["sum", "mean", "len"])
@pytest.mark.parametrize("keys", [("y",), ("y", "z")])
def test_groupby(df, engine, op, keys):
q = getattr(df.group_by(*keys), op)()
assert_gpu_result_equal(q, engine=engine, check_row_order=False)


@pytest.mark.parametrize("op", ["sum", "mean", "len", "count"])
@pytest.mark.parametrize("keys", [("y",), ("y", "z")])
def test_groupby_agg(df, engine, op, keys):
q = df.group_by(*keys).agg(getattr(pl.col("x"), op)())
assert_gpu_result_equal(q, engine=engine, check_row_order=False)


def test_groupby_raises(df, engine):
q = df.group_by("y").median()
with pytest.raises(
pl.exceptions.ComputeError,
match="NotImplementedError",
):
assert_gpu_result_equal(q, engine=engine, check_row_order=False)
Loading