Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched QR solver #4986

Closed
leofang opened this issue Mar 27, 2021 · 2 comments · Fixed by #5583
Closed

Batched QR solver #4986

leofang opened this issue Mar 27, 2021 · 2 comments · Fixed by #5583
Labels

Comments

@leofang
Copy link
Member

leofang commented Mar 27, 2021

This is for both NumPy compatibility and also the Python Array API standard (#4789, WIP proposal: data-apis/array-api#126).

Currently we have this check:

# TODO(Saito): Current implementation only accepts two-dimensional arrays
_util._assert_cupy_array(a)
_util._assert_rank2(a)

so we can only support a 2D matrix, but we might be able to follow the approach of batched svd and pinv to support batched QR factorization. I can look into this later.

@leofang
Copy link
Member Author

leofang commented Mar 27, 2021

Oops, read it too quickly: NumPy does not support it currently, so this is just for Array API.

>>> import numpy as np
>>> a = np.random.random((2,16,16))
>>> b = np.linalg.qr(a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<__array_function__ internals>", line 6, in qr
  File "/home/leofang/miniconda3/envs/cupy_cuda112_dev/lib/python3.7/site-packages/numpy/linalg/linalg.py", line 907, in qr
    _assert_2d(a)
  File "/home/leofang/miniconda3/envs/cupy_cuda112_dev/lib/python3.7/site-packages/numpy/linalg/linalg.py", line 191, in _assert_2d
    'two-dimensional' % a.ndim)
numpy.linalg.LinAlgError: 3-dimensional array given. Array must be two-dimensional

@leofang leofang changed the title Batched QR solver (cupy.linalg.qr()) Batched QR solver Mar 27, 2021
@kmaehashi kmaehashi added cat:feature New features/APIs prio:medium labels Mar 30, 2021
@leofang
Copy link
Member Author

leofang commented Jul 21, 2021

Update: In the next NumPy release batched QR will be supported, see numpy/numpy#19151. We should be prepared.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants