From f5ec1761cda7ffce9bff8d796973d279e3ab5dfc Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 14 Mar 2024 17:36:40 +0000 Subject: [PATCH] this it? --- demo.py | 21 ++++--- narwhals/__init__.py | 4 +- narwhals/dataframe.py | 108 ++++++++++++++-------------------- narwhals/expression.py | 32 +++++----- narwhals/pandas_like/utils.py | 2 +- narwhals/polars.py | 2 +- narwhals/translate.py | 6 +- tests/tpch_q1_test.py | 4 +- 8 files changed, 77 insertions(+), 102 deletions(-) diff --git a/demo.py b/demo.py index 592b3b5486..6c8b50fafe 100644 --- a/demo.py +++ b/demo.py @@ -6,25 +6,24 @@ def func(df_raw: Any) -> Any: - df = nw.NarwhalsFrame(df_raw) - - print(df) + df = nw.DataFrame(df_raw) res = df.with_columns( d=nw.col("a") + 1, e=nw.col("a") + nw.col("b"), ) - - res = res.group_by("a").agg(nw.col("b").sum()) - print(res) - + res = res.group_by(["a"]).agg( + nw.col("b").sum(), + d=nw.col("c").sum(), + # e=nw.len(), + ) return nw.to_native(res) import pandas as pd -# df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) -# print(func(df)) -df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) +df = pd.DataFrame({"a": [1, 1, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) +print(func(df)) +df = pl.DataFrame({"a": [1, 1, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) print(func(df)) -df = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) +df = pl.LazyFrame({"a": [1, 1, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) print(func(df).collect()) diff --git a/narwhals/__init__.py b/narwhals/__init__.py index b9be564f84..18021dcc41 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -3,7 +3,7 @@ from narwhals.containers import is_pandas from narwhals.containers import is_polars from narwhals.containers import is_series -from narwhals.dataframe import NarwhalsFrame +from narwhals.dataframe import DataFrame from narwhals.expression import col from narwhals.expression import len from narwhals.translate import get_namespace @@ -27,5 +27,5 @@ "to_native", "col", "len", - "NarwhalsFrame", + "DataFrame", ] diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index de199487c0..dd3949071c 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -1,12 +1,14 @@ from __future__ import annotations -from narwhals.pandas_like.utils import evaluate_into_exprs +from narwhals.pandas_like.dataframe import PandasDataFrame +from narwhals.polars import PolarsDataFrame from narwhals.translate import get_pandas from narwhals.translate import get_polars def extract_native(obj: Any, implementation) -> Any: from narwhals.expression import NarwhalsExpr + from narwhals.series import Series # if isinstance(obj, NarwhalsExpr): # return obj._call(pl.col) @@ -17,14 +19,16 @@ def extract_native(obj: Any, implementation) -> Any: return obj._call(pl.col) # if isinstance(obj, DType): # return obj._dtype - if isinstance(obj, NarwhalsFrame): + if isinstance(obj, DataFrame): return obj._dataframe + if isinstance(obj, Series): + return obj._series # if isinstance(obj, PolarsSeries): # return obj._series return obj -class NarwhalsFrame: +class DataFrame: def __init__( self, df, *, is_eager=False, is_lazy=False, implementation: str | None = None ): @@ -35,24 +39,14 @@ def __init__( self._implementation = implementation return if (pl := get_polars()) is not None: - if isinstance(df, pl.DataFrame): - if is_lazy: - raise ValueError( - "can't instantiate with `is_lazy` if you pass a polars DataFrame" - ) - self._dataframe = df - self._implementation = "polars" - return - elif isinstance(df, pl.LazyFrame): - if is_eager: - raise ValueError( - "can't instantiate with `is_eager` if you pass a polars LazyFrame" - ) - self._dataframe = df + if isinstance(df, (pl.DataFrame, pl.LazyFrame)): + self._dataframe = PolarsDataFrame(df, is_eager=is_eager, is_lazy=is_lazy) self._implementation = "polars" return if (pd := get_pandas()) is not None and isinstance(df, pd.DataFrame): - self._dataframe = df + self._dataframe = PandasDataFrame( + df, is_eager=is_eager, is_lazy=is_lazy, implementation="pandas" + ) self._implementation = "pandas" return raise TypeError( @@ -68,43 +62,22 @@ def _from_dataframe(self, df: Any) -> Self: implementation=self._implementation, ) - def _extract_native(self, obj): - return extract_native(obj, implementation=self._implementation) - def with_columns( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr ) -> Self: - if self._implementation == "polars": - return self._from_dataframe( - self._dataframe.with_columns( - *[self._extract_native(v) for v in exprs], - **{ - key: self._extract_native(value) - for key, value in named_exprs.items() - }, - ) - ) - elif self._implementation == "pandas": - new_series = evaluate_into_exprs(self, *exprs, **named_exprs) - df = self._dataframe.assign( - **{series.name: series._series for series in new_series} - ) - return self._from_dataframe(df) + return self._from_dataframe( + self._dataframe.with_columns(*exprs, **named_exprs), + ) def filter(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> Self: return self._from_dataframe( - self._dataframe.filter(*[self._extract_native(v) for v in predicates]) + self._dataframe.filter(*predicates), ) def group_by(self, *keys: str | Iterable[str]) -> GroupBy: from narwhals.group_by import NarwhalsGroupBy - return NarwhalsGroupBy( - self, - *keys, - is_eager=self._is_eager, - is_lazy=self._is_lazy, - ) + return NarwhalsGroupBy(self, *keys) def sort( self, @@ -112,29 +85,34 @@ def sort( *more_by: str, descending: bool | Sequence[bool] = False, ) -> Self: - if self._implementation == "polars": - return self._from_dataframe( - self._dataframe.sort(by, *more_by, descending=descending) - ) + return self._from_dataframe( + self._dataframe.sort(by, *more_by, descending=descending) + ) def collect(self) -> Self: - if not self._is_lazy: - raise RuntimeError( - "DataFrame.collect can only be called if frame was instantiated with `is_lazy=True`" - ) - if self._implementation == "polars": - import polars as pl - - assert isinstance(self._dataframe, pl.LazyFrame) - return self.__class__(self._dataframe.collect(), is_eager=True, is_lazy=False) + return self.__class__( + self._dataframe.collect(), + is_eager=True, + is_lazy=False, + implementation=self._implementation, + ) def to_dict(self, *, as_series: bool = True) -> dict[str, Any]: - if not self._is_eager: - raise RuntimeError( - "DataFrame.to_dict can only be called if frame was instantiated with `is_eager=True`" - ) - if self._implementation == "polars": - import polars as pl + return self._dataframe.to_dict(as_series=as_series) - assert isinstance(self._dataframe, pl.DataFrame) - return self._dataframe.to_dict(as_series=as_series) + def join( + self, + other: Self, + *, + how: Literal[inner] = "inner", + left_on: str | list[str], + right_on: str | list[str], + ) -> Self: + return self._from_dataframe( + self._dataframe.join( + other._dataframe, + how=how, + left_on=left_on, + right_on=right_on, + ) + ) diff --git a/narwhals/expression.py b/narwhals/expression.py index f80044bd53..6156718c95 100644 --- a/narwhals/expression.py +++ b/narwhals/expression.py @@ -2,8 +2,6 @@ from typing import Any -from narwhals.translate import get_polars - def extract_native(expr, other: Any) -> Any: if isinstance(other, NarwhalsExpr): @@ -97,46 +95,46 @@ def sum(self) -> Expr: return self.__class__(lambda expr: self._call(expr).sum()) def min(self) -> Expr: - return self.__class__(self._expr.min()) + return self.__class__(lambda expr: self._call(expr).min()) def max(self) -> Expr: - return self.__class__(self._expr.max()) + return self.__class__(lambda expr: self._call(expr).max()) def n_unique(self) -> Expr: - return self.__class__(self._expr.n_unique()) + return self.__class__(lambda expr: self._call(expr).n_unique()) def unique(self) -> Expr: - return self.__class__(self._expr.unique()) + return self.__class__(lambda expr: self._call(expr).unique()) # --- transform --- def is_between( self, lower_bound: Any, upper_bound: Any, closed: str = "both" ) -> Expr: - return self.__class__(self._expr.is_between(lower_bound, upper_bound, closed)) # type: ignore[arg-type] + return self.__class__( + lambda expr: self._call(expr).is_between(lower_bound, upper_bound, closed) + ) # type: ignore[arg-type] def is_in(self, other: Any) -> Expr: - return self.__class__(self._expr.is_in(other)) + return self.__class__(lambda expr: self._call(expr).is_in(other)) def is_null(self) -> Expr: - return self.__class__(self._expr.is_null()) + return self.__class__(lambda expr: self._call(expr).is_null()) # --- partial reduction --- def drop_nulls(self) -> Expr: - return self.__class__(self._expr.drop_nulls()) + return self.__class__(lambda expr: self._call(expr).drop_nulls()) def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Expr: return self.__class__( - self._expr.sample(n, fraction=fraction, with_replacement=with_replacement) + lambda expr: self._call(expr).sample( + n, fraction=fraction, with_replacement=with_replacement + ) ) def col(col_name: str): - return NarwhalsExpr(lambda expr: expr(col_name)) + return NarwhalsExpr(lambda plx: plx.col(col_name)) def len(): - def func(expr): - if (pl := get_polars()) is not None and issubclass(expr, pl.col): - return pl.len() - - return NarwhalsExpr(func) + return NarwhalsExpr(lambda plx: plx.len()) diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index 2092270995..cd755a1ca3 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -105,7 +105,7 @@ def parse_into_expr(implementation: str, into_expr: IntoExpr) -> Expr: plx = Namespace(implementation=implementation) if isinstance(into_expr, NarwhalsExpr): - return into_expr._call(plx.col) + return into_expr._call(plx) if isinstance(into_expr, str): return plx.col(into_expr) if isinstance(into_expr, Expr): diff --git a/narwhals/polars.py b/narwhals/polars.py index 23d3f884b2..337fec776e 100644 --- a/narwhals/polars.py +++ b/narwhals/polars.py @@ -29,7 +29,7 @@ def extract_native(obj: Any) -> Any: from narwhals.expression import NarwhalsExpr if isinstance(obj, NarwhalsExpr): - return obj._call(pl.col) + return obj._call(pl) if isinstance(obj, Expr): return obj._expr if isinstance(obj, DType): diff --git a/narwhals/translate.py b/narwhals/translate.py index 5f6d1a4124..deb81508ec 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -175,14 +175,14 @@ def get_namespace(obj: Any) -> Namespace: def to_native(obj: Any) -> Any: - from narwhals.dataframe import NarwhalsFrame + from narwhals.dataframe import DataFrame from narwhals.pandas_like.dataframe import PandasDataFrame from narwhals.pandas_like.series import PandasSeries from narwhals.polars import PolarsDataFrame from narwhals.polars import PolarsSeries - if isinstance(obj, NarwhalsFrame): - return obj._dataframe + if isinstance(obj, DataFrame): + return obj._dataframe._dataframe if isinstance(obj, PandasDataFrame): return obj._dataframe if isinstance(obj, PandasSeries): diff --git a/tests/tpch_q1_test.py b/tests/tpch_q1_test.py index e3141a1861..bb16dd5693 100644 --- a/tests/tpch_q1_test.py +++ b/tests/tpch_q1_test.py @@ -13,13 +13,13 @@ @pytest.mark.parametrize( "df_raw", [ - # (polars.read_parquet("tests/data/lineitem.parquet").to_pandas()), + (polars.read_parquet("tests/data/lineitem.parquet").to_pandas()), polars.scan_parquet("tests/data/lineitem.parquet"), ], ) def test_q1(df_raw: Any) -> None: var_1 = datetime(1998, 9, 2) - df = nw.NarwhalsFrame(df_raw, is_lazy=True) + df = nw.DataFrame(df_raw, is_lazy=True) query_result = ( df.filter(nw.col("l_shipdate") <= var_1) .group_by(["l_returnflag", "l_linestatus"])