Skip to content

Commit

Permalink
MAINT: Fix computation of numpy.array_api.linalg.vector_norm (numpy#2…
Browse files Browse the repository at this point in the history
…1084)

* Fix computation of numpy.array_api.linalg.vector_norm

Various pieces were incorrect due to a lack of complete coverage of this
function in the array API test suite.

* Fix the output dtype nonstandard vector norm()

Previously it would always give float64 because an internal calculation
involved a NumPy scalar and a Python float. The fix is to use a 0-D array
instead of a NumPy scalar so that it type promotes with the float correctly.

Fixes numpy#21083

I don't have a test for this yet because I'm unclear how exactly to test it.

* Clean up the numpy.array_api.linalg.vector_norm code a little bit
  • Loading branch information
asmeurer committed Jun 16, 2022
1 parent 2d44524 commit 70026c4
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions numpy/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from ._dtypes import _floating_dtypes, _numeric_dtypes
from ._manipulation_functions import reshape
from ._array_object import Array

from ..core.numeric import normalize_axis_tuple

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ._typing import Literal, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -395,18 +398,38 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in norm')

# np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
# when axis=None and the input is 2-D, so to force a vector norm, we make
# it so the input is 1-D (for axis=None), or reshape so that norm is done
# on a single dimension.
a = x._array
if axis is None:
a = a.flatten()
axis = 0
# Note: np.linalg.norm() doesn't handle 0-D arrays
a = a.ravel()
_axis = 0
elif isinstance(axis, tuple):
# Note: The axis argument supports any number of axes, whereas norm()
# only supports a single axis for vector norm.
rest = tuple(i for i in range(a.ndim) if i not in axis)
# Note: The axis argument supports any number of axes, whereas
# np.linalg.norm() only supports a single axis for vector norm.
normalized_axis = normalize_axis_tuple(axis, x.ndim)
rest = tuple(i for i in range(a.ndim) if i not in normalized_axis)
newshape = axis + rest
a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest]))
axis = 0
return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord))
a = np.transpose(a, newshape).reshape(
(np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest]))
_axis = 0
else:
_axis = axis

res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord))

if keepdims:
# We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
# above to avoid matrix norm logic.
shape = list(x.shape)
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
for i in _axis:
shape[i] = 1
res = reshape(res, tuple(shape))

return res

__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']

0 comments on commit 70026c4

Please sign in to comment.