Skip to content

Commit

Permalink
fix: cuDF compat (#794)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 15, 2024
1 parent 3f73eb4 commit 7a4a16e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def select(
) -> Self:
if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
# This is a simple slice => fastpath!
return self._from_native_frame(self._native_frame.loc[:, exprs])
return self._from_native_frame(self._native_frame.loc[:, list(exprs)])
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
if not new_series:
# return empty dataframe, like Polars does
Expand Down
15 changes: 14 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,20 @@ def translate_dtype(column: Any) -> DType:
# which is inferred by default.
return dtypes.String()
else:
return dtypes.Object()
df = column.to_frame()
if hasattr(df, "__dataframe__"):
from narwhals._interchange.dataframe import (
map_interchange_dtype_to_narwhals_dtype,
)

try:
return map_interchange_dtype_to_narwhals_dtype(
df.__dataframe__().get_column(0).dtype
)
except Exception: # noqa: BLE001
return dtypes.Object()
else: # pragma: no cover
return dtypes.Object()
return dtypes.Unknown()


Expand Down
6 changes: 4 additions & 2 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def test_schema_comparison() -> None:


def test_object() -> None:
df = pd.DataFrame({"a": [1, 2, 3]}).astype(object)
class Foo: ...

df = pd.DataFrame({"a": [Foo()]}).astype(object)
result = nw.from_native(df).schema
assert result["a"] == nw.Object

Expand All @@ -57,7 +59,7 @@ def test_string_disguised_as_object() -> None:


def test_actual_object(request: Any, constructor_eager: Any) -> None:
if "pyarrow_table" in str(constructor_eager):
if any(x in str(constructor_eager) for x in ("modin", "pyarrow_table")):
request.applymarker(pytest.mark.xfail)

class Foo: ...
Expand Down

0 comments on commit 7a4a16e

Please sign in to comment.