From 5e49378721545c619b02a3fdc6aeb1bc56c426bd Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 27 Sep 2024 17:38:43 +0200 Subject: [PATCH] feat: Initial support for nested dtypes (List, Array, Struct) (#1083) --- docs/api-reference/dtypes.md | 3 ++ narwhals/__init__.py | 6 +++ narwhals/_arrow/namespace.py | 3 ++ narwhals/_arrow/utils.py | 15 +++++++ narwhals/_dask/expr.py | 4 +- narwhals/_dask/namespace.py | 7 ++- narwhals/_dask/utils.py | 11 ++++- narwhals/_duckdb/dataframe.py | 8 ++++ narwhals/_ibis/dataframe.py | 4 ++ narwhals/_pandas_like/namespace.py | 3 ++ narwhals/_pandas_like/utils.py | 15 +++++++ narwhals/_polars/namespace.py | 3 ++ narwhals/_polars/utils.py | 15 +++++++ narwhals/dtypes.py | 9 ++++ narwhals/stable/v1.py | 6 +++ tests/frame/schema_test.py | 70 +++++++++++++++++++++++++++--- 16 files changed, 172 insertions(+), 10 deletions(-) diff --git a/docs/api-reference/dtypes.md b/docs/api-reference/dtypes.md index c21b5c766..eb96608a6 100644 --- a/docs/api-reference/dtypes.md +++ b/docs/api-reference/dtypes.md @@ -4,6 +4,9 @@ handler: python options: members: + - Array + - List + - Struct - Int64 - Int32 - Int16 diff --git a/narwhals/__init__.py b/narwhals/__init__.py index a5f95cf70..8b8529c06 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -3,6 +3,7 @@ from narwhals import stable from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame +from narwhals.dtypes import Array from narwhals.dtypes import Boolean from narwhals.dtypes import Categorical from narwhals.dtypes import Date @@ -15,8 +16,10 @@ from narwhals.dtypes import Int16 from narwhals.dtypes import Int32 from narwhals.dtypes import Int64 +from narwhals.dtypes import List from narwhals.dtypes import Object from narwhals.dtypes import String +from narwhals.dtypes import Struct from narwhals.dtypes import UInt8 from narwhals.dtypes import UInt16 from narwhals.dtypes import UInt32 @@ -107,6 +110,9 @@ "String", "Datetime", "Duration", + "Struct", + "Array", + "List", "Date", "narwhalify", "show_versions", diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 5a608f12e..6f699fa2a 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -44,6 +44,9 @@ class ArrowNamespace: Datetime = dtypes.Datetime Duration = dtypes.Duration Date = dtypes.Date + List = dtypes.List + Struct = dtypes.Struct + Array = dtypes.Array def _create_expr_from_callable( self, diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index ddf7a8639..7d6844e57 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -54,6 +54,12 @@ def translate_dtype(dtype: Any) -> dtypes.DType: return dtypes.Duration() if pa.types.is_dictionary(dtype): return dtypes.Categorical() + if pa.types.is_struct(dtype): + return dtypes.Struct() + if pa.types.is_list(dtype) or pa.types.is_large_list(dtype): + return dtypes.List() + if pa.types.is_fixed_size_list(dtype): + return dtypes.Array() return dtypes.Unknown() # pragma: no cover @@ -96,6 +102,15 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: return pa.duration("us") if isinstance_or_issubclass(dtype, dtypes.Date): return pa.date32() + if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover + msg = "Converting to List dtype is not supported yet" + return NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover + msg = "Converting to Struct dtype is not supported yet" + return NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + msg = "Converting to Array dtype is not supported yet" + return NotImplementedError(msg) msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index a6eb17566..eda0fd589 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -9,7 +9,7 @@ from narwhals._dask.utils import add_row_index from narwhals._dask.utils import maybe_evaluate -from narwhals._dask.utils import reverse_translate_dtype +from narwhals._dask.utils import narwhals_to_native_dtype from narwhals.utils import generate_unique_token if TYPE_CHECKING: @@ -700,7 +700,7 @@ def cast( dtype: DType | type[DType], ) -> Self: def func(_input: Any, dtype: DType | type[DType]) -> Any: - dtype = reverse_translate_dtype(dtype) + dtype = narwhals_to_native_dtype(dtype) return _input.astype(dtype) return self._from_call( diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index a683b39b1..7d661f063 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -13,7 +13,7 @@ from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace -from narwhals._dask.utils import reverse_translate_dtype +from narwhals._dask.utils import narwhals_to_native_dtype from narwhals._dask.utils import validate_comparand from narwhals._expression_parsing import combine_root_names from narwhals._expression_parsing import parse_into_exprs @@ -45,6 +45,9 @@ class DaskNamespace: Datetime = dtypes.Datetime Duration = dtypes.Duration Date = dtypes.Date + List = dtypes.List + Struct = dtypes.Struct + Array = dtypes.Array @property def selectors(self) -> DaskSelectorNamespace: @@ -83,7 +86,7 @@ def lit(self, value: Any, dtype: dtypes.DType | None) -> DaskExpr: def convert_if_dtype( series: dask_expr.Series, dtype: DType | type[DType] ) -> dask_expr.Series: - return series.astype(reverse_translate_dtype(dtype)) if dtype else series + return series.astype(narwhals_to_native_dtype(dtype)) if dtype else series return DaskExpr( lambda df: [ diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 274044979..02dedab4e 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -83,7 +83,7 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None: raise RuntimeError(msg) -def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: +def narwhals_to_native_dtype(dtype: DType | type[DType]) -> Any: from narwhals import dtypes if isinstance_or_issubclass(dtype, dtypes.Float64): @@ -122,6 +122,15 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: return "datetime64[us]" if isinstance_or_issubclass(dtype, dtypes.Duration): return "timedelta64[ns]" + if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover + msg = "Converting to List dtype is not supported yet" + return NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover + msg = "Converting to Struct dtype is not supported yet" + return NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + msg = "Converting to Array dtype is not supported yet" + return NotImplementedError(msg) msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 2263c3fc7..099a91b72 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any @@ -17,6 +18,7 @@ def map_duckdb_dtype_to_narwhals_dtype( duckdb_dtype: Any, ) -> dtypes.DType: + duckdb_dtype = str(duckdb_dtype) if duckdb_dtype == "BIGINT": return dtypes.Int64() if duckdb_dtype == "INTEGER": @@ -47,6 +49,12 @@ def map_duckdb_dtype_to_narwhals_dtype( return dtypes.Boolean() if duckdb_dtype == "INTERVAL": return dtypes.Duration() + if duckdb_dtype.startswith("STRUCT"): + return dtypes.Struct() + if re.match(r"\w+\[\]", duckdb_dtype): + return dtypes.List() + if re.match(r"\w+\[\d+\]", duckdb_dtype): + return dtypes.Array() return dtypes.Unknown() diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index fb7bedbf1..f0dc8f6eb 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -44,6 +44,10 @@ def map_ibis_dtype_to_narwhals_dtype( return dtypes.Date() if ibis_dtype.is_timestamp(): return dtypes.Datetime() + if ibis_dtype.is_array(): + return dtypes.List() + if ibis_dtype.is_struct(): + return dtypes.Struct() return dtypes.Unknown() # pragma: no cover diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index b60d0dcce..357ef80ab 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -44,6 +44,9 @@ class PandasLikeNamespace: Datetime = dtypes.Datetime Duration = dtypes.Duration Date = dtypes.Date + List = dtypes.List + Struct = dtypes.Struct + Array = dtypes.Array @property def selectors(self) -> PandasSelectorNamespace: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index aadb438e2..5745ffd8a 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -256,6 +256,12 @@ def translate_dtype(column: Any) -> DType: return dtypes.Duration() if dtype == "date32[day][pyarrow]": return dtypes.Date() + if dtype.startswith(("large_list", "list")): + return dtypes.List() + if dtype.startswith("fixed_size_list"): + return dtypes.Array() + if dtype.startswith("struct"): + return dtypes.Struct() if dtype == "object": if ( # pragma: no cover TODO(unassigned): why does this show as uncovered? idx := getattr(column, "first_valid_index", lambda: None)() @@ -423,6 +429,15 @@ def narwhals_to_native_dtype( # noqa: PLR0915 if isinstance_or_issubclass(dtype, dtypes.Enum): msg = "Converting to Enum is not (yet) supported" raise NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover + msg = "Converting to List dtype is not supported yet" + return NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover + msg = "Converting to Struct dtype is not supported yet" + return NotImplementedError(msg) + if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover + msg = "Converting to Array dtype is not supported yet" + return NotImplementedError(msg) msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index f25e8b81f..275c104fc 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -40,6 +40,9 @@ class PolarsNamespace: Datetime = dtypes.Datetime Duration = dtypes.Duration Date = dtypes.Date + List = dtypes.List + Struct = dtypes.Struct + Array = dtypes.Array def __init__(self, *, backend_version: tuple[int, ...]) -> None: self._backend_version = backend_version diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index e6fb36859..4a9809fc4 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -65,6 +65,12 @@ def translate_dtype(dtype: Any) -> dtypes.DType: return dtypes.Duration() if dtype == pl.Date: return dtypes.Date() + if dtype == pl.Struct: + return dtypes.Struct() + if dtype == pl.List: + return dtypes.List() + if dtype == pl.Array: + return dtypes.Array() return dtypes.Unknown() @@ -110,6 +116,15 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: return pl.Duration() if dtype == dtypes.Date: return pl.Date() + if dtype == dtypes.List: # pragma: no cover + msg = "Converting to List dtype is not supported yet" + return NotImplementedError(msg) + if dtype == dtypes.Struct: # pragma: no cover + msg = "Converting to Struct dtype is not supported yet" + return NotImplementedError(msg) + if dtype == dtypes.Array: # pragma: no cover + msg = "Converting to Array dtype is not supported yet" + return NotImplementedError(msg) return pl.Unknown() # pragma: no cover diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 4d8da4293..2d5de0f16 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -83,4 +83,13 @@ class Categorical(DType): ... class Enum(DType): ... +class Struct(DType): ... + + +class List(DType): ... + + +class Array(DType): ... + + class Date(TemporalType): ... diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index cc0a42bed..b54203ca2 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -14,6 +14,7 @@ from narwhals import selectors from narwhals.dataframe import DataFrame as NwDataFrame from narwhals.dataframe import LazyFrame as NwLazyFrame +from narwhals.dtypes import Array from narwhals.dtypes import Boolean from narwhals.dtypes import Categorical from narwhals.dtypes import Date @@ -26,8 +27,10 @@ from narwhals.dtypes import Int16 from narwhals.dtypes import Int32 from narwhals.dtypes import Int64 +from narwhals.dtypes import List from narwhals.dtypes import Object from narwhals.dtypes import String +from narwhals.dtypes import Struct from narwhals.dtypes import UInt8 from narwhals.dtypes import UInt16 from narwhals.dtypes import UInt32 @@ -1970,6 +1973,9 @@ def from_dict( "String", "Datetime", "Duration", + "Struct", + "Array", + "List", "Date", "narwhalify", "show_versions", diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index d7ba69ab2..3aa341e0f 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -4,6 +4,7 @@ from datetime import timezone from typing import Any +import duckdb import pandas as pd import polars as pl import pytest @@ -95,6 +96,8 @@ def test_dtypes() -> None: "p": ["a"], "q": [timedelta(1)], "r": ["a"], + "s": [[1, 2]], + "u": [{"a": 1}], }, schema={ "a": pl.Int64, @@ -115,6 +118,8 @@ def test_dtypes() -> None: "p": pl.Categorical, "q": pl.Duration, "r": pl.Enum(["a", "b"]), + "s": pl.List(pl.Int64), + "u": pl.Struct({"a": pl.Int64}), }, ) df_from_pl = nw.from_native(df_pl, eager_only=True) @@ -137,6 +142,8 @@ def test_dtypes() -> None: "p": nw.Categorical, "q": nw.Duration, "r": nw.Enum, + "s": nw.List, + "u": nw.Struct, } assert df_from_pl.schema == df_from_pl.collect_schema() @@ -164,11 +171,6 @@ def test_unknown_dtype() -> None: assert nw.from_native(df).schema == {"a": nw.Unknown} -def test_unknown_dtype_polars() -> None: - df = pl.DataFrame({"a": [[1, 2, 3]]}) - assert nw.from_native(df).schema == {"a": nw.Unknown} - - def test_hash() -> None: assert nw.Int64() in {nw.Int64, nw.Int32} @@ -199,3 +201,61 @@ def test_from_non_hashable_column_name() -> None: df = nw.from_native(df, eager_only=True) assert df.columns == ["pizza", ["a", "b"]] assert df["pizza"].dtype == nw.Int64 + + +@pytest.mark.skipif( + parse_version(pd.__version__) < parse_version("2.2.0"), + reason="too old for pyarrow types", +) +def test_nested_dtypes() -> None: + df = pl.DataFrame( + {"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]}, + schema_overrides={"b": pl.Array(pl.Int64, 2)}, + ).to_pandas(use_pyarrow_extension_array=True) + nwdf = nw.from_native(df) + assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct} + df = pl.DataFrame( + {"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]}, + schema_overrides={"b": pl.Array(pl.Int64, 2)}, + ) + nwdf = nw.from_native(df) + assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct} + df = pl.DataFrame( + {"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]}, + schema_overrides={"b": pl.Array(pl.Int64, 2)}, + ).to_arrow() + nwdf = nw.from_native(df) + assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct} + df = duckdb.sql("select * from df") + nwdf = nw.from_native(df) + assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct} + + +def test_nested_dtypes_ibis() -> None: # pragma: no cover + ibis = pytest.importorskip("ibis") + df = pl.DataFrame( + {"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]}, + schema_overrides={"b": pl.Array(pl.Int64, 2)}, + ) + tbl = ibis.memtable(df[["a", "c"]]) + nwdf = nw.from_native(tbl) + assert nwdf.schema == {"a": nw.List, "c": nw.Struct} + + +@pytest.mark.skipif( + parse_version(pd.__version__) < parse_version("2.2.0"), + reason="too old for pyarrow types", +) +def test_nested_dtypes_dask() -> None: + pytest.importorskip("dask") + pytest.importorskip("dask_expr", exc_type=ImportError) + import dask.dataframe as dd + + df = dd.from_pandas( + pl.DataFrame( + {"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]}, + schema_overrides={"b": pl.Array(pl.Int64, 2)}, + ).to_pandas(use_pyarrow_extension_array=True) + ) + nwdf = nw.from_native(df) + assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct}