diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 27c1425d11ac6..327cb042e4342 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -397,19 +397,19 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): if dtype is None: - dtype = np.int64 + dtype = np.bool_ dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies( sep, dtype ) - if len(labels) == 0: - return np.empty(shape=(0, 0), dtype=dtype), labels - dummies = np.vstack(dummies_pa.to_numpy()) _dtype = pandas_dtype(dtype) dummies_dtype: NpDtype if isinstance(_dtype, np.dtype): dummies_dtype = _dtype else: dummies_dtype = np.bool_ + if len(labels) == 0: + return np.empty(shape=(0, 0), dtype=dummies_dtype), labels + dummies = np.vstack(dummies_pa.to_numpy()) return dummies.astype(dummies_dtype, copy=False), labels def _convert_int_result(self, result): diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index e5b434edacc59..97055c583b8f2 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -2489,7 +2489,7 @@ def get_dummies( ---------- sep : str, default "|" String to split on. - dtype : dtype, default np.int64 + dtype : dtype, default bool Data type for new columns. Only a single dtype is allowed. Returns @@ -2505,27 +2505,48 @@ def get_dummies( Examples -------- >>> pd.Series(["a|b", "a", "a|c"]).str.get_dummies() - a b c - 0 1 1 0 - 1 1 0 0 - 2 1 0 1 + a b c + 0 True True False + 1 True False False + 2 True False True >>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies() + a b c + 0 True True False + 1 False False False + 2 True False True + + >>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=np.int64) a b c 0 1 1 0 1 0 0 0 2 1 0 1 - - >>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=bool) - a b c - 0 True True False - 1 False False False - 2 True False True """ from pandas.core.frame import DataFrame # we need to cast to Series of strings as only that has all # methods available for making the dummies... + input_dtype = self._data.dtype + if dtype is None and not isinstance(input_dtype, ArrowDtype): + from pandas.core.arrays.string_ import StringDtype + + if isinstance(input_dtype, CategoricalDtype): + input_dtype = input_dtype.categories.dtype + + if isinstance(input_dtype, ArrowDtype): + import pyarrow as pa + + dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment] + elif ( + isinstance(input_dtype, StringDtype) + and input_dtype.na_value is not np.nan + ): + from pandas.core.dtypes.common import pandas_dtype + + dtype = pandas_dtype("boolean") # type: ignore[assignment] + else: + dtype = np.bool_ + result, name = self._data.array._str_get_dummies(sep, dtype) if is_extension_array_dtype(dtype) or isinstance(dtype, ArrowDtype): return self._wrap_result( diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 0268194e64d50..595bfac5f229c 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -416,7 +416,7 @@ def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): from pandas import Series if dtype is None: - dtype = np.int64 + dtype = np.bool_ arr = Series(self).fillna("") try: arr = sep + arr + sep diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 6dd1f3f15bc15..ab68fa74ae218 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2206,6 +2206,16 @@ def test_get_dummies(): ) tm.assert_frame_equal(result, expected) + ser = pd.Series( + ["a", "b"], + dtype=pd.CategoricalDtype(pd.Index(["a", "b"], dtype=ArrowDtype(pa.string()))), + ) + result = ser.str.get_dummies() + expected = pd.DataFrame( + [[True, False], [False, True]], dtype=ArrowDtype(pa.bool_()), columns=["a", "b"] + ) + tm.assert_frame_equal(result, expected) + def test_str_partition(): ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) diff --git a/pandas/tests/strings/test_get_dummies.py b/pandas/tests/strings/test_get_dummies.py index 3b989e284ca25..04805d14616eb 100644 --- a/pandas/tests/strings/test_get_dummies.py +++ b/pandas/tests/strings/test_get_dummies.py @@ -6,6 +6,8 @@ import pandas.util._test_decorators as td from pandas import ( + NA, + CategoricalDtype, DataFrame, Index, MultiIndex, @@ -22,19 +24,28 @@ def test_get_dummies(any_string_dtype): s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) result = s.str.get_dummies("|") - expected = DataFrame([[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc")) + exp_dtype = ( + "boolean" + if any_string_dtype == "string" and any_string_dtype.na_value is NA + else "bool" + ) + expected = DataFrame( + [[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=exp_dtype + ) tm.assert_frame_equal(result, expected) s = Series(["a;b", "a", 7], dtype=any_string_dtype) result = s.str.get_dummies(";") - expected = DataFrame([[0, 1, 1], [0, 1, 0], [1, 0, 0]], columns=list("7ab")) + expected = DataFrame( + [[0, 1, 1], [0, 1, 0], [1, 0, 0]], columns=list("7ab"), dtype=exp_dtype + ) tm.assert_frame_equal(result, expected) def test_get_dummies_index(): # GH9980, GH8028 idx = Index(["a|b", "a|c", "b|c"]) - result = idx.str.get_dummies("|") + result = idx.str.get_dummies("|", dtype=np.int64) expected = MultiIndex.from_tuples( [(1, 1, 0), (1, 0, 1), (0, 1, 1)], names=("a", "b", "c") @@ -125,3 +136,15 @@ def test_get_dummies_with_pa_str_dtype(any_string_dtype): dtype="str[pyarrow]", ) tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype_type", ["string", "category"]) +def test_get_dummies_ea_dtype(dtype_type, string_dtype_no_object): + dtype = string_dtype_no_object + exp_dtype = "boolean" if dtype.na_value is NA else "bool" + if dtype_type == "category": + dtype = CategoricalDtype(Index(["a", "b"], dtype)) + s = Series(["a", "b"], dtype=dtype) + result = s.str.get_dummies() + expected = DataFrame([[1, 0], [0, 1]], columns=list("ab"), dtype=exp_dtype) + tm.assert_frame_equal(result, expected)