SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009-2012 Vojtech Franc and Soeren Sonnenburg 00008 * Written (W) 2012 Sergey Lisitsyn 00009 * Copyright (C) 2009-2012 Vojtech Franc and Soeren Sonnenburg 00010 */ 00011 00012 #include <shogun/multiclass/MulticlassOCAS.h> 00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00014 #include <shogun/mathematics/Math.h> 00015 #include <shogun/labels/MulticlassLabels.h> 00016 00017 using namespace shogun; 00018 00019 struct mocas_data 00020 { 00021 CDotFeatures* features; 00022 float64_t* W; 00023 float64_t* oldW; 00024 float64_t* full_A; 00025 float64_t* data_y; 00026 float64_t* output_values; 00027 uint32_t nY; 00028 uint32_t nData; 00029 uint32_t nDim; 00030 float64_t* new_a; 00031 }; 00032 00033 CMulticlassOCAS::CMulticlassOCAS() : 00034 CLinearMulticlassMachine() 00035 { 00036 register_parameters(); 00037 set_C(1.0); 00038 set_epsilon(1e-2); 00039 set_max_iter(1000000); 00040 set_method(1); 00041 set_buf_size(5000); 00042 } 00043 00044 CMulticlassOCAS::CMulticlassOCAS(float64_t C, CDotFeatures* train_features, CLabels* train_labels) : 00045 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(), train_features, NULL, train_labels), m_C(C) 00046 { 00047 register_parameters(); 00048 set_epsilon(1e-2); 00049 set_max_iter(1000000); 00050 set_method(1); 00051 set_buf_size(5000); 00052 } 00053 00054 void CMulticlassOCAS::register_parameters() 00055 { 00056 SG_ADD(&m_C, "m_C", "regularization constant", MS_AVAILABLE); 00057 SG_ADD(&m_epsilon, "m_epsilon", "solver relative tolerance", MS_NOT_AVAILABLE); 00058 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations", MS_NOT_AVAILABLE); 00059 SG_ADD(&m_method, "m_method", "used solver method", MS_NOT_AVAILABLE); 00060 SG_ADD(&m_buf_size, "m_buf_size", "buffer size", MS_NOT_AVAILABLE); 00061 } 00062 00063 CMulticlassOCAS::~CMulticlassOCAS() 00064 { 00065 } 00066 00067 bool CMulticlassOCAS::train_machine(CFeatures* data) 00068 { 00069 if (data) 00070 set_features((CDotFeatures*)data); 00071 00072 ASSERT(m_features) 00073 ASSERT(m_labels) 00074 ASSERT(m_multiclass_strategy) 00075 00076 int32_t num_vectors = m_features->get_num_vectors(); 00077 int32_t num_classes = m_multiclass_strategy->get_num_classes(); 00078 int32_t num_features = m_features->get_dim_feature_space(); 00079 00080 float64_t C = m_C; 00081 SGVector<float64_t> labels = ((CMulticlassLabels*) m_labels)->get_labels(); 00082 uint32_t nY = num_classes; 00083 uint32_t nData = num_vectors; 00084 float64_t TolRel = m_epsilon; 00085 float64_t TolAbs = 0.0; 00086 float64_t QPBound = 0.0; 00087 float64_t MaxTime = m_max_train_time; 00088 uint32_t BufSize = m_buf_size; 00089 uint8_t Method = m_method; 00090 00091 mocas_data user_data; 00092 user_data.features = m_features; 00093 user_data.W = SG_CALLOC(float64_t, (int64_t)num_features*num_classes); 00094 user_data.oldW = SG_CALLOC(float64_t, (int64_t)num_features*num_classes); 00095 user_data.new_a = SG_CALLOC(float64_t, (int64_t)num_features*num_classes); 00096 user_data.full_A = SG_CALLOC(float64_t, (int64_t)num_features*num_classes*m_buf_size); 00097 user_data.output_values = SG_CALLOC(float64_t, num_vectors); 00098 user_data.data_y = labels.vector; 00099 user_data.nY = num_classes; 00100 user_data.nDim = num_features; 00101 user_data.nData = num_vectors; 00102 00103 ocas_return_value_T value = 00104 msvm_ocas_solver(C, labels.vector, nY, nData, TolRel, TolAbs, 00105 QPBound, MaxTime, BufSize, Method, 00106 &CMulticlassOCAS::msvm_full_compute_W, 00107 &CMulticlassOCAS::msvm_update_W, 00108 &CMulticlassOCAS::msvm_full_add_new_cut, 00109 &CMulticlassOCAS::msvm_full_compute_output, 00110 &CMulticlassOCAS::msvm_sort_data, 00111 &CMulticlassOCAS::msvm_print, 00112 &user_data); 00113 00114 SG_DEBUG("Number of iterations [nIter] = %d \n",value.nIter) 00115 SG_DEBUG("Number of cutting planes [nCutPlanes] = %d \n",value.nCutPlanes) 00116 SG_DEBUG("Number of non-zero alphas [nNZAlpha] = %d \n",value.nNZAlpha) 00117 SG_DEBUG("Number of training errors [trn_err] = %d \n",value.trn_err) 00118 SG_DEBUG("Primal objective value [Q_P] = %f \n",value.Q_P) 00119 SG_DEBUG("Dual objective value [Q_D] = %f \n",value.Q_D) 00120 SG_DEBUG("Output time [output_time] = %f \n",value.output_time) 00121 SG_DEBUG("Sort time [sort_time] = %f \n",value.sort_time) 00122 SG_DEBUG("Add time [add_time] = %f \n",value.add_time) 00123 SG_DEBUG("W time [w_time] = %f \n",value.w_time) 00124 SG_DEBUG("QP solver time [qp_solver_time] = %f \n",value.qp_solver_time) 00125 SG_DEBUG("OCAS time [ocas_time] = %f \n",value.ocas_time) 00126 SG_DEBUG("Print time [print_time] = %f \n",value.print_time) 00127 SG_DEBUG("QP exit flag [qp_exitflag] = %d \n",value.qp_exitflag) 00128 SG_DEBUG("Exit flag [exitflag] = %d \n",value.exitflag) 00129 00130 m_machines->reset_array(); 00131 for (int32_t i=0; i<num_classes; i++) 00132 { 00133 CLinearMachine* machine = new CLinearMachine(); 00134 machine->set_w(SGVector<float64_t>(&user_data.W[i*num_features],num_features,false).clone()); 00135 00136 m_machines->push_back(machine); 00137 } 00138 00139 SG_FREE(user_data.W); 00140 SG_FREE(user_data.oldW); 00141 SG_FREE(user_data.new_a); 00142 SG_FREE(user_data.full_A); 00143 SG_FREE(user_data.output_values); 00144 00145 return true; 00146 } 00147 00148 float64_t CMulticlassOCAS::msvm_update_W(float64_t t, void* user_data) 00149 { 00150 float64_t* W = ((mocas_data*)user_data)->W; 00151 float64_t* oldW = ((mocas_data*)user_data)->oldW; 00152 uint32_t nY = ((mocas_data*)user_data)->nY; 00153 uint32_t nDim = ((mocas_data*)user_data)->nDim; 00154 00155 for(uint32_t j=0; j < nY*nDim; j++) 00156 W[j] = oldW[j]*(1-t) + t*W[j]; 00157 00158 float64_t sq_norm_W = SGVector<float64_t>::dot(W,W,nDim*nY); 00159 00160 return sq_norm_W; 00161 } 00162 00163 void CMulticlassOCAS::msvm_full_compute_W(float64_t *sq_norm_W, float64_t *dp_WoldW, 00164 float64_t *alpha, uint32_t nSel, void* user_data) 00165 { 00166 float64_t* W = ((mocas_data*)user_data)->W; 00167 float64_t* oldW = ((mocas_data*)user_data)->oldW; 00168 float64_t* full_A = ((mocas_data*)user_data)->full_A; 00169 uint32_t nY = ((mocas_data*)user_data)->nY; 00170 uint32_t nDim = ((mocas_data*)user_data)->nDim; 00171 00172 uint32_t i,j; 00173 00174 memcpy(oldW, W, sizeof(float64_t)*nDim*nY); 00175 memset(W, 0, sizeof(float64_t)*nDim*nY); 00176 00177 for(i=0; i<nSel; i++) 00178 { 00179 if(alpha[i] > 0) 00180 { 00181 for(j=0; j<nDim*nY; j++) 00182 W[j] += alpha[i]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)]; 00183 } 00184 } 00185 00186 *sq_norm_W = SGVector<float64_t>::dot(W,W,nDim*nY); 00187 *dp_WoldW = SGVector<float64_t>::dot(W,oldW,nDim*nY); 00188 00189 return; 00190 } 00191 00192 int CMulticlassOCAS::msvm_full_add_new_cut(float64_t *new_col_H, uint32_t *new_cut, 00193 uint32_t nSel, void* user_data) 00194 { 00195 float64_t* full_A = ((mocas_data*)user_data)->full_A; 00196 float64_t* new_a = ((mocas_data*)user_data)->new_a; 00197 float64_t* data_y = ((mocas_data*)user_data)->data_y; 00198 uint32_t nY = ((mocas_data*)user_data)->nY; 00199 uint32_t nDim = ((mocas_data*)user_data)->nDim; 00200 uint32_t nData = ((mocas_data*)user_data)->nData; 00201 CDotFeatures* features = ((mocas_data*)user_data)->features; 00202 00203 float64_t sq_norm_a; 00204 uint32_t i, j, y, y2; 00205 00206 memset(new_a, 0, sizeof(float64_t)*nDim*nY); 00207 00208 for(i=0; i < nData; i++) 00209 { 00210 y = (uint32_t)(data_y[i]); 00211 y2 = (uint32_t)new_cut[i]; 00212 if(y2 != y) 00213 { 00214 features->add_to_dense_vec(1.0,i,&new_a[nDim*y],nDim); 00215 features->add_to_dense_vec(-1.0,i,&new_a[nDim*y2],nDim); 00216 } 00217 } 00218 00219 // compute new_a'*new_a and insert new_a to the last column of full_A 00220 sq_norm_a = SGVector<float64_t>::dot(new_a,new_a,nDim*nY); 00221 for(j=0; j < nDim*nY; j++ ) 00222 full_A[LIBOCAS_INDEX(j,nSel,nDim*nY)] = new_a[j]; 00223 00224 new_col_H[nSel] = sq_norm_a; 00225 for(i=0; i < nSel; i++) 00226 { 00227 float64_t tmp = 0; 00228 00229 for(j=0; j < nDim*nY; j++ ) 00230 tmp += new_a[j]*full_A[LIBOCAS_INDEX(j,i,nDim*nY)]; 00231 00232 new_col_H[i] = tmp; 00233 } 00234 00235 return 0; 00236 } 00237 00238 int CMulticlassOCAS::msvm_full_compute_output(float64_t *output, void* user_data) 00239 { 00240 float64_t* W = ((mocas_data*)user_data)->W; 00241 uint32_t nY = ((mocas_data*)user_data)->nY; 00242 uint32_t nDim = ((mocas_data*)user_data)->nDim; 00243 uint32_t nData = ((mocas_data*)user_data)->nData; 00244 float64_t* output_values = ((mocas_data*)user_data)->output_values; 00245 CDotFeatures* features = ((mocas_data*)user_data)->features; 00246 00247 uint32_t i, y; 00248 00249 for(y=0; y<nY; y++) 00250 { 00251 features->dense_dot_range(output_values,0,nData,NULL,&W[nDim*y],nDim,0.0); 00252 for (i=0; i<nData; i++) 00253 output[LIBOCAS_INDEX(y,i,nY)] = output_values[i]; 00254 } 00255 00256 return 0; 00257 } 00258 00259 int CMulticlassOCAS::msvm_sort_data(float64_t* vals, float64_t* data, uint32_t size) 00260 { 00261 CMath::qsort_index(vals, data, size); 00262 return 0; 00263 } 00264 00265 void CMulticlassOCAS::msvm_print(ocas_return_value_T value) 00266 { 00267 }