EpetraExt_MatrixMatrix.cpp

Go to the documentation of this file.
00001 //@HEADER
00002 // ***********************************************************************
00003 // 
00004 //     EpetraExt: Epetra Extended - Linear Algebra Services Package
00005 //                 Copyright (2001) Sandia Corporation
00006 // 
00007 // Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
00008 // license for use of this work by or on behalf of the U.S. Government.
00009 // 
00010 // This library is free software; you can redistribute it and/or modify
00011 // it under the terms of the GNU Lesser General Public License as
00012 // published by the Free Software Foundation; either version 2.1 of the
00013 // License, or (at your option) any later version.
00014 //  
00015 // This library is distributed in the hope that it will be useful, but
00016 // WITHOUT ANY WARRANTY; without even the implied warranty of
00017 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00018 // Lesser General Public License for more details.
00019 //  
00020 // You should have received a copy of the GNU Lesser General Public
00021 // License along with this library; if not, write to the Free Software
00022 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
00023 // USA
00024 // Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
00025 // 
00026 // ***********************************************************************
00027 //@HEADER
00028 
00029 #include <EpetraExt_ConfigDefs.h>
00030 #include <EpetraExt_MatrixMatrix.h>
00031 
00032 #include <EpetraExt_Transpose_RowMatrix.h>
00033 
00034 #include <Epetra_Export.h>
00035 #include <Epetra_Import.h>
00036 #include <Epetra_Util.h>
00037 #include <Epetra_Map.h>
00038 #include <Epetra_Comm.h>
00039 #include <Epetra_CrsMatrix.h>
00040 #include <Epetra_Directory.h>
00041 #include <Epetra_HashTable.h>
00042 #include <Epetra_Distributor.h>
00043 
00044 #ifdef HAVE_VECTOR
00045 #include <vector>
00046 #endif
00047 
00048 namespace EpetraExt {
00049 
00050 //
00051 //Method for internal use... sparsedot forms a dot-product between two
00052 //sparsely-populated 'vectors'.
00053 //Important assumption: assumes the indices in u_ind and v_ind are sorted.
00054 //
00055 double sparsedot(double* u, int* u_ind, int u_len,
00056      double* v, int* v_ind, int v_len)
00057 {
00058   double result = 0.0;
00059 
00060   int v_idx = 0;
00061   int u_idx = 0;
00062 
00063   while(v_idx < v_len && u_idx < u_len) {
00064     int ui = u_ind[u_idx];
00065     int vi = v_ind[v_idx];
00066 
00067     if (ui < vi) {
00068       ++u_idx;
00069     }
00070     else if (ui > vi) {
00071       ++v_idx;
00072     }
00073     else {
00074       result += u[u_idx++]*v[v_idx++];
00075     }
00076   }
00077 
00078   return(result);
00079 }
00080 
00081 //struct that holds views of the contents of a CrsMatrix. These
00082 //contents may be a mixture of local and remote rows of the
00083 //original matrix. 
00084 class CrsMatrixStruct {
00085 public:
00086   CrsMatrixStruct()
00087     : numRows(0), numEntriesPerRow(NULL), indices(NULL), values(NULL),
00088       remote(NULL), numRemote(0), rowMap(NULL), colMap(NULL),
00089       domainMap(NULL), importColMap(NULL), importMatrix(NULL)
00090   {}
00091 
00092   virtual ~CrsMatrixStruct()
00093   {
00094     deleteContents();
00095   }
00096 
00097   void deleteContents()
00098   {
00099     numRows = 0;
00100     delete [] numEntriesPerRow; numEntriesPerRow = NULL;
00101     delete [] indices; indices = NULL;
00102     delete [] values; values = NULL;
00103     delete [] remote; remote = NULL;
00104     numRemote = 0;
00105     delete importMatrix;
00106   }
00107 
00108   int numRows;
00109   int* numEntriesPerRow;
00110   int** indices;
00111   double** values;
00112   bool* remote;
00113   int numRemote;
00114   const Epetra_Map* origRowMap;
00115   const Epetra_Map* rowMap;
00116   const Epetra_Map* colMap;
00117   const Epetra_Map* domainMap;
00118   const Epetra_Map* importColMap;
00119   Epetra_CrsMatrix* importMatrix;
00120 };
00121 
00122 int dumpCrsMatrixStruct(const CrsMatrixStruct& M)
00123 {
00124   cout << "proc " << M.rowMap->Comm().MyPID()<<endl;
00125   cout << "numRows: " << M.numRows<<endl;
00126   for(int i=0; i<M.numRows; ++i) {
00127     for(int j=0; j<M.numEntriesPerRow[i]; ++j) {
00128       if (M.remote[i]) {
00129   cout << "  *"<<M.rowMap->GID(i)<<"   "
00130        <<M.importColMap->GID(M.indices[i][j])<<"   "<<M.values[i][j]<<endl;
00131       }
00132       else {
00133   cout << "   "<<M.rowMap->GID(i)<<"   "
00134        <<M.colMap->GID(M.indices[i][j])<<"   "<<M.values[i][j]<<endl;
00135       }
00136     }
00137   }
00138   return(0);
00139 }
00140 
00141 //kernel method for computing the local portion of C = A*B
00142 int mult_A_B(CrsMatrixStruct& Aview,
00143        CrsMatrixStruct& Bview,
00144        Epetra_CrsMatrix& C)
00145 {
00146   int C_firstCol = Bview.colMap->MinLID();
00147   int C_lastCol = Bview.colMap->MaxLID();
00148 
00149   int C_firstCol_import = 0;
00150   int C_lastCol_import = -1;
00151 
00152   int* bcols = Bview.colMap->MyGlobalElements();
00153   int* bcols_import = NULL;
00154   if (Bview.importColMap != NULL) {
00155     C_firstCol_import = Bview.importColMap->MinLID();
00156     C_lastCol_import = Bview.importColMap->MaxLID();
00157 
00158     bcols_import = Bview.importColMap->MyGlobalElements();
00159   }
00160 
00161   int C_numCols = C_lastCol - C_firstCol + 1;
00162   int C_numCols_import = C_lastCol_import - C_firstCol_import + 1;
00163 
00164   if (C_numCols_import > C_numCols) C_numCols = C_numCols_import;
00165   double* dwork = new double[C_numCols];
00166   int* iwork = new int[C_numCols];
00167 
00168   double* C_row_i = dwork;
00169   int* C_cols = iwork;
00170 
00171   int C_row_i_length, i, j, k;
00172 
00173   //To form C = A*B we're going to execute this expression:
00174   //
00175   // C(i,j) = sum_k( A(i,k)*B(k,j) )
00176   //
00177   //Our goal, of course, is to navigate the data in A and B once, without
00178   //performing searches for column-indices, etc.
00179 
00180   bool C_filled = C.Filled();
00181 
00182   //loop over the rows of A.
00183   for(i=0; i<Aview.numRows; ++i) {
00184 
00185     //only navigate the local portion of Aview... (It's probable that we
00186     //imported more of A than we need for A*B, because other cases like A^T*B 
00187     //need the extra rows.)
00188     if (Aview.remote[i]) {
00189       continue;
00190     }
00191 
00192     int* Aindices_i = Aview.indices[i];
00193     double* Aval_i  = Aview.values[i];
00194 
00195     int global_row = Aview.rowMap->GID(i);
00196 
00197     //loop across the i-th row of A and for each corresponding row
00198     //in B, loop across colums and accumulate product
00199     //A(i,k)*B(k,j) into our partial sum quantities C_row_i. In other words,
00200     //as we stride across B(k,:) we're calculating updates for row i of the
00201     //result matrix C.
00202 
00203     for(k=0; k<Aview.numEntriesPerRow[i]; ++k) {
00204       int Ak = Bview.rowMap->LID(Aview.colMap->GID(Aindices_i[k]));
00205       double Aval = Aval_i[k];
00206 
00207       int* Bcol_inds = Bview.indices[Ak];
00208       double* Bvals_k = Bview.values[Ak];
00209 
00210       C_row_i_length = 0;
00211 
00212       if (Bview.remote[Ak]) {
00213   for(j=0; j<Bview.numEntriesPerRow[Ak]; ++j) {
00214     C_row_i[C_row_i_length] = Aval*Bvals_k[j];
00215           C_cols[C_row_i_length++] = bcols_import[Bcol_inds[j]];
00216   }
00217       }
00218       else {
00219   for(j=0; j<Bview.numEntriesPerRow[Ak]; ++j) {
00220     C_row_i[C_row_i_length] = Aval*Bvals_k[j];
00221           C_cols[C_row_i_length++] = bcols[Bcol_inds[j]];
00222   }
00223       }
00224 
00225       //
00226       //Now put the C_row_i values into C.
00227       //
00228 
00229       int err = C_filled ?
00230           C.SumIntoGlobalValues(global_row, C_row_i_length, C_row_i, C_cols)
00231           :
00232           C.InsertGlobalValues(global_row, C_row_i_length, C_row_i, C_cols);
00233  
00234       if (err < 0) {
00235         return(err);
00236       }
00237       if (err > 0) {
00238         if (C_filled) {
00239           //C is Filled, and doesn't
00240           //have all the necessary nonzero locations.
00241           return(err);
00242         }
00243       }
00244     }
00245   }
00246 
00247   delete [] dwork;
00248   delete [] iwork;
00249 
00250   return(0);
00251 }
00252 
00253 //kernel method for computing the local portion of C = A*B^T
00254 int mult_A_Btrans(CrsMatrixStruct& Aview,
00255       CrsMatrixStruct& Bview,
00256       Epetra_CrsMatrix& C)
00257 {
00258   int i, j, k;
00259   int returnValue = 0;
00260 
00261   int maxlen = 0;
00262   for(i=0; i<Aview.numRows; ++i) {
00263     if (Aview.numEntriesPerRow[i] > maxlen) maxlen = Aview.numEntriesPerRow[i];
00264   }
00265   for(i=0; i<Bview.numRows; ++i) {
00266     if (Bview.numEntriesPerRow[i] > maxlen) maxlen = Bview.numEntriesPerRow[i];
00267   }
00268 
00269   //cout << "Aview: " << endl;
00270   //dumpCrsMatrixStruct(Aview);
00271 
00272   //cout << "Bview: " << endl;
00273   //dumpCrsMatrixStruct(Bview);
00274 
00275   int numBcols = Bview.colMap->NumMyElements();
00276   int numBrows = Bview.numRows;
00277 
00278   int iworklen = maxlen*2 + numBcols;
00279   int* iwork = new int[iworklen];
00280 
00281   int* bcols = iwork+maxlen*2;
00282   int* bgids = Bview.colMap->MyGlobalElements();
00283   double* bvals = new double[maxlen*2];
00284   double* avals = bvals+maxlen;
00285 
00286   int max_all_b = Bview.colMap->MaxAllGID();
00287   int min_all_b = Bview.colMap->MinAllGID();
00288 
00289   //bcols will hold the GIDs from B's column-map for fast access
00290   //during the computations below
00291   for(i=0; i<numBcols; ++i) {
00292     int blid = Bview.colMap->LID(bgids[i]);
00293     bcols[blid] = bgids[i];
00294   }
00295 
00296   //next create arrays indicating the first and last column-index in
00297   //each row of B, so that we can know when to skip certain rows below.
00298   //This will provide a large performance gain for banded matrices, and
00299   //a somewhat smaller gain for *most* other matrices.
00300   int* b_firstcol = new int[2*numBrows];
00301   int* b_lastcol = b_firstcol+numBrows;
00302   int temp;
00303   for(i=0; i<numBrows; ++i) {
00304     b_firstcol[i] = max_all_b;
00305     b_lastcol[i] = min_all_b;
00306 
00307     int Blen_i = Bview.numEntriesPerRow[i];
00308     if (Blen_i < 1) continue;
00309     int* Bindices_i = Bview.indices[i];
00310 
00311     if (Bview.remote[i]) {
00312       for(k=0; k<Blen_i; ++k) {
00313         temp = Bview.importColMap->GID(Bindices_i[k]);
00314         if (temp < b_firstcol[i]) b_firstcol[i] = temp;
00315         if (temp > b_lastcol[i]) b_lastcol[i] = temp;
00316       }
00317     }
00318     else {
00319       for(k=0; k<Blen_i; ++k) {
00320         temp = bcols[Bindices_i[k]];
00321         if (temp < b_firstcol[i]) b_firstcol[i] = temp;
00322         if (temp > b_lastcol[i]) b_lastcol[i] = temp;
00323       }
00324     }
00325   }
00326 
00327   Epetra_Util util;
00328 
00329   int* Aind = iwork;
00330   int* Bind = iwork+maxlen;
00331 
00332   //To form C = A*B^T, we're going to execute this expression:
00333   //
00334   // C(i,j) = sum_k( A(i,k)*B(j,k) )
00335   //
00336   //This is the easiest case of all to code (easier than A*B, A^T*B, A^T*B^T).
00337   //But it requires the use of a 'sparsedot' function (we're simply forming
00338   //dot-products with row A_i and row B_j for all i and j).
00339 
00340   //loop over the rows of A.
00341   for(i=0; i<Aview.numRows; ++i) {
00342     if (Aview.remote[i]) {
00343       continue;
00344     }
00345 
00346     int* Aindices_i = Aview.indices[i];
00347     double* Aval_i  = Aview.values[i];
00348     int A_len_i = Aview.numEntriesPerRow[i];
00349     if (A_len_i < 1) {
00350       continue;
00351     }
00352 
00353     for(k=0; k<A_len_i; ++k) {
00354       Aind[k] = Aview.colMap->GID(Aindices_i[k]);
00355       avals[k] = Aval_i[k];
00356     }
00357 
00358     util.Sort(true, A_len_i, Aind, 1, &avals, 0, NULL);
00359 
00360     int mina = Aind[0];
00361     int maxa = Aind[A_len_i-1];
00362 
00363     if (mina > max_all_b || maxa < min_all_b) {
00364       continue;
00365     }
00366 
00367     int global_row = Aview.rowMap->GID(i);
00368 
00369     //loop over the rows of B and form results C_ij = dot(A(i,:),B(j,:))
00370     for(j=0; j<Bview.numRows; ++j) {
00371       if (b_firstcol[j] > maxa || b_lastcol[j] < mina) {
00372         continue;
00373       }
00374 
00375       int* Bindices_j = Bview.indices[j];
00376       int B_len_j = Bview.numEntriesPerRow[j];
00377       if (B_len_j < 1) {
00378         continue;
00379       }
00380 
00381       int tmp, Blen = 0;
00382 
00383       if (Bview.remote[j]) {
00384         for(k=0; k<B_len_j; ++k) {
00385     tmp = Bview.importColMap->GID(Bindices_j[k]);
00386           if (tmp < mina || tmp > maxa) {
00387             continue;
00388           }
00389 
00390           bvals[Blen] = Bview.values[j][k];
00391           Bind[Blen++] = tmp;
00392   }
00393       }
00394       else {
00395         for(k=0; k<B_len_j; ++k) {
00396     tmp = bcols[Bindices_j[k]];
00397           if (tmp < mina || tmp > maxa) {
00398             continue;
00399           }
00400 
00401           bvals[Blen] = Bview.values[j][k];
00402           Bind[Blen++] = tmp;
00403   }
00404       }
00405 
00406       if (Blen < 1) {
00407         continue;
00408       }
00409 
00410       util.Sort(true, Blen, Bind, 1, &bvals, 0, NULL);
00411 
00412       double C_ij = sparsedot(avals, Aind, A_len_i,
00413             bvals, Bind, Blen);
00414 
00415       if (C_ij == 0.0) {
00416   continue;
00417       }
00418       int global_col = Bview.rowMap->GID(j);
00419 
00420       int err = C.SumIntoGlobalValues(global_row, 1, &C_ij, &global_col);
00421       if (err < 0) {
00422   return(err);
00423       }
00424       if (err > 0) {
00425   err = C.InsertGlobalValues(global_row, 1, &C_ij, &global_col);
00426   if (err < 0) {
00427     //If we jump out here, it means C.Filled()==true, and C doesn't
00428     //have all the necessary nonzero locations, or that global_row
00429     //or global_col is out of range (less than 0 or non local).
00430     return(err);
00431   }
00432         if (err == 2) {
00433           cerr << "EpetraExt::MatrixMatrix::Multiply Warning: failed to insert"
00434               << " value in result matrix at position "<<global_row<<","
00435               <<global_col<<", possibly because result matrix has a column-map"
00436               <<" that doesn't include column "<<global_col<<" on this proc."
00437               <<endl;
00438           returnValue = err;
00439         }
00440       }
00441     }
00442   }
00443 
00444   delete [] iwork;
00445   delete [] bvals;
00446   delete [] b_firstcol;
00447 
00448   return(returnValue);
00449 }
00450 
00451 //kernel method for computing the local portion of C = A^T*B
00452 int mult_Atrans_B(CrsMatrixStruct& Aview,
00453       CrsMatrixStruct& Bview,
00454       Epetra_CrsMatrix& C)
00455 {
00456   int C_firstCol = Bview.colMap->MinLID();
00457   int C_lastCol = Bview.colMap->MaxLID();
00458 
00459   int C_firstCol_import = 0;
00460   int C_lastCol_import = -1;
00461 
00462   if (Bview.importColMap != NULL) {
00463     C_firstCol_import = Bview.importColMap->MinLID();
00464     C_lastCol_import = Bview.importColMap->MaxLID();
00465   }
00466 
00467   int C_numCols = C_lastCol - C_firstCol + 1;
00468   int C_numCols_import = C_lastCol_import - C_firstCol_import + 1;
00469 
00470   if (C_numCols_import > C_numCols) C_numCols = C_numCols_import;
00471 
00472   double* C_row_i = new double[C_numCols];
00473   int* C_colInds = new int[C_numCols];
00474 
00475   int i, j, k;
00476 
00477   for(j=0; j<C_numCols; ++j) {
00478     C_row_i[j] = 0.0;
00479     C_colInds[j] = 0;
00480   }
00481 
00482   //To form C = A^T*B, compute a series of outer-product updates.
00483   //
00484   // for (ith column of A^T) { 
00485   //   C_i = outer product of A^T(:,i) and B(i,:)
00486   // Where C_i is the ith matrix update,
00487   //       A^T(:,i) is the ith column of A^T, and
00488   //       B(i,:) is the ith row of B.
00489   //
00490 
00491   //dumpCrsMatrixStruct(Aview);
00492   //dumpCrsMatrixStruct(Bview);
00493   int localProc = Bview.colMap->Comm().MyPID();
00494 
00495   int* Arows = Aview.rowMap->MyGlobalElements();
00496 
00497   bool C_filled = C.Filled();
00498 
00499   //loop over the rows of A (which are the columns of A^T).
00500   for(i=0; i<Aview.numRows; ++i) {
00501 
00502     int* Aindices_i = Aview.indices[i];
00503     double* Aval_i  = Aview.values[i];
00504 
00505     //we'll need to get the row of B corresponding to Arows[i],
00506     //where Arows[i] is the GID of A's ith row.
00507     int Bi = Bview.rowMap->LID(Arows[i]);
00508     if (Bi<0) {
00509       cout << "mult_Atrans_B ERROR, proc "<<localProc<<" needs row "
00510      <<Arows[i]<<" of matrix B, but doesn't have it."<<endl;
00511       return(-1);
00512     }
00513 
00514     int* Bcol_inds = Bview.indices[Bi];
00515     double* Bvals_i = Bview.values[Bi];
00516 
00517     //for each column-index Aj in the i-th row of A, we'll update
00518     //global-row GID(Aj) of the result matrix C. In that row of C,
00519     //we'll update column-indices given by the column-indices in the
00520     //ith row of B that we're now holding (Bcol_inds).
00521 
00522     //First create a list of GIDs for the column-indices
00523     //that we'll be updating.
00524 
00525     int Blen = Bview.numEntriesPerRow[Bi];
00526     if (Bview.remote[Bi]) {
00527       for(j=0; j<Blen; ++j) {
00528         C_colInds[j] = Bview.importColMap->GID(Bcol_inds[j]);
00529       }
00530     }
00531     else {
00532       for(j=0; j<Blen; ++j) {
00533         C_colInds[j] = Bview.colMap->GID(Bcol_inds[j]);
00534       }
00535     }
00536 
00537     //loop across the i-th row of A (column of A^T)
00538     for(j=0; j<Aview.numEntriesPerRow[i]; ++j) {
00539 
00540       int Aj = Aindices_i[j];
00541       double Aval = Aval_i[j];
00542 
00543       int global_row;
00544       if (Aview.remote[i]) {
00545   global_row = Aview.importColMap->GID(Aj);
00546       }
00547       else {
00548   global_row = Aview.colMap->GID(Aj);
00549       }
00550 
00551       if (!C.RowMap().MyGID(global_row)) {
00552         continue;
00553       }
00554 
00555       for(k=0; k<Blen; ++k) {
00556         C_row_i[k] = Aval*Bvals_i[k];
00557       }
00558 
00559       //
00560       //Now add this row-update to C.
00561       //
00562 
00563       int err = C_filled ?
00564         C.SumIntoGlobalValues(global_row, Blen, C_row_i, C_colInds)
00565         :
00566         C.InsertGlobalValues(global_row, Blen, C_row_i, C_colInds);
00567 
00568       if (err < 0) {
00569         return(err);
00570       }
00571       if (err > 0) {
00572         if (C_filled) {
00573           //C is Filled, and doesn't have all the necessary nonzero locations.
00574           return(err);
00575         }
00576       }
00577     }
00578   }
00579 
00580   delete [] C_row_i;
00581   delete [] C_colInds;
00582 
00583   return(0);
00584 }
00585 
00586 //kernel method for computing the local portion of C = A^T*B^T
00587 int mult_Atrans_Btrans(CrsMatrixStruct& Aview,
00588            CrsMatrixStruct& Bview,
00589            Epetra_CrsMatrix& C)
00590 {
00591   int C_firstCol = Aview.rowMap->MinLID();
00592   int C_lastCol = Aview.rowMap->MaxLID();
00593 
00594   int C_firstCol_import = 0;
00595   int C_lastCol_import = -1;
00596 
00597   if (Aview.importColMap != NULL) {
00598     C_firstCol_import = Aview.importColMap->MinLID();
00599     C_lastCol_import = Aview.importColMap->MaxLID();
00600   }
00601 
00602   int C_numCols = C_lastCol - C_firstCol + 1;
00603   int C_numCols_import = C_lastCol_import - C_firstCol_import + 1;
00604 
00605   double* dwork = new double[C_numCols+C_numCols_import];
00606 
00607   double* C_col_j = dwork;
00608 
00609   double* C_col_j_import = dwork+C_numCols;
00610 
00611   //cout << "Aview: " << endl;
00612   //dumpCrsMatrixStruct(Aview);
00613 
00614   //cout << "Bview: " << endl;
00615   //dumpCrsMatrixStruct(Bview);
00616 
00617 
00618   int i, j, k;
00619 
00620   for(j=0; j<C_numCols; ++j) {
00621     C_col_j[j] = 0.0;
00622   }
00623 
00624   for(j=0; j<C_numCols_import; ++j) {
00625     C_col_j_import[j] = 0.0;
00626   }
00627 
00628   const Epetra_Map* Crowmap = &(C.RowMap());
00629 
00630   //To form C = A^T*B^T, we're going to execute this expression:
00631   //
00632   // C(i,j) = sum_k( A(k,i)*B(j,k) )
00633   //
00634   //Our goal, of course, is to navigate the data in A and B once, without
00635   //performing searches for column-indices, etc. In other words, we avoid
00636   //column-wise operations like the plague...
00637 
00638   int* Brows = Bview.rowMap->MyGlobalElements();
00639 
00640   std::vector<int> inds_i;
00641   inds_i.reserve(C_numCols);
00642 
00643   std::vector<int> inds_i_import;
00644   inds_i_import.reserve(C_numCols_import);
00645 
00646   //loop across the rows of B
00647   for(j=0; j<Bview.numRows; ++j) {
00648     int* Bindices_j = Bview.indices[j];
00649     double* Bvals_j = Bview.values[j];
00650 
00651     int global_col = Brows[j];
00652 
00653     //loop across columns in the j-th row of B and for each corresponding
00654     //row in A, loop across columns and accumulate product
00655     //A(k,i)*B(j,k) into our partial sum quantities in C_col_j. In other
00656     //words, as we stride across B(j,:), we use selected rows in A to
00657     //calculate updates for column j of the result matrix C.
00658 
00659     for(k=0; k<Bview.numEntriesPerRow[j]; ++k) {
00660       inds_i.resize(0);
00661       inds_i_import.resize(0);
00662 
00663       int bk = Bindices_j[k];
00664       double Bval = Bvals_j[k];
00665 
00666       int global_k;
00667       if (Bview.remote[j]) {
00668   global_k = Bview.importColMap->GID(bk);
00669       }
00670       else {
00671   global_k = Bview.colMap->GID(bk);
00672       }
00673 
00674       //get the corresponding row in A
00675       int ak = Aview.rowMap->LID(global_k);
00676       if (ak<0) {
00677   continue;
00678       }
00679 
00680       int* Aindices_k = Aview.indices[ak];
00681       double* Avals_k = Aview.values[ak];
00682 
00683       if (Aview.remote[ak]) {
00684   for(i=0; i<Aview.numEntriesPerRow[ak]; ++i) {
00685     int loc = Aindices_k[i] - C_firstCol_import;
00686     C_col_j_import[loc] += Avals_k[i]*Bval;
00687           inds_i_import.push_back(loc);
00688   }
00689       }
00690       else {
00691   for(i=0; i<Aview.numEntriesPerRow[ak]; ++i) {
00692     int loc = Aindices_k[i] - C_firstCol;
00693     C_col_j[loc] += Avals_k[i]*Bval;
00694           inds_i.push_back(loc);
00695   }
00696       }
00697 
00698       //Now loop across the C_col_j values and put non-zeros into C.
00699 
00700       std::vector<int>::const_iterator
00701         it = inds_i.begin(),
00702         it_end = inds_i.end();
00703 
00704       for(; it != it_end; ++it) {
00705         i = *it;
00706   if (C_col_j[i] == 0.0) continue;
00707 
00708   int global_row = Aview.colMap->GID(C_firstCol+i);
00709   if (!Crowmap->MyGID(global_row)) {
00710     continue;
00711   }
00712 
00713   int err = C.SumIntoGlobalValues(global_row, 1, &(C_col_j[i]),
00714           &global_col);
00715   if (err < 0) {
00716     return(err);
00717   }
00718   if (err > 0) {
00719     err = C.InsertGlobalValues(global_row, 1, &(C_col_j[i]),
00720              &global_col);
00721     if (err < 0) {
00722       return(err);
00723     }
00724   }
00725 
00726   C_col_j[i] = 0.0;
00727       }
00728 
00729       std::vector<int>::const_iterator
00730         iit = inds_i_import.begin(),
00731         iit_end = inds_i_import.end();
00732 
00733       for(; iit != iit_end; ++iit) {
00734         i = *iit;
00735   if (C_col_j_import[i] == 0.0) continue;
00736 
00737   int global_row = Aview.importColMap->GID(C_firstCol_import + i);
00738   if (!Crowmap->MyGID(global_row)) {
00739     continue;
00740   }
00741 
00742   int err = C.SumIntoGlobalValues(global_row, 1, &(C_col_j_import[i]),
00743           &global_col);
00744   if (err < 0) {
00745     return(err);
00746   }
00747   if (err > 0) {
00748     err = C.InsertGlobalValues(global_row, 1, &(C_col_j_import[i]),
00749              &global_col);
00750     if (err < 0) {
00751       return(err);
00752     }
00753   }
00754 
00755   C_col_j_import[i] = 0.0;
00756       }
00757     }
00758   }
00759 
00760   delete [] dwork;
00761 
00762   return(0);
00763 }
00764 
00765 int import_and_extract_views(const Epetra_CrsMatrix& M,
00766            const Epetra_Map& targetMap,
00767            CrsMatrixStruct& Mview)
00768 {
00769   //The goal of this method is to populate the 'Mview' struct with views of the
00770   //rows of M, including all rows that correspond to elements in 'targetMap'.
00771   //
00772   //If targetMap includes local elements that correspond to remotely-owned rows
00773   //of M, then those remotely-owned rows will be imported into
00774   //'Mview.importMatrix', and views of them will be included in 'Mview'.
00775 
00776   Mview.deleteContents();
00777 
00778   const Epetra_Map& Mrowmap = M.RowMap();
00779 
00780   int numProcs = Mrowmap.Comm().NumProc();
00781 
00782   Mview.numRows = targetMap.NumMyElements();
00783 
00784   int* Mrows = targetMap.MyGlobalElements();
00785 
00786   if (Mview.numRows > 0) {
00787     Mview.numEntriesPerRow = new int[Mview.numRows];
00788     Mview.indices = new int*[Mview.numRows];
00789     Mview.values = new double*[Mview.numRows];
00790     Mview.remote = new bool[Mview.numRows];
00791   }
00792 
00793   Mview.numRemote = 0;
00794 
00795   int i;
00796   for(i=0; i<Mview.numRows; ++i) {
00797     int mlid = Mrowmap.LID(Mrows[i]);
00798     if (mlid < 0) {
00799       Mview.remote[i] = true;
00800       ++Mview.numRemote;
00801     }
00802     else {
00803       EPETRA_CHK_ERR( M.ExtractMyRowView(mlid, Mview.numEntriesPerRow[i],
00804            Mview.values[i], Mview.indices[i]) );
00805       Mview.remote[i] = false;
00806     }
00807   }
00808 
00809   Mview.origRowMap = &(M.RowMap());
00810   Mview.rowMap = &targetMap;
00811   Mview.colMap = &(M.ColMap());
00812   Mview.domainMap = &(M.DomainMap());
00813   Mview.importColMap = NULL;
00814 
00815   if (numProcs < 2) {
00816     if (Mview.numRemote > 0) {
00817       cerr << "EpetraExt::MatrixMatrix::Multiply ERROR, numProcs < 2 but "
00818      << "attempting to import remote matrix rows."<<endl;
00819       return(-1);
00820     }
00821 
00822     //If only one processor we don't need to import any remote rows, so return.
00823     return(0);
00824   }
00825 
00826   //
00827   //Now we will import the needed remote rows of M, if the global maximum
00828   //value of numRemote is greater than 0.
00829   //
00830 
00831   int globalMaxNumRemote = 0;
00832   Mrowmap.Comm().MaxAll(&Mview.numRemote, &globalMaxNumRemote, 1);
00833 
00834   if (globalMaxNumRemote > 0) {
00835     //Create a map that describes the remote rows of M that we need.
00836 
00837     int* MremoteRows = Mview.numRemote>0 ? new int[Mview.numRemote] : NULL;
00838     int offset = 0;
00839     for(i=0; i<Mview.numRows; ++i) {
00840       if (Mview.remote[i]) {
00841   MremoteRows[offset++] = Mrows[i];
00842       }
00843     }
00844 
00845     Epetra_Map MremoteRowMap(-1, Mview.numRemote, MremoteRows,
00846            Mrowmap.IndexBase(), Mrowmap.Comm());
00847 
00848     //Create an importer with target-map MremoteRowMap and 
00849     //source-map Mrowmap.
00850     Epetra_Import importer(MremoteRowMap, Mrowmap);
00851 
00852     //Now create a new matrix into which we can import the remote rows of M
00853     //that we need.
00854     Mview.importMatrix = new Epetra_CrsMatrix(Copy, MremoteRowMap, 1);
00855 
00856     EPETRA_CHK_ERR( Mview.importMatrix->Import(M, importer, Insert) );
00857 
00858     EPETRA_CHK_ERR( Mview.importMatrix->FillComplete(M.DomainMap(), M.RangeMap()) );
00859 
00860     //Finally, use the freshly imported data to fill in the gaps in our views
00861     //of rows of M.
00862     for(i=0; i<Mview.numRows; ++i) {
00863       if (Mview.remote[i]) {
00864   int importLID = MremoteRowMap.LID(Mrows[i]);
00865   EPETRA_CHK_ERR( Mview.importMatrix->ExtractMyRowView(importLID,
00866               Mview.numEntriesPerRow[i],
00867               Mview.values[i],
00868               Mview.indices[i]) );
00869       }
00870     }
00871 
00872     Mview.importColMap = &(Mview.importMatrix->ColMap());
00873 
00874     delete [] MremoteRows;
00875   }
00876 
00877   return(0);
00878 }
00879 
00880 int distribute_list(const Epetra_Comm& Comm,
00881                     int lenSendList,
00882                     const int* sendList,
00883                     int& maxSendLen,
00884                     int*& recvList)
00885 {
00886   maxSendLen = 0; 
00887   Comm.MaxAll(&lenSendList, &maxSendLen, 1);
00888   int numProcs = Comm.NumProc();
00889   recvList = new int[numProcs*maxSendLen];
00890   int* send = new int[maxSendLen];
00891   for(int i=0; i<lenSendList; ++i) {
00892     send[i] = sendList[i];
00893   }
00894 
00895   Comm.GatherAll(send, recvList, maxSendLen);
00896   delete [] send;
00897 
00898   return(0);
00899 }
00900 
00901 Epetra_Map* create_map_from_imported_rows(const Epetra_Map* map,
00902             int totalNumSend,
00903             int* sendRows,
00904             int numProcs,
00905             int* numSendPerProc)
00906 {
00907   //Perform sparse all-to-all communication to send the row-GIDs
00908   //in sendRows to appropriate processors according to offset
00909   //information in numSendPerProc.
00910   //Then create and return a map containing the rows that we
00911   //received on the local processor.
00912 
00913   Epetra_Distributor* distributor = map->Comm().CreateDistributor();
00914 
00915   int* sendPIDs = totalNumSend>0 ? new int[totalNumSend] : NULL;
00916   int offset = 0;
00917   for(int i=0; i<numProcs; ++i) {
00918     for(int j=0; j<numSendPerProc[i]; ++j) {
00919       sendPIDs[offset++] = i;
00920     }
00921   }
00922 
00923   int numRecv = 0;
00924   int err = distributor->CreateFromSends(totalNumSend, sendPIDs,
00925            true, numRecv);
00926   assert( err == 0 );
00927 
00928   char* c_recv_objs = numRecv>0 ? new char[numRecv*sizeof(int)] : NULL;
00929   int num_c_recv = numRecv*(int)sizeof(int);
00930 
00931   err = distributor->Do(reinterpret_cast<char*>(sendRows),
00932       (int)sizeof(int), num_c_recv, c_recv_objs);
00933   assert( err == 0 );
00934 
00935   int* recvRows = reinterpret_cast<int*>(c_recv_objs);
00936 
00937   //Now create a map with the rows we've received from other processors.
00938   Epetra_Map* import_rows = new Epetra_Map(-1, numRecv, recvRows,
00939              map->IndexBase(), map->Comm());
00940 
00941   delete [] c_recv_objs;
00942   delete [] sendPIDs;
00943 
00944   delete distributor;
00945 
00946   return( import_rows );
00947 }
00948 
00949 int form_map_union(const Epetra_Map* map1,
00950        const Epetra_Map* map2,
00951        const Epetra_Map*& mapunion)
00952 {
00953   //form the union of two maps
00954 
00955   if (map1 == NULL) {
00956     mapunion = new Epetra_Map(*map2);
00957     return(0);
00958   }
00959 
00960   if (map2 == NULL) {
00961     mapunion = new Epetra_Map(*map1);
00962     return(0);
00963   }
00964 
00965   int map1_len       = map1->NumMyElements();
00966   int* map1_elements = map1->MyGlobalElements();
00967   int map2_len       = map2->NumMyElements();
00968   int* map2_elements = map2->MyGlobalElements();
00969 
00970   int* union_elements = new int[map1_len+map2_len];
00971 
00972   int map1_offset = 0, map2_offset = 0, union_offset = 0;
00973 
00974   while(map1_offset < map1_len && map2_offset < map2_len) {
00975     int map1_elem = map1_elements[map1_offset];
00976     int map2_elem = map2_elements[map2_offset];
00977 
00978     if (map1_elem < map2_elem) {
00979       union_elements[union_offset++] = map1_elem;
00980       ++map1_offset;
00981     }
00982     else if (map1_elem > map2_elem) {
00983       union_elements[union_offset++] = map2_elem;
00984       ++map2_offset;
00985     }
00986     else {
00987       union_elements[union_offset++] = map1_elem;
00988       ++map1_offset;
00989       ++map2_offset;
00990     }
00991   }
00992 
00993   int i;
00994   for(i=map1_offset; i<map1_len; ++i) {
00995     union_elements[union_offset++] = map1_elements[i];
00996   }
00997 
00998   for(i=map2_offset; i<map2_len; ++i) {
00999     union_elements[union_offset++] = map2_elements[i];
01000   }
01001 
01002   mapunion = new Epetra_Map(-1, union_offset, union_elements,
01003           map1->IndexBase(), map1->Comm());
01004 
01005   delete [] union_elements;
01006 
01007   return(0);
01008 }
01009 
01010 Epetra_Map* find_rows_containing_cols(const Epetra_CrsMatrix& M,
01011                                       const Epetra_Map* colmap)
01012 {
01013   //The goal of this function is to find all rows in the matrix M that contain
01014   //column-indices which are in 'colmap'. A map containing those rows is
01015   //returned.
01016 
01017   int numProcs = colmap->Comm().NumProc();
01018   int localProc = colmap->Comm().MyPID();
01019 
01020   if (numProcs < 2) {
01021     Epetra_Map* result_map = NULL;
01022 
01023     int err = form_map_union(&(M.RowMap()), NULL,
01024                              (const Epetra_Map*&)result_map);
01025     if (err != 0) {
01026       return(NULL);
01027     }
01028     return(result_map);
01029   }
01030 
01031   int MnumRows = M.NumMyRows();
01032   int numCols = colmap->NumMyElements();
01033 
01034   int* iwork = new int[numCols+2*numProcs+numProcs*MnumRows];
01035   int iworkOffset = 0;
01036 
01037   int* cols = &(iwork[iworkOffset]); iworkOffset += numCols;
01038 
01039   cols[0] = numCols;
01040   colmap->MyGlobalElements( &(cols[1]) );
01041 
01042   //cols are not necessarily sorted at this point, so we'll make sure
01043   //they are sorted.
01044   Epetra_Util util;
01045   util.Sort(true, numCols, &(cols[1]), 0, NULL, 0, NULL);
01046 
01047   int* all_proc_cols = NULL;
01048   
01049   int max_num_cols;
01050   distribute_list(colmap->Comm(), numCols+1, cols, max_num_cols, all_proc_cols);
01051 
01052   const Epetra_CrsGraph& Mgraph = M.Graph();
01053   const Epetra_Map& Mrowmap = M.RowMap();
01054   const Epetra_Map& Mcolmap = M.ColMap();
01055   int MminMyLID = Mrowmap.MinLID();
01056 
01057   int* procNumCols = &(iwork[iworkOffset]); iworkOffset += numProcs;
01058   int* procNumRows = &(iwork[iworkOffset]); iworkOffset += numProcs;
01059   int* procRows_1D = &(iwork[iworkOffset]);
01060   int** procCols = new int*[numProcs];
01061   int** procRows = new int*[numProcs];
01062   int i, err;
01063   int offset = 0;
01064   for(i=0; i<numProcs; ++i) {
01065     procNumCols[i] = all_proc_cols[offset];
01066     procCols[i] = &(all_proc_cols[offset+1]);
01067     offset += max_num_cols;
01068 
01069     procNumRows[i] = 0;
01070     procRows[i] = &(procRows_1D[i*MnumRows]);
01071   }
01072 
01073   int* Mindices;
01074 
01075   for(int row=0; row<MnumRows; ++row) {
01076     int localRow = MminMyLID+row;
01077     int globalRow = Mrowmap.GID(localRow);
01078     int MnumCols;
01079     err = Mgraph.ExtractMyRowView(localRow, MnumCols, Mindices);
01080     if (err != 0) {
01081       cerr << "proc "<<localProc<<", error in Mgraph.ExtractMyRowView, row "
01082            <<localRow<<endl;
01083       return(NULL);
01084     }
01085 
01086     for(int j=0; j<MnumCols; ++j) {
01087       int colGID = Mcolmap.GID(Mindices[j]);
01088 
01089       for(int p=0; p<numProcs; ++p) {
01090         if (p==localProc) continue;
01091 
01092         int insertPoint;
01093         int foundOffset = Epetra_Util_binary_search(colGID, procCols[p],
01094                                                     procNumCols[p], insertPoint);
01095         if (foundOffset > -1) {
01096           int numRowsP = procNumRows[p];
01097           int* prows = procRows[p];
01098           if (numRowsP < 1 || prows[numRowsP-1] < globalRow) {
01099             prows[numRowsP] = globalRow;
01100             procNumRows[p]++;
01101           }
01102         }
01103       }
01104     }
01105   }
01106 
01107   //Now make the contents of procRows occupy a contiguous section
01108   //of procRows_1D.
01109   offset = procNumRows[0];
01110   for(i=1; i<numProcs; ++i) {
01111     for(int j=0; j<procNumRows[i]; ++j) {
01112       procRows_1D[offset++] = procRows[i][j];
01113     }
01114   }
01115 
01116   int totalNumSend = offset;
01117   //Next we will do a sparse all-to-all communication to send the lists of rows
01118   //to the appropriate processors, and create a map with the rows we've received
01119   //from other processors.
01120   Epetra_Map* recvd_rows =
01121     create_map_from_imported_rows(&Mrowmap, totalNumSend,
01122                                   procRows_1D, numProcs, procNumRows);
01123 
01124   Epetra_Map* result_map = NULL;
01125 
01126   err = form_map_union(&(M.RowMap()), recvd_rows, (const Epetra_Map*&)result_map);
01127   if (err != 0) {
01128     return(NULL);
01129   }
01130 
01131   delete [] iwork;
01132   delete [] procCols;
01133   delete [] procRows;
01134   delete [] all_proc_cols;
01135   delete recvd_rows;
01136 
01137   return(result_map);
01138 }
01139 
01140 int MatrixMatrix::Multiply(const Epetra_CrsMatrix& A,
01141          bool transposeA,
01142          const Epetra_CrsMatrix& B,
01143          bool transposeB,
01144          Epetra_CrsMatrix& C)
01145 {
01146   //
01147   //This method forms the matrix-matrix product C = op(A) * op(B), where
01148   //op(A) == A   if transposeA is false,
01149   //op(A) == A^T if transposeA is true,
01150   //and similarly for op(B).
01151   //
01152 
01153   //A and B should already be Filled.
01154   //(Should we go ahead and call FillComplete() on them if necessary?
01155   // or error out? For now, we choose to error out.)
01156   if (!A.Filled() || !B.Filled()) {
01157     EPETRA_CHK_ERR(-1);
01158   }
01159 
01160   //We're going to refer to the different combinations of op(A) and op(B)
01161   //as scenario 1 through 4.
01162 
01163   int scenario = 1;//A*B
01164   if (transposeB && !transposeA) scenario = 2;//A*B^T
01165   if (transposeA && !transposeB) scenario = 3;//A^T*B
01166   if (transposeA && transposeB)  scenario = 4;//A^T*B^T
01167 
01168   //now check size compatibility
01169   int Aouter = transposeA ? A.NumGlobalCols() : A.NumGlobalRows();
01170   int Bouter = transposeB ? B.NumGlobalRows() : B.NumGlobalCols();
01171   int Ainner = transposeA ? A.NumGlobalRows() : A.NumGlobalCols();
01172   int Binner = transposeB ? B.NumGlobalCols() : B.NumGlobalRows();
01173   if (Ainner != Binner) {
01174     cerr << "MatrixMatrix::Multiply: ERROR, inner dimensions of op(A) and op(B) "
01175          << "must match for matrix-matrix product. op(A) is "
01176          <<Aouter<<"x"<<Ainner << ", op(B) is "<<Binner<<"x"<<Bouter<<endl;
01177     return(-1);
01178   }
01179 
01180   //The result matrix C must at least have a row-map that reflects the
01181   //correct row-size. Don't check the number of columns because rectangular
01182   //matrices which were constructed with only one map can still end up
01183   //having the correct capacity and dimensions when filled.
01184   if (Aouter > C.NumGlobalRows()) {
01185     cerr << "MatrixMatrix::Multiply: ERROR, dimensions of result C must "
01186          << "match dimensions of op(A) * op(B). C has "<<C.NumGlobalRows()
01187          << " rows, should have at least "<<Aouter << endl;
01188     return(-1);
01189   }
01190 
01191   //It doesn't matter whether C is already Filled or not. If it is already
01192   //Filled, it must have space allocated for the positions that will be
01193   //referenced in forming C = op(A)*op(B). If it doesn't have enough space,
01194   //we'll error out later when trying to store result values.
01195 
01196   //We're going to need to import remotely-owned sections of A and/or B
01197   //if more than 1 processor is performing this run, depending on the scenario.
01198   int numProcs = A.Comm().NumProc();
01199 
01200   //If we are to use the transpose of A and/or B, we'll need to be able to 
01201   //access, on the local processor, all rows that contain column-indices in
01202   //the domain-map.
01203   const Epetra_Map* domainMap_A = &(A.DomainMap());
01204   const Epetra_Map* domainMap_B = &(B.DomainMap());
01205 
01206   const Epetra_Map* rowmap_A = &(A.RowMap());
01207   const Epetra_Map* rowmap_B = &(B.RowMap());
01208 
01209   //Declare some 'work-space' maps which may be created depending on
01210   //the scenario, and which will be deleted before exiting this function.
01211   const Epetra_Map* workmap1 = NULL;
01212   const Epetra_Map* workmap2 = NULL;
01213   const Epetra_Map* mapunion1 = NULL;
01214 
01215   //Declare a couple of structs that will be used to hold views of the data
01216   //of A and B, to be used for fast access during the matrix-multiplication.
01217   CrsMatrixStruct Aview;
01218   CrsMatrixStruct Bview;
01219 
01220   const Epetra_Map* targetMap_A = rowmap_A;
01221   const Epetra_Map* targetMap_B = rowmap_B;
01222 
01223   if (numProcs > 1) {
01224     //If op(A) = A^T, find all rows of A that contain column-indices in the
01225     //local portion of the domain-map. (We'll import any remote rows
01226     //that fit this criteria onto the local processor.)
01227     if (transposeA) {
01228       workmap1 = find_rows_containing_cols(A, domainMap_A);
01229       targetMap_A = workmap1;
01230     }
01231   }
01232 
01233   //Now import any needed remote rows and populate the Aview struct.
01234   EPETRA_CHK_ERR( import_and_extract_views(A, *targetMap_A, Aview) );
01235 
01236   //We will also need local access to all rows of B that correspond to the
01237   //column-map of op(A).
01238   if (numProcs > 1) {
01239     const Epetra_Map* colmap_op_A = NULL;
01240     if (transposeA) {
01241       colmap_op_A = targetMap_A;
01242     }
01243     else {
01244       colmap_op_A = &(A.ColMap());
01245     }
01246 
01247     targetMap_B = colmap_op_A;
01248 
01249     //If op(B) = B^T, find all rows of B that contain column-indices in the
01250     //local-portion of the domain-map, or in the column-map of op(A).
01251     //We'll import any remote rows that fit this criteria onto the
01252     //local processor.
01253     if (transposeB) {
01254       EPETRA_CHK_ERR( form_map_union(colmap_op_A, domainMap_B, mapunion1) );
01255       workmap2 = find_rows_containing_cols(B, mapunion1);
01256       targetMap_B = workmap2;
01257     }
01258   }
01259 
01260   //Now import any needed remote rows and populate the Bview struct.
01261   EPETRA_CHK_ERR( import_and_extract_views(B, *targetMap_B, Bview) );
01262 
01263   //zero the result matrix before we start the calculations.
01264   EPETRA_CHK_ERR( C.PutScalar(0.0) );
01265 
01266 
01267   //Now call the appropriate method to perform the actual multiplication.
01268 
01269   switch(scenario) {
01270   case 1:    EPETRA_CHK_ERR( mult_A_B(Aview, Bview, C) );
01271     break;
01272   case 2:    EPETRA_CHK_ERR( mult_A_Btrans(Aview, Bview, C) );
01273     break;
01274   case 3:    EPETRA_CHK_ERR( mult_Atrans_B(Aview, Bview, C) );
01275     break;
01276   case 4:    EPETRA_CHK_ERR( mult_Atrans_Btrans(Aview, Bview, C) );
01277     break;
01278   }
01279 
01280 
01281   //We'll call FillComplete on the C matrix before we exit, and give
01282   //it a domain-map and a range-map.
01283   //The domain-map will be the domain-map of B, unless
01284   //op(B)==transpose(B), in which case the range-map of B will be used.
01285   //The range-map will be the range-map of A, unless
01286   //op(A)==transpose(A), in which case the domain-map of A will be used.
01287 
01288   const Epetra_Map* domainmap =
01289     transposeB ? &(B.RangeMap()) : &(B.DomainMap());
01290 
01291   const Epetra_Map* rangemap =
01292     transposeA ? &(A.DomainMap()) : &(A.RangeMap());
01293 
01294   if (!C.Filled()) {
01295     EPETRA_CHK_ERR( C.FillComplete(*domainmap, *rangemap) );
01296   }
01297 
01298 
01299   //Finally, delete the objects that were potentially created
01300   //during the course of importing remote sections of A and B.
01301 
01302   delete mapunion1; mapunion1 = NULL;
01303   delete workmap1; workmap1 = NULL;
01304   delete workmap2; workmap2 = NULL;
01305 
01306   return(0);
01307 }
01308 
01309 int MatrixMatrix::Add(const Epetra_CrsMatrix& A,
01310                       bool transposeA,
01311                       double scalarA,
01312                       Epetra_CrsMatrix& B,
01313                       double scalarB )
01314 {
01315   //
01316   //This method forms the matrix-matrix sum B = scalarA * op(A) + scalarB * B, where
01317 
01318   //A should already be Filled. It doesn't matter whether B is
01319   //already Filled, but if it is, then its graph must already contain
01320   //all nonzero locations that will be referenced in forming the
01321   //sum.
01322 
01323   if (!A.Filled() ) EPETRA_CHK_ERR(-1);
01324 
01325   //explicit tranpose A formed as necessary
01326   Epetra_CrsMatrix * Aprime = 0;
01327   EpetraExt::RowMatrix_Transpose * Atrans = 0;
01328   if( transposeA )
01329   {
01330     Atrans = new EpetraExt::RowMatrix_Transpose();
01331     Aprime = &(dynamic_cast<Epetra_CrsMatrix&>(((*Atrans)(const_cast<Epetra_CrsMatrix&>(A)))));
01332   }
01333   else
01334     Aprime = const_cast<Epetra_CrsMatrix*>(&A);
01335 
01336   //Initialize if B already filled
01337   if( B.Filled() )
01338     EPETRA_CHK_ERR( B.Scale( scalarB ) );
01339 
01340   //Loop over B's rows and sum into
01341   int MaxNumEntries = EPETRA_MAX( A.MaxNumEntries(), B.MaxNumEntries() );
01342   int NumEntries;
01343   int * Indices = new int[MaxNumEntries];
01344   double * Values = new double[MaxNumEntries];
01345 
01346   int NumMyRows = B.NumMyRows();
01347   int Row, err;
01348 
01349   if( scalarA )
01350   {
01351     for( int i = 0; i < NumMyRows; ++i )
01352     {
01353       Row = B.GRID(i);
01354       EPETRA_CHK_ERR( A.ExtractGlobalRowCopy( Row, MaxNumEntries, NumEntries, Values, Indices ) );
01355       if( scalarA != 1.0 )
01356         for( int j = 0; j < NumEntries; ++j ) Values[j] *= scalarA;
01357       if( B.Filled() ) {//Sum In Values
01358         err = B.SumIntoGlobalValues( Row, NumEntries, Values, Indices );
01359         assert( err == 0 );
01360       }
01361       else {
01362         err = B.InsertGlobalValues( Row, NumEntries, Values, Indices );
01363         assert( err == 0 || err == 1 || err == 3 );
01364       }
01365     }
01366   }
01367 
01368   delete [] Indices;
01369   delete [] Values;
01370 
01371   if( Atrans ) delete Atrans;
01372 
01373   if( !B.Filled() ) 
01374     EPETRA_CHK_ERR( B.FillComplete() );
01375 
01376   return(0);
01377 }
01378 
01379 } // namespace EpetraExt
01380 

Generated on Thu Sep 18 12:31:43 2008 for EpetraExt by doxygen 1.3.9.1