From 4c88de0e1d308555a10d3fea91fbd0ae08d986ce Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sun, 30 Jun 2024 20:34:29 +0100 Subject: [PATCH] fix: compatibility with old numpy versions (#364) --- .github/workflows/extremes.yml | 2 +- narwhals/_pandas_like/utils.py | 45 +++++++++++++++++++++++----------- tests/series/test_common.py | 2 +- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 58c2adc2d..eba7114ce 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -29,7 +29,7 @@ jobs: - name: install-reqs run: python -m pip install --upgrade tox virtualenv setuptools pip -r requirements-dev.txt - name: install-modin - run: python -m pip install pandas==1.1.5 polars==0.20.3 "numpy<=1.21" "pyarrow==11.0.0" tzdata + run: python -m pip install pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata - name: Run pytest run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=50 --runslow - name: Run doctests diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index e1785d262..d679058b3 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -382,31 +382,48 @@ def translate_dtype(column: Any) -> DType: from narwhals import dtypes dtype = column.dtype - if dtype in ("int64", "Int64", "Int64[pyarrow]"): + if str(dtype) in ("int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"): return dtypes.Int64() - if dtype in ("int32", "Int32", "Int32[pyarrow]"): + if str(dtype) in ("int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"): return dtypes.Int32() - if dtype in ("int16", "Int16", "Int16[pyarrow]"): + if str(dtype) in ("int16", "Int16", "Int16[pyarrow]", "int16[pyarrow]"): return dtypes.Int16() - if dtype in ("int8", "Int8", "Int8[pyarrow]"): + if str(dtype) in ("int8", "Int8", "Int8[pyarrow]", "int8[pyarrow]"): return dtypes.Int8() - if dtype in ("uint64", "UInt64", "UInt64[pyarrow]"): + if str(dtype) in ("uint64", "UInt64", "UInt64[pyarrow]", "uint64[pyarrow]"): return dtypes.UInt64() - if dtype in ("uint32", "UInt32", "UInt32[pyarrow]"): + if str(dtype) in ("uint32", "UInt32", "UInt32[pyarrow]", "uint32[pyarrow]"): return dtypes.UInt32() - if dtype in ("uint16", "UInt16", "UInt16[pyarrow]"): + if str(dtype) in ("uint16", "UInt16", "UInt16[pyarrow]", "uint16[pyarrow]"): return dtypes.UInt16() - if dtype in ("uint8", "UInt8", "UInt8[pyarrow]"): + if str(dtype) in ("uint8", "UInt8", "UInt8[pyarrow]", "uint8[pyarrow]"): return dtypes.UInt8() - if dtype in ("float64", "Float64", "Float64[pyarrow]"): + if str(dtype) in ( + "float64", + "Float64", + "Float64[pyarrow]", + "float64[pyarrow]", + "double[pyarrow]", + ): return dtypes.Float64() - if dtype in ("float32", "Float32", "Float32[pyarrow]"): + if str(dtype) in ( + "float32", + "Float32", + "Float32[pyarrow]", + "float32[pyarrow]", + "float[pyarrow]", + ): return dtypes.Float32() - if dtype in ("string", "string[python]", "string[pyarrow]", "large_string[pyarrow]"): + if str(dtype) in ( + "string", + "string[python]", + "string[pyarrow]", + "large_string[pyarrow]", + ): return dtypes.String() - if dtype in ("bool", "boolean", "boolean[pyarrow]"): + if str(dtype) in ("bool", "boolean", "boolean[pyarrow]", "bool[pyarrow]"): return dtypes.Boolean() - if dtype in ("category",) or str(dtype).startswith("dictionary<"): + if str(dtype) in ("category",) or str(dtype).startswith("dictionary<"): return dtypes.Categorical() if str(dtype).startswith("datetime64"): # todo: different time units and time zones @@ -420,7 +437,7 @@ def translate_dtype(column: Any) -> DType: return dtypes.Datetime() if str(dtype) == "date32[day][pyarrow]": return dtypes.Date() - if dtype == "object": + if str(dtype) == "object": if (idx := column.first_valid_index()) is not None and isinstance( column.loc[idx], str ): diff --git a/tests/series/test_common.py b/tests/series/test_common.py index 8eb864f8d..38fb798f4 100644 --- a/tests/series/test_common.py +++ b/tests/series/test_common.py @@ -463,7 +463,7 @@ def test_cast_string() -> None: s = nw.from_native(s_pd, series_only=True) s = s.cast(nw.String) result = nw.to_native(s) - assert result.dtype in ("string", object) + assert str(result.dtype) in ("string", "object", "dtype('O')") df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})