Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC, feat: infer datetime format for pyarrow backend #1195

Merged
merged 10 commits into from
Oct 29, 2024
4 changes: 2 additions & 2 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from narwhals._arrow.utils import floordiv_compat
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import parse_datetime_format
from narwhals._arrow.utils import validate_column_comparand
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
Expand Down Expand Up @@ -1115,8 +1116,7 @@ def to_datetime(self: Self, format: str | None) -> ArrowSeries: # noqa: A002
import pyarrow.compute as pc # ignore-banned-import()

if format is None:
msg = "`format` is required for pyarrow backend."
raise ValueError(msg)
format = parse_datetime_format(self._arrow_series._native_series)

return self._arrow_series._from_native_series(
pc.strptime(self._arrow_series._native_series, format=format, unit="us")
Expand Down
85 changes: 85 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,88 @@ def convert_str_slice_to_int_slice(
stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None
step = str_slice.step
return (start, stop, step)


# Regex for date, time, separator and timezone components
DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4})"
SEP_RE = r"(?P<sep>\s|T)"
TIME_RE = r"(?P<time>\d{2}:\d{2}:\d{2})" # \s*(?P<period>[AP]M)?)?
TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc.
FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$"

# Separate regexes for different date formats
YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"

DATE_FORMATS = (
(YMD_RE, "%Y-%m-%d"),
(DMY_RE, "%d-%m-%Y"),
(MDY_RE, "%m-%d-%Y"),
)


def parse_datetime_format(arr: pa.StringArray) -> str:
"""Try to infer datetime format from StringArray."""
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

matches = pa.concat_arrays( # converts from ChunkedArray to StructArray
pc.extract_regex(arr, pattern=FULL_RE).chunks
)

if not pc.all(matches.is_valid()).as_py():
msg = (
"Unable to infer datetime format, provided format is not supported. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise NotImplementedError(msg)

dates = matches.field("date")
separators = matches.field("sep")
times = matches.field("time")
tz = matches.field("tz")

# separators and time zones must be unique
if pc.count(pc.unique(separators)).as_py() > 1:
msg = "Found multiple separator values while inferring datetime format."
raise ValueError(msg)

if pc.count(pc.unique(tz)).as_py() > 1:
msg = "Found multiple timezone values while inferring datetime format."
raise ValueError(msg)

date_value = _parse_date_format(dates)
time_value = _parse_time_format(times)

sep_value = separators[0].as_py()
tz_value = "%z" if tz[0].as_py() else ""

return f"{date_value}{sep_value}{time_value}{tz_value}"


def _parse_date_format(arr: pa.Array) -> str:
import pyarrow.compute as pc # ignore-banned-import

for date_rgx, date_fmt in DATE_FORMATS:
matches = pc.extract_regex(arr, pattern=date_rgx)
if (
pc.all(matches.is_valid()).as_py()
and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
):
return date_fmt.replace("-", date_sep_value)

msg = (
"Unable to infer datetime format. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise ValueError(msg)


def _parse_time_format(arr: pa.Array) -> str:
import pyarrow.compute as pc # ignore-banned-import

matches = pc.extract_regex(arr, pattern=TIME_RE)
return "%H:%M:%S" if pc.all(matches.is_valid()).as_py() else ""
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
67 changes: 55 additions & 12 deletions tests/expr_and_series/str/to_datetime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from typing import TYPE_CHECKING

import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from narwhals._arrow.utils import parse_datetime_format

if TYPE_CHECKING:
from tests.utils import Constructor
from tests.utils import ConstructorEager

data = {"a": ["2020-01-01T12:34:56"]}


Expand Down Expand Up @@ -42,12 +45,7 @@ def test_to_datetime_series(constructor_eager: ConstructorEager) -> None:
assert str(result) == expected


def test_to_datetime_infer_fmt(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_to_datetime_infer_fmt(constructor: Constructor) -> None:
if "cudf" in str(constructor): # pragma: no cover
expected = "2020-01-01T12:34:56.000000000"
else:
Expand All @@ -63,12 +61,7 @@ def test_to_datetime_infer_fmt(
assert str(result) == expected


def test_to_datetime_series_infer_fmt(
request: pytest.FixtureRequest, constructor_eager: ConstructorEager
) -> None:
if "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)

def test_to_datetime_series_infer_fmt(constructor_eager: ConstructorEager) -> None:
if "cudf" in str(constructor_eager): # pragma: no cover
expected = "2020-01-01T12:34:56.000000000"
else:
Expand All @@ -78,3 +71,53 @@ def test_to_datetime_series_infer_fmt(
nw.from_native(constructor_eager(data), eager_only=True)["a"].str.to_datetime()
).item(0)
assert str(result) == expected


def test_to_datetime_infer_fmt_from_date(constructor: Constructor) -> None:
data = {"z": ["2020-01-01", "2020-01-02"]}
if "cudf" in str(constructor): # pragma: no cover
expected = "2020-01-01T00:00:00.000000000"
else:
expected = "2020-01-01 00:00:00"
result = (
nw.from_native(constructor(data))
.lazy()
.select(y=nw.col("z").str.to_datetime())
.collect()
.item(row=0, column="y")
)
assert str(result) == expected


@pytest.mark.parametrize("data", [["2024-01-01", "abc"], ["2024-01-01", None]])
def test_pyarrow_infer_datetime_raise_invalid(data: list[str | None]) -> None:
with pytest.raises(
NotImplementedError,
match="Unable to infer datetime format, provided format is not supported.",
):
parse_datetime_format(pa.chunked_array([data]))


@pytest.mark.parametrize(
("data", "duplicate"),
[
(["2024-01-01T00:00:00", "2024-01-01 01:00:00"], "separator"),
(["2024-01-01 00:00:00+01:00", "2024-01-01 01:00:00+02:00"], "timezone"),
],
)
def test_pyarrow_infer_datetime_raise_not_unique(
data: list[str | None], duplicate: str
) -> None:
with pytest.raises(
ValueError,
match=f"Found multiple {duplicate} values while inferring datetime format.",
):
parse_datetime_format(pa.chunked_array([data]))


@pytest.mark.parametrize("data", [["2024-01-01", "2024-12-01", "02-02-2024"]])
def test_pyarrow_infer_datetime_raise_inconsistent_date_fmt(
data: list[str | None],
) -> None:
with pytest.raises(ValueError, match="Unable to infer datetime format. "):
parse_datetime_format(pa.chunked_array([data]))
Loading