Skip to content

Commit

Permalink
feat: add cat.get_categories, make test numpy2.0 compatible (#329)
Browse files Browse the repository at this point in the history
* feat: add cat.get_categories

* docstring
  • Loading branch information
MarcoGorelli authored Jun 22, 2024
1 parent f4a9925 commit 714e877
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 19 deletions.
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ nav:
- API Reference:
- api-reference/dataframe.md
- api-reference/expressions.md
- api-reference/expressions_cat.md
- api-reference/expressions_dt.md
- api-reference/expressions_str.md
- api-reference/group_by.md
- api-reference/lazyframe.md
- api-reference/series.md
- api-reference/series_cat.md
- api-reference/series_dt.md
- api-reference/series_str.md
- api-reference/dependencies.md
Expand Down
16 changes: 16 additions & 0 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,22 @@ def str(self) -> PandasExprStringNamespace:
def dt(self) -> PandasExprDateTimeNamespace:
return PandasExprDateTimeNamespace(self)

@property
def cat(self) -> PandasExprCatNamespace:
return PandasExprCatNamespace(self)


class PandasExprCatNamespace:
def __init__(self, expr: PandasExpr) -> None:
self._expr = expr

def get_categories(self) -> PandasExpr:
return reuse_series_namespace_implementation(
self._expr,
"cat",
"get_categories",
)


class PandasExprStringNamespace:
def __init__(self, expr: PandasExpr) -> None:
Expand Down
13 changes: 13 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,19 @@ def str(self) -> PandasSeriesStringNamespace:
def dt(self) -> PandasSeriesDateTimeNamespace:
return PandasSeriesDateTimeNamespace(self)

@property
def cat(self) -> PandasSeriesCatNamespace:
return PandasSeriesCatNamespace(self)


class PandasSeriesCatNamespace:
def __init__(self, series: PandasSeries) -> None:
self._series = series

def get_categories(self) -> PandasSeries:
s = self._series._series
return self._series._from_series(s.__class__(s.cat.categories, name=s.name))


class PandasSeriesStringNamespace:
def __init__(self, series: PandasSeries) -> None:
Expand Down
8 changes: 4 additions & 4 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,11 +1641,11 @@ def item(self: Self, row: int | None = None, column: int | str | None = None) ->
We can then pass either pandas or Polars to `func`:
>>> func(df_pd, 1, 1), func(df_pl, 1, 1)
(5, 5)
>>> func(df_pd, 1, 1), func(df_pd, 2, "b") # doctest:+SKIP
(5, 6)
>>> func(df_pd, 2, "b"), func(df_pl, 2, "b")
(6, 6)
>>> func(df_pl, 1, 1), func(df_pl, 2, "b")
(5, 6)
"""
return self._dataframe.item(row=row, column=column)

Expand Down
51 changes: 51 additions & 0 deletions narwhals/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,57 @@ def str(self) -> ExprStringNamespace:
def dt(self) -> ExprDateTimeNamespace:
return ExprDateTimeNamespace(self)

@property
def cat(self) -> ExprCatNamespace:
return ExprCatNamespace(self)


class ExprCatNamespace:
def __init__(self, expr: Expr) -> None:
self._expr = expr

def get_categories(self) -> Expr:
"""
Get unique categories from column.
Examples:
Let's create some dataframes:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> data = {"fruits": ["apple", "mango", "mango"]}
>>> df_pd = pd.DataFrame(data, dtype="category")
>>> df_pl = pl.DataFrame(data, schema={"fruits": pl.Categorical})
We define a dataframe-agnostic function to get unique categories
from column 'fruits':
>>> @nw.narwhalify
... def func(df):
... return df.select(nw.col("fruits").cat.get_categories())
We can then pass either pandas or Polars to `func`:
>>> func(df_pd)
fruits
0 apple
1 mango
>>> func(df_pl)
shape: (2, 1)
┌────────┐
│ fruits │
│ --- │
│ str │
╞════════╡
│ apple │
│ mango │
└────────┘
"""
return self._expr.__class__(
lambda plx: self._expr._call(plx).cat.get_categories()
)


class ExprStringNamespace:
def __init__(self, expr: Expr) -> None:
Expand Down
68 changes: 57 additions & 11 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def mean(self) -> Any:
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
>>> func(s_pd) # doctest:+SKIP
2.0
>>> func(s_pl)
2.0
Expand Down Expand Up @@ -347,7 +347,7 @@ def any(self) -> Any:
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
>>> func(s_pd) # doctest:+SKIP
True
>>> func(s_pl)
True
Expand All @@ -374,7 +374,7 @@ def all(self) -> Any:
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
>>> func(s_pd) # doctest:+SKIP
False
>>> func(s_pl)
False
Expand Down Expand Up @@ -402,7 +402,7 @@ def min(self) -> Any:
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
>>> func(s_pd) # doctest:+SKIP
1
>>> func(s_pl)
1
Expand All @@ -429,7 +429,7 @@ def max(self) -> Any:
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
>>> func(s_pd) # doctest:+SKIP
3
>>> func(s_pl)
3
Expand All @@ -456,7 +456,7 @@ def sum(self) -> Any:
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
>>> func(s_pd) # doctest:+SKIP
6
>>> func(s_pl)
6
Expand Down Expand Up @@ -487,7 +487,7 @@ def std(self, *, ddof: int = 1) -> Any:
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
>>> func(s_pd) # doctest:+SKIP
1.0
>>> func(s_pl)
1.0
Expand Down Expand Up @@ -1362,7 +1362,7 @@ def null_count(self: Self) -> int:
... return series.null_count()
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
>>> func(s_pd) # doctest:+SKIP
1
>>> func(s_pl)
2
Expand Down Expand Up @@ -1567,7 +1567,7 @@ def quantile(
We can then pass either pandas or Polars to `func`:
>>> func(s_pd) # doctest: +NORMALIZE_WHITESPACE
>>> func(s_pd) # doctest: +SKIP
[5, 12, 24, 37, 44]
>>> func(s_pl) # doctest: +NORMALIZE_WHITESPACE
Expand Down Expand Up @@ -1644,10 +1644,10 @@ def item(self: Self, index: int | None = None) -> Any:
We can then pass either pandas or Polars to `func`:
>>> func(pl.Series("a", [1]), None), func(pd.Series([1]), None)
>>> func(pl.Series("a", [1]), None), func(pd.Series([1]), None) # doctest:+SKIP
(1, 1)
>>> func(pl.Series("a", [9, 8, 7]), -1), func(pd.Series([9, 8, 7]), -2)
>>> func(pl.Series("a", [9, 8, 7]), -1), func(pl.Series([9, 8, 7]), -2)
(7, 8)
"""
return self._series.item(index=index)
Expand Down Expand Up @@ -1791,6 +1791,52 @@ def str(self) -> SeriesStringNamespace:
def dt(self) -> SeriesDateTimeNamespace:
return SeriesDateTimeNamespace(self)

@property
def cat(self) -> SeriesCatNamespace:
return SeriesCatNamespace(self)


class SeriesCatNamespace:
def __init__(self, series: Series) -> None:
self._series = series

def get_categories(self) -> Series:
"""
Get unique categories from column.
Examples:
Let's create some series:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> data = ["apple", "mango", "mango"]
>>> s_pd = pd.Series(data, dtype="category")
>>> s_pl = pl.Series(data, dtype=pl.Categorical)
We define a dataframe-agnostic function to get unique categories
from column 'fruits':
>>> @nw.narwhalify(series_only=True)
... def func(s):
... return s.cat.get_categories()
We can then pass either pandas or Polars to `func`:
>>> func(s_pd)
0 apple
1 mango
dtype: object
>>> func(s_pl) # doctest: +NORMALIZE_WHITESPACE
shape: (2,)
Series: '' [str]
[
"apple"
"mango"
]
"""
return self._series.__class__(self._series._series.cat.get_categories())


class SeriesStringNamespace:
def __init__(self, series: Series) -> None:
Expand Down
7 changes: 6 additions & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ def _is_iterable(arg: Any | Iterable[Any]) -> bool:
if (pl := get_polars()) is not None and isinstance(
arg, (pl.Series, pl.Expr, pl.DataFrame, pl.LazyFrame)
):
msg = f"Expected Narwhals class or scalar, got: {type(arg)}. Perhaps you forgot a `nw.from_native` somewhere?"
msg = (
f"Expected Narwhals class or scalar, got: {type(arg)}.\n\n"
"Hint: Perhaps you\n"
"- forgot a `nw.from_native` somewhere?\n"
"- used `pl.col` instead of `nw.col`?"
)
raise TypeError(msg)

return isinstance(arg, Iterable) and not isinstance(arg, (str, bytes, Series))
Expand Down
Empty file added tests/cat/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions tests/cat/get_categories_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw
from tests.utils import compare_dicts

data = {"a": ["one", "two", "two"]}


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
def test_get_categories(constructor: Any) -> None:
df = nw.from_native(constructor(data), eager_only=True)
df = df.select(nw.col("a").cast(nw.Categorical))
result = df.select(nw.col("a").cat.get_categories())
expected = {"a": ["one", "two"]}
compare_dicts(result, expected)
result = df.select(df["a"].cat.get_categories())
compare_dicts(result, expected)
2 changes: 1 addition & 1 deletion tests/test_invalid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_native_vs_non_native() -> None:
nw.from_native(df).filter(s > 1)
s = pl.Series([1, 2, 3])
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
with pytest.raises(TypeError, match="Perhaps you forgot"):
with pytest.raises(TypeError, match="Perhaps you\n- forgot"):
nw.from_native(df).filter(s > 1)


Expand Down
12 changes: 10 additions & 2 deletions utils/check_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@
for i in content.splitlines()
if i.startswith(" - ")
]
if missing := set(top_level_functions).difference(documented).difference({"dt", "str"}):
if (
missing := set(top_level_functions)
.difference(documented)
.difference({"dt", "str", "cat"})
):
print("Series: not documented") # noqa: T201
print(missing) # noqa: T201
ret = 1
Expand All @@ -106,7 +110,11 @@
for i in content.splitlines()
if i.startswith(" - ")
]
if missing := set(top_level_functions).difference(documented).difference({"str", "dt"}):
if (
missing := set(top_level_functions)
.difference(documented)
.difference({"cat", "str", "dt"})
):
print("Expr: not documented") # noqa: T201
print(missing) # noqa: T201
ret = 1
Expand Down

0 comments on commit 714e877

Please sign in to comment.