SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 2013 Heiko Strathmann 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #include <shogun/regression/svr/LibSVR.h> 00013 #include <shogun/labels/RegressionLabels.h> 00014 #include <shogun/io/SGIO.h> 00015 00016 using namespace shogun; 00017 00018 CLibSVR::CLibSVR() 00019 : CSVM() 00020 { 00021 model=NULL; 00022 solver_type=LIBSVR_EPSILON_SVR; 00023 } 00024 00025 CLibSVR::CLibSVR(float64_t C, float64_t svr_param, CKernel* k, CLabels* lab, 00026 LIBSVR_SOLVER_TYPE st) 00027 : CSVM() 00028 { 00029 model=NULL; 00030 00031 set_C(C,C); 00032 00033 switch (st) 00034 { 00035 case LIBSVR_EPSILON_SVR: 00036 set_tube_epsilon(svr_param); 00037 break; 00038 case LIBSVR_NU_SVR: 00039 set_nu(svr_param); 00040 break; 00041 default: 00042 SG_ERROR("CLibSVR::CLibSVR(): Unknown solver type!\n"); 00043 break; 00044 } 00045 00046 set_labels(lab); 00047 set_kernel(k); 00048 solver_type=st; 00049 } 00050 00051 CLibSVR::~CLibSVR() 00052 { 00053 SG_FREE(model); 00054 } 00055 00056 EMachineType CLibSVR::get_classifier_type() 00057 { 00058 return CT_LIBSVR; 00059 } 00060 00061 bool CLibSVR::train_machine(CFeatures* data) 00062 { 00063 ASSERT(kernel) 00064 ASSERT(m_labels && m_labels->get_num_labels()) 00065 ASSERT(m_labels->get_label_type() == LT_REGRESSION) 00066 00067 if (data) 00068 { 00069 if (m_labels->get_num_labels() != data->get_num_vectors()) 00070 SG_ERROR("Number of training vectors does not match number of labels\n") 00071 kernel->init(data, data); 00072 } 00073 00074 SG_FREE(model); 00075 00076 struct svm_node* x_space; 00077 00078 problem.l=m_labels->get_num_labels(); 00079 SG_INFO("%d trainlabels\n", problem.l) 00080 00081 problem.y=SG_MALLOC(float64_t, problem.l); 00082 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00083 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00084 00085 for (int32_t i=0; i<problem.l; i++) 00086 { 00087 problem.y[i]=((CRegressionLabels*) m_labels)->get_label(i); 00088 problem.x[i]=&x_space[2*i]; 00089 x_space[2*i].index=i; 00090 x_space[2*i+1].index=-1; 00091 } 00092 00093 int32_t weights_label[2]={-1,+1}; 00094 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00095 00096 switch (solver_type) 00097 { 00098 case LIBSVR_EPSILON_SVR: 00099 param.svm_type=EPSILON_SVR; 00100 break; 00101 case LIBSVR_NU_SVR: 00102 param.svm_type=NU_SVR; 00103 break; 00104 default: 00105 SG_ERROR("%s::train_machine(): Unknown solver type!\n", get_name()); 00106 break; 00107 } 00108 00109 param.kernel_type = LINEAR; 00110 param.degree = 3; 00111 param.gamma = 0; // 1/k 00112 param.coef0 = 0; 00113 param.nu = nu; 00114 param.kernel=kernel; 00115 param.cache_size = kernel->get_cache_size(); 00116 param.max_train_time = m_max_train_time; 00117 param.C = get_C1(); 00118 param.eps = epsilon; 00119 param.p = tube_epsilon; 00120 param.shrinking = 1; 00121 param.nr_weight = 2; 00122 param.weight_label = weights_label; 00123 param.weight = weights; 00124 param.use_bias = get_bias_enabled(); 00125 00126 const char* error_msg = svm_check_parameter(&problem,¶m); 00127 00128 if(error_msg) 00129 SG_ERROR("Error: %s\n",error_msg) 00130 00131 model = svm_train(&problem, ¶m); 00132 00133 if (model) 00134 { 00135 ASSERT(model->nr_class==2) 00136 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0])) 00137 00138 int32_t num_sv=model->l; 00139 00140 create_new_model(num_sv); 00141 00142 CSVM::set_objective(model->objective); 00143 00144 set_bias(-model->rho[0]); 00145 00146 for (int32_t i=0; i<num_sv; i++) 00147 { 00148 set_support_vector(i, (model->SV[i])->index); 00149 set_alpha(i, model->sv_coef[0][i]); 00150 } 00151 00152 SG_FREE(problem.x); 00153 SG_FREE(problem.y); 00154 SG_FREE(x_space); 00155 00156 svm_destroy_model(model); 00157 model=NULL; 00158 return true; 00159 } 00160 else 00161 return false; 00162 }