SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/classifier/svm/LibSVM.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/labels/BinaryLabels.h> 00014 00015 using namespace shogun; 00016 00017 CLibSVM::CLibSVM() 00018 : CSVM(), model(NULL), solver_type(LIBSVM_C_SVC) 00019 { 00020 } 00021 00022 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st) 00023 : CSVM(), model(NULL), solver_type(st) 00024 { 00025 } 00026 00027 00028 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab, LIBSVM_SOLVER_TYPE st) 00029 : CSVM(C, k, lab), model(NULL), solver_type(st) 00030 { 00031 problem = svm_problem(); 00032 } 00033 00034 CLibSVM::~CLibSVM() 00035 { 00036 } 00037 00038 00039 bool CLibSVM::train_machine(CFeatures* data) 00040 { 00041 struct svm_node* x_space; 00042 00043 ASSERT(m_labels && m_labels->get_num_labels()) 00044 ASSERT(m_labels->get_label_type() == LT_BINARY) 00045 00046 if (data) 00047 { 00048 if (m_labels->get_num_labels() != data->get_num_vectors()) 00049 { 00050 SG_ERROR("%s::train_machine(): Number of training vectors (%d) does" 00051 " not match number of labels (%d)\n", get_name(), 00052 data->get_num_vectors(), m_labels->get_num_labels()); 00053 } 00054 kernel->init(data, data); 00055 } 00056 00057 problem.l=m_labels->get_num_labels(); 00058 SG_INFO("%d trainlabels\n", problem.l) 00059 00060 // set linear term 00061 if (m_linear_term.vlen>0) 00062 { 00063 if (m_labels->get_num_labels()!=m_linear_term.vlen) 00064 SG_ERROR("Number of training vectors does not match length of linear term\n") 00065 00066 // set with linear term from base class 00067 problem.pv = get_linear_term_array(); 00068 } 00069 else 00070 { 00071 // fill with minus ones 00072 problem.pv = SG_MALLOC(float64_t, problem.l); 00073 00074 for (int i=0; i!=problem.l; i++) 00075 problem.pv[i] = -1.0; 00076 } 00077 00078 problem.y=SG_MALLOC(float64_t, problem.l); 00079 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00080 problem.C=SG_MALLOC(float64_t, problem.l); 00081 00082 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00083 00084 for (int32_t i=0; i<problem.l; i++) 00085 { 00086 problem.y[i]=((CBinaryLabels*) m_labels)->get_label(i); 00087 problem.x[i]=&x_space[2*i]; 00088 x_space[2*i].index=i; 00089 x_space[2*i+1].index=-1; 00090 } 00091 00092 int32_t weights_label[2]={-1,+1}; 00093 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00094 00095 ASSERT(kernel && kernel->has_features()) 00096 ASSERT(kernel->get_num_vec_lhs()==problem.l) 00097 00098 switch (solver_type) 00099 { 00100 case LIBSVM_C_SVC: 00101 param.svm_type=C_SVC; 00102 break; 00103 case LIBSVM_NU_SVC: 00104 param.svm_type=NU_SVC; 00105 break; 00106 default: 00107 SG_ERROR("%s::train_machine(): Unknown solver type!\n", get_name()); 00108 break; 00109 } 00110 00111 param.kernel_type = LINEAR; 00112 param.degree = 3; 00113 param.gamma = 0; // 1/k 00114 param.coef0 = 0; 00115 param.nu = get_nu(); 00116 param.kernel=kernel; 00117 param.cache_size = kernel->get_cache_size(); 00118 param.max_train_time = m_max_train_time; 00119 param.C = get_C1(); 00120 param.eps = epsilon; 00121 param.p = 0.1; 00122 param.shrinking = 1; 00123 param.nr_weight = 2; 00124 param.weight_label = weights_label; 00125 param.weight = weights; 00126 param.use_bias = get_bias_enabled(); 00127 00128 const char* error_msg = svm_check_parameter(&problem, ¶m); 00129 00130 if(error_msg) 00131 SG_ERROR("Error: %s\n",error_msg) 00132 00133 model = svm_train(&problem, ¶m); 00134 00135 if (model) 00136 { 00137 ASSERT(model->nr_class==2) 00138 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0])) 00139 00140 int32_t num_sv=model->l; 00141 00142 create_new_model(num_sv); 00143 CSVM::set_objective(model->objective); 00144 00145 float64_t sgn=model->label[0]; 00146 00147 set_bias(-sgn*model->rho[0]); 00148 00149 for (int32_t i=0; i<num_sv; i++) 00150 { 00151 set_support_vector(i, (model->SV[i])->index); 00152 set_alpha(i, sgn*model->sv_coef[0][i]); 00153 } 00154 00155 SG_FREE(problem.x); 00156 SG_FREE(problem.y); 00157 SG_FREE(problem.pv); 00158 SG_FREE(problem.C); 00159 00160 00161 SG_FREE(x_space); 00162 00163 svm_destroy_model(model); 00164 model=NULL; 00165 return true; 00166 } 00167 else 00168 return false; 00169 }