Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for maintain_order param in joins #17698

Open
wants to merge 28 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
056ca24
[WIP] Upgrade Polars version to 1.17
Matt711 Dec 19, 2024
b952e15
xfail polars tests
Matt711 Dec 19, 2024
0d702fe
clean up
Matt711 Dec 19, 2024
e07cef8
Update all_cuda-118_arch-x86_64.yaml
Matt711 Dec 20, 2024
3005745
Update all_cuda-125_arch-x86_64.yaml
Matt711 Dec 20, 2024
2695735
Update meta.yaml
Matt711 Dec 20, 2024
01007fb
Update dependencies.yaml
Matt711 Dec 20, 2024
e01aa2e
Update pyproject.toml
Matt711 Dec 20, 2024
84830b6
Merge branch 'branch-25.02' into upgrade-polars-version
Matt711 Dec 20, 2024
7238d16
update ir
Matt711 Dec 20, 2024
5df0303
Merge branch 'branch-25.02' into upgrade-polars-version
Matt711 Jan 2, 2025
08b3a83
update copyright
Matt711 Jan 2, 2025
6022db5
Merge branch 'branch-25.02' into upgrade-polars-version
Matt711 Jan 8, 2025
1daa8c6
remove copyright changes
Matt711 Jan 8, 2025
e38d5a1
remove copyright changes
Matt711 Jan 8, 2025
712146b
remove copyright changes
Matt711 Jan 8, 2025
ad44798
Merge branch 'branch-25.02' into upgrade-polars-version
Matt711 Jan 8, 2025
807de8f
add a test
Matt711 Jan 8, 2025
6f1741f
Add support for maintain_order param in joins
Matt711 Jan 8, 2025
6a10590
merge conflict
Matt711 Jan 9, 2025
14e2508
add a test, clean up reorder gather maps
Matt711 Jan 9, 2025
c2a3be3
Merge branch 'branch-25.02' into fea/polars/support-maintain-order
Matt711 Jan 14, 2025
415b894
Merge branch 'branch-25.02' into fea/polars/support-maintain-order
Matt711 Jan 14, 2025
632c1cb
Merge branch 'branch-25.02' into fea/polars/support-maintain-order
Matt711 Jan 15, 2025
f358d58
address review
Matt711 Jan 15, 2025
82440db
clean up debug print statements
Matt711 Jan 15, 2025
f420425
Merge branch 'branch-25.02' into fea/polars/support-maintain-order
Matt711 Jan 15, 2025
6004241
xfail failing polars tests
Matt711 Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 48 additions & 24 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,9 +1139,6 @@ def __init__(
self.options = options
self.children = (left, right)
self._non_child_args = (self.left_on, self.right_on, self.options)
# TODO: Implement maintain_order
if options[5] != "none":
raise NotImplementedError("maintain_order not implemented yet")
if any(
isinstance(e.value, expr.Literal)
for e in itertools.chain(self.left_on, self.right_on)
Expand Down Expand Up @@ -1195,6 +1192,7 @@ def _reorder_maps(
right_rows: int,
rg: plc.Column,
right_policy: plc.copying.OutOfBoundsPolicy,
maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
) -> list[plc.Column]:
"""
Reorder gather maps to satisfy polars join order restrictions.
Expand All @@ -1213,31 +1211,52 @@ def _reorder_maps(
Right gather map
right_policy
Nullify policy for right map
maintain_order
Which DataFrame row order to preserve

Returns
-------
list of reordered left and right gather maps.

Notes
-----
For a left join, the polars result preserves the order of the
left keys, and is stable wrt the right keys. For all other
joins, there is no order obligation.
"""
dt = plc.interop.to_arrow(plc.types.SIZE_TYPE)
init = plc.interop.from_arrow(pa.scalar(0, type=dt))
step = plc.interop.from_arrow(pa.scalar(1, type=dt))
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
)
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy
)
if maintain_order in {"none", "left_right", "right_left"}:
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]),
lg,
left_policy,
)
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]),
rg,
right_policy,
)
elif maintain_order == "left":
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]),
lg,
left_policy,
)
elif maintain_order == "right":
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]),
rg,
right_policy,
)
if maintain_order == "left":
sort_keys = left_order.columns()
elif maintain_order == "right":
Copy link
Contributor Author

@Matt711 Matt711 Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the reviewer: This PR needs more work, but I'm opening it up for review so I can get some help handling a special case: full joins (where we maintain the order of the right table). Specifically, the case where the test fails is when there are unmatched keys in the left dataframe. Any advice on how to handle this?

Example:

left = pl.LazyFrame(
    {
        "a": [1, 2, 3, 1, None],
        "b": [1, 2, 3, 4, 5],
        "c": [2, 3, 4, 5, 6],
    }
)
right = pl.LazyFrame(
    {
        "a": [1, 4, 3, 7, None, None, 1],
        "c": [2, 3, 4, 5, 6, 7, 8],
        "d": [6, None, 7, 8, -1, 2, 4],
    }
)
q = left.join(right, on=pl.col("a"), how="full", maintain_order="right")
q.collect(engine="gpu")

The dataframe differ at column "a"

AssertionError: DataFrames are different (value mismatch for column 'a')
[left]:  [1, 1, None, 3, None, None, None, 1, 1, **None, 2]**
[right]: [1, 1, None, 3, None, None, None, 1, 1, **2, None]**

The a=2 entry is unmatched in the right dataframe, so it should be appended to the end of the result, not included with the other matches.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is expected. Because those last two rows have the same sort key in the right table column, there's no disambiguator to decide which order the left column result comes in.

e.g.

(No GPU engine involved):

In [18]: q = left.join(right, on=pl.col("a"), how="full", maintain_order="right")

In [19]: q.collect(engine="cpu")
Out[19]: 
shape: (11, 2)
┌──────┬─────────┐
│ a    ┆ a_right │
│ ---  ┆ ---     │
│ i64  ┆ i64     │
╞══════╪═════════╡
│ 1    ┆ 1       │
│ 1    ┆ 1       │
│ null ┆ 4       │
│ 3    ┆ 3       │
│ null ┆ 7       │
│ …    ┆ …       │
│ null ┆ null    │
│ 1    ┆ 1       │
│ 1    ┆ 1       │
│ null ┆ null    │
│ 2    ┆ null    │
└──────┴─────────┘

In [20]: q.collect(engine="cpu")
Out[20]: 
shape: (11, 2)
┌──────┬─────────┐
│ a    ┆ a_right │
│ ---  ┆ ---     │
│ i64  ┆ i64     │
╞══════╪═════════╡
│ 1    ┆ 1       │
│ 1    ┆ 1       │
│ null ┆ 4       │
│ 3    ┆ 3       │
│ null ┆ 7       │
│ …    ┆ …       │
│ null ┆ null    │
│ 1    ┆ 1       │
│ 1    ┆ 1       │
│ 2    ┆ null    │
│ null ┆ null    │
└──────┴─────────┘

Notice how the last two rows flip around between the two runs.

sort_keys = right_order.columns()
elif maintain_order in {"none", "left_right"}:
sort_keys = left_order.columns() + right_order.columns()
elif maintain_order == "right_left":
sort_keys = right_order.columns() + left_order.columns()
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
return plc.sorting.stable_sort_by_key(
plc.Table([lg, rg]),
plc.Table([*left_order.columns(), *right_order.columns()]),
[plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
[plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
plc.Table(sort_keys),
[plc.types.Order.ASCENDING] * len(sort_keys),
[plc.types.NullOrder.AFTER] * len(sort_keys),
).columns()

@classmethod
Expand All @@ -1257,7 +1276,7 @@ def do_evaluate(
right: DataFrame,
) -> DataFrame:
"""Evaluate and return a dataframe."""
how, join_nulls, zlice, suffix, coalesce, _ = options
how, join_nulls, zlice, suffix, coalesce, maintain_order = options
if how == "cross":
# Separate implementation, since cross_join returns the
# result, not the gather maps
Expand Down Expand Up @@ -1300,11 +1319,16 @@ def do_evaluate(
left, right = right, left
left_on, right_on = right_on, left_on
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
if how == "left" or how == "right":
# Order of left table is preserved
lg, rg = cls._reorder_maps(
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
)
# Reorder maps based on maintain_order
lg, rg = cls._reorder_maps(
left.num_rows,
lg,
left_policy,
right.num_rows,
rg,
right_policy,
maintain_order,
)
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
if coalesce and how == "inner":
right = right.discard_columns(right_on.column_names_set)
left = DataFrame.from_table(
Expand Down
16 changes: 12 additions & 4 deletions python/cudf_polars/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,21 @@ def right():
)


@pytest.mark.parametrize(
"join_expr",
[
pl.col("a"),
pl.col("a") * 2,
[pl.col("a"), pl.col("c") + 1],
["c", "a"],
],
)
@pytest.mark.parametrize(
"maintain_order", ["left", "left_right", "right_left", "right"]
)
def test_join_maintain_order_param_unsupported(left, right, maintain_order):
q = left.join(right, on=pl.col("a"), how="inner", maintain_order=maintain_order)

assert_ir_translation_raises(q, NotImplementedError)
def test_order_preserving_joins(left, right, how, join_expr, maintain_order):
query = left.join(right, on=join_expr, how=how, maintain_order=maintain_order)
assert_gpu_result_equal(query)


@pytest.mark.parametrize(
Expand Down
Loading