SparseMatrix.h
Upload User: hkdiguang
Upload Date: 2013-05-12
Package Size: 105k
Code Size: 6k
Development Platform:

Unix_Linux

  1. // -*- c++ -*-
  2. // Copyright (C) 2003 Suvrit Sra (suvrit@cs.utexas.edu)
  3. // This program is free software; you can redistribute it and/or
  4. // modify it under the terms of the GNU General Public License
  5. // as published by the Free Software Foundation; either version 2
  6. // of the License, or (at your option) any later version.
  7. // This program is distributed in the hope that it will be useful,
  8. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10. // GNU General Public License for more details.
  11. // You should have received a copy of the GNU General Public License
  12. // along with this program; if not, write to the Free Software
  13. // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
  14. #ifndef SPARSEMATRIX_H
  15. #define SPARSEMATRIX_H
  16. #include <iostream>
  17. #include <fstream>
  18. #include <string>
  19. #include <gsl/gsl_vector.h>
  20. #include <gsl/gsl_matrix.h>
  21. #include "objective.h"
  22. /**
  23.  * CLASS: SparseMatrix
  24.  * Author: Suvrit Sra
  25.  * Implemented to deal with CCS files that we use a lot here (cs.utexas)
  26.  * IDEA: I want to implement a function called 'apply' for the sparsematrix
  27.  * that allows a sort of extension to the sparsematrix functionality while
  28.  * still not sacrificing efficiency....
  29.  * for e.g. this->apply(compute_obj, arg) will apply the function
  30.  * compute_obj to each element of the sparse matrix. This is a fairly
  31.  * restricted function but can be useful without exposing all the details
  32.  * of the sparsematrix or without coming back and modifying the code of the
  33.  * sparsematrix class once it is stable. Still this idea will pose some
  34.  * implementation challenges so i will defer it for now.
  35.  *
  36.  */
  37. class SparseMatrix {
  38. private:
  39.   int m_rows, m_cols;
  40.   /* number of non-zeros */
  41.   int m_nz;
  42.      
  43.   /* file name associated with this matrix */
  44.   std::string fname;
  45.   std::string txx;
  46.      
  47.   /* The actual non-zeroes themselves */
  48.   double* m_val;
  49.   /* Colptrs for CCs structure */
  50.   long* m_colptrs;
  51.   
  52.   /* Row indices for CCs structure */
  53.   long* m_rowindx;
  54.   
  55.   /* Is the data in externally allocated arrays */
  56.   bool m_is_external;
  57.   /* Hack needed to do correct comp. of object. fn*/
  58.   bool first_time;
  59.   /* We compute norm only once to save time */
  60.   double fnormA;
  61.   /* Private function to carry out computation of fnorm of matrx */
  62.   void compute_fnorm();
  63.   bool   fnorm_avail; // is it avail or needs to be computed.
  64. public:
  65.   SparseMatrix() {
  66.     m_rows = m_cols = m_nz = 0; first_time = false; fnorm_avail = false;
  67.   }
  68.   SparseMatrix(int r, int c, int n) { 
  69.     //assert (r > 0 and c > 0 and n >= 0);
  70.     m_rows = r; 
  71.     m_cols = c;
  72.     m_nz   = n;
  73.     // Allocate the arrays
  74.     m_colptrs = new long[c+1];
  75.     m_rowindx = new long[n];
  76.     m_val     = new double[n];
  77.     m_is_external = false;
  78.     first_time = false;
  79.     fnorm_avail = false;
  80.     fname = "";
  81.     txx   = "";
  82.   }
  83.   
  84.   //SparseMatrix (int* c, int* r, double* v, int ro, int co, int nz) :
  85.   //  m_colptrs(c), m_rowindx(r), m_val(v),
  86.   //  fname("factorization"), txx(""), m_is_external(true)
  87.   //{ m_rows = ro; m_cols = co; m_nz = nz; }
  88.   ~SparseMatrix() {
  89.     //cout << "~SparseMatrix()" << endl;
  90.     delete[] m_colptrs;
  91.     delete[] m_rowindx;
  92.     delete[] m_val;
  93.   }
  94.   SparseMatrix* clone();
  95.   bool isExternal() const { return m_is_external;}
  96.   int  read_ccs_file(char*);
  97.   void makefull(gsl_matrix*);
  98.   // This operator multiplies This . x 
  99.   gsl_vector* operator * (const gsl_vector* x);
  100.   //  v'col dot prod.
  101.   double dot(int col, gsl_vector* v);
  102.   // do a saxpy
  103.   void saxpy(double alpha, int col, gsl_vector* v);
  104.   // This does a dot product of col i with col j
  105.   double col_dotprod(int i, int j);
  106.   // This computes the cosine between col i and and col j.
  107.   double col_cosine(int i, int j);
  108.   // This computes the cosine of col i with input vector v
  109.   double col_cosine2(int i, double* v);
  110.   // Calculate the norm of column (i)
  111.   double col_norm(int i);
  112.   
  113.   // return the frob norm of matrix
  114.   double fnorm()  { if (fnorm_avail) return fnormA; else compute_fnorm(); return fnormA;}
  115.   double col_diff(int i, int j);
  116.   double col_delta(int, double*);
  117.   std::string getFileName() const { return fname;}
  118.   std::string getTxx() const {return txx;}
  119.   // Fill input vector with column of A
  120.   void getcol(int, double*);
  121.   void output_matlab(std::ofstream&, float);
  122.   // Getting bulky with all these functions....well well
  123.   double compute_obj(gsl_matrix*, tObjtype);
  124.   // This calculates the transpose of this. Not yet implemented.
  125.   SparseMatrix transpose();
  126.      
  127.   // This multiplies A^T . x
  128.   gsl_vector* tran_times (const gsl_vector* x);
  129.   // This calculates the matrix product this . B
  130.   SparseMatrix& operator * (const SparseMatrix& B);
  131.   // Write out matrix to output stream
  132.   void print(std::ostream&);
  133.   void printmat(std::ostream&);
  134.   void dotdiv(SparseMatrix* l, SparseMatrix* r);
  135.   void dotdiv(gsl_matrix*, SparseMatrix* r);
  136.   // this * gsl_matrix
  137.   void lmult(gsl_matrix*, gsl_matrix*,
  138.     bool trans1=false, bool trans2=false); 
  139.   // gsl_matrix * this
  140.   void rmult(gsl_matrix*, gsl_matrix*,
  141.     bool trans1=false, bool trans2=false);
  142.   inline void set(int i, int j, double val) {
  143.     for (int t = m_colptrs[j]; t < m_colptrs[j+1]; t++) {
  144.       if (m_rowindx[t] == i) {
  145. m_val[t] = val;
  146. return;
  147.       }
  148.     }
  149.     // If we are here a 0 is being destroyed and that will make life
  150.     // bad. as of now just warn
  151.     std::cerr << "A zero being destroyed at (" << i << ", " << j << ")" << std::endl;
  152.   }
  153.   
  154.   void debuginfo();
  155.   inline double dataAt(int i, int j) const {
  156.     //assert(i >= 0 && i < numRows() && j >= 0 && j < numCols());
  157.     for (int t = m_colptrs[j]; t < m_colptrs[j+1]; t++) {
  158.       if (m_rowindx[t] == i) return m_val[t];
  159.     }
  160.     return 0;
  161.   }
  162.   inline double operator () (int i, int j) const  {
  163.     //assert(i >= 0 && i < numRows() && j >= 0 && j < numCols());
  164.     for (int t = m_colptrs[j]; t < m_colptrs[j+1]; t++)
  165.       if (m_rowindx[t] == i) return m_val[t];
  166.     return 0;
  167.   }
  168.   long* getPointr() { return m_colptrs;}
  169.   long* getIndx() { return m_rowindx;}
  170.   double* getVal() { return m_val;}
  171.   bool getObjFlag() const { return first_time;}
  172.   void setObjFlag(bool f) { first_time = f;}
  173.   int numRows() const { return m_rows;}
  174.   int numCols() const { return m_cols;}
  175.   int numNz  () const { return m_nz;  }
  176.   int read_ccs_file(char* fname, char* txx);
  177. };
  178. #endif