SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #include <shogun/structure/SOSVMHelper.h> 00012 #include <shogun/base/Parameter.h> 00013 #include <shogun/labels/StructuredLabels.h> 00014 00015 using namespace shogun; 00016 00017 CSOSVMHelper::CSOSVMHelper() : CSGObject() 00018 { 00019 init(); 00020 } 00021 00022 CSOSVMHelper::CSOSVMHelper(int32_t bufsize) : CSGObject() 00023 { 00024 m_bufsize = bufsize; 00025 init(); 00026 } 00027 00028 CSOSVMHelper::~CSOSVMHelper() 00029 { 00030 } 00031 00032 void CSOSVMHelper::init() 00033 { 00034 SG_ADD(&m_primal, "primal", "History of primal values", MS_NOT_AVAILABLE); 00035 SG_ADD(&m_dual, "dual", "History of dual values", MS_NOT_AVAILABLE); 00036 SG_ADD(&m_duality_gap, "duality_gap", "History of duality gaps", MS_NOT_AVAILABLE); 00037 SG_ADD(&m_eff_pass, "eff_pass", "Effective passes", MS_NOT_AVAILABLE); 00038 SG_ADD(&m_train_error, "train_error", "History of training errors", MS_NOT_AVAILABLE); 00039 SG_ADD(&m_tracker, "tracker", "Tracker of training progress", MS_NOT_AVAILABLE); 00040 SG_ADD(&m_bufsize, "bufsize", "Buffer size", MS_NOT_AVAILABLE); 00041 00042 m_tracker = 0; 00043 m_bufsize = 1000; 00044 m_primal = SGVector<float64_t>(m_bufsize); 00045 m_dual = SGVector<float64_t>(m_bufsize); 00046 m_duality_gap = SGVector<float64_t>(m_bufsize); 00047 m_eff_pass = SGVector<float64_t>(m_bufsize); 00048 m_train_error = SGVector<float64_t>(m_bufsize); 00049 m_primal.zero(); 00050 m_dual.zero(); 00051 m_duality_gap.zero(); 00052 m_eff_pass.zero(); 00053 m_train_error.zero(); 00054 } 00055 00056 float64_t CSOSVMHelper::primal_objective(SGVector<float64_t> w, CStructuredModel* model, float64_t lbda) 00057 { 00058 float64_t hinge_losses = 0.0; 00059 CStructuredLabels* labels = model->get_labels(); 00060 int32_t N = labels->get_num_labels(); 00061 SG_UNREF(labels); 00062 00063 for (int32_t i = 0; i < N; i++) 00064 { 00065 // solve the loss-augmented inference for point i 00066 CResultSet* result = model->argmax(w, i); 00067 00068 // hinge loss for point i 00069 float64_t hinge_loss_i = result->score; 00070 ASSERT(hinge_loss_i >= 0); 00071 00072 hinge_losses += hinge_loss_i; 00073 00074 SG_UNREF(result); 00075 } 00076 00077 return (lbda/2*SGVector<float64_t>::dot(w.vector, w.vector, w.vlen) + hinge_losses/N); 00078 } 00079 00080 float64_t CSOSVMHelper::dual_objective(SGVector<float64_t> w, float64_t b_alpha, float64_t lbda) 00081 { 00082 return (lbda/2*SGVector<float64_t>::dot(w.vector, w.vector, w.vlen) - b_alpha); 00083 } 00084 00085 float64_t CSOSVMHelper::average_loss(SGVector<float64_t> w, CStructuredModel* model) 00086 { 00087 float64_t loss = 0.0; 00088 CStructuredLabels* labels = model->get_labels(); 00089 int32_t N = labels->get_num_labels(); 00090 SG_UNREF(labels); 00091 00092 for (int32_t i = 0; i < N; i++) 00093 { 00094 // solve the standard inference for point i 00095 CResultSet* result = model->argmax(w, i, false); 00096 00097 loss += result->delta; 00098 00099 SG_UNREF(result); 00100 } 00101 00102 return loss / N; 00103 } 00104 00105 void CSOSVMHelper::add_debug_info(float64_t primal, float64_t eff_pass, float64_t train_error, 00106 float64_t dual, float64_t dgap) 00107 { 00108 if (m_tracker >= m_bufsize) 00109 { 00110 SG_PRINT("%s::add_debug_information(): Buffer overflows! No more values will be recorded!\n", 00111 get_name()); 00112 00113 return; 00114 } 00115 00116 m_primal[m_tracker] = primal; 00117 m_eff_pass[m_tracker] = eff_pass; 00118 m_train_error[m_tracker] = train_error; 00119 00120 if (dgap >= 0) 00121 { 00122 m_dual[m_tracker] = dual; 00123 m_duality_gap[m_tracker] = dgap; 00124 } 00125 00126 m_tracker++; 00127 } 00128 00129 SGVector<float64_t> CSOSVMHelper::get_primal_values() const 00130 { 00131 return m_primal; 00132 } 00133 00134 SGVector<float64_t> CSOSVMHelper::get_dual_values() const 00135 { 00136 return m_dual; 00137 } 00138 00139 SGVector<float64_t> CSOSVMHelper::get_duality_gaps() const 00140 { 00141 return m_duality_gap; 00142 } 00143 00144 SGVector<float64_t> CSOSVMHelper::get_eff_passes() const 00145 { 00146 return m_eff_pass; 00147 } 00148 00149 SGVector<float64_t> CSOSVMHelper::get_train_errors() const 00150 { 00151 return m_train_error; 00152 } 00153 00154 void CSOSVMHelper::terminate() 00155 { 00156 m_primal.resize_vector(m_tracker); 00157 m_dual.resize_vector(m_tracker); 00158 m_duality_gap.resize_vector(m_tracker); 00159 m_eff_pass.resize_vector(m_tracker); 00160 m_train_error.resize_vector(m_tracker); 00161 }