Skip to content

Commit

Permalink
feat: Initial support for nested dtypes (List, Array, Struct) (#1083)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Sep 27, 2024
1 parent a9963a6 commit 5e49378
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/api-reference/dtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
handler: python
options:
members:
- Array
- List
- Struct
- Int64
- Int32
- Int16
Expand Down
6 changes: 6 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from narwhals import stable
from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import Array
from narwhals.dtypes import Boolean
from narwhals.dtypes import Categorical
from narwhals.dtypes import Date
Expand All @@ -15,8 +16,10 @@
from narwhals.dtypes import Int16
from narwhals.dtypes import Int32
from narwhals.dtypes import Int64
from narwhals.dtypes import List
from narwhals.dtypes import Object
from narwhals.dtypes import String
from narwhals.dtypes import Struct
from narwhals.dtypes import UInt8
from narwhals.dtypes import UInt16
from narwhals.dtypes import UInt32
Expand Down Expand Up @@ -107,6 +110,9 @@
"String",
"Datetime",
"Duration",
"Struct",
"Array",
"List",
"Date",
"narwhalify",
"show_versions",
Expand Down
3 changes: 3 additions & 0 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class ArrowNamespace:
Datetime = dtypes.Datetime
Duration = dtypes.Duration
Date = dtypes.Date
List = dtypes.List
Struct = dtypes.Struct
Array = dtypes.Array

def _create_expr_from_callable(
self,
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def translate_dtype(dtype: Any) -> dtypes.DType:
return dtypes.Duration()
if pa.types.is_dictionary(dtype):
return dtypes.Categorical()
if pa.types.is_struct(dtype):
return dtypes.Struct()
if pa.types.is_list(dtype) or pa.types.is_large_list(dtype):
return dtypes.List()
if pa.types.is_fixed_size_list(dtype):
return dtypes.Array()
return dtypes.Unknown() # pragma: no cover


Expand Down Expand Up @@ -96,6 +102,15 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
return pa.duration("us")
if isinstance_or_issubclass(dtype, dtypes.Date):
return pa.date32()
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
msg = "Converting to List dtype is not supported yet"
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
4 changes: 2 additions & 2 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate
from narwhals._dask.utils import reverse_translate_dtype
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals.utils import generate_unique_token

if TYPE_CHECKING:
Expand Down Expand Up @@ -700,7 +700,7 @@ def cast(
dtype: DType | type[DType],
) -> Self:
def func(_input: Any, dtype: DType | type[DType]) -> Any:
dtype = reverse_translate_dtype(dtype)
dtype = narwhals_to_native_dtype(dtype)
return _input.astype(dtype)

return self._from_call(
Expand Down
7 changes: 5 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import reverse_translate_dtype
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals._dask.utils import validate_comparand
from narwhals._expression_parsing import combine_root_names
from narwhals._expression_parsing import parse_into_exprs
Expand Down Expand Up @@ -45,6 +45,9 @@ class DaskNamespace:
Datetime = dtypes.Datetime
Duration = dtypes.Duration
Date = dtypes.Date
List = dtypes.List
Struct = dtypes.Struct
Array = dtypes.Array

@property
def selectors(self) -> DaskSelectorNamespace:
Expand Down Expand Up @@ -83,7 +86,7 @@ def lit(self, value: Any, dtype: dtypes.DType | None) -> DaskExpr:
def convert_if_dtype(
series: dask_expr.Series, dtype: DType | type[DType]
) -> dask_expr.Series:
return series.astype(reverse_translate_dtype(dtype)) if dtype else series
return series.astype(narwhals_to_native_dtype(dtype)) if dtype else series

return DaskExpr(
lambda df: [
Expand Down
11 changes: 10 additions & 1 deletion narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:
raise RuntimeError(msg)


def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
def narwhals_to_native_dtype(dtype: DType | type[DType]) -> Any:
from narwhals import dtypes

if isinstance_or_issubclass(dtype, dtypes.Float64):
Expand Down Expand Up @@ -122,6 +122,15 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
return "datetime64[us]"
if isinstance_or_issubclass(dtype, dtypes.Duration):
return "timedelta64[ns]"
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
msg = "Converting to List dtype is not supported yet"
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)

msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
8 changes: 8 additions & 0 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING
from typing import Any

Expand All @@ -17,6 +18,7 @@
def map_duckdb_dtype_to_narwhals_dtype(
duckdb_dtype: Any,
) -> dtypes.DType:
duckdb_dtype = str(duckdb_dtype)
if duckdb_dtype == "BIGINT":
return dtypes.Int64()
if duckdb_dtype == "INTEGER":
Expand Down Expand Up @@ -47,6 +49,12 @@ def map_duckdb_dtype_to_narwhals_dtype(
return dtypes.Boolean()
if duckdb_dtype == "INTERVAL":
return dtypes.Duration()
if duckdb_dtype.startswith("STRUCT"):
return dtypes.Struct()
if re.match(r"\w+\[\]", duckdb_dtype):
return dtypes.List()
if re.match(r"\w+\[\d+\]", duckdb_dtype):
return dtypes.Array()
return dtypes.Unknown()


Expand Down
4 changes: 4 additions & 0 deletions narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def map_ibis_dtype_to_narwhals_dtype(
return dtypes.Date()
if ibis_dtype.is_timestamp():
return dtypes.Datetime()
if ibis_dtype.is_array():
return dtypes.List()
if ibis_dtype.is_struct():
return dtypes.Struct()
return dtypes.Unknown() # pragma: no cover


Expand Down
3 changes: 3 additions & 0 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class PandasLikeNamespace:
Datetime = dtypes.Datetime
Duration = dtypes.Duration
Date = dtypes.Date
List = dtypes.List
Struct = dtypes.Struct
Array = dtypes.Array

@property
def selectors(self) -> PandasSelectorNamespace:
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ def translate_dtype(column: Any) -> DType:
return dtypes.Duration()
if dtype == "date32[day][pyarrow]":
return dtypes.Date()
if dtype.startswith(("large_list", "list")):
return dtypes.List()
if dtype.startswith("fixed_size_list"):
return dtypes.Array()
if dtype.startswith("struct"):
return dtypes.Struct()
if dtype == "object":
if ( # pragma: no cover TODO(unassigned): why does this show as uncovered?
idx := getattr(column, "first_valid_index", lambda: None)()
Expand Down Expand Up @@ -423,6 +429,15 @@ def narwhals_to_native_dtype( # noqa: PLR0915
if isinstance_or_issubclass(dtype, dtypes.Enum):
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
msg = "Converting to List dtype is not supported yet"
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
return NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
3 changes: 3 additions & 0 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class PolarsNamespace:
Datetime = dtypes.Datetime
Duration = dtypes.Duration
Date = dtypes.Date
List = dtypes.List
Struct = dtypes.Struct
Array = dtypes.Array

def __init__(self, *, backend_version: tuple[int, ...]) -> None:
self._backend_version = backend_version
Expand Down
15 changes: 15 additions & 0 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def translate_dtype(dtype: Any) -> dtypes.DType:
return dtypes.Duration()
if dtype == pl.Date:
return dtypes.Date()
if dtype == pl.Struct:
return dtypes.Struct()
if dtype == pl.List:
return dtypes.List()
if dtype == pl.Array:
return dtypes.Array()
return dtypes.Unknown()


Expand Down Expand Up @@ -110,6 +116,15 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
return pl.Duration()
if dtype == dtypes.Date:
return pl.Date()
if dtype == dtypes.List: # pragma: no cover
msg = "Converting to List dtype is not supported yet"
return NotImplementedError(msg)
if dtype == dtypes.Struct: # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
return NotImplementedError(msg)
if dtype == dtypes.Array: # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
return pl.Unknown() # pragma: no cover


Expand Down
9 changes: 9 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,13 @@ class Categorical(DType): ...
class Enum(DType): ...


class Struct(DType): ...


class List(DType): ...


class Array(DType): ...


class Date(TemporalType): ...
6 changes: 6 additions & 0 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from narwhals import selectors
from narwhals.dataframe import DataFrame as NwDataFrame
from narwhals.dataframe import LazyFrame as NwLazyFrame
from narwhals.dtypes import Array
from narwhals.dtypes import Boolean
from narwhals.dtypes import Categorical
from narwhals.dtypes import Date
Expand All @@ -26,8 +27,10 @@
from narwhals.dtypes import Int16
from narwhals.dtypes import Int32
from narwhals.dtypes import Int64
from narwhals.dtypes import List
from narwhals.dtypes import Object
from narwhals.dtypes import String
from narwhals.dtypes import Struct
from narwhals.dtypes import UInt8
from narwhals.dtypes import UInt16
from narwhals.dtypes import UInt32
Expand Down Expand Up @@ -1970,6 +1973,9 @@ def from_dict(
"String",
"Datetime",
"Duration",
"Struct",
"Array",
"List",
"Date",
"narwhalify",
"show_versions",
Expand Down
Loading

0 comments on commit 5e49378

Please sign in to comment.