From 01a10c3176154e3ce0ea544d5c1935698eed3051 Mon Sep 17 00:00:00 2001 From: czgdp1807 Date: Wed, 2 Jun 2021 14:55:14 +0530 Subject: [PATCH] ENH: Implemented general qr decomposition --- numpy/linalg/umath_linalg.c.src | 119 +++++++++++++++++++++----------- 1 file changed, 79 insertions(+), 40 deletions(-) diff --git a/numpy/linalg/umath_linalg.c.src b/numpy/linalg/umath_linalg.c.src index c1b0ac332a9d..60d07b6d1a8c 100644 --- a/numpy/linalg/umath_linalg.c.src +++ b/numpy/linalg/umath_linalg.c.src @@ -162,6 +162,15 @@ FNAME(zgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs, double rwork[], fortran_int iwork[], fortran_int *info); +extern fortran_int +FNAME(dgeqrf)(fortran_int *m, fortran_int *n, double a[], fortran_int *lda, + double tau[], double work[], + fortran_int *lwork, fortran_int *info); +extern fortran_int +FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int *lda, + f2c_doublecomplex tau[], f2c_doublecomplex work[], + fortran_int *lwork, fortran_int *info); + extern fortran_int FNAME(sgesv)(fortran_int *n, fortran_int *nrhs, float a[], fortran_int *lda, @@ -3252,16 +3261,16 @@ typedef struct geqrf_params_struct static inline void dump_geqrf_params(const char *name, - GELSD_PARAMS_t *params) + GEQRF_PARAMS_t *params) { TRACE_TXT("\n%s:\n"\ "%14s: %18p\n"\ "%14s: %18p\n"\ "%14s: %18p\n"\ - "%14s: %18p\n"\ - "%14s: %18p\n"\ - "%14s: %18p\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ + "%14s: %18d\n"\ "%14s: %18d\n", name, @@ -3278,9 +3287,7 @@ dump_geqrf_params(const char *name, /**begin repeat - #TYPE=FLOAT,DOUBLE# #lapack_func=dgeqrf# - #ftyp=fortran_real,fortran_doublereal# */ static inline fortran_int @@ -3293,6 +3300,14 @@ call_@lapack_func@(GEQRF_PARAMS_t *params) return rv; } +/**end repeat**/ + +/**begin repeat + #TYPE=DOUBLE# + #lapack_func=dgeqrf# + #ftyp=fortran_doublereal# + */ + static inline int init_@lapack_func@(GEQRF_PARAMS_t *params, fortran_int m, @@ -3330,7 +3345,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, { /* compute optimal work size */ @ftyp@ work_size_query; - fortran_int iwork_size_query; params->WORK = &work_size_query; params->LWORK = -1; @@ -3365,8 +3379,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, /**end repeat**/ /**begin repeat - #TYPE=CFLOAT,CDOUBLE# - #ftyp=fortran_complex,fortran_doublecomplex# #lapack_func=zgeqrf# */ @@ -3380,6 +3392,14 @@ call_@lapack_func@(GEQRF_PARAMS_t *params) return rv; } +/**end repeat**/ + +/**begin repeat + #TYPE=CDOUBLE# + #ftyp=fortran_doublecomplex# + #lapack_func=zgeqrf# + */ + static inline int init_@lapack_func@(GEQRF_PARAMS_t *params, fortran_int m, @@ -3417,7 +3437,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, { /* compute optimal work size */ @ftyp@ work_size_query; - fortran_int iwork_size_query; params->WORK = &work_size_query; params->LWORK = -1; @@ -3425,9 +3444,9 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, if (call_@lapack_func@(params) != 0) goto error; - work_count = (fortran_int)work_size_query; + work_count = (fortran_int)work_size_query.r; - work_size = (size_t) work_size_query * sizeof(@ftyp@); + work_size = (size_t) work_size_query.r * sizeof(@ftyp@); } mem_buff2 = malloc(work_size); @@ -3453,17 +3472,10 @@ init_@lapack_func@(GEQRF_PARAMS_t *params, /**begin repeat - #TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE# - #REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE# #lapack_func=dgeqrf,zgeqrf# - #typ = npy_float, npy_double, npy_cfloat, npy_cdouble# - #basetyp = npy_float, npy_double, npy_float, npy_double# - #ftyp = fortran_real, fortran_doublereal, - fortran_complex, fortran_doublecomplex# - #cmplx = 0, 1, 1# */ static inline void -release_@lapack_func@(GELSD_PARAMS_t* params) +release_@lapack_func@(GEQRF_PARAMS_t* params) { /* A and WORK contain allocated blocks */ free(params->A); @@ -3471,27 +3483,49 @@ release_@lapack_func@(GELSD_PARAMS_t* params) memset(params, 0, sizeof(*params)); } -/** Compute the squared l2 norm of a contiguous vector */ -static @basetyp@ -@TYPE@_abs2(@typ@ *p, npy_intp n) { - npy_intp i; - @basetyp@ res = 0; - for (i = 0; i < n; i++) { - @typ@ el = p[i]; -#if @cmplx@ - res += el.real*el.real + el.imag*el.imag; -#else - res += el*el; -#endif - } - return res; -} +/**end repeat**/ + +/**begin repeat + #TYPE=DOUBLE,CDOUBLE# + #lapack_func=dgeqrf,zgeqrf# + */ static void @TYPE@_qr(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) { - + GEQRF_PARAMS_t params; + int error_occurred = get_fp_invalid_and_clear(); + fortran_int n, m, k; + + INIT_OUTER_LOOP_7 + + m = (fortran_int)dimensions[0]; + n = (fortran_int)dimensions[1]; + k = fortran_int_min(n, m) + + if (init_@lapack_func@(¶ms, m, n)) { + LINEARIZE_DATA_t a_in, q_out, r_out; + + init_linearize_data(&a_in, m, n, steps[1], steps[0]); + init_linearize_data(&q_out, m, k, steps[3], steps[2]); + init_linearize_data(&r_out, k, n, steps[5], steps[4]); + + BEGIN_OUTER_LOOP_7 + int not_ok; + linearize_@TYPE@_matrix(params.A, args[0], &a_in); + not_ok = call_@lapack_func@(¶ms); + if (not_ok) { + error_occurred = 1; + nan_@TYPE@_matrix(args[1], &q_out); + nan_@TYPE@_matrix(args[2], &r_out); + } + END_OUTER_LOOP + + release_@lapack_func@(¶ms); + } + + set_fp_invalid_or_clear(error_occurred); } /**end repeat**/ @@ -3541,6 +3575,13 @@ static void *array_of_nulls[] = { CDOUBLE_ ## NAME \ } +#define GUFUNC_FUNC_ARRAY_REAL_COMPLEX_QR(NAME) \ + static PyUFuncGenericFunction \ + FUNC_ARRAY_NAME(NAME)[] = { \ + DOUBLE_ ## NAME, \ + CDOUBLE_ ## NAME \ + } + /* There are problems with eig in complex single precision. * That kernel is disabled */ @@ -3567,7 +3608,7 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_N); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_S); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A); GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq); -GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr); +GUFUNC_FUNC_ARRAY_REAL_COMPLEX_QR(qr); GUFUNC_FUNC_ARRAY_EIG(eig); GUFUNC_FUNC_ARRAY_EIG(eigvals); @@ -3642,9 +3683,7 @@ static char lstsq_types[] = { }; static char qr_types[] = { - NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, - NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT, NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE, }; @@ -3859,7 +3898,7 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = { }, { "qr", - "(m,m),->(m,k),(k,m)", + "(m,m),->(m,m),(m,m)", "QR decomposition of last two dimensions and broadcast to the rest.\n"\ "The input matrix must be square.\n", 4, 1, 2,