Skip to content

Commit

Permalink
preserve index in list accessor (pandas-dev#58438)
Browse files Browse the repository at this point in the history
* preserve index in list accessor

* gh reference

* explode fix

* cleanup

* improve test

* Update v3.0.0.rst

Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>

* f

---------

Co-authored-by: Rohan Jain <rohanjain@microsoft.com>
Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>
  • Loading branch information
3 people authored and pmhatre1 committed May 7, 2024
1 parent c5b69b9 commit 6199cbd
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Expand Up @@ -471,6 +471,7 @@ Other
- Bug in :meth:`DataFrame.where` where using a non-bool type array in the function would return a ``ValueError`` instead of a ``TypeError`` (:issue:`56330`)
- Bug in :meth:`Index.sort_values` when passing a key function that turns values into tuples, e.g. ``key=natsort.natsort_key``, would raise ``TypeError`` (:issue:`56081`)
- Bug in Dataframe Interchange Protocol implementation was returning incorrect results for data buffers' associated dtype, for string and datetime columns (:issue:`54781`)
- Bug in ``Series.list`` methods not preserving the original :class:`Index`. (:issue:`58425`)

.. ***DO NOT USE THIS SECTION***
Expand Down
22 changes: 14 additions & 8 deletions pandas/core/arrays/arrow/accessors.py
Expand Up @@ -110,7 +110,9 @@ def len(self) -> Series:
from pandas import Series

value_lengths = pc.list_value_length(self._pa_array)
return Series(value_lengths, dtype=ArrowDtype(value_lengths.type))
return Series(
value_lengths, dtype=ArrowDtype(value_lengths.type), index=self._data.index
)

def __getitem__(self, key: int | slice) -> Series:
"""
Expand Down Expand Up @@ -149,7 +151,9 @@ def __getitem__(self, key: int | slice) -> Series:
# if key < 0:
# key = pc.add(key, pc.list_value_length(self._pa_array))
element = pc.list_element(self._pa_array, key)
return Series(element, dtype=ArrowDtype(element.type))
return Series(
element, dtype=ArrowDtype(element.type), index=self._data.index
)
elif isinstance(key, slice):
if pa_version_under11p0:
raise NotImplementedError(
Expand All @@ -167,7 +171,7 @@ def __getitem__(self, key: int | slice) -> Series:
if step is None:
step = 1
sliced = pc.list_slice(self._pa_array, start, stop, step)
return Series(sliced, dtype=ArrowDtype(sliced.type))
return Series(sliced, dtype=ArrowDtype(sliced.type), index=self._data.index)
else:
raise ValueError(f"key must be an int or slice, got {type(key).__name__}")

Expand Down Expand Up @@ -195,15 +199,17 @@ def flatten(self) -> Series:
... )
>>> s.list.flatten()
0 1
1 2
2 3
3 3
0 2
0 3
1 3
dtype: int64[pyarrow]
"""
from pandas import Series

flattened = pc.list_flatten(self._pa_array)
return Series(flattened, dtype=ArrowDtype(flattened.type))
counts = pa.compute.list_value_length(self._pa_array)
flattened = pa.compute.list_flatten(self._pa_array)
index = self._data.index.repeat(counts.fill_null(pa.scalar(0, counts.type)))
return Series(flattened, dtype=ArrowDtype(flattened.type), index=index)


class StructAccessor(ArrowAccessor):
Expand Down
25 changes: 22 additions & 3 deletions pandas/tests/series/accessors/test_list_accessor.py
Expand Up @@ -31,10 +31,23 @@ def test_list_getitem(list_dtype):
tm.assert_series_equal(actual, expected)


def test_list_getitem_index():
# GH 58425
ser = Series(
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
)
actual = ser.list[1]
expected = Series([2, None, None], dtype="int64[pyarrow]", index=[1, 3, 7])
tm.assert_series_equal(actual, expected)


def test_list_getitem_slice():
ser = Series(
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
)
if pa_version_under11p0:
with pytest.raises(
Expand All @@ -44,7 +57,9 @@ def test_list_getitem_slice():
else:
actual = ser.list[1:None:None]
expected = Series(
[[2, 3], [None, 5], None], dtype=ArrowDtype(pa.list_(pa.int64()))
[[2, 3], [None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
index=[1, 3, 7],
)
tm.assert_series_equal(actual, expected)

Expand All @@ -61,11 +76,15 @@ def test_list_len():

def test_list_flatten():
ser = Series(
[[1, 2, 3], [4, None], None],
[[1, 2, 3], None, [4, None], [], [7, 8]],
dtype=ArrowDtype(pa.list_(pa.int64())),
)
actual = ser.list.flatten()
expected = Series([1, 2, 3, 4, None], dtype=ArrowDtype(pa.int64()))
expected = Series(
[1, 2, 3, 4, None, 7, 8],
dtype=ArrowDtype(pa.int64()),
index=[0, 0, 0, 2, 2, 4, 4],
)
tm.assert_series_equal(actual, expected)


Expand Down

0 comments on commit 6199cbd

Please sign in to comment.