Skip to content

Commit

Permalink
unique
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 20, 2024
1 parent 5bad661 commit a21638c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
3 changes: 3 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def head(self, n: int) -> Self:
def drop(self, *columns: str | Iterable[str]) -> Self:
return self._from_dataframe(self._dataframe.drop(*columns))

def unique(self, subset: str | list[str]) -> Self:
return self._from_dataframe(self._dataframe.unique(subset=subset))

def filter(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> Self:
predicates, _ = self._flatten_and_extract(*predicates)
return self._from_dataframe(
Expand Down
3 changes: 2 additions & 1 deletion narwhals/pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def join(
def head(self, n: int) -> Self:
return self._from_dataframe(self._dataframe.head(n))

def unique(self, subset: list[str]) -> Self:
def unique(self, subset: str | list[str]) -> Self:
subset = flatten_str(subset)
return self._from_dataframe(self._dataframe.drop_duplicates(subset=subset))

# --- lazy-only ---
Expand Down
8 changes: 8 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,11 @@ def test_head(df_raw: Any) -> None:
result = nw.to_native(df.head(2))
expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]}
compare_dicts(result, expected)


@pytest.mark.parametrize("df_raw", [df_pandas, df_lazy])
def test_unique(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
result = nw.to_native(df.unique("b").sort("b"))
expected = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}
compare_dicts(result, expected)
23 changes: 10 additions & 13 deletions tpch/q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import polars

from narwhals import translate_frame
import narwhals as nw

Q_NUM = 4

Expand All @@ -16,28 +16,25 @@ def q4(
var_1 = datetime(1993, 7, 1)
var_2 = datetime(1993, 10, 1)

line_item_ds, pl = translate_frame(lineitem_ds_raw, is_lazy=True)
orders_ds, _ = translate_frame(orders_ds_raw, is_lazy=True)
line_item_ds = nw.LazyFrame(lineitem_ds_raw)
orders_ds = nw.LazyFrame(orders_ds_raw)

result = (
line_item_ds.join(orders_ds, left_on="l_orderkey", right_on="o_orderkey")
.filter(pl.col("o_orderdate").is_between(var_1, var_2, closed="left"))
.filter(pl.col("l_commitdate") < pl.col("l_receiptdate"))
.filter(nw.col("o_orderdate").is_between(var_1, var_2, closed="left"))
.filter(nw.col("l_commitdate") < nw.col("l_receiptdate"))
.unique(subset=["o_orderpriority", "l_orderkey"])
.group_by("o_orderpriority")
.agg(pl.len().alias("order_count"))
.agg(nw.len().alias("order_count"))
.sort(by="o_orderpriority")
.with_columns(
pl.col("order_count")
# .cast(pl.datatypes.Int64)
)
.with_columns(nw.col("order_count").cast(nw.Int64))
)

return result.collect().to_native()
return nw.to_native(result.collect())


lineitem_ds = polars.scan_parquet("../tpch-data/lineitem.parquet")
orders_ds = polars.scan_parquet("../tpch-data/orders.parquet")
lineitem_ds = polars.scan_parquet("../tpch-data/s1/lineitem.parquet")
orders_ds = polars.scan_parquet("../tpch-data/s1/orders.parquet")
print(
q4(
lineitem_ds.collect().to_pandas(),
Expand Down

0 comments on commit a21638c

Please sign in to comment.