AbstractLinAlgPack_SuperLUSolver.cpp

Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 // Moocho: Multi-functional Object-Oriented arCHitecture for Optimization
00005 //                  Copyright (2003) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Roscoe A. Bartlett (rabartl@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 // @HEADER
00028 
00029 #ifdef SPARSE_SOLVER_PACK_USE_SUPERLU
00030 
00031 #include <assert.h>
00032 #include <valarray>
00033 #include <stdexcept>
00034 
00035 #include "AbstractLinAlgPack_SuperLUSolver.hpp"
00036 #include "Teuchos_dyn_cast.hpp"
00037 #include "Teuchos_Workspace.hpp"
00038 #include "Teuchos_TestForException.hpp"
00039 
00040 // SuperLU
00041 #include "dsp_defs.h"
00042 #include "util.h"
00043 
00044 namespace {
00045 
00046 // Static SuperLU stuff
00047 
00048 int local_panel_size  = 0;
00049 int local_relax       = 0;
00050 
00051 class StaticSuperLUInit {
00052 public:
00053   StaticSuperLUInit()
00054     {
00055       local_panel_size = sp_ienv(1);
00056       local_relax      = sp_ienv(2);
00057       StatInit(local_panel_size,local_relax);
00058     }
00059   ~StaticSuperLUInit()
00060     {
00061       StatFree();
00062     }
00063 };
00064 
00065 StaticSuperLUInit static_super_lu_init; // Will be created early and destroyed late!
00066 
00067 // ToDo: RAB: 2002/10/14: We must find a better way to work with
00068 // SuperLU than this.  It should not be too hard
00069 // to do better in the future.
00070 
00071 // A cast to const is needed because the standard does not return a reference from
00072 // valarray<>::operator[]() const.
00073 template <class T>
00074 std::valarray<T>& cva(const std::valarray<T>& va )
00075 {
00076   return const_cast<std::valarray<T>&>(va);
00077 }
00078 
00079 } // end namespace
00080 
00081 namespace SuperLUPack {
00082 
00083 class SuperLUSolverImpl;
00084 
00089 class SuperLUSolverImpl : public SuperLUSolver {
00090 public:
00091 
00094 
00096   class FactorizationStructureImpl : public FactorizationStructure {
00097   public:
00098     friend class SuperLUSolverImpl;
00099   private:
00100     int                   rank_;        // For square basis
00101     int                   nz_;          // ...
00102     std::valarray<int>    perm_r_;      // ...
00103     std::valarray<int>    perm_c_;      // ...
00104     std::valarray<int>    etree_;       // ...
00105     int                   m_orig_;      // For original rectangular matrix
00106     int                   n_orig_;      // ...
00107     int                   nz_orig_;     // ...
00108     std::valarray<int>    perm_r_orig_; // ...
00109     std::valarray<int>    perm_c_orig_; // ...
00110   };
00111 
00113   class FactorizationNonzerosImpl : public FactorizationNonzeros {
00114   public:
00115     friend class SuperLUSolverImpl;
00116   private:
00117     SuperMatrix   L_;
00118     SuperMatrix   U_;
00119   };
00120 
00122 
00125 
00127   void analyze_and_factor(
00128     int                         m
00129     ,int                        n
00130     ,int                        nz
00131     ,const double               a_val[]
00132     ,const int                  a_row_i[]
00133     ,const int                  a_col_ptr[]
00134     ,FactorizationStructure     *fact_struct
00135     ,FactorizationNonzeros      *fact_nonzeros
00136     ,int                        row_perm[]
00137     ,int                        col_perm[]
00138     ,int                        *rank
00139     );
00141   void factor(
00142     int                             m
00143     ,int                            n
00144     ,int                            nz
00145     ,const double                   a_val[]
00146     ,const int                      a_row_i[]
00147     ,const int                      a_col_ptr[]
00148     ,const FactorizationStructure   &fact_struct
00149     ,FactorizationNonzeros          *fact_nonzeros
00150     );
00152   void solve(
00153     const FactorizationStructure    &fact_struct
00154     ,const FactorizationNonzeros    &fact_nonzeros
00155     ,bool                           transp
00156     ,int                            n
00157     ,int                            nrhs
00158     ,double                         rhs[]
00159     ,int                            ldrhs
00160     ) const;
00161 
00163 
00164 private:
00165 
00167   void copy_basis_nonzeros(
00168     int                             m_orig
00169     ,int                            n_orig
00170     ,int                            nz_orig
00171     ,const double                   a_orig_val[]
00172     ,const int                      a_orig_row_i[]
00173     ,const int                      a_orig_col_ptr[]
00174     ,const int                      a_orig_perm_r[]
00175     ,const int                      a_orig_perm_c[]
00176     ,const int                      rank
00177     ,double                         b_val[]
00178     ,int                            b_row_i[]
00179     ,int                            b_col_ptr[]
00180     ,int                            *b_nz
00181     ) const;
00182 
00183 }; // end class SuperLUSolver
00184 
00185 //
00186 // SuperLUSolver
00187 //
00188 
00189 Teuchos::RCP<SuperLUSolver>
00190 SuperLUSolver::create_solver()
00191 {
00192   return Teuchos::rcp(new SuperLUSolverImpl());
00193 }
00194 
00195 Teuchos::RCP<SuperLUSolver::FactorizationStructure>
00196 SuperLUSolver::create_fact_struct()
00197 {
00198   return Teuchos::rcp(new SuperLUSolverImpl::FactorizationStructureImpl());
00199 }
00200 
00201 Teuchos::RCP<SuperLUSolver::FactorizationNonzeros>
00202 SuperLUSolver::create_fact_nonzeros()
00203 {
00204   return Teuchos::rcp(new SuperLUSolverImpl::FactorizationNonzerosImpl());
00205 }
00206 
00207 //
00208 // SuperLUSolverImp
00209 //
00210 
00211 // Overridden from SuperLUSolver
00212 
00213 void SuperLUSolverImpl::analyze_and_factor(
00214   int                         m
00215   ,int                        n
00216   ,int                        nz
00217   ,const double               a_val[]
00218   ,const int                  a_row_i[]
00219   ,const int                  a_col_ptr[]
00220   ,FactorizationStructure     *fact_struct
00221   ,FactorizationNonzeros      *fact_nonzeros
00222   ,int                        perm_r[]
00223   ,int                        perm_c[]
00224   ,int                        *rank
00225   )
00226 {
00227   using Teuchos::dyn_cast;
00228   using Teuchos::Workspace;
00229   Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00230 
00231   FactorizationStructureImpl
00232     &fs = dyn_cast<FactorizationStructureImpl>(*fact_struct);
00233   FactorizationNonzerosImpl
00234     &fn = dyn_cast<FactorizationNonzerosImpl>(*fact_nonzeros);
00235 
00236   char refact[] = "N";
00237 
00238   // Resize storage.
00239   // Note: if this function was called recursively for m>n on the last call
00240   // then m_orig, n_orig etc. will already be set and should not be
00241   // disturbed.
00242   fs.rank_ = n; // Assume this for now
00243   fs.nz_   = nz;
00244   fs.perm_r_.resize(m);
00245   fs.perm_c_.resize(n);
00246   fs.etree_.resize(n);
00247 
00248     // Create matrix A in the format expected by SuperLU
00249   SuperMatrix A;
00250   dCreate_CompCol_Matrix(
00251     &A, m, n, nz
00252     ,const_cast<double*>(a_val)
00253     ,const_cast<int*>(a_row_i)
00254     ,const_cast<int*>(a_col_ptr)
00255     ,NC, D_, GE
00256     );
00257 
00258   // Get the columm permutations
00259   int permc_spec = 0; // ToDo: Make this an external parameter
00260   get_perm_c(permc_spec, &A, &fs.perm_c_[0]);
00261 
00262   // Permute the columns of the matrix
00263   SuperMatrix AC;
00264   sp_preorder(refact,&A,&fs.perm_c_[0],&fs.etree_[0],&AC);
00265 
00266   int info = -1;
00267   dgstrf(
00268     refact
00269     ,&AC  
00270     ,1.0    /* diag_pivot_thresh */
00271     ,0.0    /* drop_tol */
00272     ,local_relax
00273     ,local_panel_size
00274     ,&fs.etree_[0]
00275     ,NULL   /* work */
00276     ,0      /* lwork */
00277     ,&fs.perm_r_[0]
00278     ,&fs.perm_c_[0]
00279     ,&fn.L_
00280     ,&fn.U_
00281     ,&info
00282     );
00283 
00284   TEST_FOR_EXCEPTION(
00285     info != 0, std::runtime_error
00286     ,"SuperLUSolverImpl::analyze_and_factor(...): Error, dgstrf(...) returned info = " << info
00287     );
00288 
00289   std::copy( &fs.perm_r_[0], &fs.perm_r_[0] + m, perm_r );
00290   std::copy( &fs.perm_c_[0], &fs.perm_c_[0] + n, perm_c );
00291   *rank = n; // We must assume this until I can figure out a way to do better!
00292 
00293   if(m > n) {
00294     // Now we must refactor the basis by only passing in the elements for the basis
00295     // determined by SuperLU.  This is wasteful but it is the easiest thing to do
00296     // for now.
00297     fs.rank_        = *rank;
00298     fs.m_orig_      = m;
00299     fs.n_orig_      = n;
00300     fs.nz_orig_     = nz;
00301     fs.perm_r_orig_ = fs.perm_r_;
00302     fs.perm_c_orig_ = fs.perm_c_;
00303     // Copy the nonzeros for the sqare factor into new storage
00304     Workspace<double>       b_val(wss,nz);
00305     Workspace<int>          b_row_i(wss,nz);
00306     Workspace<int>          b_col_ptr(wss,n+1);
00307     copy_basis_nonzeros(
00308       m,n,nz,a_val,a_row_i,a_col_ptr
00309       ,&fs.perm_r_orig_[0],&fs.perm_c_orig_[0],fs.rank_
00310       ,&b_val[0],&b_row_i[0],&b_col_ptr[0]
00311       ,&fs.nz_
00312       );
00313     // Analyze and factor the new matrix
00314     int b_rank = -1;
00315     analyze_and_factor(
00316       fs.rank_, fs.rank_, fs.nz_, &b_val[0], &b_row_i[0], &b_col_ptr[0]
00317       ,fact_struct, fact_nonzeros
00318       ,&fs.perm_r_[0], &fs.perm_c_[0], &b_rank
00319       );
00320     TEST_FOR_EXCEPTION(
00321       (b_rank != *rank), std::runtime_error
00322       ,"SuperLUSolverImpl::analyze_and_factor(...): Error, the rank determined by "
00323       "the factorization of the rectangular " << m << " x " << n << " matrix of "
00324       << (*rank) << " is not the same as the refactorization of the basis returned as "
00325       << b_rank << "!"
00326       );
00327   }
00328   else {
00329     fs.m_orig_  = m;
00330     fs.n_orig_  = n;
00331     fs.nz_orig_ = nz;
00332   }
00333 }
00334 
00335 void SuperLUSolverImpl::factor(
00336   int                             m
00337   ,int                            n
00338   ,int                            nz
00339   ,const double                   a_val[]
00340   ,const int                      a_row_i[]
00341   ,const int                      a_col_ptr[]
00342   ,const FactorizationStructure   &fact_struct
00343   ,FactorizationNonzeros          *fact_nonzeros
00344   )
00345 {
00346   using Teuchos::dyn_cast;
00347   using Teuchos::Workspace;
00348   Teuchos::WorkspaceStore* wss = Teuchos::get_default_workspace_store().get();
00349 
00350   const FactorizationStructureImpl
00351     &fs = dyn_cast<const FactorizationStructureImpl>(fact_struct);
00352   FactorizationNonzerosImpl
00353     &fn = dyn_cast<FactorizationNonzerosImpl>(*fact_nonzeros);
00354 
00355   char refact[] = "Y";
00356 
00357   // Copy the nonzeros for the sqare factor into new storage
00358   Workspace<double>       b_val(wss,fs.nz_);
00359   Workspace<int>          b_row_i(wss,fs.nz_);
00360   Workspace<int>          b_col_ptr(wss,fs.rank_+1);
00361   if(fs.m_orig_ > fs.n_orig_) {
00362     int b_nz = -1;
00363     copy_basis_nonzeros(
00364       m,n,nz,a_val,a_row_i,a_col_ptr
00365       ,&cva(fs.perm_r_orig_)[0],&cva(fs.perm_c_orig_)[0],fs.rank_
00366       ,&b_val[0],&b_row_i[0],&b_col_ptr[0]
00367       ,&b_nz
00368       );
00369     TEST_FOR_EXCEPTION(
00370       (b_nz != fs.nz_), std::runtime_error
00371       ,"SuperLUSolverImpl::factor(...): Error!"
00372       );
00373   }
00374   else {
00375     std::copy( a_val,     a_val     + nz,  &b_val[0]     );
00376     std::copy( a_row_i,   a_row_i   + nz,  &b_row_i[0]   );
00377     std::copy( a_col_ptr, a_col_ptr + n+1, &b_col_ptr[0] );
00378   }
00379 
00380     // Create matrix A in the format expected by SuperLU
00381   SuperMatrix A;
00382   dCreate_CompCol_Matrix(
00383     &A, fs.rank_, fs.rank_, fs.nz_
00384     ,&b_val[0]
00385     ,&b_row_i[0]
00386     ,&b_col_ptr[0]
00387     ,NC, D_, GE
00388     );
00389 
00390   // Permute the columns
00391   SuperMatrix AC;
00392   sp_preorder(
00393     refact,&A
00394     ,&cva(fs.perm_c_)[0]
00395     ,&cva(fs.etree_)[0]
00396     ,&AC
00397     );
00398 
00399   int info = -1;
00400   dgstrf(
00401     refact
00402     ,&AC  
00403     ,1.0    /* diag_pivot_thresh */
00404     ,0.0    /* drop_tol */
00405     ,local_relax
00406     ,local_panel_size
00407     ,const_cast<int*>(&cva(fs.etree_)[0])
00408     ,NULL   /* work */
00409     ,0      /* lwork */
00410     ,&cva(fs.perm_r_)[0]
00411     ,&cva(fs.perm_c_)[0]
00412     ,&fn.L_
00413     ,&fn.U_
00414     ,&info
00415     );
00416 
00417   TEST_FOR_EXCEPTION(
00418     info != 0, std::runtime_error
00419     ,"SuperLUSolverImpl::factor(...): Error, dgstrf(...) returned info = " << info
00420     );
00421 
00422 }
00423 
00424 void SuperLUSolverImpl::solve(
00425   const FactorizationStructure    &fact_struct
00426   ,const FactorizationNonzeros    &fact_nonzeros
00427   ,bool                           transp
00428   ,int                            n
00429   ,int                            nrhs
00430   ,double                         rhs[]
00431   ,int                            ldrhs
00432   ) const
00433 {
00434 
00435   using Teuchos::dyn_cast;
00436 
00437   const FactorizationStructureImpl
00438     &fs = dyn_cast<const FactorizationStructureImpl>(fact_struct);
00439   const FactorizationNonzerosImpl
00440     &fn = dyn_cast<const FactorizationNonzerosImpl>(fact_nonzeros);
00441 
00442   TEST_FOR_EXCEPTION(
00443     n != fs.rank_, std::runtime_error
00444     ,"SuperLUSolverImpl::solve(...): Error, the dimmensions n = " << n << " and fs.rank = " << fs.rank_
00445     << " do not match up!"
00446     );
00447 
00448   SuperMatrix B;
00449     dCreate_Dense_Matrix(&B, n, nrhs, rhs, ldrhs, DN, D_, GE);
00450 
00451   char transc[1];
00452   transc[0] = ( transp ? 'T' : 'N' );
00453 
00454   int info = -1;
00455     dgstrs(
00456     transc
00457     ,const_cast<SuperMatrix*>(&fn.L_)
00458     ,const_cast<SuperMatrix*>(&fn.U_)
00459     ,&cva(fs.perm_r_)[0]
00460     ,&cva(fs.perm_c_)[0]
00461     ,&B, &info
00462     );
00463 
00464   TEST_FOR_EXCEPTION(
00465     info != 0, std::runtime_error
00466     ,"SuperLUSolverImpl::solve(...): Error, dgssv(...) returned info = " << info
00467     );
00468 
00469 }
00470 
00471 // private
00472 
00473 void SuperLUSolverImpl::copy_basis_nonzeros(
00474   int                             m_orig
00475   ,int                            n_orig
00476   ,int                            nz_orig
00477   ,const double                   a_orig_val[]
00478   ,const int                      a_orig_row_i[]
00479   ,const int                      a_orig_col_ptr[]
00480   ,const int                      a_orig_perm_r[]
00481   ,const int                      a_orig_perm_c[]
00482   ,const int                      rank
00483   ,double                         b_val[]
00484   ,int                            b_row_i[]
00485   ,int                            b_col_ptr[]
00486   ,int                            *b_nz
00487   ) const
00488 {
00489   *b_nz = 0;
00490   b_col_ptr[0] = *b_nz;
00491   for( int j = 0; j < rank; ++j ) {
00492     const int col_start_k = a_orig_col_ptr[j];
00493     const int col_end_k   = a_orig_col_ptr[j+1];
00494     for( int k = col_start_k; k < col_end_k; ++k ) {
00495       const int i_orig = a_orig_row_i[k];
00496       if(i_orig < rank) {
00497         b_val[*b_nz]     = a_orig_val[k];
00498         b_row_i[*b_nz]   = a_orig_row_i[k];
00499         ++(*b_nz);
00500       }
00501     }
00502     b_col_ptr[j+1] = *b_nz;
00503   }
00504 }
00505 
00506 } // end namespace SuperLUPack
00507 
00508 #endif // SPARSE_SOLVER_PACK_USE_SUPERLU

Generated on Wed May 12 21:52:27 2010 for MOOCHO (Single Doxygen Collection) by  doxygen 1.4.7