Blender V2.61 - r43446
|
00001 00005 /* 00006 * -- SuperLU routine (version 3.0) -- 00007 * Univ. of California Berkeley, Xerox Palo Alto Research Center, 00008 * and Lawrence Berkeley National Lab. 00009 * October 15, 2003 00010 * 00011 */ 00012 /* 00013 Copyright (c) 1994 by Xerox Corporation. All rights reserved. 00014 00015 THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY 00016 EXPRESSED OR IMPLIED. ANY USE IS AT YOUR OWN RISK. 00017 00018 Permission is hereby granted to use or copy this program for any 00019 purpose, provided the above notices are retained on all copies. 00020 Permission to modify the code and to distribute modified code is 00021 granted, provided the above notices are retained, and a notice that 00022 the code was modified is included with the above copyright notice. 00023 */ 00024 00025 #include "ssp_defs.h" 00026 00027 00028 /* 00029 * Function prototypes 00030 */ 00031 void susolve(int, int, float*, float*); 00032 void slsolve(int, int, float*, float*); 00033 void smatvec(int, int, int, float*, float*, float*); 00034 void sprint_soln(int , float *); 00035 00036 void 00037 sgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U, 00038 int *perm_c, int *perm_r, SuperMatrix *B, 00039 SuperLUStat_t *stat, int *info) 00040 { 00041 /* 00042 * Purpose 00043 * ======= 00044 * 00045 * SGSTRS solves a system of linear equations A*X=B or A'*X=B 00046 * with A sparse and B dense, using the LU factorization computed by 00047 * SGSTRF. 00048 * 00049 * See supermatrix.h for the definition of 'SuperMatrix' structure. 00050 * 00051 * Arguments 00052 * ========= 00053 * 00054 * trans (input) trans_t 00055 * Specifies the form of the system of equations: 00056 * = NOTRANS: A * X = B (No transpose) 00057 * = TRANS: A'* X = B (Transpose) 00058 * = CONJ: A**H * X = B (Conjugate transpose) 00059 * 00060 * L (input) SuperMatrix* 00061 * The factor L from the factorization Pr*A*Pc=L*U as computed by 00062 * sgstrf(). Use compressed row subscripts storage for supernodes, 00063 * i.e., L has types: Stype = SLU_SC, Dtype = SLU_S, Mtype = SLU_TRLU. 00064 * 00065 * U (input) SuperMatrix* 00066 * The factor U from the factorization Pr*A*Pc=L*U as computed by 00067 * sgstrf(). Use column-wise storage scheme, i.e., U has types: 00068 * Stype = SLU_NC, Dtype = SLU_S, Mtype = SLU_TRU. 00069 * 00070 * perm_c (input) int*, dimension (L->ncol) 00071 * Column permutation vector, which defines the 00072 * permutation matrix Pc; perm_c[i] = j means column i of A is 00073 * in position j in A*Pc. 00074 * 00075 * perm_r (input) int*, dimension (L->nrow) 00076 * Row permutation vector, which defines the permutation matrix Pr; 00077 * perm_r[i] = j means row i of A is in position j in Pr*A. 00078 * 00079 * B (input/output) SuperMatrix* 00080 * B has types: Stype = SLU_DN, Dtype = SLU_S, Mtype = SLU_GE. 00081 * On entry, the right hand side matrix. 00082 * On exit, the solution matrix if info = 0; 00083 * 00084 * stat (output) SuperLUStat_t* 00085 * Record the statistics on runtime and floating-point operation count. 00086 * See util.h for the definition of 'SuperLUStat_t'. 00087 * 00088 * info (output) int* 00089 * = 0: successful exit 00090 * < 0: if info = -i, the i-th argument had an illegal value 00091 * 00092 */ 00093 #ifdef _CRAY 00094 _fcd ftcs1, ftcs2, ftcs3, ftcs4; 00095 #endif 00096 #ifdef USE_VENDOR_BLAS 00097 float alpha = 1.0, beta = 1.0; 00098 float *work_col; 00099 #endif 00100 DNformat *Bstore; 00101 float *Bmat; 00102 SCformat *Lstore; 00103 NCformat *Ustore; 00104 float *Lval, *Uval; 00105 int fsupc, nrow, nsupr, nsupc, luptr, istart, irow; 00106 int i, j, k, iptr, jcol, n, ldb, nrhs; 00107 float *work, *rhs_work, *soln; 00108 flops_t solve_ops; 00109 void sprint_soln(); 00110 00111 /* Test input parameters ... */ 00112 *info = 0; 00113 Bstore = B->Store; 00114 ldb = Bstore->lda; 00115 nrhs = B->ncol; 00116 if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1; 00117 else if ( L->nrow != L->ncol || L->nrow < 0 || 00118 L->Stype != SLU_SC || L->Dtype != SLU_S || L->Mtype != SLU_TRLU ) 00119 *info = -2; 00120 else if ( U->nrow != U->ncol || U->nrow < 0 || 00121 U->Stype != SLU_NC || U->Dtype != SLU_S || U->Mtype != SLU_TRU ) 00122 *info = -3; 00123 else if ( ldb < SUPERLU_MAX(0, L->nrow) || 00124 B->Stype != SLU_DN || B->Dtype != SLU_S || B->Mtype != SLU_GE ) 00125 *info = -6; 00126 if ( *info ) { 00127 i = -(*info); 00128 xerbla_("sgstrs", &i); 00129 return; 00130 } 00131 00132 n = L->nrow; 00133 work = floatCalloc(n * nrhs); 00134 if ( !work ) ABORT("Malloc fails for local work[]."); 00135 soln = floatMalloc(n); 00136 if ( !soln ) ABORT("Malloc fails for local soln[]."); 00137 00138 Bmat = Bstore->nzval; 00139 Lstore = L->Store; 00140 Lval = Lstore->nzval; 00141 Ustore = U->Store; 00142 Uval = Ustore->nzval; 00143 solve_ops = 0; 00144 00145 if ( trans == NOTRANS ) { 00146 /* Permute right hand sides to form Pr*B */ 00147 for (i = 0; i < nrhs; i++) { 00148 rhs_work = &Bmat[i*ldb]; 00149 for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k]; 00150 for (k = 0; k < n; k++) rhs_work[k] = soln[k]; 00151 } 00152 00153 /* Forward solve PLy=Pb. */ 00154 for (k = 0; k <= Lstore->nsuper; k++) { 00155 fsupc = L_FST_SUPC(k); 00156 istart = L_SUB_START(fsupc); 00157 nsupr = L_SUB_START(fsupc+1) - istart; 00158 nsupc = L_FST_SUPC(k+1) - fsupc; 00159 nrow = nsupr - nsupc; 00160 00161 solve_ops += nsupc * (nsupc - 1) * nrhs; 00162 solve_ops += 2 * nrow * nsupc * nrhs; 00163 00164 if ( nsupc == 1 ) { 00165 for (j = 0; j < nrhs; j++) { 00166 rhs_work = &Bmat[j*ldb]; 00167 luptr = L_NZ_START(fsupc); 00168 for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){ 00169 irow = L_SUB(iptr); 00170 ++luptr; 00171 rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr]; 00172 } 00173 } 00174 } else { 00175 luptr = L_NZ_START(fsupc); 00176 #ifdef USE_VENDOR_BLAS 00177 #ifdef _CRAY 00178 ftcs1 = _cptofcd("L", strlen("L")); 00179 ftcs2 = _cptofcd("N", strlen("N")); 00180 ftcs3 = _cptofcd("U", strlen("U")); 00181 STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha, 00182 &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); 00183 00184 SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 00185 &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 00186 &beta, &work[0], &n ); 00187 #else 00188 strsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha, 00189 &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); 00190 00191 sgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 00192 &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 00193 &beta, &work[0], &n ); 00194 #endif 00195 for (j = 0; j < nrhs; j++) { 00196 rhs_work = &Bmat[j*ldb]; 00197 work_col = &work[j*n]; 00198 iptr = istart + nsupc; 00199 for (i = 0; i < nrow; i++) { 00200 irow = L_SUB(iptr); 00201 rhs_work[irow] -= work_col[i]; /* Scatter */ 00202 work_col[i] = 0.0; 00203 iptr++; 00204 } 00205 } 00206 #else 00207 for (j = 0; j < nrhs; j++) { 00208 rhs_work = &Bmat[j*ldb]; 00209 slsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]); 00210 smatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc], 00211 &rhs_work[fsupc], &work[0] ); 00212 00213 iptr = istart + nsupc; 00214 for (i = 0; i < nrow; i++) { 00215 irow = L_SUB(iptr); 00216 rhs_work[irow] -= work[i]; 00217 work[i] = 0.0; 00218 iptr++; 00219 } 00220 } 00221 #endif 00222 } /* else ... */ 00223 } /* for L-solve */ 00224 00225 #ifdef DEBUG 00226 printf("After L-solve: y=\n"); 00227 sprint_soln(n, Bmat); 00228 #endif 00229 00230 /* 00231 * Back solve Ux=y. 00232 */ 00233 for (k = Lstore->nsuper; k >= 0; k--) { 00234 fsupc = L_FST_SUPC(k); 00235 istart = L_SUB_START(fsupc); 00236 nsupr = L_SUB_START(fsupc+1) - istart; 00237 nsupc = L_FST_SUPC(k+1) - fsupc; 00238 luptr = L_NZ_START(fsupc); 00239 00240 solve_ops += nsupc * (nsupc + 1) * nrhs; 00241 00242 if ( nsupc == 1 ) { 00243 rhs_work = &Bmat[0]; 00244 for (j = 0; j < nrhs; j++) { 00245 rhs_work[fsupc] /= Lval[luptr]; 00246 rhs_work += ldb; 00247 } 00248 } else { 00249 #ifdef USE_VENDOR_BLAS 00250 #ifdef _CRAY 00251 ftcs1 = _cptofcd("L", strlen("L")); 00252 ftcs2 = _cptofcd("U", strlen("U")); 00253 ftcs3 = _cptofcd("N", strlen("N")); 00254 STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha, 00255 &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); 00256 #else 00257 strsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha, 00258 &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb); 00259 #endif 00260 #else 00261 for (j = 0; j < nrhs; j++) 00262 susolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] ); 00263 #endif 00264 } 00265 00266 for (j = 0; j < nrhs; ++j) { 00267 rhs_work = &Bmat[j*ldb]; 00268 for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) { 00269 solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol)); 00270 for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){ 00271 irow = U_SUB(i); 00272 rhs_work[irow] -= rhs_work[jcol] * Uval[i]; 00273 } 00274 } 00275 } 00276 00277 } /* for U-solve */ 00278 00279 #ifdef DEBUG 00280 printf("After U-solve: x=\n"); 00281 sprint_soln(n, Bmat); 00282 #endif 00283 00284 /* Compute the final solution X := Pc*X. */ 00285 for (i = 0; i < nrhs; i++) { 00286 rhs_work = &Bmat[i*ldb]; 00287 for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]]; 00288 for (k = 0; k < n; k++) rhs_work[k] = soln[k]; 00289 } 00290 00291 stat->ops[SOLVE] = solve_ops; 00292 00293 } else { /* Solve A'*X=B or CONJ(A)*X=B */ 00294 /* Permute right hand sides to form Pc'*B. */ 00295 for (i = 0; i < nrhs; i++) { 00296 rhs_work = &Bmat[i*ldb]; 00297 for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k]; 00298 for (k = 0; k < n; k++) rhs_work[k] = soln[k]; 00299 } 00300 00301 stat->ops[SOLVE] = 0; 00302 for (k = 0; k < nrhs; ++k) { 00303 00304 /* Multiply by inv(U'). */ 00305 sp_strsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info); 00306 00307 /* Multiply by inv(L'). */ 00308 sp_strsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info); 00309 00310 } 00311 /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */ 00312 for (i = 0; i < nrhs; i++) { 00313 rhs_work = &Bmat[i*ldb]; 00314 for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]]; 00315 for (k = 0; k < n; k++) rhs_work[k] = soln[k]; 00316 } 00317 00318 } 00319 00320 SUPERLU_FREE(work); 00321 SUPERLU_FREE(soln); 00322 } 00323 00324 /* 00325 * Diagnostic print of the solution vector 00326 */ 00327 void 00328 sprint_soln(int n, float *soln) 00329 { 00330 int i; 00331 00332 for (i = 0; i < n; i++) 00333 printf("\t%d: %.4f\n", i, soln[i]); 00334 }