Skip to content

Commit

Permalink
Merge pull request #83 from MarcoGorelli/preserve-object
Browse files Browse the repository at this point in the history
preserve object dtype
  • Loading branch information
MarcoGorelli authored May 5, 2024
2 parents 620e599 + df85cad commit fcd637f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 27 deletions.
23 changes: 1 addition & 22 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_dataframe_comparand
from narwhals._pandas_like.utils import validate_indices
from narwhals.dependencies import get_pyarrow
from narwhals.utils import flatten
from narwhals.utils import parse_version

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -36,7 +34,7 @@ def __init__(
implementation: str,
) -> None:
self._validate_columns(dataframe.columns)
self._dataframe = self._convert_object_dtypes(dataframe)
self._dataframe = dataframe
self._implementation = implementation

def __narwhals_dataframe__(self) -> Self:
Expand All @@ -50,25 +48,6 @@ def __narwhals_namespace__(self) -> PandasNamespace:

return PandasNamespace(self._implementation)

def _convert_object_dtypes(self, dataframe: Any) -> Any:
schema = dataframe.dtypes
if (schema != object).all():
return dataframe
replacements = {}
for col in dataframe.columns:
if schema[col] != object:
continue
import pandas as pd # todo: generalise across pandas-like implementations

if parse_version(pd.__version__) >= parse_version("2.0.0"):
if get_pyarrow() is not None:
replacements[col] = dataframe[col].astype("string[pyarrow]")
else: # pragma: no cover
replacements[col] = dataframe[col].astype("string[python]")
else: # pragma: no cover
pass
return dataframe.assign(**replacements)

def _validate_columns(self, columns: Sequence[str]) -> None:
if len(columns) != len(set(columns)):
counter = collections.Counter(columns)
Expand Down
6 changes: 1 addition & 5 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,7 @@ def translate_dtype(dtype: Any) -> DType:
if str(dtype).startswith("datetime64"):
# todo: different time units and time zones
return dtypes.Datetime()
if dtype == "object": # pragma: no cover
import pandas as pd

assert parse_version(pd.__version__) < parse_version("2.0.0")
# Should only happen for pandas pre 2.0.0
if dtype == "object":
return dtypes.String()
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
Expand Down
2 changes: 2 additions & 0 deletions tests/hypothesis/test_basic_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd
import polars as pl
import pytest
from hypothesis import given
from hypothesis import strategies as st
from numpy.testing import assert_allclose
Expand All @@ -21,6 +22,7 @@
max_size=3,
),
) # type: ignore[misc]
@pytest.mark.slow()
def test_mean(
integer: st.SearchStrategy[list[int]],
floats: st.SearchStrategy[float],
Expand Down

0 comments on commit fcd637f

Please sign in to comment.