From 31e17eeb11b3d1e6c87bfe5fbd9fe4d4b88fa882 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 28 Oct 2024 14:36:31 +0000 Subject: [PATCH] fix: `from_native` was sometimes raising unnecessarily with `strict=False` (#1274) * fix: `from_native` was sometimes raising unnecessarily with `strict=False` * coverage --- narwhals/translate.py | 143 ++++++++++++++++++---------- tests/translate/from_native_test.py | 20 ++++ 2 files changed, 113 insertions(+), 50 deletions(-) diff --git a/narwhals/translate.py b/narwhals/translate.py index 331b87d88..a1b0e2323 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -395,27 +395,35 @@ def _from_native_impl( # noqa: PLR0915 # Extensions if hasattr(native_object, "__narwhals_dataframe__"): if series_only: - msg = "Cannot only use `series_only` with dataframe" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with dataframe" + raise TypeError(msg) + return native_object return DataFrame( native_object.__narwhals_dataframe__(), level="full", ) elif hasattr(native_object, "__narwhals_lazyframe__"): if series_only: - msg = "Cannot only use `series_only` with lazyframe" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with lazyframe" + raise TypeError(msg) + return native_object if eager_only or eager_or_interchange_only: - msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with lazyframe" - raise TypeError(msg) + if strict: + msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with lazyframe" + raise TypeError(msg) + return native_object return LazyFrame( native_object.__narwhals_lazyframe__(), level="full", ) elif hasattr(native_object, "__narwhals_series__"): if not allow_series: - msg = "Please set `allow_series=True`" - raise TypeError(msg) + if strict: + msg = "Please set `allow_series=True`" + raise TypeError(msg) + return native_object return Series( native_object.__narwhals_series__(), level="full", @@ -424,8 +432,10 @@ def _from_native_impl( # noqa: PLR0915 # Polars elif is_polars_dataframe(native_object): if series_only: - msg = "Cannot only use `series_only` with polars.DataFrame" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with polars.DataFrame" + raise TypeError(msg) + return native_object pl = get_polars() return DataFrame( PolarsDataFrame( @@ -437,11 +447,15 @@ def _from_native_impl( # noqa: PLR0915 ) elif is_polars_lazyframe(native_object): if series_only: - msg = "Cannot only use `series_only` with polars.LazyFrame" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with polars.LazyFrame" + raise TypeError(msg) + return native_object if eager_only or eager_or_interchange_only: - msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with polars.LazyFrame" - raise TypeError(msg) + if strict: + msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with polars.LazyFrame" + raise TypeError(msg) + return native_object pl = get_polars() return LazyFrame( PolarsLazyFrame( @@ -454,8 +468,10 @@ def _from_native_impl( # noqa: PLR0915 elif is_polars_series(native_object): pl = get_polars() if not allow_series: - msg = "Please set `allow_series=True`" - raise TypeError(msg) + if strict: + msg = "Please set `allow_series=True`" + raise TypeError(msg) + return native_object return Series( PolarsSeries( native_object, @@ -468,8 +484,10 @@ def _from_native_impl( # noqa: PLR0915 # pandas elif is_pandas_dataframe(native_object): if series_only: - msg = "Cannot only use `series_only` with dataframe" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with dataframe" + raise TypeError(msg) + return native_object pd = get_pandas() return DataFrame( PandasLikeDataFrame( @@ -482,8 +500,10 @@ def _from_native_impl( # noqa: PLR0915 ) elif is_pandas_series(native_object): if not allow_series: - msg = "Please set `allow_series=True`" - raise TypeError(msg) + if strict: + msg = "Please set `allow_series=True`" + raise TypeError(msg) + return native_object pd = get_pandas() return Series( PandasLikeSeries( @@ -499,8 +519,10 @@ def _from_native_impl( # noqa: PLR0915 elif is_modin_dataframe(native_object): # pragma: no cover mpd = get_modin() if series_only: - msg = "Cannot only use `series_only` with modin.DataFrame" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with modin.DataFrame" + raise TypeError(msg) + return native_object return DataFrame( PandasLikeDataFrame( native_object, @@ -513,8 +535,10 @@ def _from_native_impl( # noqa: PLR0915 elif is_modin_series(native_object): # pragma: no cover mpd = get_modin() if not allow_series: - msg = "Please set `allow_series=True`" - raise TypeError(msg) + if strict: + msg = "Please set `allow_series=True`" + raise TypeError(msg) + return native_object return Series( PandasLikeSeries( native_object, @@ -529,8 +553,10 @@ def _from_native_impl( # noqa: PLR0915 elif is_cudf_dataframe(native_object): # pragma: no cover cudf = get_cudf() if series_only: - msg = "Cannot only use `series_only` with cudf.DataFrame" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with cudf.DataFrame" + raise TypeError(msg) + return native_object return DataFrame( PandasLikeDataFrame( native_object, @@ -543,8 +569,10 @@ def _from_native_impl( # noqa: PLR0915 elif is_cudf_series(native_object): # pragma: no cover cudf = get_cudf() if not allow_series: - msg = "Please set `allow_series=True`" - raise TypeError(msg) + if strict: + msg = "Please set `allow_series=True`" + raise TypeError(msg) + return native_object return Series( PandasLikeSeries( native_object, @@ -559,8 +587,10 @@ def _from_native_impl( # noqa: PLR0915 elif is_pyarrow_table(native_object): pa = get_pyarrow() if series_only: - msg = "Cannot only use `series_only` with arrow table" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with arrow table" + raise TypeError(msg) + return native_object return DataFrame( ArrowDataFrame( native_object, @@ -572,8 +602,10 @@ def _from_native_impl( # noqa: PLR0915 elif is_pyarrow_chunked_array(native_object): pa = get_pyarrow() if not allow_series: - msg = "Please set `allow_series=True`" - raise TypeError(msg) + if strict: + msg = "Please set `allow_series=True`" + raise TypeError(msg) + return native_object return Series( ArrowSeries( native_object, @@ -587,11 +619,15 @@ def _from_native_impl( # noqa: PLR0915 # Dask elif is_dask_dataframe(native_object): if series_only: - msg = "Cannot only use `series_only` with dask DataFrame" - raise TypeError(msg) + if strict: + msg = "Cannot only use `series_only` with dask DataFrame" + raise TypeError(msg) + return native_object if eager_only or eager_or_interchange_only: - msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with dask DataFrame" - raise TypeError(msg) + if strict: + msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with dask DataFrame" + raise TypeError(msg) + return native_object if get_dask_expr() is None: # pragma: no cover msg = "Please install dask-expr" raise ImportError(msg) @@ -607,10 +643,13 @@ def _from_native_impl( # noqa: PLR0915 # DuckDB elif is_duckdb_relation(native_object): if eager_only or series_only: # pragma: no cover - msg = ( - "Cannot only use `series_only=True` or `eager_only=False` " - "with DuckDB Relation" - ) + if strict: + msg = ( + "Cannot only use `series_only=True` or `eager_only=False` " + "with DuckDB Relation" + ) + else: + return native_object raise TypeError(msg) return DataFrame( DuckDBInterchangeFrame(native_object, dtypes=dtypes), @@ -620,11 +659,13 @@ def _from_native_impl( # noqa: PLR0915 # Ibis elif is_ibis_table(native_object): # pragma: no cover if eager_only or series_only: - msg = ( - "Cannot only use `series_only=True` or `eager_only=False` " - "with Ibis table" - ) - raise TypeError(msg) + if strict: + msg = ( + "Cannot only use `series_only=True` or `eager_only=False` " + "with Ibis table" + ) + raise TypeError(msg) + return native_object return DataFrame( IbisInterchangeFrame(native_object, dtypes=dtypes), level="interchange", @@ -633,11 +674,13 @@ def _from_native_impl( # noqa: PLR0915 # Interchange protocol elif hasattr(native_object, "__dataframe__"): if eager_only or series_only: - msg = ( - "Cannot only use `series_only=True` or `eager_only=False` " - "with object which only implements __dataframe__" - ) - raise TypeError(msg) + if strict: + msg = ( + "Cannot only use `series_only=True` or `eager_only=False` " + "with object which only implements __dataframe__" + ) + raise TypeError(msg) + return native_object return DataFrame( InterchangeFrame(native_object, dtypes=dtypes), level="interchange", diff --git a/tests/translate/from_native_test.py b/tests/translate/from_native_test.py index c996b2a2b..53a350878 100644 --- a/tests/translate/from_native_test.py +++ b/tests/translate/from_native_test.py @@ -99,6 +99,8 @@ def test_eager_only_lazy(dframe: Any, eager_only: Any, context: Any) -> None: with context: res = nw.from_native(dframe, eager_only=eager_only) assert isinstance(res, nw.LazyFrame) + if eager_only: + assert nw.from_native(dframe, eager_only=eager_only, strict=False) is dframe @pytest.mark.parametrize("dframe", eager_frames) @@ -122,6 +124,9 @@ def test_series_only(obj: Any, context: Any) -> None: with context: res = nw.from_native(obj, series_only=True) assert isinstance(res, nw.Series) + assert nw.from_native(obj, series_only=True, strict=False) is obj or isinstance( + res, nw.Series + ) @pytest.mark.parametrize("series", all_series) @@ -136,6 +141,8 @@ def test_allow_series(series: Any, allow_series: Any, context: Any) -> None: with context: res = nw.from_native(series, allow_series=allow_series) assert isinstance(res, nw.Series) + if not allow_series: + assert nw.from_native(series, allow_series=allow_series, strict=False) is series def test_invalid_series_combination() -> None: @@ -184,6 +191,7 @@ def test_series_only_dask() -> None: with pytest.raises(TypeError, match="Cannot only use `series_only`"): nw.from_native(dframe, series_only=True) + assert nw.from_native(dframe, series_only=True, strict=False) is dframe @pytest.mark.parametrize( @@ -203,6 +211,8 @@ def test_eager_only_lazy_dask(eager_only: Any, context: Any) -> None: with context: res = nw.from_native(dframe, eager_only=eager_only) assert isinstance(res, nw.LazyFrame) + if eager_only: + assert nw.from_native(dframe, eager_only=eager_only, strict=False) is dframe def test_from_native_strict_false_typing() -> None: @@ -214,3 +224,13 @@ def test_from_native_strict_false_typing() -> None: unstable_nw.from_native(df, strict=False) unstable_nw.from_native(df, strict=False, eager_only=True) unstable_nw.from_native(df, strict=False, eager_or_interchange_only=True) + + +def test_from_mock_interchange_protocol_non_strict() -> None: + class MockDf: + def __dataframe__(self) -> None: # pragma: no cover + pass + + mockdf = MockDf() + result = nw.from_native(mockdf, eager_only=True, strict=False) + assert result is mockdf # type: ignore[comparison-overlap]