SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
LibSVR.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 2013 Heiko Strathmann
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include <shogun/regression/svr/LibSVR.h>
00013 #include <shogun/labels/RegressionLabels.h>
00014 #include <shogun/io/SGIO.h>
00015 
00016 using namespace shogun;
00017 
00018 CLibSVR::CLibSVR()
00019 : CSVM()
00020 {
00021     model=NULL;
00022     solver_type=LIBSVR_EPSILON_SVR;
00023 }
00024 
00025 CLibSVR::CLibSVR(float64_t C, float64_t svr_param, CKernel* k, CLabels* lab,
00026         LIBSVR_SOLVER_TYPE st)
00027 : CSVM()
00028 {
00029     model=NULL;
00030 
00031     set_C(C,C);
00032 
00033     switch (st)
00034     {
00035     case LIBSVR_EPSILON_SVR:
00036         set_tube_epsilon(svr_param);
00037         break;
00038     case LIBSVR_NU_SVR:
00039         set_nu(svr_param);
00040         break;
00041     default:
00042         SG_ERROR("CLibSVR::CLibSVR(): Unknown solver type!\n");
00043         break;
00044     }
00045 
00046     set_labels(lab);
00047     set_kernel(k);
00048     solver_type=st;
00049 }
00050 
00051 CLibSVR::~CLibSVR()
00052 {
00053     SG_FREE(model);
00054 }
00055 
00056 EMachineType CLibSVR::get_classifier_type()
00057 {
00058     return CT_LIBSVR;
00059 }
00060 
00061 bool CLibSVR::train_machine(CFeatures* data)
00062 {
00063     ASSERT(kernel)
00064     ASSERT(m_labels && m_labels->get_num_labels())
00065     ASSERT(m_labels->get_label_type() == LT_REGRESSION)
00066 
00067     if (data)
00068     {
00069         if (m_labels->get_num_labels() != data->get_num_vectors())
00070             SG_ERROR("Number of training vectors does not match number of labels\n")
00071         kernel->init(data, data);
00072     }
00073 
00074     SG_FREE(model);
00075 
00076     struct svm_node* x_space;
00077 
00078     problem.l=m_labels->get_num_labels();
00079     SG_INFO("%d trainlabels\n", problem.l)
00080 
00081     problem.y=SG_MALLOC(float64_t, problem.l);
00082     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00083     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00084 
00085     for (int32_t i=0; i<problem.l; i++)
00086     {
00087         problem.y[i]=((CRegressionLabels*) m_labels)->get_label(i);
00088         problem.x[i]=&x_space[2*i];
00089         x_space[2*i].index=i;
00090         x_space[2*i+1].index=-1;
00091     }
00092 
00093     int32_t weights_label[2]={-1,+1};
00094     float64_t weights[2]={1.0,get_C2()/get_C1()};
00095 
00096     switch (solver_type)
00097     {
00098     case LIBSVR_EPSILON_SVR:
00099         param.svm_type=EPSILON_SVR;
00100         break;
00101     case LIBSVR_NU_SVR:
00102         param.svm_type=NU_SVR;
00103         break;
00104     default:
00105         SG_ERROR("%s::train_machine(): Unknown solver type!\n", get_name());
00106         break;
00107     }
00108 
00109     param.kernel_type = LINEAR;
00110     param.degree = 3;
00111     param.gamma = 0;    // 1/k
00112     param.coef0 = 0;
00113     param.nu = nu;
00114     param.kernel=kernel;
00115     param.cache_size = kernel->get_cache_size();
00116     param.max_train_time = m_max_train_time;
00117     param.C = get_C1();
00118     param.eps = epsilon;
00119     param.p = tube_epsilon;
00120     param.shrinking = 1;
00121     param.nr_weight = 2;
00122     param.weight_label = weights_label;
00123     param.weight = weights;
00124     param.use_bias = get_bias_enabled();
00125 
00126     const char* error_msg = svm_check_parameter(&problem,&param);
00127 
00128     if(error_msg)
00129         SG_ERROR("Error: %s\n",error_msg)
00130 
00131     model = svm_train(&problem, &param);
00132 
00133     if (model)
00134     {
00135         ASSERT(model->nr_class==2)
00136         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]))
00137 
00138         int32_t num_sv=model->l;
00139 
00140         create_new_model(num_sv);
00141 
00142         CSVM::set_objective(model->objective);
00143 
00144         set_bias(-model->rho[0]);
00145 
00146         for (int32_t i=0; i<num_sv; i++)
00147         {
00148             set_support_vector(i, (model->SV[i])->index);
00149             set_alpha(i, model->sv_coef[0][i]);
00150         }
00151 
00152         SG_FREE(problem.x);
00153         SG_FREE(problem.y);
00154         SG_FREE(x_space);
00155 
00156         svm_destroy_model(model);
00157         model=NULL;
00158         return true;
00159     }
00160     else
00161         return false;
00162 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation