SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
LibSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/classifier/svm/LibSVM.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/labels/BinaryLabels.h>
00014 
00015 using namespace shogun;
00016 
00017 CLibSVM::CLibSVM()
00018 : CSVM(), model(NULL), solver_type(LIBSVM_C_SVC)
00019 {
00020 }
00021 
00022 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st)
00023 : CSVM(), model(NULL), solver_type(st)
00024 {
00025 }
00026 
00027 
00028 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab, LIBSVM_SOLVER_TYPE st)
00029 : CSVM(C, k, lab), model(NULL), solver_type(st)
00030 {
00031     problem = svm_problem();
00032 }
00033 
00034 CLibSVM::~CLibSVM()
00035 {
00036 }
00037 
00038 
00039 bool CLibSVM::train_machine(CFeatures* data)
00040 {
00041     struct svm_node* x_space;
00042 
00043     ASSERT(m_labels && m_labels->get_num_labels())
00044     ASSERT(m_labels->get_label_type() == LT_BINARY)
00045 
00046     if (data)
00047     {
00048         if (m_labels->get_num_labels() != data->get_num_vectors())
00049         {
00050             SG_ERROR("%s::train_machine(): Number of training vectors (%d) does"
00051                     " not match number of labels (%d)\n", get_name(),
00052                     data->get_num_vectors(), m_labels->get_num_labels());
00053         }
00054         kernel->init(data, data);
00055     }
00056 
00057     problem.l=m_labels->get_num_labels();
00058     SG_INFO("%d trainlabels\n", problem.l)
00059 
00060     // set linear term
00061     if (m_linear_term.vlen>0)
00062     {
00063         if (m_labels->get_num_labels()!=m_linear_term.vlen)
00064             SG_ERROR("Number of training vectors does not match length of linear term\n")
00065 
00066         // set with linear term from base class
00067         problem.pv = get_linear_term_array();
00068     }
00069     else
00070     {
00071         // fill with minus ones
00072         problem.pv = SG_MALLOC(float64_t, problem.l);
00073 
00074         for (int i=0; i!=problem.l; i++)
00075             problem.pv[i] = -1.0;
00076     }
00077 
00078     problem.y=SG_MALLOC(float64_t, problem.l);
00079     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00080     problem.C=SG_MALLOC(float64_t, problem.l);
00081 
00082     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00083 
00084     for (int32_t i=0; i<problem.l; i++)
00085     {
00086         problem.y[i]=((CBinaryLabels*) m_labels)->get_label(i);
00087         problem.x[i]=&x_space[2*i];
00088         x_space[2*i].index=i;
00089         x_space[2*i+1].index=-1;
00090     }
00091 
00092     int32_t weights_label[2]={-1,+1};
00093     float64_t weights[2]={1.0,get_C2()/get_C1()};
00094 
00095     ASSERT(kernel && kernel->has_features())
00096     ASSERT(kernel->get_num_vec_lhs()==problem.l)
00097 
00098     switch (solver_type)
00099     {
00100     case LIBSVM_C_SVC:
00101         param.svm_type=C_SVC;
00102         break;
00103     case LIBSVM_NU_SVC:
00104         param.svm_type=NU_SVC;
00105         break;
00106     default:
00107         SG_ERROR("%s::train_machine(): Unknown solver type!\n", get_name());
00108         break;
00109     }
00110 
00111     param.kernel_type = LINEAR;
00112     param.degree = 3;
00113     param.gamma = 0;    // 1/k
00114     param.coef0 = 0;
00115     param.nu = get_nu();
00116     param.kernel=kernel;
00117     param.cache_size = kernel->get_cache_size();
00118     param.max_train_time = m_max_train_time;
00119     param.C = get_C1();
00120     param.eps = epsilon;
00121     param.p = 0.1;
00122     param.shrinking = 1;
00123     param.nr_weight = 2;
00124     param.weight_label = weights_label;
00125     param.weight = weights;
00126     param.use_bias = get_bias_enabled();
00127 
00128     const char* error_msg = svm_check_parameter(&problem, &param);
00129 
00130     if(error_msg)
00131         SG_ERROR("Error: %s\n",error_msg)
00132 
00133     model = svm_train(&problem, &param);
00134 
00135     if (model)
00136     {
00137         ASSERT(model->nr_class==2)
00138         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]))
00139 
00140         int32_t num_sv=model->l;
00141 
00142         create_new_model(num_sv);
00143         CSVM::set_objective(model->objective);
00144 
00145         float64_t sgn=model->label[0];
00146 
00147         set_bias(-sgn*model->rho[0]);
00148 
00149         for (int32_t i=0; i<num_sv; i++)
00150         {
00151             set_support_vector(i, (model->SV[i])->index);
00152             set_alpha(i, sgn*model->sv_coef[0][i]);
00153         }
00154 
00155         SG_FREE(problem.x);
00156         SG_FREE(problem.y);
00157         SG_FREE(problem.pv);
00158         SG_FREE(problem.C);
00159 
00160 
00161         SG_FREE(x_space);
00162 
00163         svm_destroy_model(model);
00164         model=NULL;
00165         return true;
00166     }
00167     else
00168         return false;
00169 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation