Skip to content

Commit

Permalink
ENH: Parameter struct defined, lapack call, init and release defined
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Jun 2, 2021
1 parent a436fb9 commit ca37297
Showing 1 changed file with 255 additions and 16 deletions.
271 changes: 255 additions & 16 deletions numpy/linalg/umath_linalg.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -3233,25 +3233,264 @@ static void
set_fp_invalid_or_clear(error_occurred);
}

static void
@TYPE@_qr(char **args, npy_intp const *dimensions, npy_intp const *steps,
void *NPY_UNUSED(func))
/**end repeat**/

/* -------------------------------------------------------------------------- */
/* qr */

typedef struct geqrf_params_struct
{
GELSD_PARAMS_t params;
int error_occurred = get_fp_invalid_and_clear();
fortran_int n, m;
fortran_int excess;
fortran_int M;
fortran_int N;
void *A;
fortran_int LDA;
void *TAU;
void *WORK;
fortran_int LWORK;
} GEQRF_PARAMS_t;

INIT_OUTER_LOOP_7

m = (fortran_int)dimensions[0];
n = (fortran_int)dimensions[1];
excess = m - n;
static inline void
dump_geqrf_params(const char *name,
GELSD_PARAMS_t *params)
{
TRACE_TXT("\n%s:\n"\

if (init_@lapack_func@(&params, m, n, nrhs)) {
LINEARIZE_DATA_t a_in, q_out, r_out;
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18p\n"\
"%14s: %18d\n",

name,

"A", params->A,
"TAU", params->TAU,
"WORK", params->WORK,

"M", (int)params->M,
"N", (int)params->N,
"LDA", (int)params->LDA,
"LWORK", (int)params->LWORK);
}


/**begin repeat
#TYPE=FLOAT,DOUBLE#
#lapack_func=dgeqrf#
#ftyp=fortran_real,fortran_doublereal#
*/

static inline fortran_int
call_@lapack_func@(GEQRF_PARAMS_t *params)
{
fortran_int rv;
LAPACK(@lapack_func@)(&params->M, &params->N, params->A, &params->LDA,
params->TAU, params->WORK,
&params->LWORK, &rv);
return rv;
}

static inline int
init_@lapack_func@(GEQRF_PARAMS_t *params,
fortran_int m,
fortran_int n)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *tau, *work;
fortran_int min_m_n = fortran_int_min(m, n);
size_t safe_min_m_n = min_m_n;
size_t safe_m = m;
size_t safe_n = n;

size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
size_t tau_size = safe_min_m_n * sizeof(@ftyp@);

fortran_int work_count;
size_t work_size;
fortran_int lda = fortran_int_max(1, m);

mem_buff = malloc(a_size + tau_size);

if (!mem_buff)
goto error;

a = mem_buff;
tau = a + a_size;

params->M = m;
params->N = n;
params->A = a;
params->TAU = tau;
params->LDA = lda;

{
/* compute optimal work size */
@ftyp@ work_size_query;
fortran_int iwork_size_query;

params->WORK = &work_size_query;
params->LWORK = -1;

if (call_@lapack_func@(params) != 0)
goto error;

work_count = (fortran_int)work_size_query;

work_size = (size_t) work_size_query * sizeof(@ftyp@);
}

mem_buff2 = malloc(work_size);
if (!mem_buff2)
goto error;

work = mem_buff2;

params->WORK = work;
params->LWORK = work_count;

return 1;
error:
TRACE_TXT("%s failed init\n", __FUNCTION__);
free(mem_buff);
free(mem_buff2);
memset(params, 0, sizeof(*params));

return 0;
}

/**end repeat**/

/**begin repeat
#TYPE=CFLOAT,CDOUBLE#
#ftyp=fortran_complex,fortran_doublecomplex#
#lapack_func=zgeqrf#
*/

static inline fortran_int
call_@lapack_func@(GEQRF_PARAMS_t *params)
{
fortran_int rv;
LAPACK(@lapack_func@)(&params->M, &params->N, params->A, &params->LDA,
params->TAU, params->WORK,
&params->LWORK, &rv);
return rv;
}

static inline int
init_@lapack_func@(GEQRF_PARAMS_t *params,
fortran_int m,
fortran_int n)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *tau, *work;
fortran_int min_m_n = fortran_int_min(m, n);
size_t safe_min_m_n = min_m_n;
size_t safe_m = m;
size_t safe_n = n;

size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
size_t tau_size = safe_min_m_n * sizeof(@ftyp@);

fortran_int work_count;
size_t work_size;
fortran_int lda = fortran_int_max(1, m);

mem_buff = malloc(a_size + tau_size);

if (!mem_buff)
goto error;

a = mem_buff;
tau = a + a_size;

params->M = m;
params->N = n;
params->A = a;
params->TAU = tau;
params->LDA = lda;

{
/* compute optimal work size */
@ftyp@ work_size_query;
fortran_int iwork_size_query;

params->WORK = &work_size_query;
params->LWORK = -1;

if (call_@lapack_func@(params) != 0)
goto error;

work_count = (fortran_int)work_size_query;

work_size = (size_t) work_size_query * sizeof(@ftyp@);
}

mem_buff2 = malloc(work_size);
if (!mem_buff2)
goto error;

work = mem_buff2;

params->WORK = work;
params->LWORK = work_count;

return 1;
error:
TRACE_TXT("%s failed init\n", __FUNCTION__);
free(mem_buff);
free(mem_buff2);
memset(params, 0, sizeof(*params));

return 0;
}

/**end repeat**/


/**begin repeat
#TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE#
#REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE#
#lapack_func=dgeqrf,zgeqrf#
#typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
#basetyp = npy_float, npy_double, npy_float, npy_double#
#ftyp = fortran_real, fortran_doublereal,
fortran_complex, fortran_doublecomplex#
#cmplx = 0, 1, 1#
*/
static inline void
release_@lapack_func@(GELSD_PARAMS_t* params)
{
/* A and WORK contain allocated blocks */
free(params->A);
free(params->WORK);
memset(params, 0, sizeof(*params));
}

/** Compute the squared l2 norm of a contiguous vector */
static @basetyp@
@TYPE@_abs2(@typ@ *p, npy_intp n) {
npy_intp i;
@basetyp@ res = 0;
for (i = 0; i < n; i++) {
@typ@ el = p[i];
#if @cmplx@
res += el.real*el.real + el.imag*el.imag;
#else
res += el*el;
#endif
}
return res;
}

static void
@TYPE@_qr(char **args, npy_intp const *dimensions, npy_intp const *steps,
void *NPY_UNUSED(func))
{

}

Expand Down Expand Up @@ -3403,10 +3642,10 @@ static char lstsq_types[] = {
};

static char qr_types[] = {
NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
NPY_FLOAT, NPY_FLOAT, NPY_FLOAT,
NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE,
NPY_CFLOAT, NPY_FLOAT, NPY_FLOAT,
NPY_CDOUBLE, NPY_DOUBLE, NPY_DOUBLE,
NPY_CFLOAT, NPY_CFLOAT, NPY_CFLOAT,
NPY_CDOUBLE, NPY_CDOUBLE, NPY_CDOUBLE,
};

typedef struct gufunc_descriptor_struct {
Expand Down

0 comments on commit ca37297

Please sign in to comment.