SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MulticlassSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/common.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/multiclass/MulticlassSVM.h>
00014 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00015 
00016 using namespace shogun;
00017 
00018 CMulticlassSVM::CMulticlassSVM()
00019     :CKernelMulticlassMachine(new CMulticlassOneVsRestStrategy(), NULL, new CSVM(0), NULL)
00020 {
00021     init();
00022 }
00023 
00024 CMulticlassSVM::CMulticlassSVM(CMulticlassStrategy *strategy)
00025     :CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL)
00026 {
00027     init();
00028 }
00029 
00030 CMulticlassSVM::CMulticlassSVM(
00031     CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab)
00032     : CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab)
00033 {
00034     init();
00035     m_C=C;
00036 }
00037 
00038 CMulticlassSVM::~CMulticlassSVM()
00039 {
00040 }
00041 
00042 void CMulticlassSVM::init()
00043 {
00044     SG_ADD(&m_C, "C", "C regularization constant",MS_AVAILABLE);
00045     m_C=0;
00046 }
00047 
00048 bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
00049 {
00050     if (num_classes>0)
00051     {
00052         int32_t num_svms=m_multiclass_strategy->get_num_machines();
00053 
00054         m_machines->reset_array();
00055         for (index_t i=0; i<num_svms; ++i)
00056             m_machines->push_back(NULL);
00057 
00058         return true;
00059     }
00060     return false;
00061 }
00062 
00063 bool CMulticlassSVM::set_svm(int32_t num, CSVM* svm)
00064 {
00065     if (m_machines->get_num_elements()>0 && m_machines->get_num_elements()>num && num>=0 && svm)
00066     {
00067         m_machines->set_element(svm, num);
00068         return true;
00069     }
00070     return false;
00071 }
00072 
00073 bool CMulticlassSVM::init_machines_for_apply(CFeatures* data)
00074 {
00075     if (is_data_locked())
00076     {
00077         SG_ERROR("CKernelMachine::apply(CFeatures*) cannot be called when "
00078                 "data_lock was called before. Call data_unlock to allow.");
00079     }
00080 
00081     if (!m_kernel)
00082         SG_ERROR("No kernel assigned!\n")
00083 
00084     CFeatures* lhs=m_kernel->get_lhs();
00085     if (!lhs && m_kernel->get_kernel_type()!=K_COMBINED)
00086         SG_ERROR("%s: No left hand side specified\n", get_name())
00087 
00088     if (m_kernel->get_kernel_type()!=K_COMBINED && !lhs->get_num_vectors())
00089     {
00090         SG_ERROR("%s: No vectors on left hand side (%s). This is probably due to"
00091                 " an implementation error in %s, where it was forgotten to set "
00092                 "the data (m_svs) indices\n", get_name(),
00093                 data->get_name());
00094     }
00095 
00096     if (data && m_kernel->get_kernel_type()!=K_COMBINED)
00097         m_kernel->init(lhs, data);
00098     SG_UNREF(lhs);
00099 
00100     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00101     {
00102         CSVM *the_svm = (CSVM *)m_machines->get_element(i);
00103         ASSERT(the_svm)
00104         the_svm->set_kernel(m_kernel);
00105         SG_UNREF(the_svm);
00106     }
00107 
00108     return true;
00109 }
00110 
00111 bool CMulticlassSVM::load(FILE* modelfl)
00112 {
00113     bool result=true;
00114     char char_buffer[1024];
00115     int32_t int_buffer;
00116     float64_t double_buffer;
00117     int32_t line_number=1;
00118     int32_t svm_idx=-1;
00119 
00120     SG_SET_LOCALE_C;
00121 
00122     if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00123         SG_ERROR("error in svm file, line nr:%d\n", line_number)
00124     else
00125     {
00126         char_buffer[15]='\0';
00127         if (strcmp("%MultiClassSVM", char_buffer)!=0)
00128             SG_ERROR("error in multiclass svm file, line nr:%d\n", line_number)
00129 
00130         line_number++;
00131     }
00132 
00133     int_buffer=0;
00134     if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00135         SG_ERROR("error in svm file, line nr:%d\n", line_number)
00136 
00137     if (!feof(modelfl))
00138         line_number++;
00139 
00140     if (int_buffer < 2)
00141         SG_ERROR("less than 2 classes - how is this multiclass?\n")
00142 
00143     create_multiclass_svm(int_buffer);
00144 
00145     int_buffer=0;
00146     if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00147         SG_ERROR("error in svm file, line nr:%d\n", line_number)
00148 
00149     if (!feof(modelfl))
00150         line_number++;
00151 
00152     if (m_machines->get_num_elements() != int_buffer)
00153         SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_machines->get_num_elements(), int_buffer)
00154 
00155     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00156         SG_ERROR("error in svm file, line nr:%d\n", line_number)
00157 
00158     if (!feof(modelfl))
00159         line_number++;
00160 
00161     for (int32_t n=0; n<m_machines->get_num_elements(); n++)
00162     {
00163         svm_idx=-1;
00164         if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00165         {
00166             result=false;
00167             SG_ERROR("error in svm file, line nr:%d\n", line_number)
00168         }
00169         else
00170         {
00171             char_buffer[4]='\0';
00172             if (strncmp("%SVM", char_buffer, 4)!=0)
00173             {
00174                 result=false;
00175                 SG_ERROR("error in svm file, line nr:%d\n", line_number)
00176             }
00177 
00178             if (svm_idx != n)
00179                 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx)
00180 
00181             line_number++;
00182         }
00183 
00184         int_buffer=0;
00185         if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00186             SG_ERROR("error in svm file, line nr:%d\n", line_number)
00187 
00188         if (svm_idx != n)
00189             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx)
00190 
00191         if (!feof(modelfl))
00192             line_number++;
00193 
00194         SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx)
00195         CSVM* svm=new CSVM(int_buffer);
00196 
00197         double_buffer=0;
00198 
00199         if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00200             SG_ERROR("error in svm file, line nr:%d\n", line_number)
00201 
00202         if (svm_idx != n)
00203             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx)
00204 
00205         if (!feof(modelfl))
00206             line_number++;
00207 
00208         svm->set_bias(double_buffer);
00209 
00210         if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00211             SG_ERROR("error in svm file, line nr:%d\n", line_number)
00212 
00213         if (svm_idx != n)
00214             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx)
00215 
00216         if (!feof(modelfl))
00217             line_number++;
00218 
00219         for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00220         {
00221             double_buffer=0;
00222             int_buffer=0;
00223 
00224             if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00225                 SG_ERROR("error in svm file, line nr:%d\n", line_number)
00226 
00227             if (!feof(modelfl))
00228                 line_number++;
00229 
00230             svm->set_support_vector(i, int_buffer);
00231             svm->set_alpha(i, double_buffer);
00232         }
00233 
00234         if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00235         {
00236             result=false;
00237             SG_ERROR("error in svm file, line nr:%d\n", line_number)
00238         }
00239         else
00240         {
00241             char_buffer[3]='\0';
00242             if (strcmp("];", char_buffer)!=0)
00243             {
00244                 result=false;
00245                 SG_ERROR("error in svm file, line nr:%d\n", line_number)
00246             }
00247             line_number++;
00248         }
00249 
00250         set_svm(n, svm);
00251     }
00252 
00253     svm_proto()->svm_loaded=result;
00254 
00255     SG_RESET_LOCALE;
00256     return result;
00257 }
00258 
00259 bool CMulticlassSVM::save(FILE* modelfl)
00260 {
00261     SG_SET_LOCALE_C;
00262 
00263     if (!m_kernel)
00264         SG_ERROR("Kernel not defined!\n")
00265 
00266     if (m_machines->get_num_elements()<1)
00267         SG_ERROR("Multiclass SVM not trained!\n")
00268 
00269     SG_INFO("Writing model file...")
00270     fprintf(modelfl,"%%MultiClassSVM\n");
00271     fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
00272     fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
00273     fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());
00274 
00275     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00276     {
00277         CSVM* svm=get_svm(i);
00278         ASSERT(svm)
00279         fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_machines->get_num_elements()-1);
00280         fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00281         fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00282 
00283         fprintf(modelfl, "alphas%d=[\n", i);
00284 
00285         for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00286         {
00287             fprintf(modelfl,"\t[%+10.16e,%d];\n",
00288                     svm->get_alpha(j), svm->get_support_vector(j));
00289         }
00290 
00291         fprintf(modelfl, "];\n");
00292     }
00293 
00294     SG_RESET_LOCALE;
00295     SG_DONE()
00296     return true ;
00297 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation