Skip to content

Commit

Permalink
GH-43683: [Python] Use pandas StringDtype when enabled (pandas 3+) (#…
Browse files Browse the repository at this point in the history
…44195)

### Rationale for this change

With pandas' [PDEP-14](https://pandas.pydata.org/pdeps/0014-string-dtype.html) proposal, pandas is planning to introduce a default string dtype in pandas 3.0 (instead of the current object dtype).

This will become the default in pandas 3.0, and can be enabled with an option in the upcoming pandas 2.3 (`pd.options.future.infer_string = True`). To prepare for that, we should start using that string dtype in `to_pandas()` conversions when that option is enabled.

### What changes are included in this PR?

- If pandas >= 3.0 is used or the pandas option is enabled, ensure that `to_pandas()` calls use the default string dtype of pandas for string-like columns (string, large_string, string_view)

### Are these changes tested?

It is tested in the pandas-nightly crossbow build.

There is still one failure that is because of a bug on the pandas side (pandas-dev/pandas#59879)

### Are there any user-facing changes?

**This PR includes breaking changes to public APIs.** Depending on the version of pandas, `to_pandas()` will change to use pandas' string dtype instead of object dtype. This is a breaking user-facing change, but essentially just following the equivalent change in default dtype on the pandas side.

* GitHub Issue: #43683

Lead-authored-by: Joris Van den Bossche <[email protected]>
Co-authored-by: Raúl Cumplido <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
2 people authored and amoeba committed Jan 11, 2025
1 parent fde843a commit fa0b2d7
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 35 deletions.
6 changes: 6 additions & 0 deletions dev/tasks/tasks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,12 @@ tasks:
# ensure we have at least one build with parquet encryption disabled
PARQUET_REQUIRE_ENCRYPTION: "OFF"
{% endif %}
{% if pandas_version == "nightly" %}
# TODO can be removed once this is enabled by default in pandas >= 3
# This is to enable the Pandas feature.
# See: https://github.com/pandas-dev/pandas/pull/58459
PANDAS_FUTURE_INFER_STRING: "1"
{% endif %}
{% if not cache_leaf %}
# use the latest pandas release, so prevent reusing any cached layers
flags: --no-leaf-cache
Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,7 @@ services:
PYTEST_ARGS: # inherit
HYPOTHESIS_PROFILE: # inherit
PYARROW_TEST_HYPOTHESIS: # inherit
PANDAS_FUTURE_INFER_STRING: # inherit
volumes: *conda-volumes
command: *python-conda-command

Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def _handle_arrow_array_protocol(obj, type, mask, size):
"return a pyarrow Array or ChunkedArray.")
if isinstance(res, ChunkedArray) and res.num_chunks==1:
res = res.chunk(0)
if type is not None and res.type != type:
res = res.cast(type)
return res


Expand Down
17 changes: 16 additions & 1 deletion python/pyarrow/pandas-shim.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ cdef class _PandasAPIShim(object):
object _array_like_types, _is_extension_array_dtype, _lock
bint has_sparse
bint _pd024
bint _is_v1, _is_ge_v21, _is_ge_v3
bint _is_v1, _is_ge_v21, _is_ge_v3, _is_ge_v3_strict

def __init__(self):
self._lock = Lock()
Expand Down Expand Up @@ -80,6 +80,7 @@ cdef class _PandasAPIShim(object):
self._is_v1 = self._loose_version < Version('2.0.0')
self._is_ge_v21 = self._loose_version >= Version('2.1.0')
self._is_ge_v3 = self._loose_version >= Version('3.0.0.dev0')
self._is_ge_v3_strict = self._loose_version >= Version('3.0.0')

self._compat_module = pdcompat
self._data_frame = pd.DataFrame
Expand Down Expand Up @@ -174,6 +175,20 @@ cdef class _PandasAPIShim(object):
self._check_import()
return self._is_ge_v3

def is_ge_v3_strict(self):
self._check_import()
return self._is_ge_v3_strict

def uses_string_dtype(self):
if self.is_ge_v3_strict():
return True
try:
if self.pd.options.future.infer_string:
return True
except:
pass
return False

@property
def categorical_type(self):
self._check_import()
Expand Down
62 changes: 54 additions & 8 deletions python/pyarrow/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def get_column_metadata(column, name, arrow_type, field_name):
}
string_dtype = 'object'

if name is not None and not isinstance(name, str):
if (
name is not None
and not (isinstance(name, float) and np.isnan(name))
and not isinstance(name, str)
):
raise TypeError(
'Column name must be a string. Got column {} of type {}'.format(
name, type(name).__name__
Expand Down Expand Up @@ -340,8 +344,8 @@ def _column_name_to_strings(name):
return str(tuple(map(_column_name_to_strings, name)))
elif isinstance(name, Sequence):
raise TypeError("Unsupported type for MultiIndex level")
elif name is None:
return None
elif name is None or (isinstance(name, float) and np.isnan(name)):
return name
return str(name)


Expand Down Expand Up @@ -790,10 +794,12 @@ def table_to_dataframe(
table, index = _reconstruct_index(table, index_descriptors,
all_columns, types_mapper)
ext_columns_dtypes = _get_extension_dtypes(
table, all_columns, types_mapper)
table, all_columns, types_mapper, options, categories)
else:
index = _pandas_api.pd.RangeIndex(table.num_rows)
ext_columns_dtypes = _get_extension_dtypes(table, [], types_mapper)
ext_columns_dtypes = _get_extension_dtypes(
table, [], types_mapper, options, categories
)

_check_data_column_metadata_consistency(all_columns)
columns = _deserialize_column_index(table, all_columns, column_indexes)
Expand Down Expand Up @@ -838,7 +844,7 @@ def table_to_dataframe(
}


def _get_extension_dtypes(table, columns_metadata, types_mapper=None):
def _get_extension_dtypes(table, columns_metadata, types_mapper, options, categories):
"""
Based on the stored column pandas metadata and the extension types
in the arrow schema, infer which columns should be converted to a
Expand All @@ -851,6 +857,9 @@ def _get_extension_dtypes(table, columns_metadata, types_mapper=None):
and then we can check if this dtype supports conversion from arrow.
"""
strings_to_categorical = options["strings_to_categorical"]
categories = categories or []

ext_columns = {}

# older pandas version that does not yet support extension dtypes
Expand Down Expand Up @@ -889,9 +898,32 @@ def _get_extension_dtypes(table, columns_metadata, types_mapper=None):
# that are certainly numpy dtypes
pandas_dtype = _pandas_api.pandas_dtype(dtype)
if isinstance(pandas_dtype, _pandas_api.extension_dtype):
if isinstance(pandas_dtype, _pandas_api.pd.StringDtype):
# when the metadata indicate to use the string dtype,
# ignore this in case:
# - it is specified to convert strings / this column to categorical
# - the column itself is dictionary encoded and would otherwise be
# converted to categorical
if strings_to_categorical or name in categories:
continue
try:
if pa.types.is_dictionary(table.schema.field(name).type):
continue
except KeyError:
pass
if hasattr(pandas_dtype, "__from_arrow__"):
ext_columns[name] = pandas_dtype

# for pandas 3.0+, use pandas' new default string dtype
if _pandas_api.uses_string_dtype() and not strings_to_categorical:
for field in table.schema:
if field.name not in ext_columns and (
pa.types.is_string(field.type)
or pa.types.is_large_string(field.type)
or pa.types.is_string_view(field.type)
) and field.name not in categories:
ext_columns[field.name] = _pandas_api.pd.StringDtype(na_value=np.nan)

return ext_columns


Expand Down Expand Up @@ -1049,9 +1081,9 @@ def get_pandas_logical_type_map():
'date': 'datetime64[D]',
'datetime': 'datetime64[ns]',
'datetimetz': 'datetime64[ns]',
'unicode': np.str_,
'unicode': 'str',
'bytes': np.bytes_,
'string': np.str_,
'string': 'str',
'integer': np.int64,
'floating': np.float64,
'decimal': np.object_,
Expand Down Expand Up @@ -1142,6 +1174,20 @@ def _reconstruct_columns_from_metadata(columns, column_indexes):
# GH-41503: if the column index was decimal, restore to decimal
elif pandas_dtype == "decimal":
level = _pandas_api.pd.Index([decimal.Decimal(i) for i in level])
elif (
level.dtype == "str" and numpy_dtype == "object"
and ("mixed" in pandas_dtype or pandas_dtype in ["unicode", "string"])
):
# the metadata indicate that the original dataframe used object dtype,
# but ignore this and keep string dtype if:
# - the original columns used mixed types -> we don't attempt to faithfully
# roundtrip in this case, but keep the column names as strings
# - the original columns were inferred to be strings but stored in object
# dtype -> we don't restore the object dtype because all metadata
# generated using pandas < 3 will have this case by default, and
# for pandas >= 3 we want to use the default string dtype for .columns
new_levels.append(level)
continue
elif level.dtype != dtype:
level = level.astype(dtype)
# ARROW-9096: if original DataFrame was upcast we keep that
Expand Down
19 changes: 10 additions & 9 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def test_replace_slice():
offsets = range(-3, 4)

arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd', 'abcde'])
series = arr.to_pandas()
series = arr.to_pandas().astype(object).replace({np.nan: None})
for start in offsets:
for stop in offsets:
expected = series.str.slice_replace(start, stop, 'XX')
Expand All @@ -1031,7 +1031,7 @@ def test_replace_slice():
assert pc.binary_replace_slice(arr, start, stop, 'XX') == actual

arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd', 'πbθde'])
series = arr.to_pandas()
series = arr.to_pandas().astype(object).replace({np.nan: None})
for start in offsets:
for stop in offsets:
expected = series.str.slice_replace(start, stop, 'XX')
Expand Down Expand Up @@ -2132,50 +2132,51 @@ def test_strftime():
for fmt in formats:
options = pc.StrftimeOptions(fmt)
result = pc.strftime(tsa, options=options)
expected = pa.array(ts.strftime(fmt))
# cast to the same type as result to ignore string vs large_string
expected = pa.array(ts.strftime(fmt)).cast(result.type)
assert result.equals(expected)

fmt = "%Y-%m-%dT%H:%M:%S"

# Default format
tsa = pa.array(ts, type=pa.timestamp("s", timezone))
result = pc.strftime(tsa, options=pc.StrftimeOptions())
expected = pa.array(ts.strftime(fmt))
expected = pa.array(ts.strftime(fmt)).cast(result.type)
assert result.equals(expected)

# Default format plus timezone
tsa = pa.array(ts, type=pa.timestamp("s", timezone))
result = pc.strftime(tsa, options=pc.StrftimeOptions(fmt + "%Z"))
expected = pa.array(ts.strftime(fmt + "%Z"))
expected = pa.array(ts.strftime(fmt + "%Z")).cast(result.type)
assert result.equals(expected)

# Pandas %S is equivalent to %S in arrow for unit="s"
tsa = pa.array(ts, type=pa.timestamp("s", timezone))
options = pc.StrftimeOptions("%S")
result = pc.strftime(tsa, options=options)
expected = pa.array(ts.strftime("%S"))
expected = pa.array(ts.strftime("%S")).cast(result.type)
assert result.equals(expected)

# Pandas %S.%f is equivalent to %S in arrow for unit="us"
tsa = pa.array(ts, type=pa.timestamp("us", timezone))
options = pc.StrftimeOptions("%S")
result = pc.strftime(tsa, options=options)
expected = pa.array(ts.strftime("%S.%f"))
expected = pa.array(ts.strftime("%S.%f")).cast(result.type)
assert result.equals(expected)

# Test setting locale
tsa = pa.array(ts, type=pa.timestamp("s", timezone))
options = pc.StrftimeOptions(fmt, locale="C")
result = pc.strftime(tsa, options=options)
expected = pa.array(ts.strftime(fmt))
expected = pa.array(ts.strftime(fmt)).cast(result.type)
assert result.equals(expected)

# Test timestamps without timezone
fmt = "%Y-%m-%dT%H:%M:%S"
ts = pd.to_datetime(times)
tsa = pa.array(ts, type=pa.timestamp("s"))
result = pc.strftime(tsa, options=pc.StrftimeOptions(fmt))
expected = pa.array(ts.strftime(fmt))
expected = pa.array(ts.strftime(fmt)).cast(result.type)

# Positional format
assert pc.strftime(tsa, fmt) == result
Expand Down
6 changes: 5 additions & 1 deletion python/pyarrow/tests/test_feather.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,11 @@ def test_empty_strings(version):
@pytest.mark.pandas
def test_all_none(version):
df = pd.DataFrame({'all_none': [None] * 10})
_check_pandas_roundtrip(df, version=version)
if version == 1 and pa.pandas_compat._pandas_api.uses_string_dtype():
expected = df.astype("str")
else:
expected = df
_check_pandas_roundtrip(df, version=version, expected=expected)


@pytest.mark.pandas
Expand Down
Loading

0 comments on commit fa0b2d7

Please sign in to comment.