SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Written (W) 2009 Marius Kloft 00009 * Copyright (C) 2009 TU Berlin and Max-Planck-Society 00010 */ 00011 #include <shogun/multiclass/ScatterSVM.h> 00012 00013 #ifdef USE_SVMLIGHT 00014 #include <shogun/classifier/svm/SVMLightOneClass.h> 00015 #endif //USE_SVMLIGHT 00016 00017 #include <shogun/kernel/Kernel.h> 00018 #include <shogun/kernel/normalizer/ScatterKernelNormalizer.h> 00019 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00020 #include <shogun/io/SGIO.h> 00021 00022 using namespace shogun; 00023 00024 CScatterSVM::CScatterSVM() 00025 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(NO_BIAS_LIBSVM), 00026 model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00027 { 00028 SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n") 00029 } 00030 00031 CScatterSVM::CScatterSVM(SCATTER_TYPE type) 00032 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(type), model(NULL), 00033 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00034 { 00035 } 00036 00037 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab) 00038 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL), 00039 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00040 { 00041 } 00042 00043 CScatterSVM::~CScatterSVM() 00044 { 00045 SG_FREE(norm_wc); 00046 SG_FREE(norm_wcw); 00047 } 00048 00049 bool CScatterSVM::train_machine(CFeatures* data) 00050 { 00051 ASSERT(m_labels && m_labels->get_num_labels()) 00052 ASSERT(m_labels->get_label_type() == LT_MULTICLASS) 00053 00054 m_num_classes = m_multiclass_strategy->get_num_classes(); 00055 int32_t num_vectors = m_labels->get_num_labels(); 00056 00057 if (data) 00058 { 00059 if (m_labels->get_num_labels() != data->get_num_vectors()) 00060 SG_ERROR("Number of training vectors does not match number of labels\n") 00061 m_kernel->init(data, data); 00062 } 00063 00064 int32_t* numc=SG_MALLOC(int32_t, m_num_classes); 00065 SGVector<int32_t>::fill_vector(numc, m_num_classes, 0); 00066 00067 for (int32_t i=0; i<num_vectors; i++) 00068 numc[(int32_t) ((CMulticlassLabels*) m_labels)->get_int_label(i)]++; 00069 00070 int32_t Nc=0; 00071 int32_t Nmin=num_vectors; 00072 for (int32_t i=0; i<m_num_classes; i++) 00073 { 00074 if (numc[i]>0) 00075 { 00076 Nc++; 00077 Nmin=CMath::min(Nmin, numc[i]); 00078 } 00079 00080 } 00081 SG_FREE(numc); 00082 m_num_classes=Nc; 00083 00084 bool result=false; 00085 00086 if (scatter_type==NO_BIAS_LIBSVM) 00087 { 00088 result=train_no_bias_libsvm(); 00089 } 00090 #ifdef USE_SVMLIGHT 00091 else if (scatter_type==NO_BIAS_SVMLIGHT) 00092 { 00093 result=train_no_bias_svmlight(); 00094 } 00095 #endif //USE_SVMLIGHT 00096 else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2) 00097 { 00098 float64_t nu_min=((float64_t) Nc)/num_vectors; 00099 float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors; 00100 00101 SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max) 00102 00103 if (get_nu()<nu_min || get_nu()>nu_max) 00104 SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max) 00105 00106 result=train_testrule12(); 00107 } 00108 else 00109 SG_ERROR("Unknown Scatter type\n") 00110 00111 return result; 00112 } 00113 00114 bool CScatterSVM::train_no_bias_libsvm() 00115 { 00116 struct svm_node* x_space; 00117 00118 problem.l=m_labels->get_num_labels(); 00119 SG_INFO("%d trainlabels\n", problem.l) 00120 00121 problem.y=SG_MALLOC(float64_t, problem.l); 00122 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00123 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00124 00125 for (int32_t i=0; i<problem.l; i++) 00126 { 00127 problem.y[i]=+1; 00128 problem.x[i]=&x_space[2*i]; 00129 x_space[2*i].index=i; 00130 x_space[2*i+1].index=-1; 00131 } 00132 00133 int32_t weights_label[2]={-1,+1}; 00134 float64_t weights[2]={1.0,get_C()/get_C()}; 00135 00136 ASSERT(m_kernel && m_kernel->has_features()) 00137 ASSERT(m_kernel->get_num_vec_lhs()==problem.l) 00138 00139 param.svm_type=C_SVC; // Nu MC SVM 00140 param.kernel_type = LINEAR; 00141 param.degree = 3; 00142 param.gamma = 0; // 1/k 00143 param.coef0 = 0; 00144 param.nu = get_nu(); // Nu 00145 CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer(); 00146 m_kernel->set_normalizer(new CScatterKernelNormalizer( 00147 m_num_classes-1, -1, m_labels, prev_normalizer)); 00148 param.kernel=m_kernel; 00149 param.cache_size = m_kernel->get_cache_size(); 00150 param.C = 0; 00151 param.eps = get_epsilon(); 00152 param.p = 0.1; 00153 param.shrinking = 0; 00154 param.nr_weight = 2; 00155 param.weight_label = weights_label; 00156 param.weight = weights; 00157 param.nr_class=m_num_classes; 00158 param.use_bias = svm_proto()->get_bias_enabled(); 00159 00160 const char* error_msg = svm_check_parameter(&problem,¶m); 00161 00162 if(error_msg) 00163 SG_ERROR("Error: %s\n",error_msg) 00164 00165 model = svm_train(&problem, ¶m); 00166 m_kernel->set_normalizer(prev_normalizer); 00167 SG_UNREF(prev_normalizer); 00168 00169 if (model) 00170 { 00171 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)) 00172 00173 ASSERT(model->nr_class==m_num_classes) 00174 create_multiclass_svm(m_num_classes); 00175 00176 rho=model->rho[0]; 00177 00178 SG_FREE(norm_wcw); 00179 norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements()); 00180 00181 for (int32_t i=0; i<m_num_classes; i++) 00182 { 00183 int32_t num_sv=model->nSV[i]; 00184 00185 CSVM* svm=new CSVM(num_sv); 00186 svm->set_bias(model->rho[i+1]); 00187 norm_wcw[i]=model->normwcw[i]; 00188 00189 00190 for (int32_t j=0; j<num_sv; j++) 00191 { 00192 svm->set_alpha(j, model->sv_coef[i][j]); 00193 svm->set_support_vector(j, model->SV[i][j].index); 00194 } 00195 00196 set_svm(i, svm); 00197 } 00198 00199 SG_FREE(problem.x); 00200 SG_FREE(problem.y); 00201 SG_FREE(x_space); 00202 for (int32_t i=0; i<m_num_classes; i++) 00203 { 00204 SG_FREE(model->SV[i]); 00205 model->SV[i]=NULL; 00206 } 00207 svm_destroy_model(model); 00208 00209 if (scatter_type==TEST_RULE2) 00210 compute_norm_wc(); 00211 00212 model=NULL; 00213 return true; 00214 } 00215 else 00216 return false; 00217 } 00218 00219 #ifdef USE_SVMLIGHT 00220 bool CScatterSVM::train_no_bias_svmlight() 00221 { 00222 CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer(); 00223 CScatterKernelNormalizer* n=new CScatterKernelNormalizer( 00224 m_num_classes-1, -1, m_labels, prev_normalizer); 00225 m_kernel->set_normalizer(n); 00226 m_kernel->init_normalizer(); 00227 00228 CSVMLightOneClass* light=new CSVMLightOneClass(get_C(), m_kernel); 00229 light->set_linadd_enabled(false); 00230 light->train(); 00231 00232 SG_FREE(norm_wcw); 00233 norm_wcw = SG_MALLOC(float64_t, m_num_classes); 00234 00235 int32_t num_sv=light->get_num_support_vectors(); 00236 svm_proto()->create_new_model(num_sv); 00237 00238 for (int32_t i=0; i<num_sv; i++) 00239 { 00240 svm_proto()->set_alpha(i, light->get_alpha(i)); 00241 svm_proto()->set_support_vector(i, light->get_support_vector(i)); 00242 } 00243 00244 m_kernel->set_normalizer(prev_normalizer); 00245 return true; 00246 } 00247 #endif //USE_SVMLIGHT 00248 00249 bool CScatterSVM::train_testrule12() 00250 { 00251 struct svm_node* x_space; 00252 problem.l=m_labels->get_num_labels(); 00253 SG_INFO("%d trainlabels\n", problem.l) 00254 00255 problem.y=SG_MALLOC(float64_t, problem.l); 00256 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00257 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00258 00259 for (int32_t i=0; i<problem.l; i++) 00260 { 00261 problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i); 00262 problem.x[i]=&x_space[2*i]; 00263 x_space[2*i].index=i; 00264 x_space[2*i+1].index=-1; 00265 } 00266 00267 int32_t weights_label[2]={-1,+1}; 00268 float64_t weights[2]={1.0,get_C()/get_C()}; 00269 00270 ASSERT(m_kernel && m_kernel->has_features()) 00271 ASSERT(m_kernel->get_num_vec_lhs()==problem.l) 00272 00273 param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM 00274 param.kernel_type = LINEAR; 00275 param.degree = 3; 00276 param.gamma = 0; // 1/k 00277 param.coef0 = 0; 00278 param.nu = get_nu(); // Nu 00279 param.kernel=m_kernel; 00280 param.cache_size = m_kernel->get_cache_size(); 00281 param.C = 0; 00282 param.eps = get_epsilon(); 00283 param.p = 0.1; 00284 param.shrinking = 0; 00285 param.nr_weight = 2; 00286 param.weight_label = weights_label; 00287 param.weight = weights; 00288 param.nr_class=m_num_classes; 00289 param.use_bias = svm_proto()->get_bias_enabled(); 00290 00291 const char* error_msg = svm_check_parameter(&problem,¶m); 00292 00293 if(error_msg) 00294 SG_ERROR("Error: %s\n",error_msg) 00295 00296 model = svm_train(&problem, ¶m); 00297 00298 if (model) 00299 { 00300 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)) 00301 00302 ASSERT(model->nr_class==m_num_classes) 00303 create_multiclass_svm(m_num_classes); 00304 00305 rho=model->rho[0]; 00306 00307 SG_FREE(norm_wcw); 00308 norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements()); 00309 00310 for (int32_t i=0; i<m_num_classes; i++) 00311 { 00312 int32_t num_sv=model->nSV[i]; 00313 00314 CSVM* svm=new CSVM(num_sv); 00315 svm->set_bias(model->rho[i+1]); 00316 norm_wcw[i]=model->normwcw[i]; 00317 00318 00319 for (int32_t j=0; j<num_sv; j++) 00320 { 00321 svm->set_alpha(j, model->sv_coef[i][j]); 00322 svm->set_support_vector(j, model->SV[i][j].index); 00323 } 00324 00325 set_svm(i, svm); 00326 } 00327 00328 SG_FREE(problem.x); 00329 SG_FREE(problem.y); 00330 SG_FREE(x_space); 00331 for (int32_t i=0; i<m_num_classes; i++) 00332 { 00333 SG_FREE(model->SV[i]); 00334 model->SV[i]=NULL; 00335 } 00336 svm_destroy_model(model); 00337 00338 if (scatter_type==TEST_RULE2) 00339 compute_norm_wc(); 00340 00341 model=NULL; 00342 return true; 00343 } 00344 else 00345 return false; 00346 } 00347 00348 void CScatterSVM::compute_norm_wc() 00349 { 00350 SG_FREE(norm_wc); 00351 norm_wc = SG_MALLOC(float64_t, m_machines->get_num_elements()); 00352 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00353 norm_wc[i]=0; 00354 00355 00356 for (int c=0; c<m_machines->get_num_elements(); c++) 00357 { 00358 CSVM* svm=get_svm(c); 00359 int32_t num_sv = svm->get_num_support_vectors(); 00360 00361 for (int32_t i=0; i<num_sv; i++) 00362 { 00363 int32_t ii=svm->get_support_vector(i); 00364 for (int32_t j=0; j<num_sv; j++) 00365 { 00366 int32_t jj=svm->get_support_vector(j); 00367 norm_wc[c]+=svm->get_alpha(i)*m_kernel->kernel(ii,jj)*svm->get_alpha(j); 00368 } 00369 } 00370 } 00371 00372 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00373 norm_wc[i]=CMath::sqrt(norm_wc[i]); 00374 00375 SGVector<float64_t>::display_vector(norm_wc, m_machines->get_num_elements(), "norm_wc"); 00376 } 00377 00378 CLabels* CScatterSVM::classify_one_vs_rest() 00379 { 00380 CMulticlassLabels* output=NULL; 00381 if (!m_kernel) 00382 { 00383 SG_ERROR("SVM can not proceed without kernel!\n") 00384 return NULL; 00385 } 00386 00387 if (!( m_kernel && m_kernel->get_num_vec_lhs() && m_kernel->get_num_vec_rhs())) 00388 return NULL; 00389 00390 int32_t num_vectors=m_kernel->get_num_vec_rhs(); 00391 00392 output=new CMulticlassLabels(num_vectors); 00393 SG_REF(output); 00394 00395 if (scatter_type == TEST_RULE1) 00396 { 00397 ASSERT(m_machines->get_num_elements()>0) 00398 for (int32_t i=0; i<num_vectors; i++) 00399 output->set_label(i, apply_one(i)); 00400 } 00401 #ifdef USE_SVMLIGHT 00402 else if (scatter_type == NO_BIAS_SVMLIGHT) 00403 { 00404 float64_t* outputs=SG_MALLOC(float64_t, num_vectors*m_num_classes); 00405 SGVector<float64_t>::fill_vector(outputs,num_vectors*m_num_classes,0.0); 00406 00407 for (int32_t i=0; i<num_vectors; i++) 00408 { 00409 for (int32_t j=0; j<svm_proto()->get_num_support_vectors(); j++) 00410 { 00411 float64_t score=m_kernel->kernel(svm_proto()->get_support_vector(j), i)*svm_proto()->get_alpha(j); 00412 int32_t label=((CMulticlassLabels*) m_labels)->get_int_label(svm_proto()->get_support_vector(j)); 00413 for (int32_t c=0; c<m_num_classes; c++) 00414 { 00415 float64_t s= (label==c) ? (m_num_classes-1) : (-1); 00416 outputs[c+i*m_num_classes]+=s*score; 00417 } 00418 } 00419 } 00420 00421 for (int32_t i=0; i<num_vectors; i++) 00422 { 00423 int32_t winner=0; 00424 float64_t max_out=outputs[i*m_num_classes+0]; 00425 00426 for (int32_t j=1; j<m_num_classes; j++) 00427 { 00428 float64_t out=outputs[i*m_num_classes+j]; 00429 00430 if (out>max_out) 00431 { 00432 winner=j; 00433 max_out=out; 00434 } 00435 } 00436 00437 output->set_label(i, winner); 00438 } 00439 00440 SG_FREE(outputs); 00441 } 00442 #endif //USE_SVMLIGHT 00443 else 00444 { 00445 ASSERT(m_machines->get_num_elements()>0) 00446 ASSERT(num_vectors==output->get_num_labels()) 00447 CLabels** outputs=SG_MALLOC(CLabels*, m_machines->get_num_elements()); 00448 00449 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00450 { 00451 //SG_PRINT("svm %d\n", i) 00452 CSVM *svm = get_svm(i); 00453 ASSERT(svm) 00454 svm->set_kernel(m_kernel); 00455 svm->set_labels(m_labels); 00456 outputs[i]=svm->apply(); 00457 SG_UNREF(svm); 00458 } 00459 00460 for (int32_t i=0; i<num_vectors; i++) 00461 { 00462 int32_t winner=0; 00463 float64_t max_out=((CRegressionLabels*) outputs[0])->get_label(i)/norm_wc[0]; 00464 00465 for (int32_t j=1; j<m_machines->get_num_elements(); j++) 00466 { 00467 float64_t out=((CRegressionLabels*) outputs[j])->get_label(i)/norm_wc[j]; 00468 00469 if (out>max_out) 00470 { 00471 winner=j; 00472 max_out=out; 00473 } 00474 } 00475 00476 output->set_label(i, winner); 00477 } 00478 00479 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00480 SG_UNREF(outputs[i]); 00481 00482 SG_FREE(outputs); 00483 } 00484 00485 return output; 00486 } 00487 00488 float64_t CScatterSVM::apply_one(int32_t num) 00489 { 00490 ASSERT(m_machines->get_num_elements()>0) 00491 float64_t* outputs=SG_MALLOC(float64_t, m_machines->get_num_elements()); 00492 int32_t winner=0; 00493 00494 if (scatter_type == TEST_RULE1) 00495 { 00496 for (int32_t c=0; c<m_machines->get_num_elements(); c++) 00497 outputs[c]=get_svm(c)->get_bias()-rho; 00498 00499 for (int32_t c=0; c<m_machines->get_num_elements(); c++) 00500 { 00501 float64_t v=0; 00502 00503 for (int32_t i=0; i<get_svm(c)->get_num_support_vectors(); i++) 00504 { 00505 float64_t alpha=get_svm(c)->get_alpha(i); 00506 int32_t svidx=get_svm(c)->get_support_vector(i); 00507 v += alpha*m_kernel->kernel(svidx, num); 00508 } 00509 00510 outputs[c] += v; 00511 for (int32_t j=0; j<m_machines->get_num_elements(); j++) 00512 outputs[j] -= v/m_machines->get_num_elements(); 00513 } 00514 00515 for (int32_t j=0; j<m_machines->get_num_elements(); j++) 00516 outputs[j]/=norm_wcw[j]; 00517 00518 float64_t max_out=outputs[0]; 00519 for (int32_t j=0; j<m_machines->get_num_elements(); j++) 00520 { 00521 if (outputs[j]>max_out) 00522 { 00523 max_out=outputs[j]; 00524 winner=j; 00525 } 00526 } 00527 } 00528 #ifdef USE_SVMLIGHT 00529 else if (scatter_type == NO_BIAS_SVMLIGHT) 00530 { 00531 SG_ERROR("Use classify...\n") 00532 } 00533 #endif //USE_SVMLIGHT 00534 else 00535 { 00536 float64_t max_out=get_svm(0)->apply_one(num)/norm_wc[0]; 00537 00538 for (int32_t i=1; i<m_machines->get_num_elements(); i++) 00539 { 00540 outputs[i]=get_svm(i)->apply_one(num)/norm_wc[i]; 00541 if (outputs[i]>max_out) 00542 { 00543 winner=i; 00544 max_out=outputs[i]; 00545 } 00546 } 00547 } 00548 00549 SG_FREE(outputs); 00550 return winner; 00551 }