SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
PrimalMosekSOSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Fernando José Iglesias García
00008  * Copyright (C) 2012 Fernando José Iglesias García
00009  */
00010 
00011 #ifdef USE_MOSEK
00012 
00013 #include <shogun/lib/DynamicObjectArray.h>
00014 #include <shogun/lib/List.h>
00015 #include <shogun/mathematics/Math.h>
00016 #include <shogun/structure/PrimalMosekSOSVM.h>
00017 #include <shogun/loss/HingeLoss.h>
00018 
00019 using namespace shogun;
00020 
00021 CPrimalMosekSOSVM::CPrimalMosekSOSVM()
00022 : CLinearStructuredOutputMachine(),
00023     po_value(0.0)
00024 {
00025     init();
00026 }
00027 
00028 CPrimalMosekSOSVM::CPrimalMosekSOSVM(
00029         CStructuredModel*  model,
00030         CStructuredLabels* labs)
00031 : CLinearStructuredOutputMachine(model, labs),
00032     po_value(0.0)
00033 {
00034     init();
00035 }
00036 
00037 void CPrimalMosekSOSVM::init()
00038 {
00039     SG_ADD(&m_slacks, "slacks", "Slacks vector", MS_NOT_AVAILABLE);
00040     //FIXME model selection available for SO machines
00041     SG_ADD(&m_regularization, "regularization", "Regularization constant", MS_NOT_AVAILABLE);
00042     SG_ADD(&m_epsilon, "epsilon", "Violation tolerance", MS_NOT_AVAILABLE);
00043 
00044     m_regularization = 1.0;
00045     m_epsilon = 0.0;
00046 }
00047 
00048 CPrimalMosekSOSVM::~CPrimalMosekSOSVM()
00049 {
00050 }
00051 
00052 bool CPrimalMosekSOSVM::train_machine(CFeatures* data)
00053 {
00054     SG_DEBUG("Entering CPrimalMosekSOSVM::train_machine.\n");
00055     if (data)
00056         set_features(data);
00057 
00058     CFeatures* model_features = get_features();
00059     // Initialize the model for training
00060     m_model->init_training();
00061     // Check that the scenary is correct to start with training
00062     m_model->check_training_setup();
00063     SG_DEBUG("The training setup is correct.\n");
00064 
00065     // Dimensionality of the joint feature space
00066     int32_t M = m_model->get_dim();
00067     // Number of auxiliary variables in the optimization vector
00068     int32_t num_aux = m_model->get_num_aux();
00069     // Number of auxiliary constraints
00070     int32_t num_aux_con = m_model->get_num_aux_con();
00071     // Number of training examples
00072     int32_t N = model_features->get_num_vectors();
00073 
00074     SG_DEBUG("M=%d, N =%d, num_aux=%d, num_aux_con=%d.\n", M, N, num_aux, num_aux_con);
00075 
00076     // Interface with MOSEK
00077     CMosek* mosek = new CMosek(0, M+num_aux+N);
00078     SG_REF(mosek);
00079     REQUIRE(mosek->get_rescode() == MSK_RES_OK, "Mosek object could not be properly created in PrimalMosekSOSVM training.\n");
00080 
00081     // Initialize the terms of the optimization problem
00082     SGMatrix< float64_t > A, B, C;
00083     SGVector< float64_t > a, b, lb, ub;
00084     m_model->init_primal_opt(m_regularization, A, a, B, b, lb, ub, C);
00085 
00086     SG_DEBUG("Regularization used in PrimalMosekSOSVM equal to %.2f.\n", m_regularization);
00087 
00088     // Input terms of the problem that do not change between iterations
00089     REQUIRE(mosek->init_sosvm(M, N, num_aux, num_aux_con, C, lb, ub, A, b) == MSK_RES_OK,
00090         "Mosek error in PrimalMosekSOSVM initializing SO-SVM.\n")
00091 
00092     // Initialize the weight vector
00093     m_w = SGVector< float64_t >(M);
00094     m_w.zero();
00095 
00096     m_slacks = SGVector< float64_t >(N);
00097     m_slacks.zero();
00098 
00099     // Initialize the list of constraints
00100     // Each element in results is a list of CResultSet with the constraints
00101     // associated to each training example
00102     CDynamicObjectArray* results = new CDynamicObjectArray(N);
00103     SG_REF(results);
00104     for ( int32_t i = 0 ; i < N ; ++i )
00105     {
00106         CList* list = new CList(true);
00107         results->push_back(list);
00108     }
00109 
00110     // Initialize variables used in the loop
00111     int32_t     num_con     = num_aux_con;  // number of constraints
00112     int32_t     old_num_con = num_con;
00113     bool        exception   = false;
00114     index_t     iteration   = 0;
00115 
00116     SGVector< float64_t > sol(M+num_aux+N);
00117     sol.zero();
00118 
00119     SGVector< float64_t > aux(num_aux);
00120 
00121     do
00122     {
00123         SG_DEBUG("Iteration #%d: Cutting plane training with num_con=%d and old_num_con=%d.\n",
00124                 iteration, num_con, old_num_con);
00125 
00126         old_num_con = num_con;
00127 
00128         for ( int32_t i = 0 ; i < N ; ++i )
00129         {
00130             // Predict the result of the ith training example (loss-aug)
00131             CResultSet* result = m_model->argmax(m_w, i);
00132 
00133             // Compute the loss associated with the prediction (surrogate loss, max(0, \tilde{H}))
00134             float64_t slack = CHingeLoss().loss( compute_loss_arg(result) );
00135             CList* cur_list = (CList*) results->get_element(i);
00136 
00137             // Update the list of constraints
00138             if ( cur_list->get_num_elements() > 0 )
00139             {
00140                 // Find the maximum loss within the elements of
00141                 // the list of constraints
00142                 CResultSet* cur_res = (CResultSet*) cur_list->get_first_element();
00143                 float64_t max_slack = -CMath::INFTY;
00144 
00145                 while ( cur_res != NULL )
00146                 {
00147                     max_slack = CMath::max(max_slack, CHingeLoss().loss( compute_loss_arg(cur_res) ));
00148 
00149                     SG_UNREF(cur_res);
00150                     cur_res = (CResultSet*) cur_list->get_next_element();
00151                 }
00152 
00153                 if ( slack > max_slack + m_epsilon )
00154                 {
00155                     // The current training example is a
00156                     // violated constraint
00157                     if ( ! insert_result(cur_list, result) )
00158                     {
00159                         exception = true;
00160                         break;
00161                     }
00162 
00163                     add_constraint(mosek, result, num_con, i);
00164                     ++num_con;
00165                 }
00166             }
00167             else
00168             {
00169                 // First iteration of do ... while, add constraint
00170                 if ( ! insert_result(cur_list, result) )
00171                 {
00172                     exception = true;
00173                     break;
00174                 }
00175 
00176                 add_constraint(mosek, result, num_con, i);
00177                 ++num_con;
00178             }
00179 
00180             SG_UNREF(cur_list);
00181             SG_UNREF(result);
00182         }
00183 
00184         // Solve the QP
00185         SG_DEBUG("Entering Mosek QP solver.\n");
00186 
00187         mosek->optimize(sol);
00188         for ( int32_t i = 0 ; i < M+num_aux+N ; ++i )
00189         {
00190             if ( i < M )
00191                 m_w[i] = sol[i];
00192             else if ( i < M+num_aux )
00193                 aux[i-M] = sol[i];
00194             else
00195                 m_slacks[i-M-num_aux] = sol[i];
00196         }
00197 
00198         SG_DEBUG("QP solved. The primal objective value is %.4f.\n", mosek->get_primal_objective_value());
00199 
00200         ++iteration;
00201 
00202     } while ( old_num_con != num_con && ! exception );
00203 
00204     po_value = mosek->get_primal_objective_value();
00205 
00206     // Free resources
00207     SG_UNREF(results);
00208     SG_UNREF(mosek);
00209     SG_UNREF(model_features);
00210     return true;
00211 }
00212 
00213 float64_t CPrimalMosekSOSVM::compute_loss_arg(CResultSet* result) const
00214 {
00215     // Dimensionality of the joint feature space
00216     int32_t M = m_w.vlen;
00217 
00218     return  SGVector< float64_t >::dot(m_w.vector, result->psi_pred.vector, M) +
00219         result->delta -
00220         SGVector< float64_t >::dot(m_w.vector, result->psi_truth.vector, M);
00221 }
00222 
00223 bool CPrimalMosekSOSVM::insert_result(CList* result_list, CResultSet* result) const
00224 {
00225     bool succeed = result_list->insert_element(result);
00226 
00227     if ( ! succeed )
00228     {
00229         SG_PRINT("ResultSet could not be inserted in the list..."
00230              "aborting training of PrimalMosekSOSVM\n");
00231     }
00232 
00233     return succeed;
00234 }
00235 
00236 bool CPrimalMosekSOSVM::add_constraint(
00237         CMosek* mosek,
00238         CResultSet* result,
00239         index_t con_idx,
00240         index_t train_idx) const
00241 {
00242     int32_t M = m_model->get_dim();
00243     SGVector< float64_t > dPsi(M);
00244 
00245     for ( int i = 0 ; i < M ; ++i )
00246         dPsi[i] = result->psi_pred[i] - result->psi_truth[i]; // -dPsi(y)
00247 
00248     return ( mosek->add_constraint_sosvm(dPsi, con_idx, train_idx,
00249             m_model->get_num_aux(), -result->delta) == MSK_RES_OK );
00250 }
00251 
00252 
00253 float64_t CPrimalMosekSOSVM::compute_primal_objective() const
00254 {
00255     return po_value;
00256 }
00257 
00258 EMachineType CPrimalMosekSOSVM::get_classifier_type()
00259 {
00260     return CT_PRIMALMOSEKSOSVM;
00261 }
00262 
00263 void CPrimalMosekSOSVM::set_regularization(float64_t C)
00264 {
00265     m_regularization = C;
00266 }
00267 
00268 void CPrimalMosekSOSVM::set_epsilon(float64_t epsilon)
00269 {
00270     m_epsilon = epsilon;
00271 }
00272 
00273 #endif /* USE_MOSEK */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation