SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #include <shogun/mathematics/Math.h> 00012 #include <shogun/structure/StochasticSOSVM.h> 00013 #include <shogun/labels/LabelsFactory.h> 00014 00015 using namespace shogun; 00016 00017 CStochasticSOSVM::CStochasticSOSVM() 00018 : CLinearStructuredOutputMachine() 00019 { 00020 init(); 00021 } 00022 00023 CStochasticSOSVM::CStochasticSOSVM( 00024 CStructuredModel* model, 00025 CStructuredLabels* labs, 00026 bool do_weighted_averaging, 00027 bool verbose) 00028 : CLinearStructuredOutputMachine(model, labs) 00029 { 00030 REQUIRE(model != NULL && labs != NULL, 00031 "%s::CStochasticSOSVM(): model and labels cannot be NULL!\n", get_name()); 00032 00033 REQUIRE(labs->get_num_labels() > 0, 00034 "%s::CStochasticSOSVM(): number of labels should be greater than 0!\n", get_name()); 00035 00036 init(); 00037 m_lambda = 1.0 / labs->get_num_labels(); 00038 m_do_weighted_averaging = do_weighted_averaging; 00039 m_verbose = verbose; 00040 } 00041 00042 void CStochasticSOSVM::init() 00043 { 00044 SG_ADD(&m_lambda, "lambda", "Regularization constant", MS_NOT_AVAILABLE); 00045 SG_ADD(&m_num_iter, "num_iter", "Number of iterations", MS_NOT_AVAILABLE); 00046 SG_ADD(&m_do_weighted_averaging, "do_weighted_averaging", "Do weighted averaging", MS_NOT_AVAILABLE); 00047 SG_ADD(&m_debug_multiplier, "debug_multiplier", "Debug multiplier", MS_NOT_AVAILABLE); 00048 SG_ADD(&m_rand_seed, "rand_seed", "Random seed", MS_NOT_AVAILABLE); 00049 00050 m_lambda = 1.0; 00051 m_num_iter = 50; 00052 m_do_weighted_averaging = true; 00053 m_debug_multiplier = 0; 00054 m_rand_seed = 1; 00055 } 00056 00057 CStochasticSOSVM::~CStochasticSOSVM() 00058 { 00059 } 00060 00061 EMachineType CStochasticSOSVM::get_classifier_type() 00062 { 00063 return CT_STOCHASTICSOSVM; 00064 } 00065 00066 bool CStochasticSOSVM::train_machine(CFeatures* data) 00067 { 00068 SG_DEBUG("Entering CStochasticSOSVM::train_machine.\n"); 00069 if (data) 00070 set_features(data); 00071 00072 // Initialize the model for training 00073 m_model->init_training(); 00074 // Check that the scenary is correct to start with training 00075 m_model->check_training_setup(); 00076 SG_DEBUG("The training setup is correct.\n"); 00077 00078 // Dimensionality of the joint feature space 00079 int32_t M = m_model->get_dim(); 00080 // Number of training examples 00081 int32_t N = CLabelsFactory::to_structured(m_labels)->get_num_labels(); 00082 00083 SG_DEBUG("M=%d, N =%d.\n", M, N); 00084 00085 // Initialize the weight vector 00086 m_w = SGVector<float64_t>(M); 00087 m_w.zero(); 00088 00089 SGVector<float64_t> w_avg; 00090 if (m_do_weighted_averaging) 00091 w_avg = m_w.clone(); 00092 00093 // logging 00094 if (m_verbose) 00095 { 00096 if (m_helper != NULL) 00097 SG_UNREF(m_helper); 00098 00099 m_helper = new CSOSVMHelper(); 00100 SG_REF(m_helper); 00101 } 00102 00103 int32_t debug_iter = 1; 00104 if (m_debug_multiplier == 0) 00105 { 00106 debug_iter = N; 00107 m_debug_multiplier = 100; 00108 } 00109 00110 CMath::init_random(m_rand_seed); 00111 00112 // Main loop 00113 int32_t k = 0; 00114 for (int32_t pi = 0; pi < m_num_iter; ++pi) 00115 { 00116 for (int32_t si = 0; si < N; ++si) 00117 { 00118 // 1) Picking random example 00119 int32_t i = CMath::random(0, N-1); 00120 00121 // 2) solve the loss-augmented inference for point i 00122 CResultSet* result = m_model->argmax(m_w, i); 00123 00124 // 3) get the subgradient 00125 // psi_i(y) := phi(x_i,y_i) - phi(x_i, y) 00126 SGVector<float64_t> psi_i(M); 00127 SGVector<float64_t> w_s(M); 00128 00129 SGVector<float64_t>::add(psi_i.vector, 00130 1.0, result->psi_truth.vector, -1.0, result->psi_pred.vector, psi_i.vlen); 00131 00132 w_s = psi_i.clone(); 00133 w_s.scale(1.0 / (N*m_lambda)); 00134 00135 // 4) step-size gamma 00136 float64_t gamma = 1.0 / (k+1.0); 00137 00138 // 5) finally update the weights 00139 SGVector<float64_t>::add(m_w.vector, 00140 1.0-gamma, m_w.vector, gamma*N, w_s.vector, m_w.vlen); 00141 00142 // 6) Optionally, update the weighted average 00143 if (m_do_weighted_averaging) 00144 { 00145 float64_t rho = 2.0 / (k+2.0); 00146 SGVector<float64_t>::add(w_avg.vector, 00147 1.0-rho, w_avg.vector, rho, m_w.vector, w_avg.vlen); 00148 } 00149 00150 k += 1; 00151 SG_UNREF(result); 00152 00153 // Debug: compute objective and training error 00154 if (m_verbose && k == debug_iter) 00155 { 00156 SGVector<float64_t> w_debug; 00157 if (m_do_weighted_averaging) 00158 w_debug = w_avg.clone(); 00159 else 00160 w_debug = m_w.clone(); 00161 00162 float64_t primal = CSOSVMHelper::primal_objective(w_debug, m_model, m_lambda); 00163 float64_t train_error = CSOSVMHelper::average_loss(w_debug, m_model); 00164 00165 SG_DEBUG("pass %d (iteration %d), SVM primal = %f, train_error = %f \n", 00166 pi, k, primal, train_error); 00167 00168 m_helper->add_debug_info(primal, (1.0*k) / N, train_error); 00169 00170 debug_iter = CMath::min(debug_iter+N, debug_iter*(1+m_debug_multiplier/100)); 00171 } 00172 } 00173 } 00174 00175 if (m_do_weighted_averaging) 00176 m_w = w_avg.clone(); 00177 00178 if (m_verbose) 00179 m_helper->terminate(); 00180 00181 SG_DEBUG("Leaving CStochasticSOSVM::train_machine.\n"); 00182 return true; 00183 } 00184 00185 float64_t CStochasticSOSVM::get_lambda() const 00186 { 00187 return m_lambda; 00188 } 00189 00190 void CStochasticSOSVM::set_lambda(float64_t lbda) 00191 { 00192 m_lambda = lbda; 00193 } 00194 00195 int32_t CStochasticSOSVM::get_num_iter() const 00196 { 00197 return m_num_iter; 00198 } 00199 00200 void CStochasticSOSVM::set_num_iter(int32_t num_iter) 00201 { 00202 m_num_iter = num_iter; 00203 } 00204 00205 int32_t CStochasticSOSVM::get_debug_multiplier() const 00206 { 00207 return m_debug_multiplier; 00208 } 00209 00210 void CStochasticSOSVM::set_debug_multiplier(int32_t multiplier) 00211 { 00212 m_debug_multiplier = multiplier; 00213 } 00214 00215 uint32_t CStochasticSOSVM::get_rand_seed() const 00216 { 00217 return m_rand_seed; 00218 } 00219 00220 void CStochasticSOSVM::set_rand_seed(uint32_t rand_seed) 00221 { 00222 m_rand_seed = rand_seed; 00223 } 00224