Line data Source code
1 : /*! \file
2 : Copyright (c) 2003, The Regents of the University of California, through
3 : Lawrence Berkeley National Laboratory (subject to receipt of any required
4 : approvals from U.S. Dept. of Energy)
5 :
6 : All rights reserved.
7 :
8 : The source code is distributed under BSD license, see the file License.txt
9 : at the top-level directory.
10 : */
11 :
12 : /*! @file dgstrs.c
13 : * \brief Solves a system using LU factorization
14 : *
15 : * <pre>
16 : * -- SuperLU routine (version 3.0) --
17 : * Univ. of California Berkeley, Xerox Palo Alto Research Center,
18 : * and Lawrence Berkeley National Lab.
19 : * October 15, 2003
20 : *
21 : * Copyright (c) 1994 by Xerox Corporation. All rights reserved.
22 : *
23 : * THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
24 : * EXPRESSED OR IMPLIED. ANY USE IS AT YOUR OWN RISK.
25 : *
26 : * Permission is hereby granted to use or copy this program for any
27 : * purpose, provided the above notices are retained on all copies.
28 : * Permission to modify the code and to distribute modified code is
29 : * granted, provided the above notices are retained, and a notice that
30 : * the code was modified is included with the above copyright notice.
31 : * </pre>
32 : */
33 :
34 : #include "slu_ddefs.h"
35 :
36 :
37 : /*! \brief
38 : *
39 : * <pre>
40 : * Purpose
41 : * =======
42 : *
43 : * DGSTRS solves a system of linear equations A*X=B or A'*X=B
44 : * with A sparse and B dense, using the LU factorization computed by
45 : * DGSTRF.
46 : *
47 : * See supermatrix.h for the definition of 'SuperMatrix' structure.
48 : *
49 : * Arguments
50 : * =========
51 : *
52 : * trans (input) trans_t
53 : * Specifies the form of the system of equations:
54 : * = NOTRANS: A * X = B (No transpose)
55 : * = TRANS: A'* X = B (Transpose)
56 : * = CONJ: A**H * X = B (Conjugate transpose)
57 : *
58 : * L (input) SuperMatrix*
59 : * The factor L from the factorization Pr*A*Pc=L*U as computed by
60 : * dgstrf(). Use compressed row subscripts storage for supernodes,
61 : * i.e., L has types: Stype = SLU_SC, Dtype = SLU_D, Mtype = SLU_TRLU.
62 : *
63 : * U (input) SuperMatrix*
64 : * The factor U from the factorization Pr*A*Pc=L*U as computed by
65 : * dgstrf(). Use column-wise storage scheme, i.e., U has types:
66 : * Stype = SLU_NC, Dtype = SLU_D, Mtype = SLU_TRU.
67 : *
68 : * perm_c (input) int*, dimension (L->ncol)
69 : * Column permutation vector, which defines the
70 : * permutation matrix Pc; perm_c[i] = j means column i of A is
71 : * in position j in A*Pc.
72 : *
73 : * perm_r (input) int*, dimension (L->nrow)
74 : * Row permutation vector, which defines the permutation matrix Pr;
75 : * perm_r[i] = j means row i of A is in position j in Pr*A.
76 : *
77 : * B (input/output) SuperMatrix*
78 : * B has types: Stype = SLU_DN, Dtype = SLU_D, Mtype = SLU_GE.
79 : * On entry, the right hand side matrix.
80 : * On exit, the solution matrix if info = 0;
81 : *
82 : * stat (output) SuperLUStat_t*
83 : * Record the statistics on runtime and floating-point operation count.
84 : * See util.h for the definition of 'SuperLUStat_t'.
85 : *
86 : * info (output) int*
87 : * = 0: successful exit
88 : * < 0: if info = -i, the i-th argument had an illegal value
89 : * </pre>
90 : */
91 :
92 : void
93 0 : dgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
94 : const int *perm_c, const int *perm_r, SuperMatrix *B,
95 : SuperLUStat_t *stat, int *info)
96 : {
97 :
98 : #ifdef _CRAY
99 : _fcd ftcs1, ftcs2, ftcs3, ftcs4;
100 : #endif
101 : #ifdef USE_VENDOR_BLAS
102 : double alpha = 1.0, beta = 1.0;
103 : double *work_col;
104 : #endif
105 : DNformat *Bstore;
106 : double *Bmat;
107 : SCformat *Lstore;
108 : NCformat *Ustore;
109 : double *Lval, *Uval;
110 : int fsupc, nrow, nsupr, nsupc, irow;
111 : int_t i, j, k, luptr, istart, iptr;
112 : int jcol, n, ldb, nrhs;
113 : double *work, *rhs_work, *soln;
114 : flops_t solve_ops;
115 : void dprint_soln(int n, int nrhs, const double *soln);
116 :
117 : /* Test input parameters ... */
118 0 : *info = 0;
119 0 : Bstore = B->Store;
120 0 : ldb = Bstore->lda;
121 0 : nrhs = B->ncol;
122 0 : if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
123 0 : else if ( L->nrow != L->ncol || L->nrow < 0 ||
124 0 : L->Stype != SLU_SC || L->Dtype != SLU_D || L->Mtype != SLU_TRLU )
125 0 : *info = -2;
126 0 : else if ( U->nrow != U->ncol || U->nrow < 0 ||
127 0 : U->Stype != SLU_NC || U->Dtype != SLU_D || U->Mtype != SLU_TRU )
128 0 : *info = -3;
129 0 : else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
130 0 : B->Stype != SLU_DN || B->Dtype != SLU_D || B->Mtype != SLU_GE )
131 0 : *info = -6;
132 0 : if ( *info ) {
133 0 : int ii = -(*info);
134 0 : input_error("dgstrs", &ii);
135 : return;
136 : }
137 :
138 0 : n = L->nrow;
139 0 : work = doubleCalloc((size_t) n * (size_t) nrhs);
140 0 : if ( !work ) ABORT("Malloc fails for local work[].");
141 0 : soln = doubleMalloc((size_t) n);
142 0 : if ( !soln ) ABORT("Malloc fails for local soln[].");
143 :
144 0 : Bmat = Bstore->nzval;
145 0 : Lstore = L->Store;
146 0 : Lval = Lstore->nzval;
147 0 : Ustore = U->Store;
148 0 : Uval = Ustore->nzval;
149 : solve_ops = 0;
150 :
151 0 : if ( trans == NOTRANS ) {
152 : /* Permute right hand sides to form Pr*B */
153 0 : for (i = 0; i < nrhs; i++) {
154 0 : rhs_work = &Bmat[(size_t)i * (size_t)ldb];
155 0 : for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
156 0 : for (k = 0; k < n; k++) rhs_work[k] = soln[k];
157 : }
158 :
159 : /* Forward solve PLy=Pb. */
160 0 : for (k = 0; k <= Lstore->nsuper; k++) {
161 0 : fsupc = L_FST_SUPC(k);
162 0 : istart = L_SUB_START(fsupc);
163 0 : nsupr = L_SUB_START(fsupc+1) - istart;
164 0 : nsupc = L_FST_SUPC(k+1) - fsupc;
165 0 : nrow = nsupr - nsupc;
166 :
167 0 : solve_ops += nsupc * (nsupc - 1) * nrhs;
168 0 : solve_ops += 2 * nrow * nsupc * nrhs;
169 :
170 0 : if ( nsupc == 1 ) {
171 0 : for (j = 0; j < nrhs; j++) {
172 0 : rhs_work = &Bmat[(size_t)j * (size_t)ldb];
173 0 : luptr = L_NZ_START(fsupc);
174 0 : for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
175 0 : irow = L_SUB(iptr);
176 0 : ++luptr;
177 0 : rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
178 : }
179 : }
180 : } else {
181 0 : luptr = L_NZ_START(fsupc);
182 : #ifdef USE_VENDOR_BLAS
183 : #ifdef _CRAY
184 : ftcs1 = _cptofcd("L", strlen("L"));
185 : ftcs2 = _cptofcd("N", strlen("N"));
186 : ftcs3 = _cptofcd("U", strlen("U"));
187 : STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
188 : &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
189 :
190 : SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha,
191 : &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb,
192 : &beta, &work[0], &n );
193 : #else
194 : dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
195 : &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
196 :
197 : dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha,
198 : &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb,
199 : &beta, &work[0], &n );
200 : #endif
201 : for (j = 0; j < nrhs; j++) {
202 : rhs_work = &Bmat[(size_t)j * (size_t)ldb];
203 : work_col = &work[(size_t)j * (size_t)n];
204 : iptr = istart + nsupc;
205 : for (i = 0; i < nrow; i++) {
206 : irow = L_SUB(iptr);
207 : rhs_work[irow] -= work_col[i]; /* Scatter */
208 : work_col[i] = 0.0;
209 : iptr++;
210 : }
211 : }
212 : #else
213 0 : for (j = 0; j < nrhs; j++) {
214 0 : rhs_work = &Bmat[(size_t)j * (size_t)ldb];
215 0 : dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
216 0 : dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
217 : &rhs_work[fsupc], &work[0] );
218 :
219 0 : iptr = istart + nsupc;
220 0 : for (i = 0; i < nrow; i++) {
221 0 : irow = L_SUB(iptr);
222 0 : rhs_work[irow] -= work[i];
223 0 : work[i] = 0.0;
224 0 : iptr++;
225 : }
226 : }
227 : #endif
228 : } /* else ... */
229 : } /* for L-solve */
230 :
231 : #if ( DEBUGlevel>=2 )
232 : printf("After L-solve: y=\n");
233 : dprint_soln(n, nrhs, Bmat);
234 : #endif
235 :
236 : /*
237 : * Back solve Ux=y.
238 : */
239 0 : for (k = Lstore->nsuper; k >= 0; k--) {
240 0 : fsupc = L_FST_SUPC(k);
241 0 : istart = L_SUB_START(fsupc);
242 0 : nsupr = L_SUB_START(fsupc+1) - istart;
243 0 : nsupc = L_FST_SUPC(k+1) - fsupc;
244 0 : luptr = L_NZ_START(fsupc);
245 :
246 0 : solve_ops += nsupc * (nsupc + 1) * nrhs;
247 :
248 0 : if ( nsupc == 1 ) {
249 : rhs_work = &Bmat[0];
250 0 : for (j = 0; j < nrhs; j++) {
251 0 : rhs_work[fsupc] /= Lval[luptr];
252 0 : rhs_work += ldb;
253 : }
254 : } else {
255 : #ifdef USE_VENDOR_BLAS
256 : #ifdef _CRAY
257 : ftcs1 = _cptofcd("L", strlen("L"));
258 : ftcs2 = _cptofcd("U", strlen("U"));
259 : ftcs3 = _cptofcd("N", strlen("N"));
260 : STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
261 : &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
262 : #else
263 : dtrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
264 : &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
265 : #endif
266 : #else
267 0 : for (j = 0; j < nrhs; j++)
268 0 : dusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[(size_t)fsupc + (size_t)j * (size_t)ldb] );
269 : #endif
270 : }
271 :
272 0 : for (j = 0; j < nrhs; ++j) {
273 0 : rhs_work = &Bmat[(size_t)j * (size_t)ldb];
274 0 : for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
275 0 : solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
276 0 : for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
277 0 : irow = U_SUB(i);
278 0 : rhs_work[irow] -= rhs_work[jcol] * Uval[i];
279 : }
280 : }
281 : }
282 :
283 : } /* for U-solve */
284 :
285 : #if ( DEBUGlevel>=2 )
286 : printf("After U-solve: x=\n");
287 : dprint_soln(n, nrhs, Bmat);
288 : #endif
289 :
290 : /* Compute the final solution X := Pc*X. */
291 0 : for (i = 0; i < nrhs; i++) {
292 0 : rhs_work = &Bmat[(size_t)i * (size_t)ldb];
293 0 : for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
294 0 : for (k = 0; k < n; k++) rhs_work[k] = soln[k];
295 : }
296 :
297 0 : stat->ops[SOLVE] = solve_ops;
298 :
299 : } else { /* Solve A'*X=B or CONJ(A)*X=B */
300 : /* Permute right hand sides to form Pc'*B. */
301 0 : for (i = 0; i < nrhs; i++) {
302 0 : rhs_work = &Bmat[(size_t)i * (size_t)ldb];
303 0 : for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
304 0 : for (k = 0; k < n; k++) rhs_work[k] = soln[k];
305 : }
306 :
307 0 : stat->ops[SOLVE] = 0;
308 0 : for (k = 0; k < nrhs; ++k) {
309 :
310 : /* Multiply by inv(U'). */
311 0 : sp_dtrsv("U", "T", "N", L, U, &Bmat[(size_t)k * (size_t)ldb], stat, info);
312 :
313 : /* Multiply by inv(L'). */
314 0 : sp_dtrsv("L", "T", "U", L, U, &Bmat[(size_t)k * (size_t)ldb], stat, info);
315 :
316 : }
317 : /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
318 0 : for (i = 0; i < nrhs; i++) {
319 0 : rhs_work = &Bmat[(size_t)i * (size_t)ldb];
320 0 : for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
321 0 : for (k = 0; k < n; k++) rhs_work[k] = soln[k];
322 : }
323 :
324 : }
325 :
326 0 : SUPERLU_FREE(work);
327 0 : SUPERLU_FREE(soln);
328 : }
329 :
330 : /*
331 : * Diagnostic print of the solution vector
332 : */
333 : void
334 0 : dprint_soln(int n, int nrhs, const double *soln)
335 : {
336 : int i;
337 :
338 0 : for (i = 0; i < n; i++)
339 0 : printf("\t%d: %.4f\n", i, soln[i]);
340 0 : }
|