Skip to content

Commit

Permalink
ENH: Make get_dummies return ea booleans for ea inputs (pandas-dev#56291
Browse files Browse the repository at this point in the history
)

* ENH: Make get_dummies return ea booleans for ea inputs

* ENH: Make get_dummies return ea booleans for ea inputs

* Update

* Update pandas/tests/reshape/test_get_dummies.py

Co-authored-by: Thomas Baumann <thbaumann90@gmail.com>

* Update test_get_dummies.py

* Update test_get_dummies.py

* Fixup

---------

Co-authored-by: Thomas Baumann <thbaumann90@gmail.com>
  • Loading branch information
phofl and lopof committed Dec 10, 2023
1 parent 8aa7a96 commit 9b51ab2
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Expand Up @@ -218,6 +218,7 @@ Other enhancements

- :meth:`~DataFrame.to_sql` with method parameter set to ``multi`` works with Oracle on the backend
- :attr:`Series.attrs` / :attr:`DataFrame.attrs` now uses a deepcopy for propagating ``attrs`` (:issue:`54134`).
- :func:`get_dummies` now returning extension dtypes ``boolean`` or ``bool[pyarrow]`` that are compatible with the input dtype (:issue:`56273`)
- :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`)
- :func:`read_sas` returns ``datetime64`` dtypes with resolutions better matching those stored natively in SAS, and avoids returning object-dtype in cases that cannot be stored with ``datetime64[ns]`` dtype (:issue:`56127`)
- :func:`read_spss` now returns a :class:`DataFrame` that stores the metadata in :attr:`DataFrame.attrs`. (:issue:`54264`)
Expand Down
24 changes: 23 additions & 1 deletion pandas/core/reshape/encoding.py
Expand Up @@ -21,9 +21,14 @@
is_object_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
CategoricalDtype,
)

from pandas.core.arrays import SparseArray
from pandas.core.arrays.categorical import factorize_from_iterable
from pandas.core.arrays.string_ import StringDtype
from pandas.core.frame import DataFrame
from pandas.core.indexes.api import (
Index,
Expand Down Expand Up @@ -244,8 +249,25 @@ def _get_dummies_1d(
# Series avoids inconsistent NaN handling
codes, levels = factorize_from_iterable(Series(data, copy=False))

if dtype is None:
if dtype is None and hasattr(data, "dtype"):
input_dtype = data.dtype
if isinstance(input_dtype, CategoricalDtype):
input_dtype = input_dtype.categories.dtype

if isinstance(input_dtype, ArrowDtype):
import pyarrow as pa

dtype = ArrowDtype(pa.bool_()) # type: ignore[assignment]
elif (
isinstance(input_dtype, StringDtype)
and input_dtype.storage != "pyarrow_numpy"
):
dtype = pandas_dtype("boolean") # type: ignore[assignment]
else:
dtype = np.dtype(bool)
elif dtype is None:
dtype = np.dtype(bool)

_dtype = pandas_dtype(dtype)

if is_object_dtype(_dtype):
Expand Down
45 changes: 45 additions & 0 deletions pandas/tests/reshape/test_get_dummies.py
Expand Up @@ -4,13 +4,18 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas.core.dtypes.common import is_integer_dtype

import pandas as pd
from pandas import (
ArrowDtype,
Categorical,
CategoricalDtype,
CategoricalIndex,
DataFrame,
Index,
RangeIndex,
Series,
SparseDtype,
Expand All @@ -19,6 +24,11 @@
import pandas._testing as tm
from pandas.core.arrays.sparse import SparseArray

try:
import pyarrow as pa
except ImportError:
pa = None


class TestGetDummies:
@pytest.fixture
Expand Down Expand Up @@ -217,6 +227,7 @@ def test_dataframe_dummies_string_dtype(self, df):
},
dtype=bool,
)
expected[["B_b", "B_c"]] = expected[["B_b", "B_c"]].astype("boolean")
tm.assert_frame_equal(result, expected)

def test_dataframe_dummies_mix_default(self, df, sparse, dtype):
Expand Down Expand Up @@ -693,3 +704,37 @@ def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype):
dtype=any_numeric_ea_and_arrow_dtype,
)
tm.assert_frame_equal(result, expected)

@td.skip_if_no("pyarrow")
def test_get_dummies_ea_dtype(self):
# GH#56273
for dtype, exp_dtype in [
("string[pyarrow]", "boolean"),
("string[pyarrow_numpy]", "bool"),
(CategoricalDtype(Index(["a"], dtype="string[pyarrow]")), "boolean"),
(CategoricalDtype(Index(["a"], dtype="string[pyarrow_numpy]")), "bool"),
]:
df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1})
result = get_dummies(df)
expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)})
tm.assert_frame_equal(result, expected)

@td.skip_if_no("pyarrow")
def test_get_dummies_arrow_dtype(self):
# GH#56273
df = DataFrame({"name": Series(["a"], dtype=ArrowDtype(pa.string())), "x": 1})
result = get_dummies(df)
expected = DataFrame({"x": 1, "name_a": Series([True], dtype="bool[pyarrow]")})
tm.assert_frame_equal(result, expected)

df = DataFrame(
{
"name": Series(
["a"],
dtype=CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))),
),
"x": 1,
}
)
result = get_dummies(df)
tm.assert_frame_equal(result, expected)

0 comments on commit 9b51ab2

Please sign in to comment.