Skip to content

Commit

Permalink
fix: from_native was sometimes raising unnecessarily with `strict=F…
Browse files Browse the repository at this point in the history
…alse` (#1274)

* fix: `from_native` was sometimes raising unnecessarily with `strict=False`

* coverage
  • Loading branch information
MarcoGorelli authored Oct 28, 2024
1 parent e718d1f commit 31e17ee
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 50 deletions.
143 changes: 93 additions & 50 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,27 +395,35 @@ def _from_native_impl( # noqa: PLR0915
# Extensions
if hasattr(native_object, "__narwhals_dataframe__"):
if series_only:
msg = "Cannot only use `series_only` with dataframe"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with dataframe"
raise TypeError(msg)
return native_object
return DataFrame(
native_object.__narwhals_dataframe__(),
level="full",
)
elif hasattr(native_object, "__narwhals_lazyframe__"):
if series_only:
msg = "Cannot only use `series_only` with lazyframe"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with lazyframe"
raise TypeError(msg)
return native_object
if eager_only or eager_or_interchange_only:
msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with lazyframe"
raise TypeError(msg)
if strict:
msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with lazyframe"
raise TypeError(msg)
return native_object
return LazyFrame(
native_object.__narwhals_lazyframe__(),
level="full",
)
elif hasattr(native_object, "__narwhals_series__"):
if not allow_series:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
if strict:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
return native_object
return Series(
native_object.__narwhals_series__(),
level="full",
Expand All @@ -424,8 +432,10 @@ def _from_native_impl( # noqa: PLR0915
# Polars
elif is_polars_dataframe(native_object):
if series_only:
msg = "Cannot only use `series_only` with polars.DataFrame"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with polars.DataFrame"
raise TypeError(msg)
return native_object
pl = get_polars()
return DataFrame(
PolarsDataFrame(
Expand All @@ -437,11 +447,15 @@ def _from_native_impl( # noqa: PLR0915
)
elif is_polars_lazyframe(native_object):
if series_only:
msg = "Cannot only use `series_only` with polars.LazyFrame"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with polars.LazyFrame"
raise TypeError(msg)
return native_object
if eager_only or eager_or_interchange_only:
msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with polars.LazyFrame"
raise TypeError(msg)
if strict:
msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with polars.LazyFrame"
raise TypeError(msg)
return native_object
pl = get_polars()
return LazyFrame(
PolarsLazyFrame(
Expand All @@ -454,8 +468,10 @@ def _from_native_impl( # noqa: PLR0915
elif is_polars_series(native_object):
pl = get_polars()
if not allow_series:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
if strict:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
return native_object
return Series(
PolarsSeries(
native_object,
Expand All @@ -468,8 +484,10 @@ def _from_native_impl( # noqa: PLR0915
# pandas
elif is_pandas_dataframe(native_object):
if series_only:
msg = "Cannot only use `series_only` with dataframe"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with dataframe"
raise TypeError(msg)
return native_object
pd = get_pandas()
return DataFrame(
PandasLikeDataFrame(
Expand All @@ -482,8 +500,10 @@ def _from_native_impl( # noqa: PLR0915
)
elif is_pandas_series(native_object):
if not allow_series:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
if strict:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
return native_object
pd = get_pandas()
return Series(
PandasLikeSeries(
Expand All @@ -499,8 +519,10 @@ def _from_native_impl( # noqa: PLR0915
elif is_modin_dataframe(native_object): # pragma: no cover
mpd = get_modin()
if series_only:
msg = "Cannot only use `series_only` with modin.DataFrame"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with modin.DataFrame"
raise TypeError(msg)
return native_object
return DataFrame(
PandasLikeDataFrame(
native_object,
Expand All @@ -513,8 +535,10 @@ def _from_native_impl( # noqa: PLR0915
elif is_modin_series(native_object): # pragma: no cover
mpd = get_modin()
if not allow_series:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
if strict:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
return native_object
return Series(
PandasLikeSeries(
native_object,
Expand All @@ -529,8 +553,10 @@ def _from_native_impl( # noqa: PLR0915
elif is_cudf_dataframe(native_object): # pragma: no cover
cudf = get_cudf()
if series_only:
msg = "Cannot only use `series_only` with cudf.DataFrame"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with cudf.DataFrame"
raise TypeError(msg)
return native_object
return DataFrame(
PandasLikeDataFrame(
native_object,
Expand All @@ -543,8 +569,10 @@ def _from_native_impl( # noqa: PLR0915
elif is_cudf_series(native_object): # pragma: no cover
cudf = get_cudf()
if not allow_series:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
if strict:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
return native_object
return Series(
PandasLikeSeries(
native_object,
Expand All @@ -559,8 +587,10 @@ def _from_native_impl( # noqa: PLR0915
elif is_pyarrow_table(native_object):
pa = get_pyarrow()
if series_only:
msg = "Cannot only use `series_only` with arrow table"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with arrow table"
raise TypeError(msg)
return native_object
return DataFrame(
ArrowDataFrame(
native_object,
Expand All @@ -572,8 +602,10 @@ def _from_native_impl( # noqa: PLR0915
elif is_pyarrow_chunked_array(native_object):
pa = get_pyarrow()
if not allow_series:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
if strict:
msg = "Please set `allow_series=True`"
raise TypeError(msg)
return native_object
return Series(
ArrowSeries(
native_object,
Expand All @@ -587,11 +619,15 @@ def _from_native_impl( # noqa: PLR0915
# Dask
elif is_dask_dataframe(native_object):
if series_only:
msg = "Cannot only use `series_only` with dask DataFrame"
raise TypeError(msg)
if strict:
msg = "Cannot only use `series_only` with dask DataFrame"
raise TypeError(msg)
return native_object
if eager_only or eager_or_interchange_only:
msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with dask DataFrame"
raise TypeError(msg)
if strict:
msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with dask DataFrame"
raise TypeError(msg)
return native_object
if get_dask_expr() is None: # pragma: no cover
msg = "Please install dask-expr"
raise ImportError(msg)
Expand All @@ -607,10 +643,13 @@ def _from_native_impl( # noqa: PLR0915
# DuckDB
elif is_duckdb_relation(native_object):
if eager_only or series_only: # pragma: no cover
msg = (
"Cannot only use `series_only=True` or `eager_only=False` "
"with DuckDB Relation"
)
if strict:
msg = (
"Cannot only use `series_only=True` or `eager_only=False` "
"with DuckDB Relation"
)
else:
return native_object
raise TypeError(msg)
return DataFrame(
DuckDBInterchangeFrame(native_object, dtypes=dtypes),
Expand All @@ -620,11 +659,13 @@ def _from_native_impl( # noqa: PLR0915
# Ibis
elif is_ibis_table(native_object): # pragma: no cover
if eager_only or series_only:
msg = (
"Cannot only use `series_only=True` or `eager_only=False` "
"with Ibis table"
)
raise TypeError(msg)
if strict:
msg = (
"Cannot only use `series_only=True` or `eager_only=False` "
"with Ibis table"
)
raise TypeError(msg)
return native_object
return DataFrame(
IbisInterchangeFrame(native_object, dtypes=dtypes),
level="interchange",
Expand All @@ -633,11 +674,13 @@ def _from_native_impl( # noqa: PLR0915
# Interchange protocol
elif hasattr(native_object, "__dataframe__"):
if eager_only or series_only:
msg = (
"Cannot only use `series_only=True` or `eager_only=False` "
"with object which only implements __dataframe__"
)
raise TypeError(msg)
if strict:
msg = (
"Cannot only use `series_only=True` or `eager_only=False` "
"with object which only implements __dataframe__"
)
raise TypeError(msg)
return native_object
return DataFrame(
InterchangeFrame(native_object, dtypes=dtypes),
level="interchange",
Expand Down
20 changes: 20 additions & 0 deletions tests/translate/from_native_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def test_eager_only_lazy(dframe: Any, eager_only: Any, context: Any) -> None:
with context:
res = nw.from_native(dframe, eager_only=eager_only)
assert isinstance(res, nw.LazyFrame)
if eager_only:
assert nw.from_native(dframe, eager_only=eager_only, strict=False) is dframe


@pytest.mark.parametrize("dframe", eager_frames)
Expand All @@ -122,6 +124,9 @@ def test_series_only(obj: Any, context: Any) -> None:
with context:
res = nw.from_native(obj, series_only=True)
assert isinstance(res, nw.Series)
assert nw.from_native(obj, series_only=True, strict=False) is obj or isinstance(
res, nw.Series
)


@pytest.mark.parametrize("series", all_series)
Expand All @@ -136,6 +141,8 @@ def test_allow_series(series: Any, allow_series: Any, context: Any) -> None:
with context:
res = nw.from_native(series, allow_series=allow_series)
assert isinstance(res, nw.Series)
if not allow_series:
assert nw.from_native(series, allow_series=allow_series, strict=False) is series


def test_invalid_series_combination() -> None:
Expand Down Expand Up @@ -184,6 +191,7 @@ def test_series_only_dask() -> None:

with pytest.raises(TypeError, match="Cannot only use `series_only`"):
nw.from_native(dframe, series_only=True)
assert nw.from_native(dframe, series_only=True, strict=False) is dframe


@pytest.mark.parametrize(
Expand All @@ -203,6 +211,8 @@ def test_eager_only_lazy_dask(eager_only: Any, context: Any) -> None:
with context:
res = nw.from_native(dframe, eager_only=eager_only)
assert isinstance(res, nw.LazyFrame)
if eager_only:
assert nw.from_native(dframe, eager_only=eager_only, strict=False) is dframe


def test_from_native_strict_false_typing() -> None:
Expand All @@ -214,3 +224,13 @@ def test_from_native_strict_false_typing() -> None:
unstable_nw.from_native(df, strict=False)
unstable_nw.from_native(df, strict=False, eager_only=True)
unstable_nw.from_native(df, strict=False, eager_or_interchange_only=True)


def test_from_mock_interchange_protocol_non_strict() -> None:
class MockDf:
def __dataframe__(self) -> None: # pragma: no cover
pass

mockdf = MockDf()
result = nw.from_native(mockdf, eager_only=True, strict=False)
assert result is mockdf # type: ignore[comparison-overlap]

0 comments on commit 31e17ee

Please sign in to comment.