Skip to content

Commit

Permalink
Add test that svd values are sorted
Browse files Browse the repository at this point in the history
This test will not pass NumPy without the changes in
numpy/numpy#20066 due to an update in indexing
behavior in the spec.
  • Loading branch information
asmeurer committed Oct 26, 2021
1 parent e17273f commit 139e83e
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,17 @@ def test_svd(x, kw):
assert s.dtype == x.dtype, "svd().s did not return the correct dtype"
assert vh.dtype == x.dtype, "svd().vh did not return the correct dtype"

assert s.shape == (*stack, K)
if full_matrices:
assert u.shape == (*stack, M, M)
assert vh.shape == (*stack, N, N)
else:
assert u.shape == (*stack, M, K)
assert vh.shape == (*stack, K, N)
assert s.shape == (*stack, K)

# The values of s must be sorted from largest to smallest
if K >= 1:
assert _array_module.all(s[..., :-1] >= s[..., 1:])

@pytest.mark.xp_extension('linalg')
@given(
Expand Down

0 comments on commit 139e83e

Please sign in to comment.