Skip to content

Commit

Permalink
test1 working for polars!
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 11, 2024
1 parent dd1f70e commit 2bf5c2b
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 17 deletions.
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from narwhals.containers import is_series
from narwhals.dataframe import NarwhalsFrame
from narwhals.expression import col
from narwhals.expression import len
from narwhals.translate import get_namespace
from narwhals.translate import to_native
from narwhals.translate import translate_any
Expand All @@ -25,5 +26,6 @@
"get_namespace",
"to_native",
"col",
"len",
"NarwhalsFrame",
]
35 changes: 34 additions & 1 deletion narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,40 @@ def group_by(self, *keys: str | Iterable[str]) -> GroupBy:

return NarwhalsGroupBy(
self,
keys,
*keys,
is_eager=self._is_eager,
is_lazy=self._is_lazy,
)

def sort(
self,
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
) -> Self:
if self._implementation == "polars":
return self._from_dataframe(
self._dataframe.sort(by, *more_by, descending=descending)
)

def collect(self) -> Self:
if not self._is_lazy:
raise RuntimeError(
"DataFrame.collect can only be called if frame was instantiated with `is_lazy=True`"
)
if self._implementation == "polars":
import polars as pl

assert isinstance(self._dataframe, pl.LazyFrame)
return self.__class__(self._dataframe.collect(), is_eager=True, is_lazy=False)

def to_dict(self, *, as_series: bool = True) -> dict[str, Any]:
if not self._is_eager:
raise RuntimeError(
"DataFrame.to_dict can only be called if frame was instantiated with `is_eager=True`"
)
if self._implementation == "polars":
import polars as pl

assert isinstance(self._dataframe, pl.DataFrame)
return self._dataframe.to_dict(as_series=as_series)
10 changes: 10 additions & 0 deletions narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Any

from narwhals.translate import get_polars


def extract_native(expr, other: Any) -> Any:
if isinstance(other, NarwhalsExpr):
Expand Down Expand Up @@ -130,3 +132,11 @@ def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Expr:

def col(col_name: str):
return NarwhalsExpr(lambda expr: expr(col_name))


def len():
def func(expr):
if (pl := get_polars()) is not None and issubclass(expr, pl.col):
return pl.len()

return NarwhalsExpr(func)
30 changes: 15 additions & 15 deletions tests/tpch_q1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def test_q1(df_raw: Any) -> None:
.agg(
[
nw.col("l_quantity").sum().alias("sum_qty"),
# nw.col("l_extendedprice").sum().alias("sum_base_price"),
# (nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
# .sum()
# .alias("sum_disc_price"),
# (
# nw.col("l_extendedprice")
# * (1.0 - nw.col("l_discount"))
# * (1.0 + nw.col("l_tax"))
# )
# .sum()
# .alias("sum_charge"),
# nw.col("l_quantity").mean().alias("avg_qty"),
# nw.col("l_extendedprice").mean().alias("avg_price"),
# nw.col("l_discount").mean().alias("avg_disc"),
# nw.len().alias("count_order"),
nw.col("l_extendedprice").sum().alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
nw.col("l_quantity").mean().alias("avg_qty"),
nw.col("l_extendedprice").mean().alias("avg_price"),
nw.col("l_discount").mean().alias("avg_disc"),
nw.len().alias("count_order"),
],
)
.sort(["l_returnflag", "l_linestatus"])
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def compare_dicts(result: dict[str, Any], expected: dict[str, Any]) -> None:
for key in result:
for key in expected:
for lhs, rhs in zip(result[key], expected[key]):
if isinstance(lhs, float):
assert abs(lhs - rhs) < 1e-6
Expand Down

0 comments on commit 2bf5c2b

Please sign in to comment.