Skip to content

Commit

Permalink
Preserve dtype backend (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jun 3, 2024
1 parent 1dbfe40 commit ad76129
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 25 deletions.
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def cast(
dtype: Any,
) -> Self:
ser = self._series
dtype = reverse_translate_dtype(dtype)
dtype = reverse_translate_dtype(dtype, ser.dtype, self._implementation)
return self._from_series(ser.astype(dtype))

def item(self) -> Any:
Expand Down
127 changes: 104 additions & 23 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow
from narwhals.utils import flatten
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version
Expand Down Expand Up @@ -405,52 +404,134 @@ def translate_dtype(dtype: Any) -> DType:
if str(dtype).startswith("datetime64"):
# todo: different time units and time zones
return dtypes.Datetime()
if str(dtype).startswith("timestamp["):
# pyarrow-backed datetime
# todo: different time units and time zones
return dtypes.Datetime()
if dtype == "object":
return dtypes.String()
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)


def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
# Use the default pandas dtype here
# TODO: maybe this could be configurable?
def get_dtype_backend(dtype: Any, implementation: str) -> str:
if implementation == "pandas":
pd = get_pandas()
if hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype):
return "pyarrow-nullable"

try:
if isinstance(dtype, pd.core.dtypes.dtypes.BaseMaskedDtype):
return "pandas-nullable"
except AttributeError: # pragma: no cover
# defensive check for old pandas versions
pass
return "numpy"
else: # pragma: no cover
return "numpy"


def reverse_translate_dtype( # noqa: PLR0915
dtype: DType | type[DType], starting_dtype: Any, implementation: str
) -> Any:
from narwhals import dtypes

dtype_backend = get_dtype_backend(starting_dtype, implementation)

if isinstance_or_issubclass(dtype, dtypes.Float64):
return "float64"
if dtype_backend == "pyarrow-nullable":
return "Float64[pyarrow]"
if dtype_backend == "pandas-nullable":
return "Float64"
else:
return "float64"
if isinstance_or_issubclass(dtype, dtypes.Float32):
return "float32"
if dtype_backend == "pyarrow-nullable":
return "Float32[pyarrow]"
if dtype_backend == "pandas-nullable":
return "Float32"
else:
return "float32"
if isinstance_or_issubclass(dtype, dtypes.Int64):
return "int64"
if dtype_backend == "pyarrow-nullable":
return "Int64[pyarrow]"
if dtype_backend == "pandas-nullable":
return "Int64"
else:
return "int64"
if isinstance_or_issubclass(dtype, dtypes.Int32):
return "int32"
if dtype_backend == "pyarrow-nullable":
return "Int32[pyarrow]"
if dtype_backend == "pandas-nullable":
return "Int32"
else:
return "int32"
if isinstance_or_issubclass(dtype, dtypes.Int16):
return "int16"
if dtype_backend == "pyarrow-nullable":
return "Int16[pyarrow]"
if dtype_backend == "pandas-nullable":
return "Int16"
else:
return "int16"
if isinstance_or_issubclass(dtype, dtypes.Int8):
return "int8"
if dtype_backend == "pyarrow-nullable":
return "Int8[pyarrow]"
if dtype_backend == "pandas-nullable":
return "Int8"
else:
return "int8"
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return "uint64"
if dtype_backend == "pyarrow-nullable":
return "UInt64[pyarrow]"
if dtype_backend == "pandas-nullable":
return "UInt64"
else:
return "uint64"
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return "uint32"
if dtype_backend == "pyarrow-nullable":
return "UInt32[pyarrow]"
if dtype_backend == "pandas-nullable":
return "UInt32"
else:
return "uint32"
if isinstance_or_issubclass(dtype, dtypes.UInt16):
return "uint16"
if dtype_backend == "pyarrow-nullable":
return "UInt16[pyarrow]"
if dtype_backend == "pandas-nullable":
return "UInt16"
else:
return "uint16"
if isinstance_or_issubclass(dtype, dtypes.UInt8):
return "uint8"
if dtype_backend == "pyarrow-nullable":
return "UInt8[pyarrow]"
if dtype_backend == "pandas-nullable":
return "UInt8"
else:
return "uint8"
if isinstance_or_issubclass(dtype, dtypes.String):
pd = get_pandas()

if pd is not None and parse_version(pd.__version__) >= parse_version("2.0.0"):
if get_pyarrow() is not None:
return "string[pyarrow]"
return "string[python]" # pragma: no cover
return "object" # pragma: no cover
if dtype_backend == "pyarrow-nullable":
return "string[pyarrow]"
if dtype_backend == "pandas-nullable":
return "string"
else:
return object
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return "bool"
if dtype_backend == "pyarrow-nullable":
return "boolean[pyarrow]"
if dtype_backend == "pandas-nullable":
return "boolean"
else:
return "bool"
if isinstance_or_issubclass(dtype, dtypes.Categorical):
# todo: is there no pyarrow-backed categorical?
# or at least, convert_dtypes(dtype_backend='pyarrow') doesn't
# convert to it?
return "category"
if isinstance_or_issubclass(dtype, dtypes.Datetime):
# todo: different time units and time zones
return "datetime64[us]"
if dtype_backend == "pyarrow-nullable":
return "timestamp[ns][pyarrow]"
return "datetime64[ns]"
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
2 changes: 1 addition & 1 deletion narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_cudf() -> Any:
return sys.modules.get("cudf", None)


def get_pyarrow() -> Any:
def get_pyarrow() -> Any: # pragma: no cover
"""Get pyarrow module (if already imported - else return None)."""
return sys.modules.get("pyarrow", None)

Expand Down
48 changes: 48 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,46 @@ def test_cast() -> None:
df["o"].cast(nw.Categorical),
).schema
assert result == expected
df = nw.from_native(df.to_pandas().convert_dtypes()) # type: ignore[assignment]
result_pd = df.select(
df["a"].cast(nw.Int32),
df["b"].cast(nw.Int16),
df["c"].cast(nw.Int8),
df["d"].cast(nw.Int64),
df["e"].cast(nw.UInt32),
df["f"].cast(nw.UInt16),
df["g"].cast(nw.UInt8),
df["h"].cast(nw.UInt64),
df["i"].cast(nw.Float32),
df["j"].cast(nw.Float64),
df["k"].cast(nw.String),
df["l"].cast(nw.Datetime),
df["m"].cast(nw.Int8),
df["n"].cast(nw.Boolean),
df["o"].cast(nw.Categorical),
).schema
assert result == expected
if parse_version(pd.__version__) < parse_version("2.0.0"): # pragma: no cover
return
df = nw.from_native(df.to_pandas().convert_dtypes(dtype_backend="pyarrow")) # type: ignore[assignment]
result_pd = df.select(
df["a"].cast(nw.Int32),
df["b"].cast(nw.Int16),
df["c"].cast(nw.Int8),
df["d"].cast(nw.Int64),
df["e"].cast(nw.UInt32),
df["f"].cast(nw.UInt16),
df["g"].cast(nw.UInt8),
df["h"].cast(nw.UInt64),
df["i"].cast(nw.Float32),
df["j"].cast(nw.Float64),
df["k"].cast(nw.String),
df["l"].cast(nw.Datetime),
df["m"].cast(nw.Int8),
df["n"].cast(nw.Boolean),
df["o"].cast(nw.Categorical),
).schema
assert result == expected


def test_to_numpy() -> None:
Expand Down Expand Up @@ -456,3 +496,11 @@ def test_zip_with(df_raw: Any, mask: Any, expected: Any) -> None:
result = series1.zip_with(mask, series2)
expected = nw.Series(expected)
assert result == expected


def test_cast_string() -> None:
s_pd = pd.Series([1, 2]).convert_dtypes()
s = nw.from_native(s_pd, series_only=True)
s = s.cast(nw.String)
result = nw.to_native(s)
assert result.dtype in ("string", object)

0 comments on commit ad76129

Please sign in to comment.