Tpetra Matrix/Vector Services Version of the Day
TpetraExt_MatrixMatrix_def.hpp
Go to the documentation of this file.
00001 // @HEADER
00002 // ***********************************************************************
00003 // 
00004 //          Tpetra: Templated Linear Algebra Services Package
00005 //                 Copyright (2008) Sandia Corporation
00006 // 
00007 // Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
00008 // the U.S. Government retains certain rights in this software.
00009 // 
00010 // Redistribution and use in source and binary forms, with or without
00011 // modification, are permitted provided that the following conditions are
00012 // met:
00013 //
00014 // 1. Redistributions of source code must retain the above copyright
00015 // notice, this list of conditions and the following disclaimer.
00016 //
00017 // 2. Redistributions in binary form must reproduce the above copyright
00018 // notice, this list of conditions and the following disclaimer in the
00019 // documentation and/or other materials provided with the distribution.
00020 //
00021 // 3. Neither the name of the Corporation nor the names of the
00022 // contributors may be used to endorse or promote products derived from
00023 // this software without specific prior written permission.
00024 //
00025 // THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
00026 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00027 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00028 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
00029 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
00030 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
00031 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00032 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
00033 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
00034 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
00035 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00036 //
00037 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00038 // 
00039 // ************************************************************************
00040 // @HEADER
00041 
00042 #ifndef TPETRA_MATRIXMATRIX_DEF_HPP
00043 #define TPETRA_MATRIXMATRIX_DEF_HPP
00044 
00045 #include "TpetraExt_MatrixMatrix_decl.hpp"
00046 #include "Teuchos_VerboseObject.hpp"
00047 #include "Teuchos_Array.hpp"
00048 #include "Tpetra_Util.hpp"
00049 #include "Tpetra_ConfigDefs.hpp"
00050 #include "Tpetra_CrsMatrix.hpp"
00051 #include "TpetraExt_MMHelpers_def.hpp"
00052 #include "Tpetra_RowMatrixTransposer.hpp"
00053 #include "Tpetra_ConfigDefs.hpp"
00054 #include "Tpetra_Map.hpp"
00055 #include <algorithm>
00056 #include "Teuchos_FancyOStream.hpp"
00057 
00058 
00064 namespace Tpetra {
00065 
00066 
00067 namespace MatrixMatrix{
00068 
00069 template <class Scalar, 
00070            class LocalOrdinal,
00071            class GlobalOrdinal,
00072            class Node,
00073            class SpMatOps >
00074 void Multiply(
00075   const CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& A,
00076   bool transposeA,
00077   const CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& B,
00078   bool transposeB,
00079   CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& C,
00080   bool call_FillComplete_on_result)
00081 {
00082   typedef CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps> Matrix_t;
00083   //
00084   //This method forms the matrix-matrix product C = op(A) * op(B), where
00085   //op(A) == A   if transposeA is false,
00086   //op(A) == A^T if transposeA is true,
00087   //and similarly for op(B).
00088   //
00089 
00090   //A and B should already be Filled.
00091   //(Should we go ahead and call FillComplete() on them if necessary?
00092   // or error out? For now, we choose to error out.)
00093   TEUCHOS_TEST_FOR_EXCEPTION(!A.isFillComplete(), std::runtime_error,
00094     "Uh oh. Looks like there's a bit of a problem here. No worries though. We'll help you figure it out. You're "
00095     "a fantastic programer and this just a minor bump in the road! Maybe the information below can help you out a bit."
00096     "\n\n MatrixMatrix::Multiply(): Matrix A is not fill complete.");
00097   TEUCHOS_TEST_FOR_EXCEPTION(!B.isFillComplete(), std::runtime_error,
00098     "Uh oh. Looks like there's a bit of a problem here. No worries though. We'll help you figure it out. You're "
00099     "a fantastic programer and this just a minor bump in the road! Maybe the information below can help you out a bit."
00100     "\n\n MatrixMatrix::Multiply(): Matrix B is not fill complete.");
00101 
00102   //Convience typedefs
00103   typedef CrsMatrixStruct<
00104     Scalar, 
00105     LocalOrdinal,
00106     GlobalOrdinal,
00107     Node,
00108     SpMatOps> CrsMatrixStruct_t;
00109   typedef Map<LocalOrdinal, GlobalOrdinal, Node> Map_t;
00110 
00111   RCP<const Matrix_t > Aprime = null;
00112   RCP<const Matrix_t > Bprime = null;
00113   if(transposeA){
00114     RowMatrixTransposer<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>  at(A);
00115     Aprime = at.createTranspose();
00116   }
00117   else{
00118     Aprime = rcpFromRef(A);
00119   }
00120   if(transposeB){
00121     RowMatrixTransposer<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>  bt(B);
00122     Bprime=bt.createTranspose();
00123   }
00124   else{
00125     Bprime = rcpFromRef(B);
00126   }
00127     
00128 
00129   //now check size compatibility
00130   global_size_t numACols = A.getDomainMap()->getGlobalNumElements();
00131   global_size_t numBCols = B.getDomainMap()->getGlobalNumElements();
00132   global_size_t Aouter = transposeA ? numACols : A.getGlobalNumRows();
00133   global_size_t Bouter = transposeB ? B.getGlobalNumRows() : numBCols;
00134   global_size_t Ainner = transposeA ? A.getGlobalNumRows() : numACols;
00135   global_size_t Binner = transposeB ? numBCols : B.getGlobalNumRows();
00136   TEUCHOS_TEST_FOR_EXCEPTION(!A.isFillComplete(), std::runtime_error,
00137     "MatrixMatrix::Multiply: ERROR, inner dimensions of op(A) and op(B) "
00138     "must match for matrix-matrix product. op(A) is "
00139     <<Aouter<<"x"<<Ainner << ", op(B) is "<<Binner<<"x"<<Bouter<<std::endl);
00140 
00141   //The result matrix C must at least have a row-map that reflects the
00142   //correct row-size. Don't check the number of columns because rectangular
00143   //matrices which were constructed with only one map can still end up
00144   //having the correct capacity and dimensions when filled.
00145   TEUCHOS_TEST_FOR_EXCEPTION(Aouter > C.getGlobalNumRows(), std::runtime_error,
00146     "MatrixMatrix::Multiply: ERROR, dimensions of result C must "
00147     "match dimensions of op(A) * op(B). C has "<<C.getGlobalNumRows()
00148      << " rows, should have at least "<<Aouter << std::endl);
00149 
00150   //It doesn't matter whether C is already Filled or not. If it is already
00151   //Filled, it must have space allocated for the positions that will be
00152   //referenced in forming C = op(A)*op(B). If it doesn't have enough space,
00153   //we'll error out later when trying to store result values.
00154   
00155   // CGB: However, matrix must be in active-fill
00156   TEUCHOS_TEST_FOR_EXCEPT( C.isFillActive() == false );
00157 
00158   //We're going to need to import remotely-owned sections of A and/or B
00159   //if more than 1 processor is performing this run, depending on the scenario.
00160   int numProcs = A.getComm()->getSize();
00161 
00162   //Declare a couple of structs that will be used to hold views of the data
00163   //of A and B, to be used for fast access during the matrix-multiplication.
00164   CrsMatrixStruct_t Aview;
00165   CrsMatrixStruct_t Bview;
00166 
00167   RCP<const Map_t > targetMap_A = Aprime->getRowMap();
00168   RCP<const Map_t > targetMap_B = Bprime->getRowMap();
00169 
00170   //Now import any needed remote rows and populate the Aview struct.
00171   MMdetails::import_and_extract_views(*Aprime, targetMap_A, Aview);
00172  
00173 
00174   //We will also need local access to all rows of B that correspond to the
00175   //column-map of op(A).
00176   if (numProcs > 1) {
00177     targetMap_B = Aprime->getColMap(); //colmap_op_A;
00178   }
00179 
00180   //Now import any needed remote rows and populate the Bview struct.
00181   MMdetails::import_and_extract_views(*Bprime, targetMap_B, Bview);
00182 
00183 
00184   //If the result matrix C is not already FillComplete'd, we will do a
00185   //preprocessing step to create the nonzero structure,
00186   if (!C.isFillComplete()) {
00187     CrsWrapper_GraphBuilder<Scalar, LocalOrdinal, GlobalOrdinal, Node> crsgraphbuilder(C.getRowMap());
00188 
00189     //pass the graph-builder object to the multiplication kernel to fill in all
00190     //the nonzero positions that will be used in the result matrix.
00191     MMdetails::mult_A_B(Aview, Bview, crsgraphbuilder, true);
00192 
00193     //now insert all of the nonzero positions into the result matrix.
00194     insert_matrix_locations(crsgraphbuilder, C);
00195 
00196 
00197     if (call_FillComplete_on_result) {
00198       C.fillComplete(Bprime->getDomainMap(), Aprime->getRangeMap());
00199       call_FillComplete_on_result = false;
00200     }
00201   }
00202 
00203   //Now call the appropriate method to perform the actual multiplication.
00204 
00205   CrsWrapper_CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node> crsmat(C);
00206 
00207   MMdetails::mult_A_B(Aview, Bview, crsmat);
00208 
00209   if (call_FillComplete_on_result) {
00210     //We'll call FillComplete on the C matrix before we exit, and give
00211     //it a domain-map and a range-map.
00212     //The domain-map will be the domain-map of B, unless
00213     //op(B)==transpose(B), in which case the range-map of B will be used.
00214     //The range-map will be the range-map of A, unless
00215     //op(A)==transpose(A), in which case the domain-map of A will be used.
00216     if (!C.isFillComplete()) {
00217       C.fillComplete(Bprime->getDomainMap(), Aprime->getRangeMap());
00218     }
00219   }
00220 
00221 }
00222 
00223 template <class Scalar, 
00224           class LocalOrdinal,
00225           class GlobalOrdinal,
00226           class Node,
00227           class SpMatOps >
00228 void Add(
00229   const CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& A,
00230   bool transposeA,
00231   Scalar scalarA,
00232   CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& B,
00233   Scalar scalarB )
00234 {
00235   TEUCHOS_TEST_FOR_EXCEPTION(!A.isFillComplete(), std::runtime_error,
00236     "MatrixMatrix::Add ERROR, input matrix A.isFillComplete() is false; it is required to be true. (Result matrix B is not required to be isFillComplete()).");
00237   TEUCHOS_TEST_FOR_EXCEPTION(B.isFillComplete() , std::runtime_error,
00238     "MatrixMatrix::Add ERROR, input matrix B must not be fill complete!");
00239   TEUCHOS_TEST_FOR_EXCEPTION(B.getProfileType()!=DynamicProfile, std::runtime_error,
00240     "MatrixMatrix::Add ERROR, input matrix B must have a dynamic profile!");
00241   //Convience typedef
00242   typedef CrsMatrix<
00243     Scalar, 
00244     LocalOrdinal,
00245     GlobalOrdinal,
00246     Node,
00247     SpMatOps> CrsMatrix_t;
00248   RCP<const CrsMatrix_t> Aprime = null;
00249   if( transposeA ){
00250     RowMatrixTransposer<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps> theTransposer(A);
00251     Aprime = theTransposer.createTranspose(DoOptimizeStorage); 
00252   }
00253   else{
00254     Aprime = rcpFromRef(A);
00255   }
00256   size_t a_numEntries;
00257   Array<GlobalOrdinal> a_inds(A.getNodeMaxNumRowEntries());
00258   Array<Scalar> a_vals(A.getNodeMaxNumRowEntries());
00259   GlobalOrdinal row;
00260 
00261   if(scalarB != ScalarTraits<Scalar>::one()){
00262     B.scale(scalarB);
00263   }
00264 
00265   bool bFilled = B.isFillComplete();
00266   size_t numMyRows = B.getNodeNumRows();
00267   if(scalarA != ScalarTraits<Scalar>::zero()){
00268     for(LocalOrdinal i = 0; (size_t)i < numMyRows; ++i){
00269       row = B.getRowMap()->getGlobalElement(i);
00270       Aprime->getGlobalRowCopy(row, a_inds(), a_vals(), a_numEntries);
00271       if(scalarA != ScalarTraits<Scalar>::one()){
00272         for(size_t j =0; j<a_numEntries; ++j){
00273           a_vals[j] *= scalarA;
00274         }
00275       }
00276       if(bFilled){
00277         B.sumIntoGlobalValues(row, a_inds(0,a_numEntries), a_vals(0,a_numEntries));
00278       }
00279       else{
00280         B.insertGlobalValues(row, a_inds(0,a_numEntries), a_vals(0,a_numEntries));
00281       }
00282 
00283     }
00284   }
00285 }
00286 
00287 template <class Scalar, 
00288           class LocalOrdinal,
00289           class GlobalOrdinal,
00290           class Node,
00291           class SpMatOps>
00292 void Add(
00293   const CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& A,
00294   bool transposeA,
00295   Scalar scalarA,
00296   const CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& B,
00297   bool transposeB,
00298   Scalar scalarB,
00299   RCP<CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps> > C)
00300 {
00301   //
00302   //This method forms the matrix-matrix sum C = scalarA * op(A) + scalarB * op(B), where
00303 
00304   //Convience typedef
00305   typedef CrsMatrix<
00306     Scalar, 
00307     LocalOrdinal,
00308     GlobalOrdinal,
00309     Node,
00310     SpMatOps> CrsMatrix_t;
00311 
00312   //A and B should already be Filled. C should be an empty pointer.
00313 
00314 
00315   TEUCHOS_TEST_FOR_EXCEPTION(!A.isFillComplete() || !B.isFillComplete(), std::runtime_error,
00316     "EpetraExt::MatrixMatrix::Add ERROR, input matrix A.Filled() or B.Filled() is false,"
00317     "they are required to be true. (Result matrix C should be an empty pointer)" << std::endl);
00318 
00319 
00320   RCP<const CrsMatrix_t> Aprime = null;
00321   RCP<const CrsMatrix_t> Bprime = null;
00322 
00323 
00324   //explicit tranpose A formed as necessary
00325   if( transposeA ) {
00326     RowMatrixTransposer<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps> theTransposer(A);
00327     Aprime = theTransposer.createTranspose(DoOptimizeStorage);
00328   }
00329   else{
00330     Aprime = rcpFromRef(A);
00331   }
00332 
00333   //explicit tranpose B formed as necessary
00334   if( transposeB ) {
00335     RowMatrixTransposer<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps> theTransposer(B);
00336     Bprime = theTransposer.createTranspose(DoOptimizeStorage);
00337   }
00338   else{
00339     Bprime = rcpFromRef(B);
00340   }
00341 
00342   // allocate or zero the new matrix
00343   if(C != null)
00344      C->setAllToScalar(ScalarTraits<Scalar>::zero());
00345   else
00346     C = rcp(new CrsMatrix_t(Aprime->getRowMap(), null));
00347 
00348   Array<RCP<const CrsMatrix_t> > Mat = 
00349     Teuchos::tuple<RCP<const CrsMatrix_t> >(Aprime, Bprime);
00350   Array<Scalar> scalar = Teuchos::tuple<Scalar>(scalarA, scalarB);
00351 
00352   // do a loop over each matrix to add: A reordering might be more efficient
00353   for(int k=0;k<2;++k) {
00354     size_t NumEntries;
00355     Array<GlobalOrdinal> Indices;
00356     Array<Scalar> Values;
00357    
00358      size_t NumMyRows = Mat[k]->getNodeNumRows();
00359      GlobalOrdinal Row;
00360    
00361      //Loop over rows and sum into C
00362      for( size_t i = OrdinalTraits<size_t>::zero(); i < NumMyRows; ++i ) {
00363         Row = Mat[k]->getRowMap()->getGlobalElement(i);
00364         NumEntries = Mat[k]->getNumEntriesInGlobalRow(Row);
00365         if(NumEntries == OrdinalTraits<global_size_t>::zero()){
00366           continue;
00367         }
00368 
00369         Indices.resize(NumEntries);
00370         Values.resize(NumEntries);
00371         Mat[k]->getGlobalRowCopy(Row, Indices(), Values(), NumEntries);
00372    
00373         if( scalar[k] != ScalarTraits<Scalar>::one() )
00374            for( size_t j = OrdinalTraits<size_t>::zero(); j < NumEntries; ++j ) Values[j] *= scalar[k];
00375    
00376         if(C->isFillComplete()) { // Sum in values
00377            C->sumIntoGlobalValues( Row, Indices, Values);
00378         } else { // just add it to the unfilled CRS Matrix
00379            C->insertGlobalValues( Row, Indices, Values);
00380         }
00381      }
00382   }
00383 }
00384 
00385 } //End namespace MatrixMatrix
00386 
00387 namespace MMdetails{
00388 
00389 
00390 //kernel method for computing the local portion of C = A*B
00391 template<class Scalar, 
00392          class LocalOrdinal, 
00393          class GlobalOrdinal, 
00394          class Node, 
00395          class SpMatOps>
00396 void mult_A_B(
00397   CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& Aview, 
00398   CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& Bview, 
00399   CrsWrapper<Scalar, LocalOrdinal, GlobalOrdinal, Node>& C,
00400   bool onlyCalculateStructure)
00401 {
00402   LocalOrdinal C_firstCol = Bview.colMap->getMinLocalIndex();
00403   LocalOrdinal C_lastCol = Bview.colMap->getMaxLocalIndex();
00404 
00405   LocalOrdinal C_firstCol_import = OrdinalTraits<LocalOrdinal>::zero();
00406   LocalOrdinal C_lastCol_import = OrdinalTraits<LocalOrdinal>::invalid();
00407 
00408   ArrayView<const GlobalOrdinal> bcols =Bview.colMap->getNodeElementList();
00409   ArrayView<const GlobalOrdinal> bcols_import = null;
00410   if (Bview.importColMap != null) {
00411     C_firstCol_import = Bview.importColMap->getMinLocalIndex();
00412     C_lastCol_import = Bview.importColMap->getMaxLocalIndex();
00413 
00414     bcols_import = Bview.importColMap->getNodeElementList();
00415   }
00416 
00417   size_t C_numCols = C_lastCol - C_firstCol + OrdinalTraits<LocalOrdinal>::one();
00418   size_t C_numCols_import = C_lastCol_import - C_firstCol_import + OrdinalTraits<LocalOrdinal>::one();
00419 
00420   if (C_numCols_import > C_numCols) C_numCols = C_numCols_import;
00421 
00422   Array<Scalar> dwork = onlyCalculateStructure ? Array<Scalar>() : Array<Scalar>(C_numCols);
00423   Array<GlobalOrdinal> iwork = Array<GlobalOrdinal>(C_numCols);
00424 
00425   Array<Scalar> C_row_i = dwork;
00426   Array<GlobalOrdinal> C_cols = iwork;
00427 
00428   size_t C_row_i_length, j, k;
00429 
00430   // Run through all the hash table lookups once and for all
00431   Array<LocalOrdinal> Acol2Brow(Aview.colMap->getNodeNumElements());
00432   if(Aview.colMap->isSameAs(*Bview.rowMap)){
00433     // Maps are the same: Use local IDs as the hash
00434     for(LocalOrdinal i=Aview.colMap->getMinLocalIndex();i<=Aview.colMap->getMaxLocalIndex();i++)
00435       Acol2Brow[i]=i;       
00436   }
00437   else {
00438     // Maps are not the same:  Use the map's hash
00439     for(LocalOrdinal i=Aview.colMap->getMinLocalIndex();i<=Aview.colMap->getMaxLocalIndex();i++)
00440       Acol2Brow[i]=Bview.rowMap->getLocalElement(Aview.colMap->getGlobalElement(i));
00441   }
00442 
00443   //To form C = A*B we're going to execute this expression:
00444   //
00445   // C(i,j) = sum_k( A(i,k)*B(k,j) )
00446   //
00447   //Our goal, of course, is to navigate the data in A and B once, without
00448   //performing searches for column-indices, etc.
00449 
00450   bool C_filled = C.isFillComplete();
00451 
00452   //loop over the rows of A.
00453   for(size_t i=0; i<Aview.numRows; ++i) {
00454 
00455     //only navigate the local portion of Aview... (It's probable that we
00456     //imported more of A than we need for A*B, because other cases like A^T*B 
00457     //need the extra rows.)
00458     if (Aview.remote[i]) {
00459       continue;
00460     }
00461 
00462     ArrayView<const LocalOrdinal> Aindices_i = Aview.indices[i];
00463     ArrayView<const Scalar> Aval_i  = onlyCalculateStructure ? null : Aview.values[i];
00464 
00465     GlobalOrdinal global_row = Aview.rowMap->getGlobalElement(i);
00466 
00467 
00468     //loop across the i-th row of A and for each corresponding row
00469     //in B, loop across colums and accumulate product
00470     //A(i,k)*B(k,j) into our partial sum quantities C_row_i. In other words,
00471     //as we stride across B(k,:) we're calculating updates for row i of the
00472     //result matrix C.
00473 
00474 
00475 
00476     for(k=OrdinalTraits<size_t>::zero(); k<Aview.numEntriesPerRow[i]; ++k) {
00477       LocalOrdinal Ak=Acol2Brow[Aindices_i[k]];
00478       Scalar Aval = onlyCalculateStructure ? Teuchos::as<Scalar>(0) : Aval_i[k];
00479 
00480       ArrayView<const LocalOrdinal> Bcol_inds = Bview.indices[Ak];
00481       ArrayView<const Scalar> Bvals_k = onlyCalculateStructure ? null : Bview.values[Ak];
00482 
00483       C_row_i_length = OrdinalTraits<size_t>::zero();
00484 
00485       if (Bview.remote[Ak]) {
00486         for(j=OrdinalTraits<size_t>::zero(); j<Bview.numEntriesPerRow[Ak]; ++j) {
00487           if(!onlyCalculateStructure){
00488             C_row_i[C_row_i_length] = Aval*Bvals_k[j];
00489           }
00490           C_cols[C_row_i_length++] = bcols_import[Bcol_inds[j]];
00491         }
00492       }
00493       else {
00494         for(j=OrdinalTraits<size_t>::zero(); j<Bview.numEntriesPerRow[Ak]; ++j) {
00495           if(!onlyCalculateStructure){
00496             C_row_i[C_row_i_length] = Aval*Bvals_k[j];
00497           }
00498           C_cols[C_row_i_length++] = bcols[Bcol_inds[j]];
00499         }
00500       }
00501 
00502       //
00503       //Now put the C_row_i values into C.
00504       //
00505 
00506       C_filled ?
00507         C.sumIntoGlobalValues(
00508           global_row, 
00509           C_cols.view(OrdinalTraits<size_t>::zero(), C_row_i_length), 
00510           onlyCalculateStructure ? null : C_row_i.view(OrdinalTraits<size_t>::zero(), C_row_i_length))
00511         :
00512         C.insertGlobalValues(
00513           global_row, 
00514           C_cols.view(OrdinalTraits<size_t>::zero(), C_row_i_length), 
00515           onlyCalculateStructure ? null : C_row_i.view(OrdinalTraits<size_t>::zero(), C_row_i_length));
00516 
00517     }
00518   }
00519 
00520 }
00521 
00522 template<class Scalar,
00523          class LocalOrdinal, 
00524          class GlobalOrdinal, 
00525          class Node,
00526          class SpMatOps>
00527 void setMaxNumEntriesPerRow(
00528   CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& Mview)
00529 {
00530   typedef typename Array<ArrayView<const LocalOrdinal> >::size_type  local_length_size;
00531   Mview.maxNumRowEntries = OrdinalTraits<local_length_size>::zero();
00532   if(Mview.indices.size() > OrdinalTraits<local_length_size>::zero() ){
00533     Mview.maxNumRowEntries = Mview.indices[0].size();
00534     for(local_length_size i = 1; i<Mview.indices.size(); ++i){
00535       if(Mview.indices[i].size() > Mview.maxNumRowEntries){
00536         Mview.maxNumRowEntries = Mview.indices[i].size();
00537       }
00538     }
00539   }
00540 }
00541 
00542 template<class Scalar,
00543          class LocalOrdinal, 
00544          class GlobalOrdinal, 
00545          class Node,
00546          class SpMatOps>
00547 void import_and_extract_views(
00548   const CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& M,
00549   RCP<const Map<LocalOrdinal, GlobalOrdinal, Node> > targetMap,
00550   CrsMatrixStruct<Scalar, LocalOrdinal, GlobalOrdinal, Node, SpMatOps>& Mview)
00551 {
00552   //Convience typedef
00553   typedef Map<LocalOrdinal, GlobalOrdinal, Node> Map_t;
00554   // The goal of this method is to populate the 'Mview' struct with views of the
00555   // rows of M, including all rows that correspond to elements in 'targetMap'.
00556   // 
00557   // If targetMap includes local elements that correspond to remotely-owned rows
00558   // of M, then those remotely-owned rows will be imported into
00559   // 'Mview.importMatrix', and views of them will be included in 'Mview'.
00560   Mview.deleteContents();
00561 
00562   RCP<const Map_t> Mrowmap = M.getRowMap();
00563 
00564   const int numProcs = Mrowmap->getComm()->getSize();
00565 
00566   ArrayView<const GlobalOrdinal> Mrows = targetMap->getNodeElementList();
00567 
00568   Mview.numRemote = 0;
00569   Mview.numRows = targetMap->getNodeNumElements();
00570   Mview.numEntriesPerRow.resize(Mview.numRows);
00571   Mview.indices.resize(         Mview.numRows);
00572   Mview.values.resize(          Mview.numRows);
00573   Mview.remote.resize(          Mview.numRows);
00574 
00575 
00576   Mview.origRowMap = M.getRowMap();
00577   Mview.rowMap = targetMap;
00578   Mview.colMap = M.getColMap();
00579   Mview.domainMap = M.getDomainMap();
00580   Mview.importColMap = null;
00581 
00582 
00583   // mark each row in targetMap as local or remote, and go ahead and get a view for the local rows
00584 
00585   for(size_t i=0; i < Mview.numRows; ++i) 
00586   {
00587     const LocalOrdinal mlid = Mrowmap->getLocalElement(Mrows[i]);
00588 
00589     if (mlid == OrdinalTraits<LocalOrdinal>::invalid()) {
00590       Mview.remote[i] = true;
00591       ++Mview.numRemote;
00592     }
00593     else {
00594       Mview.remote[i] = false;
00595       M.getLocalRowView(mlid, Mview.indices[i], Mview.values[i]);
00596       Mview.numEntriesPerRow[i] = Mview.indices[i].size();
00597     }
00598   }
00599 
00600   if (numProcs < 2) {
00601     TEUCHOS_TEST_FOR_EXCEPTION(Mview.numRemote > 0, std::runtime_error,
00602       "MatrixMatrix::import_and_extract_views ERROR, numProcs < 2 but attempting to import remote matrix rows." <<std::endl);
00603     setMaxNumEntriesPerRow(Mview);
00604     //If only one processor we don't need to import any remote rows, so return.
00605     return;
00606   }
00607 
00608   //
00609   // Now we will import the needed remote rows of M, if the global maximum
00610   // value of numRemote is greater than 0.
00611   //
00612 
00613   global_size_t globalMaxNumRemote = 0;
00614   Teuchos::reduceAll(*(Mrowmap->getComm()) , Teuchos::REDUCE_MAX, Mview.numRemote, Teuchos::outArg(globalMaxNumRemote) );
00615 
00616   if (globalMaxNumRemote > 0) {
00617     // Create a map that describes the remote rows of M that we need.
00618 
00619     Array<GlobalOrdinal> MremoteRows(Mview.numRemote);
00620 
00621 
00622     global_size_t offset = 0;
00623     for(size_t i=0; i < Mview.numRows; ++i) {
00624       if (Mview.remote[i]) {
00625         MremoteRows[offset++] = Mrows[i];
00626       }
00627     }
00628 
00629     RCP<const Map_t> MremoteRowMap = rcp(new Map_t(
00630       OrdinalTraits<GlobalOrdinal>::invalid(), 
00631       MremoteRows(), 
00632       Mrowmap->getIndexBase(), 
00633       Mrowmap->getComm(), 
00634       Mrowmap->getNode()));
00635 
00636     // Create an importer with target-map MremoteRowMap and source-map Mrowmap.
00637     Import<LocalOrdinal, GlobalOrdinal, Node> importer(Mrowmap, MremoteRowMap);
00638 
00639     // Now create a new matrix into which we can import the remote rows of M that we need.
00640     Mview.importMatrix = rcp(new CrsMatrix<Scalar,LocalOrdinal, GlobalOrdinal, Node, SpMatOps>( MremoteRowMap, 1 ));
00641     Mview.importMatrix->doImport(M, importer, INSERT);
00642     Mview.importMatrix->fillComplete(M.getDomainMap(), M.getRangeMap());
00643 
00644     // Save the column map of the imported matrix, so that we can convert indices back to global for arithmetic later
00645     Mview.importColMap = Mview.importMatrix->getColMap();
00646 
00647     // Finally, use the freshly imported data to fill in the gaps in our views of rows of M.
00648     for(size_t i=0; i < Mview.numRows; ++i) 
00649     {
00650       if (Mview.remote[i]) {
00651         const LocalOrdinal importLID = MremoteRowMap->getLocalElement(Mrows[i]);
00652         Mview.importMatrix->getLocalRowView(importLID,
00653                                              Mview.indices[i],
00654                                              Mview.values[i]);
00655         Mview.numEntriesPerRow[i] = Mview.indices[i].size();
00656       }
00657     }
00658   }
00659   setMaxNumEntriesPerRow(Mview);
00660 }
00661 
00662 } //End namepsace MMdetails
00663 
00664 } //End namespace Tpetra
00665 //
00666 // Explicit instantiation macro
00667 //
00668 // Must be expanded from within the Tpetra namespace!
00669 //
00670 
00671 #define TPETRA_MATRIXMATRIX_INSTANT(SCALAR,LO,GO,NODE) \
00672   \
00673   template \
00674   void MatrixMatrix::Multiply( \
00675     const CrsMatrix< SCALAR , LO , GO , NODE >& A, \
00676     bool transposeA, \
00677     const CrsMatrix< SCALAR , LO , GO , NODE >& B, \
00678     bool transposeB, \
00679     CrsMatrix< SCALAR , LO , GO , NODE >& C, \
00680     bool call_FillComplete_on_result); \
00681 \
00682   template \
00683   void MatrixMatrix::Add( \
00684     const CrsMatrix< SCALAR , LO , GO , NODE >& A, \
00685     bool transposeA, \
00686     SCALAR scalarA, \
00687     const CrsMatrix< SCALAR , LO , GO , NODE >& B, \
00688     bool transposeB, \
00689     SCALAR scalarB, \
00690     RCP<CrsMatrix< SCALAR , LO , GO , NODE > > C); \
00691   \
00692   template  \
00693   void MatrixMatrix::Add( \
00694     const CrsMatrix<SCALAR, LO, GO, NODE>& A, \
00695     bool transposeA, \
00696     SCALAR scalarA, \
00697     CrsMatrix<SCALAR, LO, GO, NODE>& B, \
00698     SCALAR scalarB ); \
00699   \
00700 
00701 
00702 #endif // TPETRA_MATRIXMATRIX_DEF_HPP
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines