Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change sample method to work for pandas #23

Merged
merged 5 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def __ge__(self, other: Any) -> Expr:
)

# --- unary ---
def __invert__(self) -> Expr:
return self.__class__(lambda plx: self._call(plx).__invert__())

def mean(self) -> Expr:
return self.__class__(lambda plx: self._call(plx).mean())

Expand Down
2 changes: 2 additions & 0 deletions narwhals/pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from narwhals.pandas_like.utils import evaluate_into_exprs
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import maybe_reset_indices
from narwhals.pandas_like.utils import translate_dtype
from narwhals.pandas_like.utils import validate_dataframe_comparand
from narwhals.utils import flatten_str
Expand Down Expand Up @@ -86,6 +87,7 @@ def select(
**named_exprs: IntoPandasExpr,
) -> Self:
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
new_series = maybe_reset_indices(new_series)
df = horizontal_concat(
[series._series for series in new_series],
implementation=self._implementation,
Expand Down
12 changes: 10 additions & 2 deletions narwhals/pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,16 @@ def n_unique(self) -> Self:
def unique(self) -> Self:
return register_expression_call(self, "unique")

def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Self:
return register_expression_call(self, "sample", n, fraction, with_replacement)
def sample(
self,
n: int | None = None,
fraction: float | None = None,
*,
with_replacement: bool = False,
) -> Self:
return register_expression_call(
self, "sample", n, fraction=fraction, with_replacement=with_replacement
)

def alias(self, name: str) -> Self:
# Define this one manually, so that we can
Expand Down
12 changes: 8 additions & 4 deletions narwhals/pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,15 @@ def zip_with(self, mask: PandasSeries, other: PandasSeries) -> PandasSeries:
ser = self._series
return self._from_series(ser.where(mask, other))

def sample(self, n: int, fraction: float, *, with_replacement: bool) -> PandasSeries:
def sample(
self,
n: int | None = None,
fraction: float | None = None,
*,
with_replacement: bool = False,
) -> PandasSeries:
ser = self._series
return self._from_series(
ser.sample(n=n, frac=fraction, with_replacement=with_replacement)
)
return self._from_series(ser.sample(n=n, frac=fraction, replace=with_replacement))

def unique(self) -> PandasSeries:
if self._implementation != "pandas":
Expand Down
12 changes: 12 additions & 0 deletions narwhals/pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,15 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
return "bool"
msg = f"Unknown dtype: {dtype}"
raise TypeError(msg)


def maybe_reset_indices(series: list[PandasSeries]) -> list[PandasSeries]:
idx = series[0]._series.index
found_non_matching_index = False
for s in series[1:]:
if s._series.index is not idx and not (s._series.index == idx).all():
found_non_matching_index = True
break
if found_non_matching_index:
return [s._from_series(s._series.reset_index(drop=True)) for s in series]
return series
10 changes: 8 additions & 2 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ def zip_with(self, mask: Self, other: Self) -> Self:
self._series.zip_with(self._extract_native(mask), self._extract_native(other))
)

def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Self:
def sample(
self,
n: int | None = None,
fraction: float | None = None,
*,
with_replacement: bool = False,
) -> Self:
return self._from_series(
self._series.sample(n, fraction=fraction, with_replacement=with_replacement)
self._series.sample(n=n, fraction=fraction, with_replacement=with_replacement)
)

def to_numpy(self) -> Any:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_pandas_na = pd.DataFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]})
df_lazy_na = pl.LazyFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]})

if os.environ.get("CI", None):
import modin.pandas as mpd
Expand Down Expand Up @@ -321,3 +323,21 @@ def test_expr_min_max(df_raw: Any) -> None:
expected_max = {"a": [3], "b": [6], "z": [9]}
compare_dicts(result_min, expected_min)
compare_dicts(result_max, expected_max)


@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy])
def test_expr_sample(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
result_shape = nw.to_native(df.select(nw.col("a", "b").sample(n=2)).collect()).shape
expected = (2, 2)
assert result_shape == expected


@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na])
def test_expr_na(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
result_nna = nw.to_native(
df.filter((~nw.col("a").is_null()) & (~nw.col("z").is_null()))
)
expected = {"a": [2], "b": [6], "z": [9]}
compare_dicts(result_nna, expected)
Loading