SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/common.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/multiclass/MulticlassSVM.h> 00014 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00015 00016 using namespace shogun; 00017 00018 CMulticlassSVM::CMulticlassSVM() 00019 :CKernelMulticlassMachine(new CMulticlassOneVsRestStrategy(), NULL, new CSVM(0), NULL) 00020 { 00021 init(); 00022 } 00023 00024 CMulticlassSVM::CMulticlassSVM(CMulticlassStrategy *strategy) 00025 :CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL) 00026 { 00027 init(); 00028 } 00029 00030 CMulticlassSVM::CMulticlassSVM( 00031 CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab) 00032 : CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab) 00033 { 00034 init(); 00035 m_C=C; 00036 } 00037 00038 CMulticlassSVM::~CMulticlassSVM() 00039 { 00040 } 00041 00042 void CMulticlassSVM::init() 00043 { 00044 SG_ADD(&m_C, "C", "C regularization constant",MS_AVAILABLE); 00045 m_C=0; 00046 } 00047 00048 bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes) 00049 { 00050 if (num_classes>0) 00051 { 00052 int32_t num_svms=m_multiclass_strategy->get_num_machines(); 00053 00054 m_machines->reset_array(); 00055 for (index_t i=0; i<num_svms; ++i) 00056 m_machines->push_back(NULL); 00057 00058 return true; 00059 } 00060 return false; 00061 } 00062 00063 bool CMulticlassSVM::set_svm(int32_t num, CSVM* svm) 00064 { 00065 if (m_machines->get_num_elements()>0 && m_machines->get_num_elements()>num && num>=0 && svm) 00066 { 00067 m_machines->set_element(svm, num); 00068 return true; 00069 } 00070 return false; 00071 } 00072 00073 bool CMulticlassSVM::init_machines_for_apply(CFeatures* data) 00074 { 00075 if (is_data_locked()) 00076 { 00077 SG_ERROR("CKernelMachine::apply(CFeatures*) cannot be called when " 00078 "data_lock was called before. Call data_unlock to allow."); 00079 } 00080 00081 if (!m_kernel) 00082 SG_ERROR("No kernel assigned!\n") 00083 00084 CFeatures* lhs=m_kernel->get_lhs(); 00085 if (!lhs && m_kernel->get_kernel_type()!=K_COMBINED) 00086 SG_ERROR("%s: No left hand side specified\n", get_name()) 00087 00088 if (m_kernel->get_kernel_type()!=K_COMBINED && !lhs->get_num_vectors()) 00089 { 00090 SG_ERROR("%s: No vectors on left hand side (%s). This is probably due to" 00091 " an implementation error in %s, where it was forgotten to set " 00092 "the data (m_svs) indices\n", get_name(), 00093 data->get_name()); 00094 } 00095 00096 if (data && m_kernel->get_kernel_type()!=K_COMBINED) 00097 m_kernel->init(lhs, data); 00098 SG_UNREF(lhs); 00099 00100 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00101 { 00102 CSVM *the_svm = (CSVM *)m_machines->get_element(i); 00103 ASSERT(the_svm) 00104 the_svm->set_kernel(m_kernel); 00105 SG_UNREF(the_svm); 00106 } 00107 00108 return true; 00109 } 00110 00111 bool CMulticlassSVM::load(FILE* modelfl) 00112 { 00113 bool result=true; 00114 char char_buffer[1024]; 00115 int32_t int_buffer; 00116 float64_t double_buffer; 00117 int32_t line_number=1; 00118 int32_t svm_idx=-1; 00119 00120 SG_SET_LOCALE_C; 00121 00122 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF) 00123 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00124 else 00125 { 00126 char_buffer[15]='\0'; 00127 if (strcmp("%MultiClassSVM", char_buffer)!=0) 00128 SG_ERROR("error in multiclass svm file, line nr:%d\n", line_number) 00129 00130 line_number++; 00131 } 00132 00133 int_buffer=0; 00134 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1) 00135 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00136 00137 if (!feof(modelfl)) 00138 line_number++; 00139 00140 if (int_buffer < 2) 00141 SG_ERROR("less than 2 classes - how is this multiclass?\n") 00142 00143 create_multiclass_svm(int_buffer); 00144 00145 int_buffer=0; 00146 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1) 00147 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00148 00149 if (!feof(modelfl)) 00150 line_number++; 00151 00152 if (m_machines->get_num_elements() != int_buffer) 00153 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_machines->get_num_elements(), int_buffer) 00154 00155 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1) 00156 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00157 00158 if (!feof(modelfl)) 00159 line_number++; 00160 00161 for (int32_t n=0; n<m_machines->get_num_elements(); n++) 00162 { 00163 svm_idx=-1; 00164 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF) 00165 { 00166 result=false; 00167 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00168 } 00169 else 00170 { 00171 char_buffer[4]='\0'; 00172 if (strncmp("%SVM", char_buffer, 4)!=0) 00173 { 00174 result=false; 00175 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00176 } 00177 00178 if (svm_idx != n) 00179 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx) 00180 00181 line_number++; 00182 } 00183 00184 int_buffer=0; 00185 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2) 00186 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00187 00188 if (svm_idx != n) 00189 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx) 00190 00191 if (!feof(modelfl)) 00192 line_number++; 00193 00194 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx) 00195 CSVM* svm=new CSVM(int_buffer); 00196 00197 double_buffer=0; 00198 00199 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2) 00200 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00201 00202 if (svm_idx != n) 00203 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx) 00204 00205 if (!feof(modelfl)) 00206 line_number++; 00207 00208 svm->set_bias(double_buffer); 00209 00210 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1) 00211 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00212 00213 if (svm_idx != n) 00214 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx) 00215 00216 if (!feof(modelfl)) 00217 line_number++; 00218 00219 for (int32_t i=0; i<svm->get_num_support_vectors(); i++) 00220 { 00221 double_buffer=0; 00222 int_buffer=0; 00223 00224 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2) 00225 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00226 00227 if (!feof(modelfl)) 00228 line_number++; 00229 00230 svm->set_support_vector(i, int_buffer); 00231 svm->set_alpha(i, double_buffer); 00232 } 00233 00234 if (fscanf(modelfl,"%2s", char_buffer) == EOF) 00235 { 00236 result=false; 00237 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00238 } 00239 else 00240 { 00241 char_buffer[3]='\0'; 00242 if (strcmp("];", char_buffer)!=0) 00243 { 00244 result=false; 00245 SG_ERROR("error in svm file, line nr:%d\n", line_number) 00246 } 00247 line_number++; 00248 } 00249 00250 set_svm(n, svm); 00251 } 00252 00253 svm_proto()->svm_loaded=result; 00254 00255 SG_RESET_LOCALE; 00256 return result; 00257 } 00258 00259 bool CMulticlassSVM::save(FILE* modelfl) 00260 { 00261 SG_SET_LOCALE_C; 00262 00263 if (!m_kernel) 00264 SG_ERROR("Kernel not defined!\n") 00265 00266 if (m_machines->get_num_elements()<1) 00267 SG_ERROR("Multiclass SVM not trained!\n") 00268 00269 SG_INFO("Writing model file...") 00270 fprintf(modelfl,"%%MultiClassSVM\n"); 00271 fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes()); 00272 fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements()); 00273 fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name()); 00274 00275 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00276 { 00277 CSVM* svm=get_svm(i); 00278 ASSERT(svm) 00279 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_machines->get_num_elements()-1); 00280 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors()); 00281 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias()); 00282 00283 fprintf(modelfl, "alphas%d=[\n", i); 00284 00285 for(int32_t j=0; j<svm->get_num_support_vectors(); j++) 00286 { 00287 fprintf(modelfl,"\t[%+10.16e,%d];\n", 00288 svm->get_alpha(j), svm->get_support_vector(j)); 00289 } 00290 00291 fprintf(modelfl, "];\n"); 00292 } 00293 00294 SG_RESET_LOCALE; 00295 SG_DONE() 00296 return true ; 00297 }