SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/classifier/svm/MPDSVM.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/lib/common.h> 00014 #include <shogun/mathematics/Math.h> 00015 00016 using namespace shogun; 00017 00018 CMPDSVM::CMPDSVM() 00019 : CSVM() 00020 { 00021 } 00022 00023 CMPDSVM::CMPDSVM(float64_t C, CKernel* k, CLabels* lab) 00024 : CSVM(C, k, lab) 00025 { 00026 } 00027 00028 CMPDSVM::~CMPDSVM() 00029 { 00030 } 00031 00032 bool CMPDSVM::train_machine(CFeatures* data) 00033 { 00034 ASSERT(m_labels) 00035 ASSERT(m_labels->get_label_type() == LT_BINARY) 00036 ASSERT(kernel) 00037 00038 if (data) 00039 { 00040 if (m_labels->get_num_labels() != data->get_num_vectors()) 00041 SG_ERROR("Number of training vectors does not match number of labels\n") 00042 kernel->init(data, data); 00043 } 00044 ASSERT(kernel->has_features()) 00045 00046 //const float64_t nu=0.32; 00047 const float64_t alpha_eps=1e-12; 00048 const float64_t eps=get_epsilon(); 00049 const int64_t maxiter = 1L<<30; 00050 //const bool nustop=false; 00051 //const int32_t k=2; 00052 const int32_t n=m_labels->get_num_labels(); 00053 ASSERT(n>0) 00054 //const float64_t d = 1.0/n/nu; //NUSVC 00055 const float64_t d = get_C1(); //CSVC 00056 const float64_t primaleps=eps; 00057 const float64_t dualeps=eps*n; //heuristic 00058 int64_t niter=0; 00059 00060 kernel_cache = new CCache<KERNELCACHE_ELEM>(kernel->get_cache_size(), n, n); 00061 float64_t* alphas=SG_MALLOC(float64_t, n); 00062 float64_t* dalphas=SG_MALLOC(float64_t, n); 00063 //float64_t* hessres=SG_MALLOC(float64_t, 2*n); 00064 float64_t* hessres=SG_MALLOC(float64_t, n); 00065 //float64_t* F=SG_MALLOC(float64_t, 2*n); 00066 float64_t* F=SG_MALLOC(float64_t, n); 00067 00068 //float64_t hessest[2]={0,0}; 00069 //float64_t hstep[2]; 00070 //float64_t etas[2]={0,0}; 00071 //float64_t detas[2]={0,1}; //NUSVC 00072 float64_t etas=0; 00073 float64_t detas=0; //CSVC 00074 float64_t hessest=0; 00075 float64_t hstep; 00076 00077 const float64_t stopfac = 1; 00078 00079 bool primalcool; 00080 bool dualcool; 00081 00082 //if (nustop) 00083 //etas[1] = 1; 00084 00085 for (int32_t i=0; i<n; i++) 00086 { 00087 alphas[i]=0; 00088 F[i]=((CBinaryLabels*) m_labels)->get_label(i); 00089 //F[i+n]=-1; 00090 hessres[i]=((CBinaryLabels*) m_labels)->get_label(i); 00091 //hessres[i+n]=-1; 00092 //dalphas[i]=F[i+n]*etas[1]; //NUSVC 00093 dalphas[i]=-1; //CSVC 00094 } 00095 00096 // go ... 00097 while (niter++ < maxiter) 00098 { 00099 int32_t maxpidx=-1; 00100 float64_t maxpviol = -1; 00101 //float64_t maxdviol = CMath::abs(detas[0]); 00102 float64_t maxdviol = CMath::abs(detas); 00103 bool free_alpha=false; 00104 00105 //if (CMath::abs(detas[1])> maxdviol) 00106 //maxdviol=CMath::abs(detas[1]); 00107 00108 // compute kkt violations with correct sign ... 00109 for (int32_t i=0; i<n; i++) 00110 { 00111 float64_t v=CMath::abs(dalphas[i]); 00112 00113 if (alphas[i] > 0 && alphas[i] < d) 00114 free_alpha=true; 00115 00116 if ( (dalphas[i]==0) || 00117 (alphas[i]==0 && dalphas[i] >0) || 00118 (alphas[i]==d && dalphas[i] <0) 00119 ) 00120 v=0; 00121 00122 if (v > maxpviol) 00123 { 00124 maxpviol=v; 00125 maxpidx=i; 00126 } // if we cannot improve on maxpviol, we can still improve by choosing a cached element 00127 else if (v == maxpviol) 00128 { 00129 if (kernel_cache->is_cached(i)) 00130 maxpidx=i; 00131 } 00132 } 00133 00134 if (maxpidx<0 || maxdviol<0) 00135 SG_ERROR("no violation no convergence, should not happen!\n") 00136 00137 // ... and evaluate stopping conditions 00138 //if (nustop) 00139 //stopfac = CMath::max(etas[1], 1e-10); 00140 //else 00141 //stopfac = 1; 00142 00143 if (niter%10000 == 0) 00144 { 00145 float64_t obj=0; 00146 00147 for (int32_t i=0; i<n; i++) 00148 { 00149 obj-=alphas[i]; 00150 for (int32_t j=0; j<n; j++) 00151 obj+=0.5*((CBinaryLabels*) m_labels)->get_label(i)*((CBinaryLabels*) m_labels)->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j); 00152 } 00153 00154 SG_DEBUG("obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter) 00155 } 00156 00157 //for (int32_t i=0; i<n; i++) 00158 // SG_DEBUG("alphas:%f dalphas:%f\n", alphas[i], dalphas[i]) 00159 00160 primalcool = (maxpviol < primaleps*stopfac); 00161 dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha); 00162 00163 // done? 00164 if (primalcool && dualcool) 00165 { 00166 if (!free_alpha) 00167 SG_INFO(" no free alpha, stopping! #iter=%d\n", niter) 00168 else 00169 SG_INFO(" done! #iter=%d\n", niter) 00170 break; 00171 } 00172 00173 00174 ASSERT(maxpidx>=0 && maxpidx<n) 00175 // hessian updates 00176 hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx); 00177 //hstep[0]=-hessres[maxpidx]/(compute_H(maxpidx,maxpidx)+hessreg); 00178 //hstep[1]=-hessres[maxpidx+n]/(compute_H(maxpidx,maxpidx)+hessreg); 00179 00180 hessest-=F[maxpidx]*hstep; 00181 //hessest[0]-=F[maxpidx]*hstep[0]; 00182 //hessest[1]-=F[maxpidx+n]*hstep[1]; 00183 00184 // do primal updates .. 00185 float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx); 00186 00187 if (tmpalpha > d-alpha_eps) 00188 tmpalpha = d; 00189 00190 if (tmpalpha < 0+alpha_eps) 00191 tmpalpha = 0; 00192 00193 // update alphas & dalphas & detas ... 00194 float64_t alphachange = tmpalpha - alphas[maxpidx]; 00195 alphas[maxpidx] = tmpalpha; 00196 00197 KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx); 00198 for (int32_t i=0; i<n; i++) 00199 { 00200 hessres[i]+=h[i]*hstep; 00201 //hessres[i]+=h[i]*hstep[0]; 00202 //hessres[i+n]+=h[i]*hstep[1]; 00203 dalphas[i] +=h[i]*alphachange; 00204 } 00205 unlock_kernel_row(maxpidx); 00206 00207 detas+=F[maxpidx]*alphachange; 00208 //detas[0]+=F[maxpidx]*alphachange; 00209 //detas[1]+=F[maxpidx+n]*alphachange; 00210 00211 // if at primal minimum, do eta update ... 00212 if (primalcool) 00213 { 00214 //float64_t etachange[2] = { detas[0]/hessest[0] , detas[1]/hessest[1] }; 00215 float64_t etachange = detas/hessest; 00216 00217 etas+=etachange; 00218 //etas[0]+=etachange[0]; 00219 //etas[1]+=etachange[1]; 00220 00221 // update dalphas 00222 for (int32_t i=0; i<n; i++) 00223 dalphas[i]+= F[i] * etachange; 00224 //dalphas[i]+= F[i] * etachange[0] + F[i+n] * etachange[1]; 00225 } 00226 } 00227 00228 if (niter >= maxiter) 00229 SG_WARNING("increase maxiter ... \n") 00230 00231 00232 int32_t nsv=0; 00233 for (int32_t i=0; i<n; i++) 00234 { 00235 if (alphas[i]>0) 00236 nsv++; 00237 } 00238 00239 00240 create_new_model(nsv); 00241 //set_bias(etas[0]/etas[1]); 00242 set_bias(etas); 00243 00244 int32_t j=0; 00245 for (int32_t i=0; i<n; i++) 00246 { 00247 if (alphas[i]>0) 00248 { 00249 //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]); 00250 set_alpha(j, alphas[i]*((CBinaryLabels*) m_labels)->get_label(i)); 00251 set_support_vector(j, i); 00252 j++; 00253 } 00254 } 00255 compute_svm_dual_objective(); 00256 SG_INFO("obj = %.16f, rho = %.16f\n",get_objective(),get_bias()) 00257 SG_INFO("Number of SV: %ld\n", get_num_support_vectors()) 00258 00259 SG_FREE(alphas); 00260 SG_FREE(dalphas); 00261 SG_FREE(hessres); 00262 SG_FREE(F); 00263 delete kernel_cache; 00264 00265 return true; 00266 }