From ad7612929a9f81b48f112a1e0ac3444ce28f462b Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 3 Jun 2024 21:48:20 +0100 Subject: [PATCH] Preserve dtype backend (#248) --- narwhals/_pandas_like/series.py | 2 +- narwhals/_pandas_like/utils.py | 127 ++++++++++++++++++++++++++------ narwhals/dependencies.py | 2 +- tests/test_series.py | 48 ++++++++++++ 4 files changed, 154 insertions(+), 25 deletions(-) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 140eaf110..ba8cfff69 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -134,7 +134,7 @@ def cast( dtype: Any, ) -> Self: ser = self._series - dtype = reverse_translate_dtype(dtype) + dtype = reverse_translate_dtype(dtype, ser.dtype, self._implementation) return self._from_series(ser.astype(dtype)) def item(self) -> Any: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index df5e423e6..74dd96823 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -10,7 +10,6 @@ from narwhals.dependencies import get_modin from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_pyarrow from narwhals.utils import flatten from narwhals.utils import isinstance_or_issubclass from narwhals.utils import parse_version @@ -405,52 +404,134 @@ def translate_dtype(dtype: Any) -> DType: if str(dtype).startswith("datetime64"): # todo: different time units and time zones return dtypes.Datetime() + if str(dtype).startswith("timestamp["): + # pyarrow-backed datetime + # todo: different time units and time zones + return dtypes.Datetime() if dtype == "object": return dtypes.String() msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) -def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: - # Use the default pandas dtype here - # TODO: maybe this could be configurable? +def get_dtype_backend(dtype: Any, implementation: str) -> str: + if implementation == "pandas": + pd = get_pandas() + if hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype): + return "pyarrow-nullable" + + try: + if isinstance(dtype, pd.core.dtypes.dtypes.BaseMaskedDtype): + return "pandas-nullable" + except AttributeError: # pragma: no cover + # defensive check for old pandas versions + pass + return "numpy" + else: # pragma: no cover + return "numpy" + + +def reverse_translate_dtype( # noqa: PLR0915 + dtype: DType | type[DType], starting_dtype: Any, implementation: str +) -> Any: from narwhals import dtypes + dtype_backend = get_dtype_backend(starting_dtype, implementation) + if isinstance_or_issubclass(dtype, dtypes.Float64): - return "float64" + if dtype_backend == "pyarrow-nullable": + return "Float64[pyarrow]" + if dtype_backend == "pandas-nullable": + return "Float64" + else: + return "float64" if isinstance_or_issubclass(dtype, dtypes.Float32): - return "float32" + if dtype_backend == "pyarrow-nullable": + return "Float32[pyarrow]" + if dtype_backend == "pandas-nullable": + return "Float32" + else: + return "float32" if isinstance_or_issubclass(dtype, dtypes.Int64): - return "int64" + if dtype_backend == "pyarrow-nullable": + return "Int64[pyarrow]" + if dtype_backend == "pandas-nullable": + return "Int64" + else: + return "int64" if isinstance_or_issubclass(dtype, dtypes.Int32): - return "int32" + if dtype_backend == "pyarrow-nullable": + return "Int32[pyarrow]" + if dtype_backend == "pandas-nullable": + return "Int32" + else: + return "int32" if isinstance_or_issubclass(dtype, dtypes.Int16): - return "int16" + if dtype_backend == "pyarrow-nullable": + return "Int16[pyarrow]" + if dtype_backend == "pandas-nullable": + return "Int16" + else: + return "int16" if isinstance_or_issubclass(dtype, dtypes.Int8): - return "int8" + if dtype_backend == "pyarrow-nullable": + return "Int8[pyarrow]" + if dtype_backend == "pandas-nullable": + return "Int8" + else: + return "int8" if isinstance_or_issubclass(dtype, dtypes.UInt64): - return "uint64" + if dtype_backend == "pyarrow-nullable": + return "UInt64[pyarrow]" + if dtype_backend == "pandas-nullable": + return "UInt64" + else: + return "uint64" if isinstance_or_issubclass(dtype, dtypes.UInt32): - return "uint32" + if dtype_backend == "pyarrow-nullable": + return "UInt32[pyarrow]" + if dtype_backend == "pandas-nullable": + return "UInt32" + else: + return "uint32" if isinstance_or_issubclass(dtype, dtypes.UInt16): - return "uint16" + if dtype_backend == "pyarrow-nullable": + return "UInt16[pyarrow]" + if dtype_backend == "pandas-nullable": + return "UInt16" + else: + return "uint16" if isinstance_or_issubclass(dtype, dtypes.UInt8): - return "uint8" + if dtype_backend == "pyarrow-nullable": + return "UInt8[pyarrow]" + if dtype_backend == "pandas-nullable": + return "UInt8" + else: + return "uint8" if isinstance_or_issubclass(dtype, dtypes.String): - pd = get_pandas() - - if pd is not None and parse_version(pd.__version__) >= parse_version("2.0.0"): - if get_pyarrow() is not None: - return "string[pyarrow]" - return "string[python]" # pragma: no cover - return "object" # pragma: no cover + if dtype_backend == "pyarrow-nullable": + return "string[pyarrow]" + if dtype_backend == "pandas-nullable": + return "string" + else: + return object if isinstance_or_issubclass(dtype, dtypes.Boolean): - return "bool" + if dtype_backend == "pyarrow-nullable": + return "boolean[pyarrow]" + if dtype_backend == "pandas-nullable": + return "boolean" + else: + return "bool" if isinstance_or_issubclass(dtype, dtypes.Categorical): + # todo: is there no pyarrow-backed categorical? + # or at least, convert_dtypes(dtype_backend='pyarrow') doesn't + # convert to it? return "category" if isinstance_or_issubclass(dtype, dtypes.Datetime): # todo: different time units and time zones - return "datetime64[us]" + if dtype_backend == "pyarrow-nullable": + return "timestamp[ns][pyarrow]" + return "datetime64[ns]" msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 0a98cf568..95ffa8307 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -28,7 +28,7 @@ def get_cudf() -> Any: return sys.modules.get("cudf", None) -def get_pyarrow() -> Any: +def get_pyarrow() -> Any: # pragma: no cover """Get pyarrow module (if already imported - else return None).""" return sys.modules.get("pyarrow", None) diff --git a/tests/test_series.py b/tests/test_series.py index 0881fae7a..4a9fe7364 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -325,6 +325,46 @@ def test_cast() -> None: df["o"].cast(nw.Categorical), ).schema assert result == expected + df = nw.from_native(df.to_pandas().convert_dtypes()) # type: ignore[assignment] + result_pd = df.select( + df["a"].cast(nw.Int32), + df["b"].cast(nw.Int16), + df["c"].cast(nw.Int8), + df["d"].cast(nw.Int64), + df["e"].cast(nw.UInt32), + df["f"].cast(nw.UInt16), + df["g"].cast(nw.UInt8), + df["h"].cast(nw.UInt64), + df["i"].cast(nw.Float32), + df["j"].cast(nw.Float64), + df["k"].cast(nw.String), + df["l"].cast(nw.Datetime), + df["m"].cast(nw.Int8), + df["n"].cast(nw.Boolean), + df["o"].cast(nw.Categorical), + ).schema + assert result == expected + if parse_version(pd.__version__) < parse_version("2.0.0"): # pragma: no cover + return + df = nw.from_native(df.to_pandas().convert_dtypes(dtype_backend="pyarrow")) # type: ignore[assignment] + result_pd = df.select( + df["a"].cast(nw.Int32), + df["b"].cast(nw.Int16), + df["c"].cast(nw.Int8), + df["d"].cast(nw.Int64), + df["e"].cast(nw.UInt32), + df["f"].cast(nw.UInt16), + df["g"].cast(nw.UInt8), + df["h"].cast(nw.UInt64), + df["i"].cast(nw.Float32), + df["j"].cast(nw.Float64), + df["k"].cast(nw.String), + df["l"].cast(nw.Datetime), + df["m"].cast(nw.Int8), + df["n"].cast(nw.Boolean), + df["o"].cast(nw.Categorical), + ).schema + assert result == expected def test_to_numpy() -> None: @@ -456,3 +496,11 @@ def test_zip_with(df_raw: Any, mask: Any, expected: Any) -> None: result = series1.zip_with(mask, series2) expected = nw.Series(expected) assert result == expected + + +def test_cast_string() -> None: + s_pd = pd.Series([1, 2]).convert_dtypes() + s = nw.from_native(s_pd, series_only=True) + s = s.cast(nw.String) + result = nw.to_native(s) + assert result.dtype in ("string", object)