Skip to content

Commit

Permalink
feat: support Date dtype in Narwhals (#341)
Browse files Browse the repository at this point in the history
* feat: support Date dtype in Narwhals

* skip on old pandas versions
  • Loading branch information
MarcoGorelli authored Jun 27, 2024
1 parent f77cd21 commit 351681e
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 0 deletions.
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import Boolean
from narwhals.dtypes import Categorical
from narwhals.dtypes import Date
from narwhals.dtypes import Datetime
from narwhals.dtypes import Float32
from narwhals.dtypes import Float64
Expand Down Expand Up @@ -75,6 +76,7 @@
"Categorical",
"String",
"Datetime",
"Date",
"narwhalify",
"narwhalify_method",
"show_versions",
Expand Down
1 change: 1 addition & 0 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class PandasNamespace:
Categorical = dtypes.Categorical
String = dtypes.String
Datetime = dtypes.Datetime
Date = dtypes.Date

@property
def selectors(self) -> PandasSelectorNamespace:
Expand Down
7 changes: 7 additions & 0 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ def translate_dtype(dtype: Any) -> DType:
# pyarrow-backed datetime
# todo: different time units and time zones
return dtypes.Datetime()
if str(dtype) == "date32[day][pyarrow]":
return dtypes.Date()
if dtype == "object":
return dtypes.String()
msg = f"Unknown dtype: {dtype}" # pragma: no cover
Expand Down Expand Up @@ -532,6 +534,11 @@ def reverse_translate_dtype( # noqa: PLR0915
if dtype_backend == "pyarrow-nullable":
return "timestamp[ns][pyarrow]"
return "datetime64[ns]"
if isinstance_or_issubclass(dtype, dtypes.Date):
if dtype_backend == "pyarrow-nullable":
return "date32[pyarrow]"
msg = "Date dtype only supported for pyarrow-backed data types in pandas"
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
4 changes: 4 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def translate_dtype(plx: Any, dtype: DType) -> Any:
return plx.Categorical
if dtype == Datetime:
return plx.Datetime
if dtype == Date:
return plx.Date
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down Expand Up @@ -147,5 +149,7 @@ def to_narwhals_dtype(dtype: Any, *, is_polars: bool) -> DType:
return Categorical()
if dtype == pl.Datetime:
return Datetime()
if dtype == pl.Date:
return Date()
msg = f"Unexpected dtype, got: {type(dtype)}" # pragma: no cover
raise AssertionError(msg)
68 changes: 68 additions & 0 deletions tests/series/cast_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from datetime import date
from datetime import datetime

import pandas as pd
import polars as pl
import pytest
from polars.testing import assert_frame_equal

import narwhals as nw
from narwhals.utils import parse_version


def test_cast_253() -> None:
Expand All @@ -16,3 +22,65 @@ def test_cast_253() -> None:
nw.col("a").cast(nw.String) + "hi"
)["a"][0]
assert result == "1hi"


def test_cast_date_datetime_polars() -> None:
# polars: date to datetime
dfpl = pl.DataFrame({"a": [date(2020, 1, 1), date(2020, 1, 2)]})
df = nw.from_native(dfpl)
df = df.select(nw.col("a").cast(nw.Datetime))
result = nw.to_native(df)
expected = pl.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]})
assert_frame_equal(result, expected)

# polars: datetime to date
dfpl = pl.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]})
df = nw.from_native(dfpl)
df = df.select(nw.col("a").cast(nw.Date))
result = nw.to_native(df)
expected = pl.DataFrame({"a": [date(2020, 1, 1), date(2020, 1, 2)]})
assert_frame_equal(result, expected)
assert df.schema == {"a": nw.Date}


@pytest.mark.skipif(
parse_version(pd.__version__) < parse_version("2.0.0"),
reason="pyarrow dtype not available",
)
def test_cast_date_datetime_pandas() -> None:
# pandas: pyarrow date to datetime
dfpd = pd.DataFrame({"a": [date(2020, 1, 1), date(2020, 1, 2)]}).astype(
{"a": "date32[pyarrow]"}
)
df = nw.from_native(dfpd)
df = df.select(nw.col("a").cast(nw.Datetime))
result = nw.to_native(df)
expected = pd.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}).astype(
{"a": "timestamp[ns][pyarrow]"}
)
pd.testing.assert_frame_equal(result, expected)

# pandas: pyarrow datetime to date
dfpd = pd.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]}).astype(
{"a": "timestamp[ns][pyarrow]"}
)
df = nw.from_native(dfpd)
df = df.select(nw.col("a").cast(nw.Date))
result = nw.to_native(df)
expected = pd.DataFrame({"a": [date(2020, 1, 1), date(2020, 1, 2)]}).astype(
{"a": "date32[pyarrow]"}
)
pd.testing.assert_frame_equal(result, expected)
assert df.schema == {"a": nw.Date}


@pytest.mark.skipif(
parse_version(pd.__version__) < parse_version("2.0.0"),
reason="pyarrow dtype not available",
)
def test_cast_date_datetime_invalid() -> None:
# pandas: pyarrow datetime to date
dfpd = pd.DataFrame({"a": [datetime(2020, 1, 1), datetime(2020, 1, 2)]})
df = nw.from_native(dfpd)
with pytest.raises(NotImplementedError, match="pyarrow"):
df.select(nw.col("a").cast(nw.Date))

0 comments on commit 351681e

Please sign in to comment.