SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SOSVMHelper.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Shell Hu
00008  * Copyright (C) 2013 Shell Hu
00009  */
00010 
00011 #include <shogun/structure/SOSVMHelper.h>
00012 #include <shogun/base/Parameter.h>
00013 #include <shogun/labels/StructuredLabels.h>
00014 
00015 using namespace shogun;
00016 
00017 CSOSVMHelper::CSOSVMHelper() : CSGObject()
00018 {
00019     init();
00020 }
00021 
00022 CSOSVMHelper::CSOSVMHelper(int32_t bufsize) : CSGObject()
00023 {
00024     m_bufsize = bufsize;
00025     init();
00026 }
00027 
00028 CSOSVMHelper::~CSOSVMHelper()
00029 {
00030 }
00031 
00032 void CSOSVMHelper::init()
00033 {
00034     SG_ADD(&m_primal, "primal", "History of primal values", MS_NOT_AVAILABLE);
00035     SG_ADD(&m_dual, "dual", "History of dual values", MS_NOT_AVAILABLE);
00036     SG_ADD(&m_duality_gap, "duality_gap", "History of duality gaps", MS_NOT_AVAILABLE);
00037     SG_ADD(&m_eff_pass, "eff_pass", "Effective passes", MS_NOT_AVAILABLE);
00038     SG_ADD(&m_train_error, "train_error", "History of training errors", MS_NOT_AVAILABLE);
00039     SG_ADD(&m_tracker, "tracker", "Tracker of training progress", MS_NOT_AVAILABLE);
00040     SG_ADD(&m_bufsize, "bufsize", "Buffer size", MS_NOT_AVAILABLE);
00041 
00042     m_tracker = 0;
00043     m_bufsize = 1000;
00044     m_primal = SGVector<float64_t>(m_bufsize);
00045     m_dual = SGVector<float64_t>(m_bufsize);
00046     m_duality_gap = SGVector<float64_t>(m_bufsize);
00047     m_eff_pass = SGVector<float64_t>(m_bufsize);
00048     m_train_error = SGVector<float64_t>(m_bufsize);
00049     m_primal.zero();
00050     m_dual.zero();
00051     m_duality_gap.zero();
00052     m_eff_pass.zero();
00053     m_train_error.zero();
00054 }
00055 
00056 float64_t CSOSVMHelper::primal_objective(SGVector<float64_t> w, CStructuredModel* model, float64_t lbda)
00057 {
00058     float64_t hinge_losses = 0.0;
00059     CStructuredLabels* labels = model->get_labels();
00060     int32_t N = labels->get_num_labels();
00061     SG_UNREF(labels);
00062 
00063     for (int32_t i = 0; i < N; i++)
00064     {
00065         // solve the loss-augmented inference for point i
00066         CResultSet* result = model->argmax(w, i);
00067 
00068         // hinge loss for point i
00069         float64_t hinge_loss_i = result->score;
00070         ASSERT(hinge_loss_i >= 0);
00071 
00072         hinge_losses += hinge_loss_i;
00073 
00074         SG_UNREF(result);
00075     }
00076 
00077     return (lbda/2*SGVector<float64_t>::dot(w.vector, w.vector, w.vlen) + hinge_losses/N);
00078 }
00079 
00080 float64_t CSOSVMHelper::dual_objective(SGVector<float64_t> w, float64_t b_alpha, float64_t lbda)
00081 {
00082     return (lbda/2*SGVector<float64_t>::dot(w.vector, w.vector, w.vlen) - b_alpha);
00083 }
00084 
00085 float64_t CSOSVMHelper::average_loss(SGVector<float64_t> w, CStructuredModel* model)
00086 {
00087     float64_t loss = 0.0;
00088     CStructuredLabels* labels = model->get_labels();
00089     int32_t N = labels->get_num_labels();
00090     SG_UNREF(labels);
00091 
00092     for (int32_t i = 0; i < N; i++)
00093     {
00094         // solve the standard inference for point i
00095         CResultSet* result = model->argmax(w, i, false);
00096 
00097         loss += result->delta;
00098 
00099         SG_UNREF(result);
00100     }
00101 
00102     return loss / N;
00103 }
00104 
00105 void CSOSVMHelper::add_debug_info(float64_t primal, float64_t eff_pass, float64_t train_error,
00106         float64_t dual, float64_t dgap)
00107 {
00108     if (m_tracker >= m_bufsize)
00109     {
00110         SG_PRINT("%s::add_debug_information(): Buffer overflows! No more values will be recorded!\n",
00111             get_name());
00112 
00113         return;
00114     }
00115 
00116     m_primal[m_tracker] = primal;
00117     m_eff_pass[m_tracker] = eff_pass;
00118     m_train_error[m_tracker] = train_error;
00119 
00120     if (dgap >= 0)
00121     {
00122         m_dual[m_tracker] = dual;
00123         m_duality_gap[m_tracker] = dgap;
00124     }
00125 
00126     m_tracker++;
00127 }
00128 
00129 SGVector<float64_t> CSOSVMHelper::get_primal_values() const
00130 {
00131     return m_primal;
00132 }
00133 
00134 SGVector<float64_t> CSOSVMHelper::get_dual_values() const
00135 {
00136     return m_dual;
00137 }
00138 
00139 SGVector<float64_t> CSOSVMHelper::get_duality_gaps() const
00140 {
00141     return m_duality_gap;
00142 }
00143 
00144 SGVector<float64_t> CSOSVMHelper::get_eff_passes() const
00145 {
00146     return m_eff_pass;
00147 }
00148 
00149 SGVector<float64_t> CSOSVMHelper::get_train_errors() const
00150 {
00151     return m_train_error;
00152 }
00153 
00154 void CSOSVMHelper::terminate()
00155 {
00156     m_primal.resize_vector(m_tracker);
00157     m_dual.resize_vector(m_tracker);
00158     m_duality_gap.resize_vector(m_tracker);
00159     m_eff_pass.resize_vector(m_tracker);
00160     m_train_error.resize_vector(m_tracker);
00161 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation