Skip to content

Commit

Permalink
Add support for maintain_order param in joins
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Jan 8, 2025
1 parent 807de8f commit 6f1741f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 29 deletions.
78 changes: 54 additions & 24 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,9 +1139,9 @@ def __init__(
self.options = options
self.children = (left, right)
self._non_child_args = (self.left_on, self.right_on, self.options)
# TODO: Implement maintain_order
if options[5] != "none":
raise NotImplementedError("maintain_order not implemented yet")
# # TODO: Implement maintain_order
# if options[5] != "none":
# raise NotImplementedError("maintain_order not implemented yet")
if any(
isinstance(e.value, expr.Literal)
for e in itertools.chain(self.left_on, self.right_on)
Expand Down Expand Up @@ -1195,6 +1195,7 @@ def _reorder_maps(
right_rows: int,
rg: plc.Column,
right_policy: plc.copying.OutOfBoundsPolicy,
maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
) -> list[plc.Column]:
"""
Reorder gather maps to satisfy polars join order restrictions.
Expand All @@ -1213,31 +1214,55 @@ def _reorder_maps(
Right gather map
right_policy
Nullify policy for right map
maintain_order
Which DataFrame row order to preserve
Returns
-------
list of reordered left and right gather maps.
Notes
-----
For a left join, the polars result preserves the order of the
left keys, and is stable wrt the right keys. For all other
joins, there is no order obligation.
"""
dt = plc.interop.to_arrow(plc.types.SIZE_TYPE)
init = plc.interop.from_arrow(pa.scalar(0, type=dt))
step = plc.interop.from_arrow(pa.scalar(1, type=dt))
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
)
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy
)

if maintain_order in {"none", "left_right", "right_left"}:
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]),
lg,
left_policy,
)
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]),
rg,
right_policy,
)
elif maintain_order == "left":
left_order = plc.copying.gather(
plc.Table([plc.filling.sequence(left_rows, init, step)]),
lg,
left_policy,
)
elif maintain_order == "right":
right_order = plc.copying.gather(
plc.Table([plc.filling.sequence(right_rows, init, step)]),
rg,
right_policy,
)
if maintain_order == "left":
sort_keys = left_order.columns()
elif maintain_order == "right":
sort_keys = right_order.columns()
elif maintain_order in {"none", "left_right"}:
sort_keys = left_order.columns() + right_order.columns()
elif maintain_order == "right_left":
sort_keys = right_order.columns() + left_order.columns()
else:
sort_keys = []
return plc.sorting.stable_sort_by_key(
plc.Table([lg, rg]),
plc.Table([*left_order.columns(), *right_order.columns()]),
[plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
[plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
plc.Table(sort_keys),
[plc.types.Order.ASCENDING] * len(sort_keys),
[plc.types.NullOrder.AFTER] * len(sort_keys),
).columns()

@classmethod
Expand All @@ -1257,7 +1282,7 @@ def do_evaluate(
right: DataFrame,
) -> DataFrame:
"""Evaluate and return a dataframe."""
how, join_nulls, zlice, suffix, coalesce, _ = options
how, join_nulls, zlice, suffix, coalesce, maintain_order = options
if how == "cross":
# Separate implementation, since cross_join returns the
# result, not the gather maps
Expand Down Expand Up @@ -1300,11 +1325,16 @@ def do_evaluate(
left, right = right, left
left_on, right_on = right_on, left_on
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
if how == "left" or how == "right":
# Order of left table is preserved
lg, rg = cls._reorder_maps(
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
)
# Reorder maps based on maintain_order
lg, rg = cls._reorder_maps(
left.num_rows,
lg,
left_policy,
right.num_rows,
rg,
right_policy,
maintain_order,
)
if coalesce and how == "inner":
right = right.discard_columns(right_on.column_names_set)
left = DataFrame.from_table(
Expand Down
18 changes: 13 additions & 5 deletions python/cudf_polars/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,20 @@ def right():


@pytest.mark.parametrize(
"maintain_order", ["left", "left_right", "right_left", "right"]
"join_expr",
[
pl.col("a"),
["c", "a"],
],
)
def test_join_maintain_order_param_unsupported(left, right, maintain_order):
q = left.join(right, on=pl.col("a"), how="inner", maintain_order=maintain_order)

assert_ir_translation_raises(q, NotImplementedError)
@pytest.mark.parametrize(
"maintain_order", ["none", "left", "right", "left_right", "right_left"]
)
def test_join_preserving_different_orderings(
left, right, how, join_expr, maintain_order
):
query = left.join(right, on=join_expr, how=how, maintain_order=maintain_order)
assert_gpu_result_equal(query)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 6f1741f

Please sign in to comment.