Skip to content

Commit

Permalink
added arrow
Browse files Browse the repository at this point in the history
  • Loading branch information
DeaMariaLeon committed Oct 13, 2024
1 parent eed4d7b commit 3f86f54
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
8 changes: 7 additions & 1 deletion narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
if pa.types.is_dictionary(dtype):
return dtypes.Categorical()
if pa.types.is_struct(dtype):
return dtypes.Struct()
return dtypes.Struct(
[
dtypes.Field(dtype.field(i).name, native_to_narwhals_dtype(dtype.field(i).type, dtypes))
for i in range(dtype.num_fields)
]
)

if pa.types.is_list(dtype) or pa.types.is_large_list(dtype):
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, dtypes))
if pa.types.is_fixed_size_list(dtype):
Expand Down
5 changes: 4 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ def native_to_narwhals_dtype(column: Any, dtypes: DTypes) -> DType:
column.dtype.pyarrow_dtype.list_size,
)
if dtype.startswith("struct"):
return dtypes.Struct()

return dtypes.Struct(

)
if dtype == "object":
if ( # pragma: no cover TODO(unassigned): why does this show as uncovered?
idx := getattr(column, "first_valid_index", lambda: None)()
Expand Down
9 changes: 4 additions & 5 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def __repr__(self) -> str:

class Struct(DType):
fields: list[Field]

def __init__(
self, fields: Sequence[Field] | Mapping[str, DType | type[DType]]
) -> None:
Expand All @@ -220,7 +219,7 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
# inner types to those without (eg: inner=None). if one of the
# arguments is not specific about its inner type we infer it
# as being equal. (See the List type for more info).
if type(other) and issubclass(other, self.__class__):
if type(other) is type and issubclass(other, self.__class__):
return True
elif isinstance(other, self.__class__):
return self.fields == other.fields
Expand All @@ -230,19 +229,19 @@ def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override]
def __hash__(self) -> int:
return hash((self.__class__, tuple(self.fields)))

def __iter__(self) -> Iterator[tuple[str, DType]]:
def __iter__(self) -> Iterator[tuple[str, DType | type[DType]]]:
for fld in self.fields:
yield fld.name, fld.dtype

def __reversed__(self) -> Iterator[tuple[str, DType]]:
def __reversed__(self) -> Iterator[tuple[str, DType | type[DType]]]:
for fld in reversed(self.fields):
yield fld.name, fld.dtype

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}({dict(self)})"

def to_schema(self) -> OrderedDict[str, DType]:
def to_schema(self) -> OrderedDict[str, DType | type[DType]]:
"""Return Struct dtype as a schema dict."""
return OrderedDict(self)

Expand Down
26 changes: 13 additions & 13 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,21 @@ def test_nested_dtypes() -> None:
df = pl.DataFrame(
{"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]},
schema_overrides={"b": pl.Array(pl.Int64, 2)},
) # .to_pandas(use_pyarrow_extension_array=True)
)#.to_pandas(use_pyarrow_extension_array=True)
nwdf = nw.from_native(df)
assert nwdf.schema == {"a": nw.List, "b": nw.Array, "c": nw.Struct}
# df = pl.DataFrame(
# {"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]},
# schema_overrides={"b": pl.Array(pl.Int64, 2)},
# )
# nwdf = nw.from_native(df)
# assert nwdf.schema == {"a": nw.List, "b": nw.Array(nw.Int64, 2), "c": nw.Struct}
# df = pl.DataFrame(
# {"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]},
# schema_overrides={"b": pl.Array(pl.Int64, 2)},
# ).to_arrow()
# nwdf = nw.from_native(df)
# assert nwdf.schema == {"a": nw.List, "b": nw.Array(nw.Int64, 2), "c": nw.Struct}
df = pl.DataFrame(
{"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]},
schema_overrides={"b": pl.Array(pl.Int64, 2)},
)
nwdf = nw.from_native(df)
assert nwdf.schema == {"a": nw.List, "b": nw.Array(nw.Int64, 2), "c": nw.Struct}
df = pl.DataFrame(
{"a": [[1, 2]], "b": [[1, 2]], "c": [{"a": 1}]},
schema_overrides={"b": pl.Array(pl.Int64, 2)},
).to_arrow()
nwdf = nw.from_native(df)
assert nwdf.schema == {"a": nw.List, "b": nw.Array(nw.Int64, 2), "c": nw.Struct}
# df = duckdb.sql("select * from df")
# nwdf = nw.from_native(df)
# assert nwdf.schema == {"a": nw.List, "b": nw.Array(nw.Int64, 2), "c": nw.Struct}
Expand Down

0 comments on commit 3f86f54

Please sign in to comment.