Skip to content

Commit

Permalink
Ensure that np.concatenate with dtype argument works on quantities an…
Browse files Browse the repository at this point in the history
…d masked data.
  • Loading branch information
mhvk committed Jun 10, 2022
1 parent 29ee076 commit 99e8f42
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 4 deletions.
4 changes: 2 additions & 2 deletions astropy/units/quantity_helper/function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,10 @@ def _iterable_helper(*args, out=None, **kwargs):


@function_helper
def concatenate(arrays, axis=0, out=None):
def concatenate(arrays, axis=0, out=None, **kwargs):
# TODO: make this smarter by creating an appropriately shaped
# empty output array and just filling it.
arrays, kwargs, unit, out = _iterable_helper(*arrays, out=out, axis=axis)
arrays, kwargs, unit, out = _iterable_helper(*arrays, out=out, axis=axis, **kwargs)
return (arrays,), kwargs, unit, out


Expand Down
4 changes: 4 additions & 0 deletions astropy/units/tests/test_quantity_non_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,10 @@ def check(self, func, *args, **kwargs):
def test_concatenate(self):
self.check(np.concatenate)
self.check(np.concatenate, axis=1)
if not NUMPY_LT_1_20:
# dtype argument only introduced in numpy 1.20
# regression test for gh-13322.
self.check(np.concatenate, dtype='f4')

self.check(np.concatenate, q_list=[np.zeros(self.q1.shape), self.q1, self.q2],
q_ref=self.q1)
Expand Down
4 changes: 2 additions & 2 deletions astropy/utils/masked/function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,9 @@ def sort_complex(a):


@apply_to_both
def concatenate(arrays, axis=0, out=None):
def concatenate(arrays, axis=0, out=None, **kwargs):
data, masks = _get_data_and_masks(*arrays)
return (data,), (masks,), dict(axis=axis), out
return (data,), (masks,), dict(axis=axis, **kwargs), out


@apply_to_both
Expand Down
3 changes: 3 additions & 0 deletions astropy/utils/masked/tests/test_function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ def test_concatenate(self):
self.check(np.concatenate)
self.check(np.concatenate, axis=1)
self.check(np.concatenate, ma_list=[self.a, self.ma])
if not NUMPY_LT_1_20:
# Check that we can accept a dtype argument (introduced in numpy 1.20)
self.check(np.concatenate, dtype='f4')

out = Masked(np.empty((4, 3)))
result = np.concatenate([self.ma, self.ma], out=out)
Expand Down
1 change: 1 addition & 0 deletions docs/changes/units/13323.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure that ``np.concatenate`` on quantities can take a ``dtype`` argument (added in numpy 1.20).
1 change: 1 addition & 0 deletions docs/changes/utils/13323.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ensure that ``np.concatenate`` on masked data can take a ``dtype`` argument (added in numpy 1.20).

0 comments on commit 99e8f42

Please sign in to comment.