diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index fd56329a48e..6d5684edbc3 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -1139,9 +1139,9 @@ def __init__( self.options = options self.children = (left, right) self._non_child_args = (self.left_on, self.right_on, self.options) - # TODO: Implement maintain_order - if options[5] != "none": - raise NotImplementedError("maintain_order not implemented yet") + # # TODO: Implement maintain_order + # if options[5] != "none": + # raise NotImplementedError("maintain_order not implemented yet") if any( isinstance(e.value, expr.Literal) for e in itertools.chain(self.left_on, self.right_on) @@ -1195,6 +1195,7 @@ def _reorder_maps( right_rows: int, rg: plc.Column, right_policy: plc.copying.OutOfBoundsPolicy, + maintain_order: Literal["none", "left", "right", "left_right", "right_left"], ) -> list[plc.Column]: """ Reorder gather maps to satisfy polars join order restrictions. @@ -1213,31 +1214,55 @@ def _reorder_maps( Right gather map right_policy Nullify policy for right map + maintain_order + Which DataFrame row order to preserve Returns ------- list of reordered left and right gather maps. - - Notes - ----- - For a left join, the polars result preserves the order of the - left keys, and is stable wrt the right keys. For all other - joins, there is no order obligation. """ dt = plc.interop.to_arrow(plc.types.SIZE_TYPE) init = plc.interop.from_arrow(pa.scalar(0, type=dt)) step = plc.interop.from_arrow(pa.scalar(1, type=dt)) - left_order = plc.copying.gather( - plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy - ) - right_order = plc.copying.gather( - plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy - ) + + if maintain_order in {"none", "left_right", "right_left"}: + left_order = plc.copying.gather( + plc.Table([plc.filling.sequence(left_rows, init, step)]), + lg, + left_policy, + ) + right_order = plc.copying.gather( + plc.Table([plc.filling.sequence(right_rows, init, step)]), + rg, + right_policy, + ) + elif maintain_order == "left": + left_order = plc.copying.gather( + plc.Table([plc.filling.sequence(left_rows, init, step)]), + lg, + left_policy, + ) + elif maintain_order == "right": + right_order = plc.copying.gather( + plc.Table([plc.filling.sequence(right_rows, init, step)]), + rg, + right_policy, + ) + if maintain_order == "left": + sort_keys = left_order.columns() + elif maintain_order == "right": + sort_keys = right_order.columns() + elif maintain_order in {"none", "left_right"}: + sort_keys = left_order.columns() + right_order.columns() + elif maintain_order == "right_left": + sort_keys = right_order.columns() + left_order.columns() + else: + sort_keys = [] return plc.sorting.stable_sort_by_key( plc.Table([lg, rg]), - plc.Table([*left_order.columns(), *right_order.columns()]), - [plc.types.Order.ASCENDING, plc.types.Order.ASCENDING], - [plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER], + plc.Table(sort_keys), + [plc.types.Order.ASCENDING] * len(sort_keys), + [plc.types.NullOrder.AFTER] * len(sort_keys), ).columns() @classmethod @@ -1257,7 +1282,7 @@ def do_evaluate( right: DataFrame, ) -> DataFrame: """Evaluate and return a dataframe.""" - how, join_nulls, zlice, suffix, coalesce, _ = options + how, join_nulls, zlice, suffix, coalesce, maintain_order = options if how == "cross": # Separate implementation, since cross_join returns the # result, not the gather maps @@ -1300,11 +1325,16 @@ def do_evaluate( left, right = right, left left_on, right_on = right_on, left_on lg, rg = join_fn(left_on.table, right_on.table, null_equality) - if how == "left" or how == "right": - # Order of left table is preserved - lg, rg = cls._reorder_maps( - left.num_rows, lg, left_policy, right.num_rows, rg, right_policy - ) + # Reorder maps based on maintain_order + lg, rg = cls._reorder_maps( + left.num_rows, + lg, + left_policy, + right.num_rows, + rg, + right_policy, + maintain_order, + ) if coalesce and how == "inner": right = right.discard_columns(right_on.column_names_set) left = DataFrame.from_table( diff --git a/python/cudf_polars/tests/test_join.py b/python/cudf_polars/tests/test_join.py index f1f47bfb9f1..0a08d46525b 100644 --- a/python/cudf_polars/tests/test_join.py +++ b/python/cudf_polars/tests/test_join.py @@ -54,12 +54,20 @@ def right(): @pytest.mark.parametrize( - "maintain_order", ["left", "left_right", "right_left", "right"] + "join_expr", + [ + pl.col("a"), + ["c", "a"], + ], ) -def test_join_maintain_order_param_unsupported(left, right, maintain_order): - q = left.join(right, on=pl.col("a"), how="inner", maintain_order=maintain_order) - - assert_ir_translation_raises(q, NotImplementedError) +@pytest.mark.parametrize( + "maintain_order", ["none", "left", "right", "left_right", "right_left"] +) +def test_join_preserving_different_orderings( + left, right, how, join_expr, maintain_order +): + query = left.join(right, on=join_expr, how=how, maintain_order=maintain_order) + assert_gpu_result_equal(query) @pytest.mark.parametrize(