LCOV - code coverage report
Current view: top level - /builds/ug4-project/ugcore/ug4-new/plugins/SuperLU6/external/superlu/SRC - dgstrs.c (source / functions) Coverage Total Hit
Test: coverage.info Lines: 0.0 % 103 0
Test Date: 2026-06-01 23:54:59 Functions: 0.0 % 2 0

            Line data    Source code
       1              : /*! \file
       2              : Copyright (c) 2003, The Regents of the University of California, through
       3              : Lawrence Berkeley National Laboratory (subject to receipt of any required 
       4              : approvals from U.S. Dept. of Energy) 
       5              : 
       6              : All rights reserved. 
       7              : 
       8              : The source code is distributed under BSD license, see the file License.txt
       9              : at the top-level directory.
      10              : */
      11              : 
      12              : /*! @file dgstrs.c
      13              :  * \brief Solves a system using LU factorization
      14              :  *
      15              :  * <pre>
      16              :  * -- SuperLU routine (version 3.0) --
      17              :  * Univ. of California Berkeley, Xerox Palo Alto Research Center,
      18              :  * and Lawrence Berkeley National Lab.
      19              :  * October 15, 2003
      20              :  *
      21              :  * Copyright (c) 1994 by Xerox Corporation.  All rights reserved.
      22              :  *
      23              :  * THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
      24              :  * EXPRESSED OR IMPLIED.  ANY USE IS AT YOUR OWN RISK.
      25              :  *
      26              :  * Permission is hereby granted to use or copy this program for any
      27              :  * purpose, provided the above notices are retained on all copies.
      28              :  * Permission to modify the code and to distribute modified code is
      29              :  * granted, provided the above notices are retained, and a notice that
      30              :  * the code was modified is included with the above copyright notice.
      31              :  * </pre>
      32              :  */
      33              : 
      34              : #include "slu_ddefs.h"
      35              : 
      36              : 
      37              : /*! \brief
      38              :  *
      39              :  * <pre>
      40              :  * Purpose
      41              :  * =======
      42              :  *
      43              :  * DGSTRS solves a system of linear equations A*X=B or A'*X=B
      44              :  * with A sparse and B dense, using the LU factorization computed by
      45              :  * DGSTRF.
      46              :  *
      47              :  * See supermatrix.h for the definition of 'SuperMatrix' structure.
      48              :  *
      49              :  * Arguments
      50              :  * =========
      51              :  *
      52              :  * trans   (input) trans_t
      53              :  *          Specifies the form of the system of equations:
      54              :  *          = NOTRANS: A * X = B  (No transpose)
      55              :  *          = TRANS:   A'* X = B  (Transpose)
      56              :  *          = CONJ:    A**H * X = B  (Conjugate transpose)
      57              :  *
      58              :  * L       (input) SuperMatrix*
      59              :  *         The factor L from the factorization Pr*A*Pc=L*U as computed by
      60              :  *         dgstrf(). Use compressed row subscripts storage for supernodes,
      61              :  *         i.e., L has types: Stype = SLU_SC, Dtype = SLU_D, Mtype = SLU_TRLU.
      62              :  *
      63              :  * U       (input) SuperMatrix*
      64              :  *         The factor U from the factorization Pr*A*Pc=L*U as computed by
      65              :  *         dgstrf(). Use column-wise storage scheme, i.e., U has types:
      66              :  *         Stype = SLU_NC, Dtype = SLU_D, Mtype = SLU_TRU.
      67              :  *
      68              :  * perm_c  (input) int*, dimension (L->ncol)
      69              :  *         Column permutation vector, which defines the 
      70              :  *         permutation matrix Pc; perm_c[i] = j means column i of A is 
      71              :  *         in position j in A*Pc.
      72              :  *
      73              :  * perm_r  (input) int*, dimension (L->nrow)
      74              :  *         Row permutation vector, which defines the permutation matrix Pr; 
      75              :  *         perm_r[i] = j means row i of A is in position j in Pr*A.
      76              :  *
      77              :  * B       (input/output) SuperMatrix*
      78              :  *         B has types: Stype = SLU_DN, Dtype = SLU_D, Mtype = SLU_GE.
      79              :  *         On entry, the right hand side matrix.
      80              :  *         On exit, the solution matrix if info = 0;
      81              :  *
      82              :  * stat     (output) SuperLUStat_t*
      83              :  *          Record the statistics on runtime and floating-point operation count.
      84              :  *          See util.h for the definition of 'SuperLUStat_t'.
      85              :  *
      86              :  * info    (output) int*
      87              :  *         = 0: successful exit
      88              :  *         < 0: if info = -i, the i-th argument had an illegal value
      89              :  * </pre>
      90              :  */
      91              : 
      92              : void
      93            0 : dgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
      94              :         const int *perm_c, const int *perm_r, SuperMatrix *B,
      95              :         SuperLUStat_t *stat, int *info)
      96              : {
      97              : 
      98              : #ifdef _CRAY
      99              :     _fcd ftcs1, ftcs2, ftcs3, ftcs4;
     100              : #endif
     101              : #ifdef USE_VENDOR_BLAS
     102              :     double   alpha = 1.0, beta = 1.0;
     103              :     double   *work_col;
     104              : #endif
     105              :     DNformat *Bstore;
     106              :     double   *Bmat;
     107              :     SCformat *Lstore;
     108              :     NCformat *Ustore;
     109              :     double   *Lval, *Uval;
     110              :     int      fsupc, nrow, nsupr, nsupc, irow;
     111              :     int_t    i, j, k, luptr, istart, iptr;
     112              :     int      jcol, n, ldb, nrhs;
     113              :     double   *work, *rhs_work, *soln;
     114              :     flops_t  solve_ops;
     115              :     void dprint_soln(int n, int nrhs, const double *soln);
     116              : 
     117              :     /* Test input parameters ... */
     118            0 :     *info = 0;
     119            0 :     Bstore = B->Store;
     120            0 :     ldb = Bstore->lda;
     121            0 :     nrhs = B->ncol;
     122            0 :     if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
     123            0 :     else if ( L->nrow != L->ncol || L->nrow < 0 ||
     124            0 :               L->Stype != SLU_SC || L->Dtype != SLU_D || L->Mtype != SLU_TRLU )
     125            0 :         *info = -2;
     126            0 :     else if ( U->nrow != U->ncol || U->nrow < 0 ||
     127            0 :               U->Stype != SLU_NC || U->Dtype != SLU_D || U->Mtype != SLU_TRU )
     128            0 :         *info = -3;
     129            0 :     else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
     130            0 :               B->Stype != SLU_DN || B->Dtype != SLU_D || B->Mtype != SLU_GE )
     131            0 :         *info = -6;
     132            0 :     if ( *info ) {
     133            0 :         int ii = -(*info);
     134            0 :         input_error("dgstrs", &ii);
     135              :         return;
     136              :     }
     137              : 
     138            0 :     n = L->nrow;
     139            0 :     work = doubleCalloc((size_t) n * (size_t) nrhs);
     140            0 :     if ( !work ) ABORT("Malloc fails for local work[].");
     141            0 :     soln = doubleMalloc((size_t) n);
     142            0 :     if ( !soln ) ABORT("Malloc fails for local soln[].");
     143              : 
     144            0 :     Bmat = Bstore->nzval;
     145            0 :     Lstore = L->Store;
     146            0 :     Lval = Lstore->nzval;
     147            0 :     Ustore = U->Store;
     148            0 :     Uval = Ustore->nzval;
     149              :     solve_ops = 0;
     150              :     
     151            0 :     if ( trans == NOTRANS ) {
     152              :         /* Permute right hand sides to form Pr*B */
     153            0 :         for (i = 0; i < nrhs; i++) {
     154            0 :             rhs_work = &Bmat[(size_t)i * (size_t)ldb];
     155            0 :             for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
     156            0 :             for (k = 0; k < n; k++) rhs_work[k] = soln[k];
     157              :         }
     158              :         
     159              :         /* Forward solve PLy=Pb. */
     160            0 :         for (k = 0; k <= Lstore->nsuper; k++) {
     161            0 :             fsupc = L_FST_SUPC(k);
     162            0 :             istart = L_SUB_START(fsupc);
     163            0 :             nsupr = L_SUB_START(fsupc+1) - istart;
     164            0 :             nsupc = L_FST_SUPC(k+1) - fsupc;
     165            0 :             nrow = nsupr - nsupc;
     166              : 
     167            0 :             solve_ops += nsupc * (nsupc - 1) * nrhs;
     168            0 :             solve_ops += 2 * nrow * nsupc * nrhs;
     169              :             
     170            0 :             if ( nsupc == 1 ) {
     171            0 :                 for (j = 0; j < nrhs; j++) {
     172            0 :                     rhs_work = &Bmat[(size_t)j * (size_t)ldb];
     173            0 :                     luptr = L_NZ_START(fsupc);
     174            0 :                     for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
     175            0 :                         irow = L_SUB(iptr);
     176            0 :                         ++luptr;
     177            0 :                         rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
     178              :                     }
     179              :                 }
     180              :             } else {
     181            0 :                 luptr = L_NZ_START(fsupc);
     182              : #ifdef USE_VENDOR_BLAS
     183              : #ifdef _CRAY
     184              :                 ftcs1 = _cptofcd("L", strlen("L"));
     185              :                 ftcs2 = _cptofcd("N", strlen("N"));
     186              :                 ftcs3 = _cptofcd("U", strlen("U"));
     187              :                 STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
     188              :                        &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
     189              :                 
     190              :                 SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 
     191              :                         &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
     192              :                         &beta, &work[0], &n );
     193              : #else
     194              :                 dtrsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
     195              :                        &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
     196              :                 
     197              :                 dgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 
     198              :                         &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
     199              :                         &beta, &work[0], &n );
     200              : #endif
     201              :                 for (j = 0; j < nrhs; j++) {
     202              :                     rhs_work = &Bmat[(size_t)j * (size_t)ldb];
     203              :                     work_col = &work[(size_t)j * (size_t)n];
     204              :                     iptr = istart + nsupc;
     205              :                     for (i = 0; i < nrow; i++) {
     206              :                         irow = L_SUB(iptr);
     207              :                         rhs_work[irow] -= work_col[i]; /* Scatter */
     208              :                         work_col[i] = 0.0;
     209              :                         iptr++;
     210              :                     }
     211              :                 }
     212              : #else           
     213            0 :                 for (j = 0; j < nrhs; j++) {
     214            0 :                     rhs_work = &Bmat[(size_t)j * (size_t)ldb];
     215            0 :                     dlsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
     216            0 :                     dmatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
     217              :                             &rhs_work[fsupc], &work[0] );
     218              : 
     219            0 :                     iptr = istart + nsupc;
     220            0 :                     for (i = 0; i < nrow; i++) {
     221            0 :                         irow = L_SUB(iptr);
     222            0 :                         rhs_work[irow] -= work[i];
     223            0 :                         work[i] = 0.0;
     224            0 :                         iptr++;
     225              :                     }
     226              :                 }
     227              : #endif              
     228              :             } /* else ... */
     229              :         } /* for L-solve */
     230              : 
     231              : #if ( DEBUGlevel>=2 )
     232              :         printf("After L-solve: y=\n");
     233              :         dprint_soln(n, nrhs, Bmat);
     234              : #endif
     235              : 
     236              :         /*
     237              :          * Back solve Ux=y.
     238              :          */
     239            0 :         for (k = Lstore->nsuper; k >= 0; k--) {
     240            0 :             fsupc = L_FST_SUPC(k);
     241            0 :             istart = L_SUB_START(fsupc);
     242            0 :             nsupr = L_SUB_START(fsupc+1) - istart;
     243            0 :             nsupc = L_FST_SUPC(k+1) - fsupc;
     244            0 :             luptr = L_NZ_START(fsupc);
     245              : 
     246            0 :             solve_ops += nsupc * (nsupc + 1) * nrhs;
     247              : 
     248            0 :             if ( nsupc == 1 ) {
     249              :                 rhs_work = &Bmat[0];
     250            0 :                 for (j = 0; j < nrhs; j++) {
     251            0 :                     rhs_work[fsupc] /= Lval[luptr];
     252            0 :                     rhs_work += ldb;
     253              :                 }
     254              :             } else {
     255              : #ifdef USE_VENDOR_BLAS
     256              : #ifdef _CRAY
     257              :                 ftcs1 = _cptofcd("L", strlen("L"));
     258              :                 ftcs2 = _cptofcd("U", strlen("U"));
     259              :                 ftcs3 = _cptofcd("N", strlen("N"));
     260              :                 STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
     261              :                        &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
     262              : #else
     263              :                 dtrsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
     264              :                        &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
     265              : #endif
     266              : #else           
     267            0 :                 for (j = 0; j < nrhs; j++)
     268            0 :                     dusolve ( nsupr, nsupc, &Lval[luptr], &Bmat[(size_t)fsupc + (size_t)j * (size_t)ldb] );
     269              : #endif          
     270              :             }
     271              : 
     272            0 :             for (j = 0; j < nrhs; ++j) {
     273            0 :                 rhs_work = &Bmat[(size_t)j * (size_t)ldb];
     274            0 :                 for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
     275            0 :                     solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
     276            0 :                     for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
     277            0 :                         irow = U_SUB(i);
     278            0 :                         rhs_work[irow] -= rhs_work[jcol] * Uval[i];
     279              :                     }
     280              :                 }
     281              :             }
     282              :             
     283              :         } /* for U-solve */
     284              : 
     285              : #if ( DEBUGlevel>=2 )
     286              :         printf("After U-solve: x=\n");
     287              :         dprint_soln(n, nrhs, Bmat);
     288              : #endif
     289              : 
     290              :         /* Compute the final solution X := Pc*X. */
     291            0 :         for (i = 0; i < nrhs; i++) {
     292            0 :             rhs_work = &Bmat[(size_t)i * (size_t)ldb];
     293            0 :             for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
     294            0 :             for (k = 0; k < n; k++) rhs_work[k] = soln[k];
     295              :         }
     296              :         
     297            0 :         stat->ops[SOLVE] = solve_ops;
     298              : 
     299              :     } else { /* Solve A'*X=B or CONJ(A)*X=B */
     300              :         /* Permute right hand sides to form Pc'*B. */
     301            0 :         for (i = 0; i < nrhs; i++) {
     302            0 :             rhs_work = &Bmat[(size_t)i * (size_t)ldb];
     303            0 :             for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
     304            0 :             for (k = 0; k < n; k++) rhs_work[k] = soln[k];
     305              :         }
     306              : 
     307            0 :         stat->ops[SOLVE] = 0;
     308            0 :         for (k = 0; k < nrhs; ++k) {
     309              :             
     310              :             /* Multiply by inv(U'). */
     311            0 :             sp_dtrsv("U", "T", "N", L, U, &Bmat[(size_t)k * (size_t)ldb], stat, info);
     312              :             
     313              :             /* Multiply by inv(L'). */
     314            0 :             sp_dtrsv("L", "T", "U", L, U, &Bmat[(size_t)k * (size_t)ldb], stat, info);
     315              :             
     316              :         }
     317              :         /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
     318            0 :         for (i = 0; i < nrhs; i++) {
     319            0 :             rhs_work = &Bmat[(size_t)i * (size_t)ldb];
     320            0 :             for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
     321            0 :             for (k = 0; k < n; k++) rhs_work[k] = soln[k];
     322              :         }
     323              : 
     324              :     }
     325              : 
     326            0 :     SUPERLU_FREE(work);
     327            0 :     SUPERLU_FREE(soln);
     328              : }
     329              : 
     330              : /*
     331              :  * Diagnostic print of the solution vector 
     332              :  */
     333              : void
     334            0 : dprint_soln(int n, int nrhs, const double *soln)
     335              : {
     336              :     int i;
     337              : 
     338            0 :     for (i = 0; i < n; i++)
     339            0 :         printf("\t%d: %.4f\n", i, soln[i]);
     340            0 : }
        

Generated by: LCOV version 2.0-1