Skip to content

Commit

Permalink
ENH: Implemented general qr decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Jun 2, 2021
1 parent ca37297 commit 01a10c3
Showing 1 changed file with 79 additions and 40 deletions.
119 changes: 79 additions & 40 deletions numpy/linalg/umath_linalg.c.src
Expand Up @@ -162,6 +162,15 @@ FNAME(zgelsd)(fortran_int *m, fortran_int *n, fortran_int *nrhs,
double rwork[], fortran_int iwork[],
fortran_int *info);

extern fortran_int
FNAME(dgeqrf)(fortran_int *m, fortran_int *n, double a[], fortran_int *lda,
double tau[], double work[],
fortran_int *lwork, fortran_int *info);
extern fortran_int
FNAME(zgeqrf)(fortran_int *m, fortran_int *n, f2c_doublecomplex a[], fortran_int *lda,
f2c_doublecomplex tau[], f2c_doublecomplex work[],
fortran_int *lwork, fortran_int *info);

extern fortran_int
FNAME(sgesv)(fortran_int *n, fortran_int *nrhs,
float a[], fortran_int *lda,
Expand Down Expand Up @@ -3252,16 +3261,16 @@ typedef struct geqrf_params_struct

static inline void
dump_geqrf_params(const char *name,
GELSD_PARAMS_t *params)
GEQRF_PARAMS_t *params)
{
TRACE_TXT("\n%s:\n"\

"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18d\n"\
"%14s: %18d\n"\
"%14s: %18d\n"\
"%14s: %18d\n",

name,
Expand All @@ -3278,9 +3287,7 @@ dump_geqrf_params(const char *name,


/**begin repeat
#TYPE=FLOAT,DOUBLE#
#lapack_func=dgeqrf#
#ftyp=fortran_real,fortran_doublereal#
*/

static inline fortran_int
Expand All @@ -3293,6 +3300,14 @@ call_@lapack_func@(GEQRF_PARAMS_t *params)
return rv;
}

/**end repeat**/

/**begin repeat
#TYPE=DOUBLE#
#lapack_func=dgeqrf#
#ftyp=fortran_doublereal#
*/

static inline int
init_@lapack_func@(GEQRF_PARAMS_t *params,
fortran_int m,
Expand Down Expand Up @@ -3330,7 +3345,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
{
/* compute optimal work size */
@ftyp@ work_size_query;
fortran_int iwork_size_query;

params->WORK = &work_size_query;
params->LWORK = -1;
Expand Down Expand Up @@ -3365,8 +3379,6 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
/**end repeat**/

/**begin repeat
#TYPE=CFLOAT,CDOUBLE#
#ftyp=fortran_complex,fortran_doublecomplex#
#lapack_func=zgeqrf#
*/

Expand All @@ -3380,6 +3392,14 @@ call_@lapack_func@(GEQRF_PARAMS_t *params)
return rv;
}

/**end repeat**/

/**begin repeat
#TYPE=CDOUBLE#
#ftyp=fortran_doublecomplex#
#lapack_func=zgeqrf#
*/

static inline int
init_@lapack_func@(GEQRF_PARAMS_t *params,
fortran_int m,
Expand Down Expand Up @@ -3417,17 +3437,16 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,
{
/* compute optimal work size */
@ftyp@ work_size_query;
fortran_int iwork_size_query;

params->WORK = &work_size_query;
params->LWORK = -1;

if (call_@lapack_func@(params) != 0)
goto error;

work_count = (fortran_int)work_size_query;
work_count = (fortran_int)work_size_query.r;

work_size = (size_t) work_size_query * sizeof(@ftyp@);
work_size = (size_t) work_size_query.r * sizeof(@ftyp@);
}

mem_buff2 = malloc(work_size);
Expand All @@ -3453,45 +3472,60 @@ init_@lapack_func@(GEQRF_PARAMS_t *params,


/**begin repeat
#TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE#
#REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE#
#lapack_func=dgeqrf,zgeqrf#
#typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
#basetyp = npy_float, npy_double, npy_float, npy_double#
#ftyp = fortran_real, fortran_doublereal,
fortran_complex, fortran_doublecomplex#
#cmplx = 0, 1, 1#
*/
static inline void
release_@lapack_func@(GELSD_PARAMS_t* params)
release_@lapack_func@(GEQRF_PARAMS_t* params)
{
/* A and WORK contain allocated blocks */
free(params->A);
free(params->WORK);
memset(params, 0, sizeof(*params));
}

/** Compute the squared l2 norm of a contiguous vector */
static @basetyp@
@TYPE@_abs2(@typ@ *p, npy_intp n) {
npy_intp i;
@basetyp@ res = 0;
for (i = 0; i < n; i++) {
@typ@ el = p[i];
#if @cmplx@
res += el.real*el.real + el.imag*el.imag;
#else
res += el*el;
#endif
}
return res;
}
/**end repeat**/

/**begin repeat
#TYPE=DOUBLE,CDOUBLE#
#lapack_func=dgeqrf,zgeqrf#
*/

static void
@TYPE@_qr(char **args, npy_intp const *dimensions, npy_intp const *steps,
void *NPY_UNUSED(func))
{

GEQRF_PARAMS_t params;
int error_occurred = get_fp_invalid_and_clear();
fortran_int n, m, k;

INIT_OUTER_LOOP_7

m = (fortran_int)dimensions[0];
n = (fortran_int)dimensions[1];
k = fortran_int_min(n, m)

if (init_@lapack_func@(&params, m, n)) {
LINEARIZE_DATA_t a_in, q_out, r_out;

init_linearize_data(&a_in, m, n, steps[1], steps[0]);
init_linearize_data(&q_out, m, k, steps[3], steps[2]);
init_linearize_data(&r_out, k, n, steps[5], steps[4]);

BEGIN_OUTER_LOOP_7
int not_ok;
linearize_@TYPE@_matrix(params.A, args[0], &a_in);
not_ok = call_@lapack_func@(&params);
if (not_ok) {
error_occurred = 1;
nan_@TYPE@_matrix(args[1], &q_out);
nan_@TYPE@_matrix(args[2], &r_out);
}
END_OUTER_LOOP

release_@lapack_func@(&params);
}

set_fp_invalid_or_clear(error_occurred);
}

/**end repeat**/
Expand Down Expand Up @@ -3541,6 +3575,13 @@ static void *array_of_nulls[] = {
CDOUBLE_ ## NAME \
}

#define GUFUNC_FUNC_ARRAY_REAL_COMPLEX_QR(NAME) \
static PyUFuncGenericFunction \
FUNC_ARRAY_NAME(NAME)[] = { \
DOUBLE_ ## NAME, \
CDOUBLE_ ## NAME \
}

/* There are problems with eig in complex single precision.
* That kernel is disabled
*/
Expand All @@ -3567,7 +3608,7 @@ GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_N);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_S);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(svd_A);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(lstsq);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX(qr);
GUFUNC_FUNC_ARRAY_REAL_COMPLEX_QR(qr);
GUFUNC_FUNC_ARRAY_EIG(eig);
GUFUNC_FUNC_ARRAY_EIG(eigvals);

Expand Down Expand Up @@ -3642,9 +3683,7 @@ static char lstsq_types[] = {
};

static char qr_types[] = {
NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT,
NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
};

Expand Down Expand Up @@ -3859,7 +3898,7 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = {
},
{
"qr",
"(m,m),->(m,k),(k,m)",
"(m,m),->(m,m),(m,m)",
"QR decomposition of last two dimensions and broadcast to the rest.\n"\
"The input matrix must be square.\n",
4, 1, 2,
Expand Down

0 comments on commit 01a10c3

Please sign in to comment.