Skip to content

Commit

Permalink
Avoid needless casting of masks
Browse files Browse the repository at this point in the history
  • Loading branch information
mhvk committed Jun 10, 2022
1 parent 3d25f10 commit 7d72e13
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions astropy/utils/masked/function_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,27 @@ def sort_complex(a):
return b


@apply_to_both
def concatenate(arrays, axis=0, out=None, **kwargs):
data, masks = _get_data_and_masks(*arrays)
return (data,), (masks,), dict(axis=axis, **kwargs), out
if NUMPY_LT_1_20:
@apply_to_both
def concatenate(arrays, axis=0, out=None):
data, masks = _get_data_and_masks(*arrays)
return (data,), (masks,), dict(axis=axis), out

else:
@dispatched_function
def concatenate(arrays, axis=0, out=None, dtype=None, casting='same_kind'):
data, masks = _get_data_and_masks(*arrays)
if out is None:
return (np.concatenate(data, axis=axis, dtype=dtype, casting=casting),
np.concatenate(masks, axis=axis),
None)
else:
from astropy.utils.masked import Masked
if not isinstance(out, Masked):
raise NotImplementedError
np.concatenate(masks, out=out.mask, axis=axis)
np.concatenate(data, out=out.unmasked, axis=axis, dtype=dtype, casting=casting)
return out


@apply_to_both
Expand Down

0 comments on commit 7d72e13

Please sign in to comment.