SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/multiclass/MulticlassLibSVM.h> 00012 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h> 00013 #include <shogun/labels/MulticlassLabels.h> 00014 #include <shogun/io/SGIO.h> 00015 00016 using namespace shogun; 00017 00018 CMulticlassLibSVM::CMulticlassLibSVM(LIBSVM_SOLVER_TYPE st) 00019 : CMulticlassSVM(new CMulticlassOneVsOneStrategy()), model(NULL), solver_type(st) 00020 { 00021 } 00022 00023 CMulticlassLibSVM::CMulticlassLibSVM(float64_t C, CKernel* k, CLabels* lab) 00024 : CMulticlassSVM(new CMulticlassOneVsOneStrategy(), C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC) 00025 { 00026 } 00027 00028 CMulticlassLibSVM::~CMulticlassLibSVM() 00029 { 00030 } 00031 00032 bool CMulticlassLibSVM::train_machine(CFeatures* data) 00033 { 00034 struct svm_node* x_space; 00035 00036 problem = svm_problem(); 00037 00038 ASSERT(m_labels && m_labels->get_num_labels()) 00039 ASSERT(m_labels->get_label_type() == LT_MULTICLASS) 00040 int32_t num_classes = m_multiclass_strategy->get_num_classes(); 00041 problem.l=m_labels->get_num_labels(); 00042 SG_INFO("%d trainlabels, %d classes\n", problem.l, num_classes) 00043 00044 00045 if (data) 00046 { 00047 if (m_labels->get_num_labels() != data->get_num_vectors()) 00048 { 00049 SG_ERROR("Number of training vectors does not match number of " 00050 "labels\n"); 00051 } 00052 m_kernel->init(data, data); 00053 } 00054 00055 problem.y=SG_MALLOC(float64_t, problem.l); 00056 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00057 problem.pv=SG_MALLOC(float64_t, problem.l); 00058 problem.C=SG_MALLOC(float64_t, problem.l); 00059 00060 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00061 00062 for (int32_t i=0; i<problem.l; i++) 00063 { 00064 problem.pv[i]=-1.0; 00065 problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i); 00066 problem.x[i]=&x_space[2*i]; 00067 x_space[2*i].index=i; 00068 x_space[2*i+1].index=-1; 00069 } 00070 00071 ASSERT(m_kernel) 00072 00073 param.svm_type=solver_type; // C SVM or NU_SVM 00074 param.kernel_type = LINEAR; 00075 param.degree = 3; 00076 param.gamma = 0; // 1/k 00077 param.coef0 = 0; 00078 param.nu = get_nu(); // Nu 00079 param.kernel=m_kernel; 00080 param.cache_size = m_kernel->get_cache_size(); 00081 param.max_train_time = m_max_train_time; 00082 param.C = get_C(); 00083 param.eps = get_epsilon(); 00084 param.p = 0.1; 00085 param.shrinking = 1; 00086 param.nr_weight = 0; 00087 param.weight_label = NULL; 00088 param.weight = NULL; 00089 param.use_bias = svm_proto()->get_bias_enabled(); 00090 00091 const char* error_msg = svm_check_parameter(&problem,¶m); 00092 00093 if(error_msg) 00094 SG_ERROR("Error: %s\n",error_msg) 00095 00096 model = svm_train(&problem, ¶m); 00097 00098 if (model) 00099 { 00100 if (model->nr_class!=num_classes) 00101 { 00102 SG_ERROR("LibSVM model->nr_class=%d while num_classes=%d\n", 00103 model->nr_class, num_classes); 00104 } 00105 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef)) 00106 create_multiclass_svm(num_classes); 00107 00108 int32_t* offsets=SG_MALLOC(int32_t, num_classes); 00109 offsets[0]=0; 00110 00111 for (int32_t i=1; i<num_classes; i++) 00112 offsets[i] = offsets[i-1]+model->nSV[i-1]; 00113 00114 int32_t s=0; 00115 for (int32_t i=0; i<num_classes; i++) 00116 { 00117 for (int32_t j=i+1; j<num_classes; j++) 00118 { 00119 int32_t k, l; 00120 00121 float64_t sgn=1; 00122 if (model->label[i]>model->label[j]) 00123 sgn=-1; 00124 00125 int32_t num_sv=model->nSV[i]+model->nSV[j]; 00126 float64_t bias=-model->rho[s]; 00127 00128 ASSERT(num_sv>0) 00129 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]) 00130 00131 CSVM* svm=new CSVM(num_sv); 00132 00133 svm->set_bias(sgn*bias); 00134 00135 int32_t sv_idx=0; 00136 for (k=0; k<model->nSV[i]; k++) 00137 { 00138 SG_DEBUG("setting SV[%d] to %d\n", sv_idx, 00139 model->SV[offsets[i]+k]->index); 00140 svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index); 00141 svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]); 00142 sv_idx++; 00143 } 00144 00145 for (k=0; k<model->nSV[j]; k++) 00146 { 00147 SG_DEBUG("setting SV[%d] to %d\n", sv_idx, 00148 model->SV[offsets[i]+k]->index); 00149 svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index); 00150 svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]); 00151 sv_idx++; 00152 } 00153 00154 int32_t idx=0; 00155 00156 if (num_classes > 3) 00157 { 00158 if (sgn>0) 00159 { 00160 for (k=0; k<model->label[i]; k++) 00161 idx+=num_classes-k-1; 00162 00163 for (l=model->label[i]+1; l<model->label[j]; l++) 00164 idx++; 00165 } 00166 else 00167 { 00168 for (k=0; k<model->label[j]; k++) 00169 idx+=num_classes-k-1; 00170 00171 for (l=model->label[j]+1; l<model->label[i]; l++) 00172 idx++; 00173 } 00174 } 00175 else if (num_classes == 3) 00176 { 00177 idx = model->label[j]+model->label[i] - 1; 00178 } 00179 else if (num_classes == 2) 00180 { 00181 idx = i; 00182 } 00183 // 00184 // if (sgn>0) 00185 // idx=((num_classes-1)*model->label[i]+model->label[j])/2; 00186 // else 00187 // idx=((num_classes-1)*model->label[j]+model->label[i])/2; 00188 // 00189 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f " 00190 "label:(%d,%d) -> svm[%d]\n", 00191 s, num_sv, model->l, bias, model->label[i], 00192 model->label[j], idx); 00193 00194 REQUIRE(set_svm(idx, svm),"SVM set failed") 00195 s++; 00196 } 00197 } 00198 00199 set_objective(model->objective); 00200 00201 SG_FREE(offsets); 00202 SG_FREE(problem.x); 00203 SG_FREE(problem.y); 00204 SG_FREE(x_space); 00205 SG_FREE(problem.pv); 00206 SG_FREE(problem.C); 00207 00208 svm_destroy_model(model); 00209 model=NULL; 00210 00211 return true; 00212 } 00213 else 00214 return false; 00215 } 00216