SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
ScatterSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Soeren Sonnenburg
00008  * Written (W) 2009 Marius Kloft
00009  * Copyright (C) 2009 TU Berlin and Max-Planck-Society
00010  */
00011 #include <shogun/multiclass/ScatterSVM.h>
00012 
00013 #ifdef USE_SVMLIGHT
00014 #include <shogun/classifier/svm/SVMLightOneClass.h>
00015 #endif //USE_SVMLIGHT
00016 
00017 #include <shogun/kernel/Kernel.h>
00018 #include <shogun/kernel/normalizer/ScatterKernelNormalizer.h>
00019 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00020 #include <shogun/io/SGIO.h>
00021 
00022 using namespace shogun;
00023 
00024 CScatterSVM::CScatterSVM()
00025 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(NO_BIAS_LIBSVM),
00026   model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00027 {
00028     SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n")
00029 }
00030 
00031 CScatterSVM::CScatterSVM(SCATTER_TYPE type)
00032 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(type), model(NULL),
00033     norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00034 {
00035 }
00036 
00037 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab)
00038 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL),
00039     norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00040 {
00041 }
00042 
00043 CScatterSVM::~CScatterSVM()
00044 {
00045     SG_FREE(norm_wc);
00046     SG_FREE(norm_wcw);
00047 }
00048 
00049 bool CScatterSVM::train_machine(CFeatures* data)
00050 {
00051     ASSERT(m_labels && m_labels->get_num_labels())
00052     ASSERT(m_labels->get_label_type() == LT_MULTICLASS)
00053 
00054     m_num_classes = m_multiclass_strategy->get_num_classes();
00055     int32_t num_vectors = m_labels->get_num_labels();
00056 
00057     if (data)
00058     {
00059         if (m_labels->get_num_labels() != data->get_num_vectors())
00060             SG_ERROR("Number of training vectors does not match number of labels\n")
00061         m_kernel->init(data, data);
00062     }
00063 
00064     int32_t* numc=SG_MALLOC(int32_t, m_num_classes);
00065     SGVector<int32_t>::fill_vector(numc, m_num_classes, 0);
00066 
00067     for (int32_t i=0; i<num_vectors; i++)
00068         numc[(int32_t) ((CMulticlassLabels*) m_labels)->get_int_label(i)]++;
00069 
00070     int32_t Nc=0;
00071     int32_t Nmin=num_vectors;
00072     for (int32_t i=0; i<m_num_classes; i++)
00073     {
00074         if (numc[i]>0)
00075         {
00076             Nc++;
00077             Nmin=CMath::min(Nmin, numc[i]);
00078         }
00079 
00080     }
00081     SG_FREE(numc);
00082     m_num_classes=Nc;
00083 
00084     bool result=false;
00085 
00086     if (scatter_type==NO_BIAS_LIBSVM)
00087     {
00088         result=train_no_bias_libsvm();
00089     }
00090 #ifdef USE_SVMLIGHT
00091     else if (scatter_type==NO_BIAS_SVMLIGHT)
00092     {
00093         result=train_no_bias_svmlight();
00094     }
00095 #endif //USE_SVMLIGHT
00096     else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2)
00097     {
00098         float64_t nu_min=((float64_t) Nc)/num_vectors;
00099         float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors;
00100 
00101         SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max)
00102 
00103         if (get_nu()<nu_min || get_nu()>nu_max)
00104             SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max)
00105 
00106         result=train_testrule12();
00107     }
00108     else
00109         SG_ERROR("Unknown Scatter type\n")
00110 
00111     return result;
00112 }
00113 
00114 bool CScatterSVM::train_no_bias_libsvm()
00115 {
00116     struct svm_node* x_space;
00117 
00118     problem.l=m_labels->get_num_labels();
00119     SG_INFO("%d trainlabels\n", problem.l)
00120 
00121     problem.y=SG_MALLOC(float64_t, problem.l);
00122     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00123     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00124 
00125     for (int32_t i=0; i<problem.l; i++)
00126     {
00127         problem.y[i]=+1;
00128         problem.x[i]=&x_space[2*i];
00129         x_space[2*i].index=i;
00130         x_space[2*i+1].index=-1;
00131     }
00132 
00133     int32_t weights_label[2]={-1,+1};
00134     float64_t weights[2]={1.0,get_C()/get_C()};
00135 
00136     ASSERT(m_kernel && m_kernel->has_features())
00137     ASSERT(m_kernel->get_num_vec_lhs()==problem.l)
00138 
00139     param.svm_type=C_SVC; // Nu MC SVM
00140     param.kernel_type = LINEAR;
00141     param.degree = 3;
00142     param.gamma = 0;    // 1/k
00143     param.coef0 = 0;
00144     param.nu = get_nu(); // Nu
00145     CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer();
00146     m_kernel->set_normalizer(new CScatterKernelNormalizer(
00147                 m_num_classes-1, -1, m_labels, prev_normalizer));
00148     param.kernel=m_kernel;
00149     param.cache_size = m_kernel->get_cache_size();
00150     param.C = 0;
00151     param.eps = get_epsilon();
00152     param.p = 0.1;
00153     param.shrinking = 0;
00154     param.nr_weight = 2;
00155     param.weight_label = weights_label;
00156     param.weight = weights;
00157     param.nr_class=m_num_classes;
00158     param.use_bias = svm_proto()->get_bias_enabled();
00159 
00160     const char* error_msg = svm_check_parameter(&problem,&param);
00161 
00162     if(error_msg)
00163         SG_ERROR("Error: %s\n",error_msg)
00164 
00165     model = svm_train(&problem, &param);
00166     m_kernel->set_normalizer(prev_normalizer);
00167     SG_UNREF(prev_normalizer);
00168 
00169     if (model)
00170     {
00171         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef))
00172 
00173         ASSERT(model->nr_class==m_num_classes)
00174         create_multiclass_svm(m_num_classes);
00175 
00176         rho=model->rho[0];
00177 
00178         SG_FREE(norm_wcw);
00179         norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements());
00180 
00181         for (int32_t i=0; i<m_num_classes; i++)
00182         {
00183             int32_t num_sv=model->nSV[i];
00184 
00185             CSVM* svm=new CSVM(num_sv);
00186             svm->set_bias(model->rho[i+1]);
00187             norm_wcw[i]=model->normwcw[i];
00188 
00189 
00190             for (int32_t j=0; j<num_sv; j++)
00191             {
00192                 svm->set_alpha(j, model->sv_coef[i][j]);
00193                 svm->set_support_vector(j, model->SV[i][j].index);
00194             }
00195 
00196             set_svm(i, svm);
00197         }
00198 
00199         SG_FREE(problem.x);
00200         SG_FREE(problem.y);
00201         SG_FREE(x_space);
00202         for (int32_t i=0; i<m_num_classes; i++)
00203         {
00204             SG_FREE(model->SV[i]);
00205             model->SV[i]=NULL;
00206         }
00207         svm_destroy_model(model);
00208 
00209         if (scatter_type==TEST_RULE2)
00210             compute_norm_wc();
00211 
00212         model=NULL;
00213         return true;
00214     }
00215     else
00216         return false;
00217 }
00218 
00219 #ifdef USE_SVMLIGHT
00220 bool CScatterSVM::train_no_bias_svmlight()
00221 {
00222     CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer();
00223     CScatterKernelNormalizer* n=new CScatterKernelNormalizer(
00224                  m_num_classes-1, -1, m_labels, prev_normalizer);
00225     m_kernel->set_normalizer(n);
00226     m_kernel->init_normalizer();
00227 
00228     CSVMLightOneClass* light=new CSVMLightOneClass(get_C(), m_kernel);
00229     light->set_linadd_enabled(false);
00230     light->train();
00231 
00232     SG_FREE(norm_wcw);
00233     norm_wcw = SG_MALLOC(float64_t, m_num_classes);
00234 
00235     int32_t num_sv=light->get_num_support_vectors();
00236     svm_proto()->create_new_model(num_sv);
00237 
00238     for (int32_t i=0; i<num_sv; i++)
00239     {
00240         svm_proto()->set_alpha(i, light->get_alpha(i));
00241         svm_proto()->set_support_vector(i, light->get_support_vector(i));
00242     }
00243 
00244     m_kernel->set_normalizer(prev_normalizer);
00245     return true;
00246 }
00247 #endif //USE_SVMLIGHT
00248 
00249 bool CScatterSVM::train_testrule12()
00250 {
00251     struct svm_node* x_space;
00252     problem.l=m_labels->get_num_labels();
00253     SG_INFO("%d trainlabels\n", problem.l)
00254 
00255     problem.y=SG_MALLOC(float64_t, problem.l);
00256     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00257     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00258 
00259     for (int32_t i=0; i<problem.l; i++)
00260     {
00261         problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i);
00262         problem.x[i]=&x_space[2*i];
00263         x_space[2*i].index=i;
00264         x_space[2*i+1].index=-1;
00265     }
00266 
00267     int32_t weights_label[2]={-1,+1};
00268     float64_t weights[2]={1.0,get_C()/get_C()};
00269 
00270     ASSERT(m_kernel && m_kernel->has_features())
00271     ASSERT(m_kernel->get_num_vec_lhs()==problem.l)
00272 
00273     param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM
00274     param.kernel_type = LINEAR;
00275     param.degree = 3;
00276     param.gamma = 0;    // 1/k
00277     param.coef0 = 0;
00278     param.nu = get_nu(); // Nu
00279     param.kernel=m_kernel;
00280     param.cache_size = m_kernel->get_cache_size();
00281     param.C = 0;
00282     param.eps = get_epsilon();
00283     param.p = 0.1;
00284     param.shrinking = 0;
00285     param.nr_weight = 2;
00286     param.weight_label = weights_label;
00287     param.weight = weights;
00288     param.nr_class=m_num_classes;
00289     param.use_bias = svm_proto()->get_bias_enabled();
00290 
00291     const char* error_msg = svm_check_parameter(&problem,&param);
00292 
00293     if(error_msg)
00294         SG_ERROR("Error: %s\n",error_msg)
00295 
00296     model = svm_train(&problem, &param);
00297 
00298     if (model)
00299     {
00300         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef))
00301 
00302         ASSERT(model->nr_class==m_num_classes)
00303         create_multiclass_svm(m_num_classes);
00304 
00305         rho=model->rho[0];
00306 
00307         SG_FREE(norm_wcw);
00308         norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements());
00309 
00310         for (int32_t i=0; i<m_num_classes; i++)
00311         {
00312             int32_t num_sv=model->nSV[i];
00313 
00314             CSVM* svm=new CSVM(num_sv);
00315             svm->set_bias(model->rho[i+1]);
00316             norm_wcw[i]=model->normwcw[i];
00317 
00318 
00319             for (int32_t j=0; j<num_sv; j++)
00320             {
00321                 svm->set_alpha(j, model->sv_coef[i][j]);
00322                 svm->set_support_vector(j, model->SV[i][j].index);
00323             }
00324 
00325             set_svm(i, svm);
00326         }
00327 
00328         SG_FREE(problem.x);
00329         SG_FREE(problem.y);
00330         SG_FREE(x_space);
00331         for (int32_t i=0; i<m_num_classes; i++)
00332         {
00333             SG_FREE(model->SV[i]);
00334             model->SV[i]=NULL;
00335         }
00336         svm_destroy_model(model);
00337 
00338         if (scatter_type==TEST_RULE2)
00339             compute_norm_wc();
00340 
00341         model=NULL;
00342         return true;
00343     }
00344     else
00345         return false;
00346 }
00347 
00348 void CScatterSVM::compute_norm_wc()
00349 {
00350     SG_FREE(norm_wc);
00351     norm_wc = SG_MALLOC(float64_t, m_machines->get_num_elements());
00352     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00353         norm_wc[i]=0;
00354 
00355 
00356     for (int c=0; c<m_machines->get_num_elements(); c++)
00357     {
00358         CSVM* svm=get_svm(c);
00359         int32_t num_sv = svm->get_num_support_vectors();
00360 
00361         for (int32_t i=0; i<num_sv; i++)
00362         {
00363             int32_t ii=svm->get_support_vector(i);
00364             for (int32_t j=0; j<num_sv; j++)
00365             {
00366                 int32_t jj=svm->get_support_vector(j);
00367                 norm_wc[c]+=svm->get_alpha(i)*m_kernel->kernel(ii,jj)*svm->get_alpha(j);
00368             }
00369         }
00370     }
00371 
00372     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00373         norm_wc[i]=CMath::sqrt(norm_wc[i]);
00374 
00375     SGVector<float64_t>::display_vector(norm_wc, m_machines->get_num_elements(), "norm_wc");
00376 }
00377 
00378 CLabels* CScatterSVM::classify_one_vs_rest()
00379 {
00380     CMulticlassLabels* output=NULL;
00381     if (!m_kernel)
00382     {
00383         SG_ERROR("SVM can not proceed without kernel!\n")
00384         return NULL;
00385     }
00386 
00387     if (!( m_kernel && m_kernel->get_num_vec_lhs() && m_kernel->get_num_vec_rhs()))
00388         return NULL;
00389 
00390     int32_t num_vectors=m_kernel->get_num_vec_rhs();
00391 
00392     output=new CMulticlassLabels(num_vectors);
00393     SG_REF(output);
00394 
00395     if (scatter_type == TEST_RULE1)
00396     {
00397         ASSERT(m_machines->get_num_elements()>0)
00398         for (int32_t i=0; i<num_vectors; i++)
00399             output->set_label(i, apply_one(i));
00400     }
00401 #ifdef USE_SVMLIGHT
00402     else if (scatter_type == NO_BIAS_SVMLIGHT)
00403     {
00404         float64_t* outputs=SG_MALLOC(float64_t, num_vectors*m_num_classes);
00405         SGVector<float64_t>::fill_vector(outputs,num_vectors*m_num_classes,0.0);
00406 
00407         for (int32_t i=0; i<num_vectors; i++)
00408         {
00409             for (int32_t j=0; j<svm_proto()->get_num_support_vectors(); j++)
00410             {
00411                 float64_t score=m_kernel->kernel(svm_proto()->get_support_vector(j), i)*svm_proto()->get_alpha(j);
00412                 int32_t label=((CMulticlassLabels*) m_labels)->get_int_label(svm_proto()->get_support_vector(j));
00413                 for (int32_t c=0; c<m_num_classes; c++)
00414                 {
00415                     float64_t s= (label==c) ? (m_num_classes-1) : (-1);
00416                     outputs[c+i*m_num_classes]+=s*score;
00417                 }
00418             }
00419         }
00420 
00421         for (int32_t i=0; i<num_vectors; i++)
00422         {
00423             int32_t winner=0;
00424             float64_t max_out=outputs[i*m_num_classes+0];
00425 
00426             for (int32_t j=1; j<m_num_classes; j++)
00427             {
00428                 float64_t out=outputs[i*m_num_classes+j];
00429 
00430                 if (out>max_out)
00431                 {
00432                     winner=j;
00433                     max_out=out;
00434                 }
00435             }
00436 
00437             output->set_label(i, winner);
00438         }
00439 
00440         SG_FREE(outputs);
00441     }
00442 #endif //USE_SVMLIGHT
00443     else
00444     {
00445         ASSERT(m_machines->get_num_elements()>0)
00446         ASSERT(num_vectors==output->get_num_labels())
00447         CLabels** outputs=SG_MALLOC(CLabels*, m_machines->get_num_elements());
00448 
00449         for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00450         {
00451             //SG_PRINT("svm %d\n", i)
00452             CSVM *svm = get_svm(i);
00453             ASSERT(svm)
00454             svm->set_kernel(m_kernel);
00455             svm->set_labels(m_labels);
00456             outputs[i]=svm->apply();
00457             SG_UNREF(svm);
00458         }
00459 
00460         for (int32_t i=0; i<num_vectors; i++)
00461         {
00462             int32_t winner=0;
00463             float64_t max_out=((CRegressionLabels*) outputs[0])->get_label(i)/norm_wc[0];
00464 
00465             for (int32_t j=1; j<m_machines->get_num_elements(); j++)
00466             {
00467                 float64_t out=((CRegressionLabels*) outputs[j])->get_label(i)/norm_wc[j];
00468 
00469                 if (out>max_out)
00470                 {
00471                     winner=j;
00472                     max_out=out;
00473                 }
00474             }
00475 
00476             output->set_label(i, winner);
00477         }
00478 
00479         for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00480             SG_UNREF(outputs[i]);
00481 
00482         SG_FREE(outputs);
00483     }
00484 
00485     return output;
00486 }
00487 
00488 float64_t CScatterSVM::apply_one(int32_t num)
00489 {
00490     ASSERT(m_machines->get_num_elements()>0)
00491     float64_t* outputs=SG_MALLOC(float64_t, m_machines->get_num_elements());
00492     int32_t winner=0;
00493 
00494     if (scatter_type == TEST_RULE1)
00495     {
00496         for (int32_t c=0; c<m_machines->get_num_elements(); c++)
00497             outputs[c]=get_svm(c)->get_bias()-rho;
00498 
00499         for (int32_t c=0; c<m_machines->get_num_elements(); c++)
00500         {
00501             float64_t v=0;
00502 
00503             for (int32_t i=0; i<get_svm(c)->get_num_support_vectors(); i++)
00504             {
00505                 float64_t alpha=get_svm(c)->get_alpha(i);
00506                 int32_t svidx=get_svm(c)->get_support_vector(i);
00507                 v += alpha*m_kernel->kernel(svidx, num);
00508             }
00509 
00510             outputs[c] += v;
00511             for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00512                 outputs[j] -= v/m_machines->get_num_elements();
00513         }
00514 
00515         for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00516             outputs[j]/=norm_wcw[j];
00517 
00518         float64_t max_out=outputs[0];
00519         for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00520         {
00521             if (outputs[j]>max_out)
00522             {
00523                 max_out=outputs[j];
00524                 winner=j;
00525             }
00526         }
00527     }
00528 #ifdef USE_SVMLIGHT
00529     else if (scatter_type == NO_BIAS_SVMLIGHT)
00530     {
00531         SG_ERROR("Use classify...\n")
00532     }
00533 #endif //USE_SVMLIGHT
00534     else
00535     {
00536         float64_t max_out=get_svm(0)->apply_one(num)/norm_wc[0];
00537 
00538         for (int32_t i=1; i<m_machines->get_num_elements(); i++)
00539         {
00540             outputs[i]=get_svm(i)->apply_one(num)/norm_wc[i];
00541             if (outputs[i]>max_out)
00542             {
00543                 winner=i;
00544                 max_out=outputs[i];
00545             }
00546         }
00547     }
00548 
00549     SG_FREE(outputs);
00550     return winner;
00551 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation