SHOGUN
v3.2.0
|
00001 /* 00002 SVM with stochastic gradient 00003 Copyright (C) 2007- Leon Bottou 00004 00005 This program is free software; you can redistribute it and/or 00006 modify it under the terms of the GNU Lesser General Public 00007 License as published by the Free Software Foundation; either 00008 version 2.1 of the License, or (at your option) any later version. 00009 00010 This program is distributed in the hope that it will be useful, 00011 but WITHOUT ANY WARRANTY; without even the implied warranty of 00012 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00013 GNU General Public License for more details. 00014 00015 You should have received a copy of the GNU Lesser General Public 00016 License along with this library; if not, write to the Free Software 00017 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 00018 $Id: svmsgd.cpp,v 1.13 2007/10/02 20:40:06 cvs Exp $ 00019 00020 Shogun adjustments (w) 2008-2009 Soeren Sonnenburg 00021 */ 00022 00023 #include <shogun/classifier/svm/OnlineSVMSGD.h> 00024 #include <shogun/base/Parameter.h> 00025 #include <shogun/lib/Signal.h> 00026 #include <shogun/loss/HingeLoss.h> 00027 00028 using namespace shogun; 00029 00030 COnlineSVMSGD::COnlineSVMSGD() 00031 : COnlineLinearMachine() 00032 { 00033 init(); 00034 } 00035 00036 COnlineSVMSGD::COnlineSVMSGD(float64_t C) 00037 : COnlineLinearMachine() 00038 { 00039 init(); 00040 00041 C1=C; 00042 C2=C; 00043 } 00044 00045 COnlineSVMSGD::COnlineSVMSGD(float64_t C, CStreamingDotFeatures* traindat) 00046 : COnlineLinearMachine() 00047 { 00048 init(); 00049 C1=C; 00050 C2=C; 00051 00052 set_features(traindat); 00053 } 00054 00055 COnlineSVMSGD::~COnlineSVMSGD() 00056 { 00057 SG_UNREF(loss); 00058 } 00059 00060 void COnlineSVMSGD::set_loss_function(CLossFunction* loss_func) 00061 { 00062 SG_REF(loss_func); 00063 SG_UNREF(loss); 00064 loss=loss_func; 00065 } 00066 00067 bool COnlineSVMSGD::train(CFeatures* data) 00068 { 00069 if (data) 00070 { 00071 if (!data->has_property(FP_STREAMING_DOT)) 00072 SG_ERROR("Specified features are not of type CStreamingDotFeatures\n") 00073 set_features((CStreamingDotFeatures*) data); 00074 } 00075 00076 features->start_parser(); 00077 00078 // allocate memory for w and initialize everyting w and bias with 0 00079 ASSERT(features) 00080 ASSERT(features->get_has_labels()) 00081 if (w) 00082 SG_FREE(w); 00083 w_dim=1; 00084 w=new float32_t; 00085 bias=0; 00086 00087 // Shift t in order to have a 00088 // reasonable initial learning rate. 00089 // This assumes |x| \approx 1. 00090 float64_t maxw = 1.0 / sqrt(lambda); 00091 float64_t typw = sqrt(maxw); 00092 float64_t eta0 = typw / CMath::max(1.0,-loss->first_derivative(-typw,1)); 00093 t = 1 / (eta0 * lambda); 00094 00095 SG_INFO("lambda=%f, epochs=%d, eta0=%f\n", lambda, epochs, eta0) 00096 00097 //do the sgd 00098 calibrate(); 00099 if (features->is_seekable()) 00100 features->reset_stream(); 00101 00102 CSignal::clear_cancel(); 00103 00104 ELossType loss_type = loss->get_loss_type(); 00105 bool is_log_loss = false; 00106 if ((loss_type == L_LOGLOSS) || (loss_type == L_LOGLOSSMARGIN)) 00107 is_log_loss = true; 00108 00109 int32_t vec_count; 00110 for(int32_t e=0; e<epochs && (!CSignal::cancel_computations()); e++) 00111 { 00112 vec_count=0; 00113 count = skip; 00114 while (features->get_next_example()) 00115 { 00116 vec_count++; 00117 // Expand w vector if more features are seen in this example 00118 features->expand_if_required(w, w_dim); 00119 00120 float64_t eta = 1.0 / (lambda * t); 00121 float64_t y = features->get_label(); 00122 float64_t z = y * (features->dense_dot(w, w_dim) + bias); 00123 00124 if (z < 1 || is_log_loss) 00125 { 00126 float64_t etd = -eta * loss->first_derivative(z,1); 00127 features->add_to_dense_vec(etd * y / wscale, w, w_dim); 00128 00129 if (use_bias) 00130 { 00131 if (use_regularized_bias) 00132 bias *= 1 - eta * lambda * bscale; 00133 bias += etd * y * bscale; 00134 } 00135 } 00136 00137 if (--count <= 0) 00138 { 00139 float32_t r = 1 - eta * lambda * skip; 00140 if (r < 0.8) 00141 r = pow(1 - eta * lambda, skip); 00142 SGVector<float32_t>::scale_vector(r, w, w_dim); 00143 count = skip; 00144 } 00145 t++; 00146 00147 features->release_example(); 00148 } 00149 00150 // If the stream is seekable, reset the stream to the first 00151 // example (for epochs > 1) 00152 if (features->is_seekable() && e < epochs-1) 00153 features->reset_stream(); 00154 else 00155 break; 00156 00157 } 00158 00159 features->end_parser(); 00160 float64_t wnorm = SGVector<float32_t>::dot(w,w, w_dim); 00161 SG_INFO("Norm: %.6f, Bias: %.6f\n", wnorm, bias) 00162 00163 return true; 00164 } 00165 00166 void COnlineSVMSGD::calibrate(int32_t max_vec_num) 00167 { 00168 int32_t c_dim=1; 00169 float32_t* c=new float32_t; 00170 00171 // compute average gradient size 00172 int32_t n = 0; 00173 float64_t m = 0; 00174 float64_t r = 0; 00175 00176 while (features->get_next_example()) 00177 { 00178 //Expand c if more features are seen in this example 00179 features->expand_if_required(c, c_dim); 00180 00181 r += features->get_nnz_features_for_vector(); 00182 features->add_to_dense_vec(1, c, c_dim, true); 00183 00184 //waste cpu cycles for readability 00185 //(only changed dims need checking) 00186 m=SGVector<float32_t>::max(c, c_dim); 00187 n++; 00188 00189 features->release_example(); 00190 if (n>=max_vec_num || m > 1000) 00191 break; 00192 } 00193 00194 SG_PRINT("Online SGD calibrated using %d vectors.\n", n) 00195 00196 // bias update scaling 00197 bscale = 0.5*m/n; 00198 00199 // compute weight decay skip 00200 skip = (int32_t) ((16 * n * c_dim) / r); 00201 00202 SG_INFO("using %d examples. skip=%d bscale=%.6f\n", n, skip, bscale) 00203 00204 SG_FREE(c); 00205 } 00206 00207 void COnlineSVMSGD::init() 00208 { 00209 t=1; 00210 C1=1; 00211 C2=1; 00212 lambda=1e-4; 00213 wscale=1; 00214 bscale=1; 00215 epochs=1; 00216 skip=1000; 00217 count=1000; 00218 use_bias=true; 00219 00220 use_regularized_bias=false; 00221 00222 loss=new CHingeLoss(); 00223 SG_REF(loss); 00224 00225 m_parameters->add(&C1, "C1", "Cost constant 1."); 00226 m_parameters->add(&C2, "C2", "Cost constant 2."); 00227 m_parameters->add(&lambda, "lambda", "Regularization parameter."); 00228 m_parameters->add(&wscale, "wscale", "W scale"); 00229 m_parameters->add(&bscale, "bscale", "b scale"); 00230 m_parameters->add(&epochs, "epochs", "epochs"); 00231 m_parameters->add(&skip, "skip", "skip"); 00232 m_parameters->add(&count, "count", "count"); 00233 m_parameters->add(&use_bias, "use_bias", "Indicates if bias is used."); 00234 m_parameters->add(&use_regularized_bias, "use_regularized_bias", "Indicates if bias is regularized."); 00235 }