SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #include <shogun/lib/config.h> 00012 #include <shogun/multiclass/MulticlassLibLinear.h> 00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00014 #include <shogun/mathematics/Math.h> 00015 #include <shogun/lib/v_array.h> 00016 #include <shogun/lib/Signal.h> 00017 #include <shogun/labels/MulticlassLabels.h> 00018 00019 using namespace shogun; 00020 00021 CMulticlassLibLinear::CMulticlassLibLinear() : 00022 CLinearMulticlassMachine() 00023 { 00024 init_defaults(); 00025 } 00026 00027 CMulticlassLibLinear::CMulticlassLibLinear(float64_t C, CDotFeatures* features, CLabels* labs) : 00028 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),features,NULL,labs) 00029 { 00030 init_defaults(); 00031 set_C(C); 00032 } 00033 00034 void CMulticlassLibLinear::init_defaults() 00035 { 00036 set_C(1.0); 00037 set_epsilon(1e-2); 00038 set_max_iter(10000); 00039 set_use_bias(false); 00040 set_save_train_state(false); 00041 m_train_state = NULL; 00042 } 00043 00044 void CMulticlassLibLinear::register_parameters() 00045 { 00046 SG_ADD(&m_C, "m_C", "regularization constant",MS_AVAILABLE); 00047 SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE); 00048 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE); 00049 SG_ADD(&m_use_bias, "m_use_bias", "indicates whether bias should be used",MS_NOT_AVAILABLE); 00050 SG_ADD(&m_save_train_state, "m_save_train_state", "indicates whether bias should be used",MS_NOT_AVAILABLE); 00051 } 00052 00053 CMulticlassLibLinear::~CMulticlassLibLinear() 00054 { 00055 reset_train_state(); 00056 } 00057 00058 SGVector<int32_t> CMulticlassLibLinear::get_support_vectors() const 00059 { 00060 if (!m_train_state) 00061 SG_ERROR("Please enable save_train_state option and train machine.\n") 00062 00063 ASSERT(m_labels && m_labels->get_label_type() == LT_MULTICLASS) 00064 00065 int32_t num_vectors = m_features->get_num_vectors(); 00066 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00067 00068 v_array<int32_t> nz_idxs; 00069 nz_idxs.reserve(num_vectors); 00070 00071 for (int32_t i=0; i<num_vectors; i++) 00072 { 00073 for (int32_t y=0; y<num_classes; y++) 00074 { 00075 if (CMath::abs(m_train_state->alpha[i*num_classes+y])>1e-6) 00076 { 00077 nz_idxs.push(i); 00078 break; 00079 } 00080 } 00081 } 00082 int32_t num_nz = nz_idxs.index(); 00083 nz_idxs.reserve(num_nz); 00084 return SGVector<int32_t>(nz_idxs.begin,num_nz); 00085 } 00086 00087 SGMatrix<float64_t> CMulticlassLibLinear::obtain_regularizer_matrix() const 00088 { 00089 return SGMatrix<float64_t>(); 00090 } 00091 00092 bool CMulticlassLibLinear::train_machine(CFeatures* data) 00093 { 00094 if (data) 00095 set_features((CDotFeatures*)data); 00096 00097 ASSERT(m_features) 00098 ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS) 00099 ASSERT(m_multiclass_strategy) 00100 00101 int32_t num_vectors = m_features->get_num_vectors(); 00102 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00103 int32_t bias_n = m_use_bias ? 1 : 0; 00104 00105 liblinear_problem mc_problem; 00106 mc_problem.l = num_vectors; 00107 mc_problem.n = m_features->get_dim_feature_space() + bias_n; 00108 mc_problem.y = SG_MALLOC(float64_t, mc_problem.l); 00109 for (int32_t i=0; i<num_vectors; i++) 00110 mc_problem.y[i] = ((CMulticlassLabels*) m_labels)->get_int_label(i); 00111 00112 mc_problem.x = m_features; 00113 mc_problem.use_bias = m_use_bias; 00114 00115 SGMatrix<float64_t> w0 = obtain_regularizer_matrix(); 00116 00117 if (!m_train_state) 00118 m_train_state = new mcsvm_state(); 00119 00120 float64_t* C = SG_MALLOC(float64_t, num_vectors); 00121 for (int32_t i=0; i<num_vectors; i++) 00122 C[i] = m_C; 00123 00124 CSignal::clear_cancel(); 00125 00126 Solver_MCSVM_CS solver(&mc_problem,num_classes,C,w0.matrix,m_epsilon, 00127 m_max_iter,m_max_train_time,m_train_state); 00128 solver.solve(); 00129 00130 m_machines->reset_array(); 00131 for (int32_t i=0; i<num_classes; i++) 00132 { 00133 CLinearMachine* machine = new CLinearMachine(); 00134 SGVector<float64_t> cw(mc_problem.n-bias_n); 00135 00136 for (int32_t j=0; j<mc_problem.n-bias_n; j++) 00137 cw[j] = m_train_state->w[j*num_classes+i]; 00138 00139 machine->set_w(cw); 00140 00141 if (m_use_bias) 00142 machine->set_bias(m_train_state->w[(mc_problem.n-bias_n)*num_classes+i]); 00143 00144 m_machines->push_back(machine); 00145 } 00146 00147 if (!m_save_train_state) 00148 reset_train_state(); 00149 00150 SG_FREE(C); 00151 SG_FREE(mc_problem.y); 00152 00153 return true; 00154 }