Skip to content

Commit

Permalink
Add test that svd values are sorted
Browse files Browse the repository at this point in the history
This test will not pass NumPy without the changes in
numpy/numpy#20066 due to an update in indexing
behavior in the spec.
  • Loading branch information
cr313 committed Oct 26, 2021
1 parent 7d29cfe commit 43ff81d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,17 @@ def test_svd(x, kw):
assert s.dtype == x.dtype, "svd().s did not return the correct dtype"
assert vh.dtype == x.dtype, "svd().vh did not return the correct dtype"

assert s.shape == (*stack, K)
if full_matrices:
assert u.shape == (*stack, M, M)
assert vh.shape == (*stack, N, N)
else:
assert u.shape == (*stack, M, K)
assert vh.shape == (*stack, K, N)
assert s.shape == (*stack, K)

# The values of s must be sorted from largest to smallest
if K >= 1:
assert _array_module.all(s[..., :-1] >= s[..., 1:])

@pytest.mark.xp_extension('linalg')
@given(
Expand Down

0 comments on commit 43ff81d

Please sign in to comment.