SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
StochasticSOSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Shell Hu
00008  * Copyright (C) 2013 Shell Hu
00009  */
00010 
00011 #include <shogun/mathematics/Math.h>
00012 #include <shogun/structure/StochasticSOSVM.h>
00013 #include <shogun/labels/LabelsFactory.h>
00014 
00015 using namespace shogun;
00016 
00017 CStochasticSOSVM::CStochasticSOSVM()
00018 : CLinearStructuredOutputMachine()
00019 {
00020     init();
00021 }
00022 
00023 CStochasticSOSVM::CStochasticSOSVM(
00024         CStructuredModel*  model,
00025         CStructuredLabels* labs,
00026         bool do_weighted_averaging,
00027         bool verbose)
00028 : CLinearStructuredOutputMachine(model, labs)
00029 {
00030     REQUIRE(model != NULL && labs != NULL,
00031         "%s::CStochasticSOSVM(): model and labels cannot be NULL!\n", get_name());
00032 
00033     REQUIRE(labs->get_num_labels() > 0,
00034         "%s::CStochasticSOSVM(): number of labels should be greater than 0!\n", get_name());
00035 
00036     init();
00037     m_lambda = 1.0 / labs->get_num_labels();
00038     m_do_weighted_averaging = do_weighted_averaging;
00039     m_verbose = verbose;
00040 }
00041 
00042 void CStochasticSOSVM::init()
00043 {
00044     SG_ADD(&m_lambda, "lambda", "Regularization constant", MS_NOT_AVAILABLE);
00045     SG_ADD(&m_num_iter, "num_iter", "Number of iterations", MS_NOT_AVAILABLE);
00046     SG_ADD(&m_do_weighted_averaging, "do_weighted_averaging", "Do weighted averaging", MS_NOT_AVAILABLE);
00047     SG_ADD(&m_debug_multiplier, "debug_multiplier", "Debug multiplier", MS_NOT_AVAILABLE);
00048     SG_ADD(&m_rand_seed, "rand_seed", "Random seed", MS_NOT_AVAILABLE);
00049 
00050     m_lambda = 1.0;
00051     m_num_iter = 50;
00052     m_do_weighted_averaging = true;
00053     m_debug_multiplier = 0;
00054     m_rand_seed = 1;
00055 }
00056 
00057 CStochasticSOSVM::~CStochasticSOSVM()
00058 {
00059 }
00060 
00061 EMachineType CStochasticSOSVM::get_classifier_type()
00062 {
00063     return CT_STOCHASTICSOSVM;
00064 }
00065 
00066 bool CStochasticSOSVM::train_machine(CFeatures* data)
00067 {
00068     SG_DEBUG("Entering CStochasticSOSVM::train_machine.\n");
00069     if (data)
00070         set_features(data);
00071 
00072     // Initialize the model for training
00073     m_model->init_training();
00074     // Check that the scenary is correct to start with training
00075     m_model->check_training_setup();
00076     SG_DEBUG("The training setup is correct.\n");
00077 
00078     // Dimensionality of the joint feature space
00079     int32_t M = m_model->get_dim();
00080     // Number of training examples
00081     int32_t N = CLabelsFactory::to_structured(m_labels)->get_num_labels();
00082 
00083     SG_DEBUG("M=%d, N =%d.\n", M, N);
00084 
00085     // Initialize the weight vector
00086     m_w = SGVector<float64_t>(M);
00087     m_w.zero();
00088 
00089     SGVector<float64_t> w_avg;
00090     if (m_do_weighted_averaging)
00091         w_avg = m_w.clone();
00092 
00093     // logging
00094     if (m_verbose)
00095     {
00096         if (m_helper != NULL)
00097             SG_UNREF(m_helper);
00098 
00099         m_helper = new CSOSVMHelper();
00100         SG_REF(m_helper);
00101     }
00102 
00103     int32_t debug_iter = 1;
00104     if (m_debug_multiplier == 0)
00105     {
00106         debug_iter = N;
00107         m_debug_multiplier = 100;
00108     }
00109 
00110     CMath::init_random(m_rand_seed);
00111 
00112     // Main loop
00113     int32_t k = 0;
00114     for (int32_t pi = 0; pi < m_num_iter; ++pi)
00115     {
00116         for (int32_t si = 0; si < N; ++si)
00117         {
00118             // 1) Picking random example
00119             int32_t i = CMath::random(0, N-1);
00120 
00121             // 2) solve the loss-augmented inference for point i
00122             CResultSet* result = m_model->argmax(m_w, i);
00123 
00124             // 3) get the subgradient
00125             // psi_i(y) := phi(x_i,y_i) - phi(x_i, y)
00126             SGVector<float64_t> psi_i(M);
00127             SGVector<float64_t> w_s(M);
00128 
00129             SGVector<float64_t>::add(psi_i.vector,
00130                 1.0, result->psi_truth.vector, -1.0, result->psi_pred.vector, psi_i.vlen);
00131 
00132             w_s = psi_i.clone();
00133             w_s.scale(1.0 / (N*m_lambda));
00134 
00135             // 4) step-size gamma
00136             float64_t gamma = 1.0 / (k+1.0);
00137 
00138             // 5) finally update the weights
00139             SGVector<float64_t>::add(m_w.vector,
00140                 1.0-gamma, m_w.vector, gamma*N, w_s.vector, m_w.vlen);
00141 
00142             // 6) Optionally, update the weighted average
00143             if (m_do_weighted_averaging)
00144             {
00145                 float64_t rho = 2.0 / (k+2.0);
00146                 SGVector<float64_t>::add(w_avg.vector,
00147                     1.0-rho, w_avg.vector, rho, m_w.vector, w_avg.vlen);
00148             }
00149 
00150             k += 1;
00151             SG_UNREF(result);
00152 
00153             // Debug: compute objective and training error
00154             if (m_verbose && k == debug_iter)
00155             {
00156                 SGVector<float64_t> w_debug;
00157                 if (m_do_weighted_averaging)
00158                     w_debug = w_avg.clone();
00159                 else
00160                     w_debug = m_w.clone();
00161 
00162                 float64_t primal = CSOSVMHelper::primal_objective(w_debug, m_model, m_lambda);
00163                 float64_t train_error = CSOSVMHelper::average_loss(w_debug, m_model);
00164 
00165                 SG_DEBUG("pass %d (iteration %d), SVM primal = %f, train_error = %f \n",
00166                     pi, k, primal, train_error);
00167 
00168                 m_helper->add_debug_info(primal, (1.0*k) / N, train_error);
00169 
00170                 debug_iter = CMath::min(debug_iter+N, debug_iter*(1+m_debug_multiplier/100));
00171             }
00172         }
00173     }
00174 
00175     if (m_do_weighted_averaging)
00176         m_w = w_avg.clone();
00177 
00178     if (m_verbose)
00179         m_helper->terminate();
00180 
00181     SG_DEBUG("Leaving CStochasticSOSVM::train_machine.\n");
00182     return true;
00183 }
00184 
00185 float64_t CStochasticSOSVM::get_lambda() const
00186 {
00187     return m_lambda;
00188 }
00189 
00190 void CStochasticSOSVM::set_lambda(float64_t lbda)
00191 {
00192     m_lambda = lbda;
00193 }
00194 
00195 int32_t CStochasticSOSVM::get_num_iter() const
00196 {
00197     return m_num_iter;
00198 }
00199 
00200 void CStochasticSOSVM::set_num_iter(int32_t num_iter)
00201 {
00202     m_num_iter = num_iter;
00203 }
00204 
00205 int32_t CStochasticSOSVM::get_debug_multiplier() const
00206 {
00207     return m_debug_multiplier;
00208 }
00209 
00210 void CStochasticSOSVM::set_debug_multiplier(int32_t multiplier)
00211 {
00212     m_debug_multiplier = multiplier;
00213 }
00214 
00215 uint32_t CStochasticSOSVM::get_rand_seed() const
00216 {
00217     return m_rand_seed;
00218 }
00219 
00220 void CStochasticSOSVM::set_rand_seed(uint32_t rand_seed)
00221 {
00222     m_rand_seed = rand_seed;
00223 }
00224 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation