Skip to content

Commit

Permalink
fix: is_duplicated was returning wrong-length result for PyArrow an…
Browse files Browse the repository at this point in the history
…d Dask (#1679)



---------

Co-authored-by: FBruzzesi <[email protected]>
Co-authored-by: Francesco Bruzzesi <[email protected]>
  • Loading branch information
3 people authored Dec 30, 2024
1 parent 0a67642 commit e63923a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 16 deletions.
20 changes: 14 additions & 6 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,15 @@ def is_duplicated(self: Self) -> ArrowSeries:

from narwhals._arrow.series import ArrowSeries

df = self._native_frame

columns = self.columns
col_token = generate_temporary_column_name(n_bytes=8, columns=columns)
index_token = generate_temporary_column_name(n_bytes=8, columns=columns)
col_token = generate_temporary_column_name(
n_bytes=8,
columns=[*columns, index_token],
)

df = self.with_row_index(index_token)._native_frame

row_count = (
df.append_column(col_token, pa.repeat(pa.scalar(1), len(self)))
.group_by(columns)
Expand All @@ -616,17 +621,20 @@ def is_duplicated(self: Self) -> ArrowSeries:
row_count,
keys=columns,
right_keys=columns,
join_type="inner",
join_type="left outer",
use_threads=False,
).column(f"{col_token}_sum"),
)
.sort_by(index_token)
.column(f"{col_token}_sum"),
1,
)
return ArrowSeries(
res = ArrowSeries(
is_duplicated,
name="",
backend_version=self._backend_version,
version=self._version,
)
return res.fill_null(res.null_count() > 1, strategy=None, limit=None)

def is_unique(self: Self) -> ArrowSeries:
import pyarrow.compute as pc
Expand Down
10 changes: 8 additions & 2 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,10 @@ def is_duplicated(self: Self) -> Self:
def func(_input: dask_expr.Series) -> dask_expr.Series:
_name = _input.name
return (
_input.to_frame().groupby(_name).transform("size", meta=(_name, int)) > 1
_input.to_frame()
.groupby(_name, dropna=False)
.transform("size", meta=(_name, int))
> 1
)

return self._from_call(
Expand All @@ -781,7 +784,10 @@ def is_unique(self: Self) -> Self:
def func(_input: dask_expr.Series) -> dask_expr.Series:
_name = _input.name
return (
_input.to_frame().groupby(_name).transform("size", meta=(_name, int)) == 1
_input.to_frame()
.groupby(_name, dropna=False)
.transform("size", meta=(_name, int))
== 1
)

return self._from_call(
Expand Down
12 changes: 10 additions & 2 deletions tests/expr_and_series/is_duplicated_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,25 @@
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {"a": [1, 1, 2], "b": [1, 2, 3], "index": [0, 1, 2]}


def test_is_duplicated_expr(constructor: Constructor) -> None:
data = {"a": [1, 1, 2], "b": [1, 2, 3], "index": [0, 1, 2]}
df = nw.from_native(constructor(data))
result = df.select(nw.col("a", "b").is_duplicated(), "index").sort("index")
expected = {"a": [True, True, False], "b": [False, False, False], "index": [0, 1, 2]}
assert_equal_data(result, expected)


def test_is_duplicated_w_nulls_expr(constructor: Constructor) -> None:
data = {"a": [1, 1, None], "b": [1, None, None], "index": [0, 1, 2]}
df = nw.from_native(constructor(data))
result = df.select(nw.col("a", "b").is_duplicated(), "index").sort("index")
expected = {"a": [True, True, False], "b": [False, True, True], "index": [0, 1, 2]}
assert_equal_data(result, expected)


def test_is_duplicated_series(constructor_eager: ConstructorEager) -> None:
data = {"a": [1, 1, 2], "b": [1, 2, 3], "index": [0, 1, 2]}
series = nw.from_native(constructor_eager(data), eager_only=True)["a"]
result = series.is_duplicated()
expected = {"a": [True, True, False]}
Expand Down
32 changes: 26 additions & 6 deletions tests/expr_and_series/is_unique_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {
"a": [1, 1, 2],
"b": [1, 2, 3],
"index": [0, 1, 2],
}


def test_is_unique_expr(constructor: Constructor) -> None:
data = {
"a": [1, 1, 2],
"b": [1, 2, 3],
"index": [0, 1, 2],
}
df = nw.from_native(constructor(data))
result = df.select(nw.col("a", "b").is_unique(), "index").sort("index")
expected = {
Expand All @@ -23,7 +22,28 @@ def test_is_unique_expr(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_is_unique_w_nulls_expr(constructor: Constructor) -> None:
data = {
"a": [None, 1, 2],
"b": [None, 2, None],
"index": [0, 1, 2],
}
df = nw.from_native(constructor(data))
result = df.select(nw.col("a", "b").is_unique(), "index").sort("index")
expected = {
"a": [True, True, True],
"b": [False, True, False],
"index": [0, 1, 2],
}
assert_equal_data(result, expected)


def test_is_unique_series(constructor_eager: ConstructorEager) -> None:
data = {
"a": [1, 1, 2],
"b": [1, 2, 3],
"index": [0, 1, 2],
}
series = nw.from_native(constructor_eager(data), eager_only=True)["a"]
result = series.is_unique()
expected = {
Expand Down

0 comments on commit e63923a

Please sign in to comment.