Skip to content

Commit

Permalink
Merge pull request #23 from raisadz/increase-coverage
Browse files Browse the repository at this point in the history
change sample method to work for pandas
  • Loading branch information
MarcoGorelli authored Mar 20, 2024
2 parents dd38333 + 1a9cb03 commit 47eea5a
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 8 deletions.
3 changes: 3 additions & 0 deletions narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def __ge__(self, other: Any) -> Expr:
)

# --- unary ---
def __invert__(self) -> Expr:
return self.__class__(lambda plx: self._call(plx).__invert__())

def mean(self) -> Expr:
return self.__class__(lambda plx: self._call(plx).mean())

Expand Down
2 changes: 2 additions & 0 deletions narwhals/pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from narwhals.pandas_like.utils import evaluate_into_exprs
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import maybe_reset_indices
from narwhals.pandas_like.utils import translate_dtype
from narwhals.pandas_like.utils import validate_dataframe_comparand
from narwhals.utils import flatten_str
Expand Down Expand Up @@ -86,6 +87,7 @@ def select(
**named_exprs: IntoPandasExpr,
) -> Self:
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
new_series = maybe_reset_indices(new_series)
df = horizontal_concat(
[series._series for series in new_series],
implementation=self._implementation,
Expand Down
12 changes: 10 additions & 2 deletions narwhals/pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,16 @@ def n_unique(self) -> Self:
def unique(self) -> Self:
return register_expression_call(self, "unique")

def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Self:
return register_expression_call(self, "sample", n, fraction, with_replacement)
def sample(
self,
n: int | None = None,
fraction: float | None = None,
*,
with_replacement: bool = False,
) -> Self:
return register_expression_call(
self, "sample", n, fraction=fraction, with_replacement=with_replacement
)

def alias(self, name: str) -> Self:
# Define this one manually, so that we can
Expand Down
12 changes: 8 additions & 4 deletions narwhals/pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,15 @@ def zip_with(self, mask: PandasSeries, other: PandasSeries) -> PandasSeries:
ser = self._series
return self._from_series(ser.where(mask, other))

def sample(self, n: int, fraction: float, *, with_replacement: bool) -> PandasSeries:
def sample(
self,
n: int | None = None,
fraction: float | None = None,
*,
with_replacement: bool = False,
) -> PandasSeries:
ser = self._series
return self._from_series(
ser.sample(n=n, frac=fraction, with_replacement=with_replacement)
)
return self._from_series(ser.sample(n=n, frac=fraction, replace=with_replacement))

def unique(self) -> PandasSeries:
if self._implementation != "pandas":
Expand Down
12 changes: 12 additions & 0 deletions narwhals/pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,15 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
return "bool"
msg = f"Unknown dtype: {dtype}"
raise TypeError(msg)


def maybe_reset_indices(series: list[PandasSeries]) -> list[PandasSeries]:
idx = series[0]._series.index
found_non_matching_index = False
for s in series[1:]:
if s._series.index is not idx and not (s._series.index == idx).all():
found_non_matching_index = True
break
if found_non_matching_index:
return [s._from_series(s._series.reset_index(drop=True)) for s in series]
return series
10 changes: 8 additions & 2 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ def zip_with(self, mask: Self, other: Self) -> Self:
self._series.zip_with(self._extract_native(mask), self._extract_native(other))
)

def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Self:
def sample(
self,
n: int | None = None,
fraction: float | None = None,
*,
with_replacement: bool = False,
) -> Self:
return self._from_series(
self._series.sample(n, fraction=fraction, with_replacement=with_replacement)
self._series.sample(n=n, fraction=fraction, with_replacement=with_replacement)
)

def to_numpy(self) -> Any:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
df_pandas_na = pd.DataFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]})
df_lazy_na = pl.LazyFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]})

if os.environ.get("CI", None):
import modin.pandas as mpd
Expand Down Expand Up @@ -321,3 +323,21 @@ def test_expr_min_max(df_raw: Any) -> None:
expected_max = {"a": [3], "b": [6], "z": [9]}
compare_dicts(result_min, expected_min)
compare_dicts(result_max, expected_max)


@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy])
def test_expr_sample(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
result_shape = nw.to_native(df.select(nw.col("a", "b").sample(n=2)).collect()).shape
expected = (2, 2)
assert result_shape == expected


@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na])
def test_expr_na(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
result_nna = nw.to_native(
df.filter((~nw.col("a").is_null()) & (~nw.col("z").is_null()))
)
expected = {"a": [2], "b": [6], "z": [9]}
compare_dicts(result_nna, expected)

0 comments on commit 47eea5a

Please sign in to comment.