SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Fernando José Iglesias García 00008 * Copyright (C) 2012 Fernando José Iglesias García 00009 */ 00010 00011 #include <shogun/lib/common.h> 00012 00013 #ifdef HAVE_EIGEN3 00014 00015 #include <shogun/multiclass/QDA.h> 00016 #include <shogun/machine/NativeMulticlassMachine.h> 00017 #include <shogun/features/Features.h> 00018 #include <shogun/labels/Labels.h> 00019 #include <shogun/labels/MulticlassLabels.h> 00020 #include <shogun/mathematics/Math.h> 00021 00022 #include <shogun/mathematics/eigen3.h> 00023 00024 using namespace shogun; 00025 using namespace Eigen; 00026 00027 CQDA::CQDA(float64_t tolerance, bool store_covs) 00028 : CNativeMulticlassMachine(), m_tolerance(tolerance), 00029 m_store_covs(store_covs), m_num_classes(0), m_dim(0) 00030 { 00031 init(); 00032 } 00033 00034 CQDA::CQDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance, bool store_covs) 00035 : CNativeMulticlassMachine(), m_tolerance(tolerance), m_store_covs(store_covs), m_num_classes(0), m_dim(0) 00036 { 00037 init(); 00038 set_features(traindat); 00039 set_labels(trainlab); 00040 } 00041 00042 CQDA::~CQDA() 00043 { 00044 SG_UNREF(m_features); 00045 00046 cleanup(); 00047 } 00048 00049 void CQDA::init() 00050 { 00051 SG_ADD(&m_tolerance, "m_tolerance", "Tolerance member.", MS_AVAILABLE); 00052 SG_ADD(&m_store_covs, "m_store_covs", "Store covariances member", MS_NOT_AVAILABLE); 00053 SG_ADD((CSGObject**) &m_features, "m_features", "Feature object.", MS_NOT_AVAILABLE); 00054 SG_ADD(&m_means, "m_means", "Mean vectors list", MS_NOT_AVAILABLE); 00055 SG_ADD(&m_slog, "m_slog", "Vector used in classification", MS_NOT_AVAILABLE); 00056 00057 //TODO include SGNDArray objects for serialization 00058 00059 m_features = NULL; 00060 } 00061 00062 void CQDA::cleanup() 00063 { 00064 m_means=SGMatrix<float64_t>(); 00065 00066 m_num_classes = 0; 00067 } 00068 00069 CMulticlassLabels* CQDA::apply_multiclass(CFeatures* data) 00070 { 00071 if (data) 00072 { 00073 if (!data->has_property(FP_DOT)) 00074 SG_ERROR("Specified features are not of type CDotFeatures\n") 00075 00076 set_features((CDotFeatures*) data); 00077 } 00078 00079 if ( !m_features ) 00080 return NULL; 00081 00082 int32_t num_vecs = m_features->get_num_vectors(); 00083 ASSERT(num_vecs > 0) 00084 ASSERT( m_dim == m_features->get_dim_feature_space() ) 00085 00086 CDenseFeatures< float64_t >* rf = (CDenseFeatures< float64_t >*) m_features; 00087 00088 MatrixXd X(num_vecs, m_dim); 00089 MatrixXd A(num_vecs, m_dim); 00090 VectorXd norm2(num_vecs*m_num_classes); 00091 norm2.setZero(); 00092 00093 int32_t vlen; 00094 bool vfree; 00095 float64_t* vec; 00096 for (int k = 0; k < m_num_classes; k++) 00097 { 00098 // X = features - means 00099 for (int i = 0; i < num_vecs; i++) 00100 { 00101 vec = rf->get_feature_vector(i, vlen, vfree); 00102 ASSERT(vec) 00103 00104 Map< VectorXd > Evec(vec,vlen); 00105 Map< VectorXd > Em_means_col(m_means.get_column_vector(k), m_dim); 00106 00107 X.row(i) = Evec - Em_means_col; 00108 00109 rf->free_feature_vector(vec, i, vfree); 00110 } 00111 00112 Map< MatrixXd > Em_M(m_M.get_matrix(k), m_dim, m_dim); 00113 A = X*Em_M; 00114 00115 for (int i = 0; i < num_vecs; i++) 00116 norm2(i + k*num_vecs) = A.row(i).array().square().sum(); 00117 00118 #ifdef DEBUG_QDA 00119 SG_PRINT("\n>>> Displaying A ...\n") 00120 SGMatrix< float64_t >::display_matrix(A.data(), num_vecs, m_dim); 00121 #endif 00122 } 00123 00124 for (int i = 0; i < num_vecs; i++) 00125 for (int k = 0; k < m_num_classes; k++) 00126 { 00127 norm2[i + k*num_vecs] += m_slog[k]; 00128 norm2[i + k*num_vecs] *= -0.5; 00129 } 00130 00131 #ifdef DEBUG_QDA 00132 SG_PRINT("\n>>> Displaying norm2 ...\n") 00133 SGMatrix< float64_t >::display_matrix(norm2.data(), num_vecs, m_num_classes); 00134 #endif 00135 00136 CMulticlassLabels* out = new CMulticlassLabels(num_vecs); 00137 00138 for (int i = 0 ; i < num_vecs; i++) 00139 out->set_label(i, SGVector<float64_t>::arg_max(norm2.data()+i, num_vecs, m_num_classes)); 00140 00141 return out; 00142 } 00143 00144 bool CQDA::train_machine(CFeatures* data) 00145 { 00146 if (!m_labels) 00147 SG_ERROR("No labels allocated in QDA training\n") 00148 00149 if ( data ) 00150 { 00151 if (!data->has_property(FP_DOT)) 00152 SG_ERROR("Speficied features are not of type CDotFeatures\n") 00153 00154 set_features((CDotFeatures*) data); 00155 } 00156 00157 if (!m_features) 00158 SG_ERROR("No features allocated in QDA training\n") 00159 00160 SGVector< int32_t > train_labels = ((CMulticlassLabels*) m_labels)->get_int_labels(); 00161 00162 if (!train_labels.vector) 00163 SG_ERROR("No train_labels allocated in QDA training\n") 00164 00165 cleanup(); 00166 00167 m_num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00168 m_dim = m_features->get_dim_feature_space(); 00169 int32_t num_vec = m_features->get_num_vectors(); 00170 00171 if (num_vec != train_labels.vlen) 00172 SG_ERROR("Dimension mismatch between features and labels in QDA training") 00173 00174 int32_t* class_idxs = SG_MALLOC(int32_t, num_vec*m_num_classes); // number of examples of each class 00175 int32_t* class_nums = SG_MALLOC(int32_t, m_num_classes); 00176 memset(class_nums, 0, m_num_classes*sizeof(int32_t)); 00177 int32_t class_idx; 00178 00179 for (int i = 0; i < train_labels.vlen; i++) 00180 { 00181 class_idx = train_labels.vector[i]; 00182 00183 if (class_idx < 0 || class_idx >= m_num_classes) 00184 { 00185 SG_ERROR("found label out of {0, 1, 2, ..., num_classes-1}...") 00186 return false; 00187 } 00188 else 00189 { 00190 class_idxs[ class_idx*num_vec + class_nums[class_idx]++ ] = i; 00191 } 00192 } 00193 00194 for (int i = 0; i < m_num_classes; i++) 00195 { 00196 if (class_nums[i] <= 0) 00197 { 00198 SG_ERROR("What? One class with no elements\n") 00199 return false; 00200 } 00201 } 00202 00203 if (m_store_covs) 00204 { 00205 // cov_dims will be free in m_covs.destroy_ndarray() 00206 index_t * cov_dims = SG_MALLOC(index_t, 3); 00207 cov_dims[0] = m_dim; 00208 cov_dims[1] = m_dim; 00209 cov_dims[2] = m_num_classes; 00210 m_covs = SGNDArray< float64_t >(cov_dims, 3); 00211 } 00212 00213 m_means = SGMatrix< float64_t >(m_dim, m_num_classes, true); 00214 SGMatrix< float64_t > scalings = SGMatrix< float64_t >(m_dim, m_num_classes); 00215 00216 // rot_dims will be freed in rotations.destroy_ndarray() 00217 index_t* rot_dims = SG_MALLOC(index_t, 3); 00218 rot_dims[0] = m_dim; 00219 rot_dims[1] = m_dim; 00220 rot_dims[2] = m_num_classes; 00221 SGNDArray< float64_t > rotations = SGNDArray< float64_t >(rot_dims, 3); 00222 00223 CDenseFeatures< float64_t >* rf = (CDenseFeatures< float64_t >*) m_features; 00224 00225 m_means.zero(); 00226 00227 int32_t vlen; 00228 bool vfree; 00229 float64_t* vec; 00230 for (int k = 0; k < m_num_classes; k++) 00231 { 00232 MatrixXd buffer(class_nums[k], m_dim); 00233 Map< VectorXd > Em_means(m_means.get_column_vector(k), m_dim); 00234 for (int i = 0; i < class_nums[k]; i++) 00235 { 00236 vec = rf->get_feature_vector(class_idxs[k*num_vec + i], vlen, vfree); 00237 ASSERT(vec) 00238 00239 Map< VectorXd > Evec(vec, vlen); 00240 Em_means += Evec; 00241 buffer.row(i) = Evec; 00242 00243 rf->free_feature_vector(vec, class_idxs[k*num_vec + i], vfree); 00244 } 00245 00246 Em_means /= class_nums[k]; 00247 00248 for (int i = 0; i < class_nums[k]; i++) 00249 buffer.row(i) -= Em_means; 00250 00251 // SVD 00252 float64_t * col = scalings.get_column_vector(k); 00253 float64_t * rot_mat = rotations.get_matrix(k); 00254 00255 Eigen::JacobiSVD<MatrixXd> eSvd; 00256 eSvd.compute(buffer,Eigen::ComputeFullV); 00257 memcpy(col, eSvd.singularValues().data(), m_dim*sizeof(float64_t)); 00258 memcpy(rot_mat, eSvd.matrixV().data(), m_dim*m_dim*sizeof(float64_t)); 00259 00260 SGVector<float64_t>::vector_multiply(col, col, col, m_dim); 00261 SGVector<float64_t>::scale_vector(1.0/(class_nums[k]-1), col, m_dim); 00262 rotations.transpose_matrix(k); 00263 00264 if (m_store_covs) 00265 { 00266 SGMatrix< float64_t > M(m_dim ,m_dim); 00267 MatrixXd MEig = Map<MatrixXd>(rot_mat,m_dim,m_dim); 00268 for (int i = 0; i < m_dim; i++) 00269 for (int j = 0; j < m_dim; j++) 00270 M(i,j)*=scalings[k*m_dim + j]; 00271 MatrixXd rotE = Map<MatrixXd>(rot_mat,m_dim,m_dim); 00272 MatrixXd resE(m_dim,m_dim); 00273 resE = MEig * rotE.transpose(); 00274 memcpy(m_covs.get_matrix(k),resE.data(),m_dim*m_dim*sizeof(float64_t)); 00275 } 00276 } 00277 00278 /* Computation of terms required for classification */ 00279 SGVector< float32_t > sinvsqrt(m_dim); 00280 00281 // M_dims will be freed in m_M.destroy_ndarray() 00282 index_t* M_dims = SG_MALLOC(index_t, 3); 00283 M_dims[0] = m_dim; 00284 M_dims[1] = m_dim; 00285 M_dims[2] = m_num_classes; 00286 m_M = SGNDArray< float64_t >(M_dims, 3); 00287 00288 m_slog = SGVector< float32_t >(m_num_classes); 00289 m_slog.zero(); 00290 00291 index_t idx = 0; 00292 for (int k = 0; k < m_num_classes; k++) 00293 { 00294 for (int j = 0; j < m_dim; j++) 00295 { 00296 sinvsqrt[j] = 1.0 / CMath::sqrt(scalings[k*m_dim + j]); 00297 m_slog[k] += CMath::log(scalings[k*m_dim + j]); 00298 } 00299 00300 for (int i = 0; i < m_dim; i++) 00301 for (int j = 0; j < m_dim; j++) 00302 { 00303 idx = k*m_dim*m_dim + i + j*m_dim; 00304 m_M[idx] = rotations[idx] * sinvsqrt[j]; 00305 } 00306 } 00307 00308 #ifdef DEBUG_QDA 00309 SG_PRINT(">>> QDA machine trained with %d classes\n", m_num_classes) 00310 00311 SG_PRINT("\n>>> Displaying means ...\n") 00312 SGMatrix< float64_t >::display_matrix(m_means.matrix, m_dim, m_num_classes); 00313 00314 SG_PRINT("\n>>> Displaying scalings ...\n") 00315 SGMatrix< float64_t >::display_matrix(scalings.matrix, m_dim, m_num_classes); 00316 00317 SG_PRINT("\n>>> Displaying rotations ... \n") 00318 for (int k = 0; k < m_num_classes; k++) 00319 SGMatrix< float64_t >::display_matrix(rotations.get_matrix(k), m_dim, m_dim); 00320 00321 SG_PRINT("\n>>> Displaying sinvsqrt ... \n") 00322 sinvsqrt.display_vector(); 00323 00324 SG_PRINT("\n>>> Diplaying m_M matrices ... \n") 00325 for (int k = 0; k < m_num_classes; k++) 00326 SGMatrix< float64_t >::display_matrix(m_M.get_matrix(k), m_dim, m_dim); 00327 00328 SG_PRINT("\n>>> Exit DEBUG_QDA\n") 00329 #endif 00330 00331 SG_FREE(class_idxs); 00332 SG_FREE(class_nums); 00333 return true; 00334 } 00335 00336 #endif /* HAVE_EIGEN3 */