Skip to content

Commit

Permalink
enh: get_native_namespace accepts native objects (#1520)
Browse files Browse the repository at this point in the history
* enh: more flexible get_native_namespace

* add test
  • Loading branch information
MarcoGorelli authored Dec 6, 2024
1 parent 0bb191c commit 3bad94d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
34 changes: 32 additions & 2 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
from narwhals.utils import Version

if TYPE_CHECKING:
import pandas as pd
import polars as pl
import pyarrow as pa

from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.series import Series
Expand Down Expand Up @@ -749,7 +753,18 @@ def _from_native_impl( # noqa: PLR0915
return native_object


def get_native_namespace(obj: DataFrame[Any] | LazyFrame[Any] | Series[Any]) -> Any:
def get_native_namespace(
obj: DataFrame[Any]
| LazyFrame[Any]
| Series[Any]
| pd.DataFrame
| pd.Series
| pl.DataFrame
| pl.LazyFrame
| pl.Series
| pa.Table
| pa.ChunkedArray,
) -> Any:
"""Get native namespace from object.
Arguments:
Expand All @@ -769,7 +784,22 @@ def get_native_namespace(obj: DataFrame[Any] | LazyFrame[Any] | Series[Any]) ->
>>> nw.get_native_namespace(df)
<module 'polars'...>
"""
return obj.__native_namespace__()
if hasattr(obj, "__native_namespace__"):
return obj.__native_namespace__()
if is_pandas_dataframe(obj) or is_pandas_series(obj):
return get_pandas()
if is_modin_dataframe(obj) or is_modin_series(obj): # pragma: no cover
return get_modin()
if is_pyarrow_table(obj) or is_pyarrow_chunked_array(obj):
return get_pyarrow()
if is_cudf_dataframe(obj) or is_cudf_series(obj): # pragma: no cover
return get_cudf()
if is_dask_dataframe(obj): # pragma: no cover
return get_dask()
if is_polars_dataframe(obj) or is_polars_lazyframe(obj) or is_polars_series(obj):
return get_polars()
msg = f"Could not get native namespace from object of type: {type(obj)}"
raise TypeError(msg)


def narwhalify(
Expand Down
13 changes: 13 additions & 0 deletions tests/translate/get_native_namespace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,27 @@
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw


def test_native_namespace() -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3]}))
assert nw.get_native_namespace(df) is pl
assert nw.get_native_namespace(df.to_native()) is pl
assert nw.get_native_namespace(df.lazy().to_native()) is pl
assert nw.get_native_namespace(df["a"].to_native()) is pl
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}))
assert nw.get_native_namespace(df) is pd
assert nw.get_native_namespace(df.to_native()) is pd
assert nw.get_native_namespace(df["a"].to_native()) is pd
df = nw.from_native(pa.table({"a": [1, 2, 3]}))
assert nw.get_native_namespace(df) is pa
assert nw.get_native_namespace(df.to_native()) is pa
assert nw.get_native_namespace(df["a"].to_native()) is pa


def test_get_native_namespace_invalid() -> None:
with pytest.raises(TypeError, match="Could not get native namespace"):
nw.get_native_namespace(1)

0 comments on commit 3bad94d

Please sign in to comment.