Skip to content

Commit

Permalink
feat: allow inspecting the inner type / length of nw.Array (narwhals-…
Browse files Browse the repository at this point in the history
…dev#1136)


---------

Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
DeaMariaLeon and MarcoGorelli authored Oct 6, 2024
1 parent 5aa4e12 commit 3e0405d
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 9 deletions.
4 changes: 3 additions & 1 deletion narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
if pa.types.is_list(dtype) or pa.types.is_large_list(dtype):
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, dtypes))
if pa.types.is_fixed_size_list(dtype):
return dtypes.Array()
return dtypes.Array(
native_to_narwhals_dtype(dtype.value_type, dtypes), dtype.list_size
)
return dtypes.Unknown() # pragma: no cover


Expand Down
7 changes: 5 additions & 2 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ def map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype: Any, dtypes: DTypes) -> DTy
return dtypes.Struct()
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype):
return dtypes.List(map_duckdb_dtype_to_narwhals_dtype(match_.group(1), dtypes))
if re.match(r"\w+\[\d+\]", duckdb_dtype):
return dtypes.Array()
if match_ := re.match(r"(\w+)\[(\d+)\]", duckdb_dtype):
return dtypes.Array(
map_duckdb_dtype_to_narwhals_dtype(match_.group(1), dtypes),
int(match_.group(2)),
)
return dtypes.Unknown()


Expand Down
5 changes: 4 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ def native_to_narwhals_dtype(column: Any, dtypes: DTypes) -> DType:
arrow_native_to_narwhals_dtype(column.dtype.pyarrow_dtype.value_type, dtypes)
)
if dtype.startswith("fixed_size_list"):
return dtypes.Array()
return dtypes.Array(
arrow_native_to_narwhals_dtype(column.dtype.pyarrow_dtype.value_type, dtypes),
column.dtype.pyarrow_dtype.list_size,
)
if dtype.startswith("struct"):
return dtypes.Struct()
if dtype == "object":
Expand Down
9 changes: 8 additions & 1 deletion narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from narwhals.dtypes import DType
from narwhals.typing import DTypes

from narwhals.utils import parse_version


def extract_native(obj: Any) -> Any:
from narwhals._polars.dataframe import PolarsDataFrame
Expand Down Expand Up @@ -77,7 +79,12 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
if dtype == pl.List:
return dtypes.List(native_to_narwhals_dtype(dtype.inner, dtypes))
if dtype == pl.Array:
return dtypes.Array()
if parse_version(pl.__version__) < (1, 0): # pragma: no cover
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, dtypes), dtype.width
)
else:
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, dtypes), dtype.size)
return dtypes.Unknown()


Expand Down
30 changes: 29 additions & 1 deletion narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,35 @@ def __repr__(self) -> str:
return f"{class_name}({self.inner!r})"


class Array(DType): ...
class Array(DType):
def __init__(self, inner: DType | type[DType], width: int | None = None) -> None:
self.inner = inner
if width is None:
error = "`width` must be specified when initializing an `Array`"
raise TypeError(error)
self.width = width

def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
# This equality check allows comparison of type classes and type instances.
# If a parent type is not specific about its inner type, we infer it as equal:
# > array[i64] == array[i64] -> True
# > array[i64] == array[f32] -> False
# > array[i64] == array -> True

# allow comparing object instances to class
if type(other) is type and issubclass(other, self.__class__):
return True
elif isinstance(other, self.__class__):
return self.inner == other.inner
else:
return False

def __hash__(self) -> int:
return hash((self.__class__, self.inner, self.width))

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}({self.inner!r}, {self.width})"


class Date(TemporalType): ...
37 changes: 37 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

Expand Down Expand Up @@ -67,6 +68,42 @@ def test_list_valid() -> None:
assert dtype in {nw.List(nw.List(nw.Int64))}


def test_array_valid() -> None:
dtype = nw.Array(nw.Int64, 2)
assert dtype == nw.Array(nw.Int64, 2)
assert dtype == nw.Array
assert dtype != nw.Array(nw.Float32, 2)
assert dtype != nw.Duration
assert repr(dtype) == "Array(<class 'narwhals.dtypes.Int64'>, 2)"
dtype = nw.Array(nw.Array(nw.Int64, 2), 2)
assert dtype == nw.Array(nw.Array(nw.Int64, 2), 2)
assert dtype == nw.Array
assert dtype != nw.Array(nw.Array(nw.Float32, 2), 2)
assert dtype in {nw.Array(nw.Array(nw.Int64, 2), 2)}

with pytest.raises(
TypeError, match="`width` must be specified when initializing an `Array`"
):
dtype = nw.Array(nw.Int64)


@pytest.mark.skipif(
parse_version(pl.__version__) < (1,) or parse_version(pd.__version__) < (2, 2),
reason="`shape` is only available after 1.0",
)
def test_polars_2d_array() -> None:
df = pl.DataFrame(
{"a": [[[1, 2], [3, 4], [5, 6]]]}, schema={"a": pl.Array(pl.Int64, (3, 2))}
)
assert nw.from_native(df).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64, 2), 3)
assert nw.from_native(df.to_arrow()).collect_schema()["a"] == nw.Array(
nw.Array(nw.Int64, 2), 3
)
assert nw.from_native(
df.to_pandas(use_pyarrow_extension_array=True)
).collect_schema()["a"] == nw.Array(nw.Array(nw.Int64, 2), 3)


def test_second_time_unit() -> None:
s = pd.Series(np.array([np.datetime64("2020-01-01", "s")]))
result = nw.from_native(s, series_only=True)
Expand Down
7 changes: 4 additions & 3 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,23 @@ def test_nested_dtypes() -> None:
schema_overrides={"b": pl.Array(pl.Int64, 2)},
).to_pandas(use_pyarrow_extension_array=True)
nwdf = nw.from_native(df)

assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct}
df = pl.DataFrame(
{"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]},
schema_overrides={"b": pl.Array(pl.Int64, 2)},
)
nwdf = nw.from_native(df)
assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct}
assert nwdf.schema == {"a": nw.List, "b": nw.Array(nw.Int64, 2), "c": nw.Struct}
df = pl.DataFrame(
{"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]},
schema_overrides={"b": pl.Array(pl.Int64, 2)},
).to_arrow()
nwdf = nw.from_native(df)
assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct}
assert nwdf.schema == {"a": nw.List, "b": nw.Array(nw.Int64, 2), "c": nw.Struct}
df = duckdb.sql("select * from df")
nwdf = nw.from_native(df)
assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct}
assert nwdf.schema == {"a": nw.List, "b": nw.Array(nw.Int64, 2), "c": nw.Struct}


def test_nested_dtypes_ibis() -> None: # pragma: no cover
Expand Down

0 comments on commit 3e0405d

Please sign in to comment.