SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Fernando José Iglesias García 00008 * Written (W) 2010,2012 Soeren Sonnenburg 00009 * Copyright (C) 2010 Berlin Institute of Technology 00010 * Copyright (C) 2012 Soeren Sonnenburg 00011 */ 00012 00013 #ifndef __SGSPARSEMATRIX_H__ 00014 #define __SGSPARSEMATRIX_H__ 00015 00016 #include <shogun/lib/common.h> 00017 #include <shogun/lib/DataType.h> 00018 #include <shogun/lib/SGSparseVector.h> 00019 #include <shogun/lib/SGReferencedData.h> 00020 #include <shogun/io/LibSVMFile.h> 00021 00022 namespace shogun 00023 { 00024 00025 template <class T> class SGSparseVector; 00026 template<class T> class SGMatrix; 00027 class CFile; 00028 class CLibSVMFile; 00029 class CRegressionLabels; 00030 00032 template <class T> class SGSparseMatrix : public SGReferencedData 00033 { 00034 public: 00036 SGSparseMatrix(); 00037 00039 SGSparseMatrix(SGSparseVector<T>* vecs, index_t num_feat, 00040 index_t num_vec, bool ref_counting=true); 00041 00043 SGSparseMatrix(index_t num_feat, index_t num_vec, bool ref_counting=true); 00044 00049 SGSparseMatrix(SGMatrix<T> dense); 00050 00052 SGSparseMatrix(const SGSparseMatrix &orig); 00053 00055 virtual ~SGSparseMatrix(); 00056 00058 inline const SGSparseVector<T>& operator[](index_t index) const 00059 { 00060 return sparse_matrix[index]; 00061 } 00062 00064 inline SGSparseVector<T>& operator[](index_t index) 00065 { 00066 return sparse_matrix[index]; 00067 } 00068 00074 inline SGSparseMatrix<T> get() 00075 { 00076 return *this; 00077 } 00078 00083 const SGVector<T> operator*(SGVector<T> v) const 00084 { 00085 SGVector<T> result(num_vectors); 00086 REQUIRE(v.vlen==num_features, 00087 "Dimension mismatch! %d vs %d\n", 00088 v.vlen, num_features); 00089 for (index_t i=0; i<num_vectors; ++i) 00090 result[i]=sparse_matrix[i].dense_dot(1.0, v.vector, v.vlen, 0.0); 00091 00092 return result; 00093 } 00094 00099 template<class ST> const SGVector<T> operator*(SGVector<ST> v) const; 00100 00105 inline const T operator()(index_t i_row, index_t i_col) const 00106 { 00107 REQUIRE(i_row>=0, "index %d negative!\n", i_row); 00108 REQUIRE(i_col>=0, "index %d negative!\n", i_col); 00109 REQUIRE(i_row<num_vectors, "index should be less than %d, %d provided!\n", 00110 num_vectors, i_row); 00111 REQUIRE(i_col<num_features, "index should be less than %d, %d provided!\n", 00112 num_features, i_col); 00113 00114 for (index_t i=0; i<sparse_matrix[i_row].num_feat_entries; ++i) 00115 { 00116 if (i_col==sparse_matrix[i_row].features[i].feat_index) 00117 return sparse_matrix[i_row].features[i].entry; 00118 } 00119 return 0; 00120 } 00121 00126 inline T& operator()(index_t i_row, index_t i_col) 00127 { 00128 REQUIRE(i_row>=0, "index %d negative!\n", i_row); 00129 REQUIRE(i_col>=0, "index %d negative!\n", i_col); 00130 REQUIRE(i_row<num_vectors, "index should be less than %d, %d provided!\n", 00131 num_vectors, i_row); 00132 REQUIRE(i_col<num_features, "index should be less than %d, %d provided!\n", 00133 num_features, i_col); 00134 00135 for (index_t i=0; i<sparse_matrix[i_row].num_feat_entries; ++i) 00136 { 00137 if (i_col==sparse_matrix[i_row].features[i].feat_index) 00138 return sparse_matrix[i_row].features[i].entry; 00139 } 00140 index_t j=sparse_matrix[i_row].num_feat_entries; 00141 sparse_matrix[i_row].num_feat_entries=j+1; 00142 sparse_matrix[i_row].features=SG_REALLOC(SGSparseVectorEntry<T>, 00143 sparse_matrix[i_row].features, j, j+1); 00144 sparse_matrix[i_row].features[j].feat_index=i_col; 00145 sparse_matrix[i_row].features[j].entry=static_cast<T>(0); 00146 00147 return sparse_matrix[i_row].features[j].entry; 00148 } 00149 00154 void load(CFile* loader); 00155 00163 SGVector<float64_t> load_with_labels(CLibSVMFile* libsvm_file, bool do_sort_features=true); 00164 00169 void save(CFile* saver); 00170 00176 void save_with_labels(CLibSVMFile* saver, SGVector<float64_t> labels); 00177 00179 SGSparseMatrix<T> get_transposed(); 00180 00185 void from_dense(SGMatrix<T> full); 00186 00188 void sort_features(); 00189 00190 protected: 00191 00193 virtual void copy_data(const SGReferencedData& orig); 00194 00196 virtual void init_data(); 00197 00199 virtual void free_data(); 00200 00201 public: 00202 00204 index_t num_vectors; 00205 00207 index_t num_features; 00208 00210 SGSparseVector<T>* sparse_matrix; 00211 00212 }; 00213 } 00214 #endif // __SGSPARSEMATRIX_H__