diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 98d8c6914..31eb3ef56 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -1,8 +1,12 @@ from __future__ import annotations +from collections import OrderedDict from datetime import timezone from typing import TYPE_CHECKING +from typing import Iterator from typing import Literal +from typing import Mapping +from typing import Sequence if TYPE_CHECKING: from typing_extensions import Self @@ -170,7 +174,75 @@ class Categorical(DType): ... class Enum(DType): ... -class Struct(DType): ... +class Field: + """ + Definition of a single field within a `Struct` DataType. + + Parameters + ---------- + name + The name of the field within its parent `Struct`. + dtype + The `DataType` of the field's values. + """ + + name: str + dtype: DType + + def __init__(self, name: str, dtype: DType) -> None: + self.name = name + self.dtype = dtype + + def __eq__(self, other: Field) -> bool: # type: ignore[override] + return (self.name == other.name) & (self.dtype == other.dtype) + + def __hash__(self) -> int: + return hash((self.name, self.dtype)) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}({self.name!r}, {self.dtype})" + + +class Struct(DType): + fields: list[Field] + + def __init__(self, fields: Sequence[Field] | Mapping[str, DType]) -> None: + if isinstance(fields, Mapping): + self.fields = [Field(name, dtype) for name, dtype in fields.items()] + else: + self.fields = list(fields) + + def __eq__(self, other: DType | type[DType]) -> bool: # type: ignore[override] + # The comparison allows comparing objects to classes, and specific + # inner types to those without (eg: inner=None). if one of the + # arguments is not specific about its inner type we infer it + # as being equal. (See the List type for more info). + if type(other) and issubclass(other, self.__class__): + return True + elif isinstance(other, self.__class__): + return self.fields == other.fields + else: + return False + + def __hash__(self) -> int: + return hash((self.__class__, tuple(self.fields))) + + def __iter__(self) -> Iterator[tuple[str, DType]]: + for fld in self.fields: + yield fld.name, fld.dtype + + def __reversed__(self) -> Iterator[tuple[str, DType]]: + for fld in reversed(self.fields): + yield fld.name, fld.dtype + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}({dict(self)})" + + def to_schema(self) -> OrderedDict[str, DType]: + """Return Struct dtype as a schema dict.""" + return OrderedDict(self) class List(DType): diff --git a/narwhals/stable/v1/typing.py b/narwhals/stable/v1/typing.py index e8ab9e1ae..aebe78fc7 100644 --- a/narwhals/stable/v1/typing.py +++ b/narwhals/stable/v1/typing.py @@ -73,6 +73,7 @@ class DTypes: Datetime: type[dtypes.Datetime] Duration: type[dtypes.Duration] Date: type[dtypes.Date] + Field: type[dtypes.Field] Struct: type[dtypes.Struct] List: type[dtypes.List] Array: type[dtypes.Array] diff --git a/narwhals/typing.py b/narwhals/typing.py index 30de0a097..8fcbc697c 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -73,6 +73,7 @@ class DTypes: Datetime: type[dtypes.Datetime] Duration: type[dtypes.Duration] Date: type[dtypes.Date] + Field: type[dtypes.Field] Struct: type[dtypes.Struct] List: type[dtypes.List] Array: type[dtypes.Array]