Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for maintain_order param in joins #17698

Open
wants to merge 21 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 48 additions & 24 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,9 +1139,6 @@ def __init__(
self.options = options
self.children = (left, right)
self._non_child_args = (self.left_on, self.right_on, self.options)
# TODO: Implement maintain_order
if options[5] != "none":
raise NotImplementedError("maintain_order not implemented yet")
if any(
isinstance(e.value, expr.Literal)
for e in itertools.chain(self.left_on, self.right_on)
Expand Down Expand Up @@ -1195,6 +1192,7 @@ def _reorder_maps(
right_rows: int,
rg: plc.Column,
right_policy: plc.copying.OutOfBoundsPolicy,
maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
maintain_order: Literal["left", "right", "left_right", "right_left"],

Or accept "none" but just return the input maps immediately.

) -> list[plc.Column]:
"""
Reorder gather maps to satisfy polars join order restrictions.
Expand All @@ -1213,31 +1211,52 @@ def _reorder_maps(
Right gather map
right_policy
Nullify policy for right map
maintain_order
Which DataFrame row order to preserve

Returns
-------
list of reordered left and right gather maps.

Notes
-----
For a left join, the polars result preserves the order of the
left keys, and is stable wrt the right keys. For all other
joins, there is no order obligation.
"""
dt = plc.interop.to_arrow(plc.types.SIZE_TYPE)
init = plc.interop.from_arrow(pa.scalar(0, type=dt))
step = plc.interop.from_arrow(pa.scalar(1, type=dt))
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
)
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy
)
if maintain_order in {"none", "left_right", "right_left"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue/question: If we have no obligation maintain_order == "none" I think we should not be doing any work, what is happening here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're correct, I'll also need to update the other tests in the suite since polars defaults to "none" . So I'll add maintain_order="left" to ensure those tests are reproducible.

left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]),
lg,
left_policy,
)
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]),
rg,
right_policy,
)
elif maintain_order == "left":
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]),
lg,
left_policy,
)
elif maintain_order == "right":
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]),
rg,
right_policy,
)
if maintain_order == "left":
sort_keys = left_order.columns()
elif maintain_order == "right":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the reviewer: This PR needs more work, but I'm opening it up for review so I can get some help handling a special case: full right joins. Specifically, the case where the test fails is when there are unmatched keys in the left dataframe. Any advice on how to handle this?

Example:

left = pl.LazyFrame(
    {
        "a": [1, 2, 3, 1, None],
        "b": [1, 2, 3, 4, 5],
        "c": [2, 3, 4, 5, 6],
    }
)
right = pl.LazyFrame(
    {
        "a": [1, 4, 3, 7, None, None, 1],
        "c": [2, 3, 4, 5, 6, 7, 8],
        "d": [6, None, 7, 8, -1, 2, 4],
    }
)
q = left.join(right, on=pl.col("a"), how="full", maintain_order="right")
q.collect(engine="gpu")

The dataframe differ at column "a"

AssertionError: DataFrames are different (value mismatch for column 'a')
[left]:  [1, 1, None, 3, None, None, None, 1, 1, **None, 2]**
[right]: [1, 1, None, 3, None, None, None, 1, 1, **2, None]**

The a=2 entry is unmatched in the right dataframe, so it should be appended to the end of the result, not included with the other matches.

sort_keys = right_order.columns()
elif maintain_order in {"none", "left_right"}:
sort_keys = left_order.columns() + right_order.columns()
elif maintain_order == "right_left":
sort_keys = right_order.columns() + left_order.columns()
Comment on lines +1225 to +1254
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid repetition here by just immediately making the sort_keys list?

return plc.sorting.stable_sort_by_key(
plc.Table([lg, rg]),
plc.Table([*left_order.columns(), *right_order.columns()]),
[plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
[plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
plc.Table(sort_keys),
[plc.types.Order.ASCENDING] * len(sort_keys),
[plc.types.NullOrder.AFTER] * len(sort_keys),
).columns()

@classmethod
Expand All @@ -1257,7 +1276,7 @@ def do_evaluate(
right: DataFrame,
) -> DataFrame:
"""Evaluate and return a dataframe."""
how, join_nulls, zlice, suffix, coalesce, _ = options
how, join_nulls, zlice, suffix, coalesce, maintain_order = options
if how == "cross":
# Separate implementation, since cross_join returns the
# result, not the gather maps
Expand Down Expand Up @@ -1300,11 +1319,16 @@ def do_evaluate(
left, right = right, left
left_on, right_on = right_on, left_on
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
if how == "left" or how == "right":
# Order of left table is preserved
lg, rg = cls._reorder_maps(
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
)
# Reorder maps based on maintain_order
lg, rg = cls._reorder_maps(
left.num_rows,
lg,
left_policy,
right.num_rows,
rg,
right_policy,
maintain_order,
)
Comment on lines +1322 to +1331
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Reorder maps based on maintain_order
lg, rg = cls._reorder_maps(
left.num_rows,
lg,
left_policy,
right.num_rows,
rg,
right_policy,
maintain_order,
)
if maintain_order != "none":
lg, rg = cls._reorder_maps(
left.num_rows,
lg,
left_policy,
right.num_rows,
rg,
right_policy,
maintain_order,
)

if coalesce and how == "inner":
right = right.discard_columns(right_on.column_names_set)
left = DataFrame.from_table(
Expand Down
16 changes: 12 additions & 4 deletions python/cudf_polars/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,21 @@ def right():
)


@pytest.mark.parametrize(
"join_expr",
[
pl.col("a"),
pl.col("a") * 2,
[pl.col("a"), pl.col("c") + 1],
["c", "a"],
],
)
@pytest.mark.parametrize(
"maintain_order", ["left", "left_right", "right_left", "right"]
)
def test_join_maintain_order_param_unsupported(left, right, maintain_order):
q = left.join(right, on=pl.col("a"), how="inner", maintain_order=maintain_order)

assert_ir_translation_raises(q, NotImplementedError)
def test_order_preserving_joins(left, right, how, join_expr, maintain_order):
query = left.join(right, on=join_expr, how=how, maintain_order=maintain_order)
assert_gpu_result_equal(query)


@pytest.mark.parametrize(
Expand Down
Loading