SHOGUN
v3.2.0
|
00001 /* 00002 SVM with stochastic gradient 00003 Copyright (C) 2007- Leon Bottou 00004 00005 This program is free software; you can redistribute it and/or 00006 modify it under the terms of the GNU Lesser General Public 00007 License as published by the Free Software Foundation; either 00008 version 2.1 of the License, or (at your option) any later version. 00009 00010 This program is distributed in the hope that it will be useful, 00011 but WITHOUT ANY WARRANTY; without even the implied warranty of 00012 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00013 GNU General Public License for more details. 00014 00015 You should have received a copy of the GNU Lesser General Public 00016 License along with this library; if not, write to the Free Software 00017 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 00018 $Id: svmsgd.cpp,v 1.13 2007/10/02 20:40:06 cvs Exp $ 00019 00020 Shogun adjustments (w) 2008-2009 Soeren Sonnenburg 00021 */ 00022 00023 #include <shogun/classifier/svm/SVMSGD.h> 00024 #include <shogun/base/Parameter.h> 00025 #include <shogun/lib/Signal.h> 00026 #include <shogun/labels/BinaryLabels.h> 00027 #include <shogun/loss/HingeLoss.h> 00028 00029 using namespace shogun; 00030 00031 CSVMSGD::CSVMSGD() 00032 : CLinearMachine() 00033 { 00034 init(); 00035 } 00036 00037 CSVMSGD::CSVMSGD(float64_t C) 00038 : CLinearMachine() 00039 { 00040 init(); 00041 00042 C1=C; 00043 C2=C; 00044 } 00045 00046 CSVMSGD::CSVMSGD(float64_t C, CDotFeatures* traindat, CLabels* trainlab) 00047 : CLinearMachine() 00048 { 00049 init(); 00050 C1=C; 00051 C2=C; 00052 00053 set_features(traindat); 00054 set_labels(trainlab); 00055 } 00056 00057 CSVMSGD::~CSVMSGD() 00058 { 00059 SG_UNREF(loss); 00060 } 00061 00062 void CSVMSGD::set_loss_function(CLossFunction* loss_func) 00063 { 00064 SG_REF(loss_func); 00065 SG_UNREF(loss); 00066 loss=loss_func; 00067 } 00068 00069 bool CSVMSGD::train_machine(CFeatures* data) 00070 { 00071 // allocate memory for w and initialize everyting w and bias with 0 00072 ASSERT(m_labels) 00073 ASSERT(m_labels->get_label_type() == LT_BINARY) 00074 00075 if (data) 00076 { 00077 if (!data->has_property(FP_DOT)) 00078 SG_ERROR("Specified features are not of type CDotFeatures\n") 00079 set_features((CDotFeatures*) data); 00080 } 00081 00082 ASSERT(features) 00083 00084 int32_t num_train_labels=m_labels->get_num_labels(); 00085 int32_t num_vec=features->get_num_vectors(); 00086 00087 ASSERT(num_vec==num_train_labels) 00088 ASSERT(num_vec>0) 00089 00090 w=SGVector<float64_t>(features->get_dim_feature_space()); 00091 w.zero(); 00092 bias=0; 00093 00094 float64_t lambda= 1.0/(C1*num_vec); 00095 00096 // Shift t in order to have a 00097 // reasonable initial learning rate. 00098 // This assumes |x| \approx 1. 00099 float64_t maxw = 1.0 / sqrt(lambda); 00100 float64_t typw = sqrt(maxw); 00101 float64_t eta0 = typw / CMath::max(1.0,-loss->first_derivative(-typw,1)); 00102 t = 1 / (eta0 * lambda); 00103 00104 SG_INFO("lambda=%f, epochs=%d, eta0=%f\n", lambda, epochs, eta0) 00105 00106 00107 //do the sgd 00108 calibrate(); 00109 00110 SG_INFO("Training on %d vectors\n", num_vec) 00111 CSignal::clear_cancel(); 00112 00113 ELossType loss_type = loss->get_loss_type(); 00114 bool is_log_loss = false; 00115 if ((loss_type == L_LOGLOSS) || (loss_type == L_LOGLOSSMARGIN)) 00116 is_log_loss = true; 00117 00118 for(int32_t e=0; e<epochs && (!CSignal::cancel_computations()); e++) 00119 { 00120 count = skip; 00121 for (int32_t i=0; i<num_vec; i++) 00122 { 00123 float64_t eta = 1.0 / (lambda * t); 00124 float64_t y = ((CBinaryLabels*) m_labels)->get_label(i); 00125 float64_t z = y * (features->dense_dot(i, w.vector, w.vlen) + bias); 00126 00127 if (z < 1 || is_log_loss) 00128 { 00129 float64_t etd = -eta * loss->first_derivative(z,1); 00130 features->add_to_dense_vec(etd * y / wscale, i, w.vector, w.vlen); 00131 00132 if (use_bias) 00133 { 00134 if (use_regularized_bias) 00135 bias *= 1 - eta * lambda * bscale; 00136 bias += etd * y * bscale; 00137 } 00138 } 00139 00140 if (--count <= 0) 00141 { 00142 float64_t r = 1 - eta * lambda * skip; 00143 if (r < 0.8) 00144 r = pow(1 - eta * lambda, skip); 00145 SGVector<float64_t>::scale_vector(r, w.vector, w.vlen); 00146 count = skip; 00147 } 00148 t++; 00149 } 00150 } 00151 00152 float64_t wnorm = SGVector<float64_t>::dot(w.vector,w.vector, w.vlen); 00153 SG_INFO("Norm: %.6f, Bias: %.6f\n", wnorm, bias) 00154 00155 return true; 00156 } 00157 00158 void CSVMSGD::calibrate() 00159 { 00160 ASSERT(features) 00161 int32_t num_vec=features->get_num_vectors(); 00162 int32_t c_dim=features->get_dim_feature_space(); 00163 00164 ASSERT(num_vec>0) 00165 ASSERT(c_dim>0) 00166 00167 float64_t* c=SG_MALLOC(float64_t, c_dim); 00168 memset(c, 0, c_dim*sizeof(float64_t)); 00169 00170 SG_INFO("Estimating sparsity and bscale num_vec=%d num_feat=%d.\n", num_vec, c_dim) 00171 00172 // compute average gradient size 00173 int32_t n = 0; 00174 float64_t m = 0; 00175 float64_t r = 0; 00176 00177 for (int32_t j=0; j<num_vec && m<=1000; j++, n++) 00178 { 00179 r += features->get_nnz_features_for_vector(j); 00180 features->add_to_dense_vec(1, j, c, c_dim, true); 00181 00182 //waste cpu cycles for readability 00183 //(only changed dims need checking) 00184 m=SGVector<float64_t>::max(c, c_dim); 00185 } 00186 00187 // bias update scaling 00188 bscale = 0.5*m/n; 00189 00190 // compute weight decay skip 00191 skip = (int32_t) ((16 * n * c_dim) / r); 00192 SG_INFO("using %d examples. skip=%d bscale=%.6f\n", n, skip, bscale) 00193 00194 SG_FREE(c); 00195 } 00196 00197 void CSVMSGD::init() 00198 { 00199 t=1; 00200 C1=1; 00201 C2=1; 00202 wscale=1; 00203 bscale=1; 00204 epochs=5; 00205 skip=1000; 00206 count=1000; 00207 use_bias=true; 00208 00209 use_regularized_bias=false; 00210 00211 loss=new CHingeLoss(); 00212 SG_REF(loss); 00213 00214 m_parameters->add(&C1, "C1", "Cost constant 1."); 00215 m_parameters->add(&C2, "C2", "Cost constant 2."); 00216 m_parameters->add(&wscale, "wscale", "W scale"); 00217 m_parameters->add(&bscale, "bscale", "b scale"); 00218 m_parameters->add(&epochs, "epochs", "epochs"); 00219 m_parameters->add(&skip, "skip", "skip"); 00220 m_parameters->add(&count, "count", "count"); 00221 m_parameters->add(&use_bias, "use_bias", "Indicates if bias is used."); 00222 m_parameters->add(&use_regularized_bias, "use_regularized_bias", "Indicates if bias is regularized."); 00223 }