SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Written (W) 2013 Thoralf Klein 00009 * Written (W) 2012 Fernando José Iglesias García 00010 * Copyright (C) 2012 Fernando José Iglesias García 00011 */ 00012 00013 #include <shogun/machine/StructuredOutputMachine.h> 00014 00015 using namespace shogun; 00016 00017 CStructuredOutputMachine::CStructuredOutputMachine() 00018 : CMachine(), m_model(NULL), m_surrogate_loss(NULL) 00019 { 00020 register_parameters(); 00021 } 00022 00023 CStructuredOutputMachine::CStructuredOutputMachine( 00024 CStructuredModel* model, 00025 CStructuredLabels* labs) 00026 : CMachine(), m_model(model), m_surrogate_loss(NULL) 00027 { 00028 SG_REF(m_model); 00029 set_labels(labs); 00030 register_parameters(); 00031 } 00032 00033 CStructuredOutputMachine::~CStructuredOutputMachine() 00034 { 00035 SG_UNREF(m_model); 00036 SG_UNREF(m_surrogate_loss); 00037 SG_UNREF(m_helper); 00038 } 00039 00040 void CStructuredOutputMachine::set_model(CStructuredModel* model) 00041 { 00042 SG_REF(model); 00043 SG_UNREF(m_model); 00044 m_model = model; 00045 } 00046 00047 CStructuredModel* CStructuredOutputMachine::get_model() const 00048 { 00049 SG_REF(m_model); 00050 return m_model; 00051 } 00052 00053 void CStructuredOutputMachine::register_parameters() 00054 { 00055 SG_ADD((CSGObject**)&m_model, "m_model", "Structured model", MS_NOT_AVAILABLE); 00056 SG_ADD((CSGObject**)&m_surrogate_loss, "m_surrogate_loss", "Surrogate loss", MS_NOT_AVAILABLE); 00057 SG_ADD(&m_verbose, "verbose", "Verbosity flag", MS_NOT_AVAILABLE); 00058 SG_ADD((CSGObject**)&m_helper, "helper", "Training helper", MS_NOT_AVAILABLE); 00059 00060 m_verbose = false; 00061 m_helper = NULL; 00062 } 00063 00064 void CStructuredOutputMachine::set_labels(CLabels* lab) 00065 { 00066 CMachine::set_labels(lab); 00067 REQUIRE(m_model != NULL, "please call set_model() before set_labels()\n"); 00068 m_model->set_labels(CLabelsFactory::to_structured(lab)); 00069 } 00070 00071 void CStructuredOutputMachine::set_features(CFeatures* f) 00072 { 00073 m_model->set_features(f); 00074 } 00075 00076 CFeatures* CStructuredOutputMachine::get_features() const 00077 { 00078 return m_model->get_features(); 00079 } 00080 00081 void CStructuredOutputMachine::set_surrogate_loss(CLossFunction* loss) 00082 { 00083 SG_REF(loss); 00084 SG_UNREF(m_surrogate_loss); 00085 m_surrogate_loss = loss; 00086 } 00087 00088 CLossFunction* CStructuredOutputMachine::get_surrogate_loss() const 00089 { 00090 SG_REF(m_surrogate_loss); 00091 return m_surrogate_loss; 00092 } 00093 00094 float64_t CStructuredOutputMachine::risk_nslack_margin_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info) 00095 { 00096 int32_t dim = m_model->get_dim(); 00097 00098 int32_t from=0, to=0; 00099 CFeatures* features = get_features(); 00100 if (info) 00101 { 00102 from = info->m_from; 00103 to = (info->m_N == 0) ? features->get_num_vectors() : from+info->m_N; 00104 } 00105 else 00106 { 00107 from = 0; 00108 to = features->get_num_vectors(); 00109 } 00110 SG_UNREF(features); 00111 00112 float64_t R = 0.0; 00113 for (int32_t i=0; i<dim; i++) 00114 subgrad[i] = 0; 00115 00116 for (int32_t i=from; i<to; i++) 00117 { 00118 CResultSet* result = m_model->argmax(SGVector<float64_t>(W,dim,false), i, true); 00119 SGVector<float64_t> psi_pred = result->psi_pred; 00120 SGVector<float64_t> psi_truth = result->psi_truth; 00121 SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, 1.0, psi_pred.vector, dim); 00122 SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, -1.0, psi_truth.vector, dim); 00123 R += result->score; 00124 SG_UNREF(result); 00125 } 00126 00127 return R; 00128 } 00129 00130 float64_t CStructuredOutputMachine::risk_nslack_slack_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info) 00131 { 00132 SG_ERROR("%s::risk_nslack_slack_rescale() has not been implemented!\n", get_name()); 00133 return 0.0; 00134 } 00135 00136 float64_t CStructuredOutputMachine::risk_1slack_margin_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info) 00137 { 00138 SG_ERROR("%s::risk_1slack_margin_rescale() has not been implemented!\n", get_name()); 00139 return 0.0; 00140 } 00141 00142 float64_t CStructuredOutputMachine::risk_1slack_slack_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info) 00143 { 00144 SG_ERROR("%s::risk_1slack_slack_rescale() has not been implemented!\n", get_name()); 00145 return 0.0; 00146 } 00147 00148 float64_t CStructuredOutputMachine::risk_customized_formulation(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info) 00149 { 00150 SG_ERROR("%s::risk_customized_formulation() has not been implemented!\n", get_name()); 00151 return 0.0; 00152 } 00153 00154 float64_t CStructuredOutputMachine::risk(float64_t* subgrad, float64_t* W, 00155 TMultipleCPinfo* info, EStructRiskType rtype) 00156 { 00157 float64_t ret = 0.0; 00158 switch(rtype) 00159 { 00160 case N_SLACK_MARGIN_RESCALING: 00161 ret = risk_nslack_margin_rescale(subgrad, W, info); 00162 break; 00163 case N_SLACK_SLACK_RESCALING: 00164 ret = risk_nslack_slack_rescale(subgrad, W, info); 00165 break; 00166 case ONE_SLACK_MARGIN_RESCALING: 00167 ret = risk_1slack_margin_rescale(subgrad, W, info); 00168 break; 00169 case ONE_SLACK_SLACK_RESCALING: 00170 ret = risk_1slack_slack_rescale(subgrad, W, info); 00171 break; 00172 case CUSTOMIZED_RISK: 00173 ret = risk_customized_formulation(subgrad, W, info); 00174 break; 00175 default: 00176 SG_ERROR("%s::risk(): cannot recognize the risk type!\n", get_name()); 00177 ret = -1; 00178 break; 00179 } 00180 return ret; 00181 } 00182 00183 CSOSVMHelper* CStructuredOutputMachine::get_helper() const 00184 { 00185 if (m_helper == NULL) 00186 { 00187 SG_ERROR("%s::get_helper(): no helper has been created!" 00188 "Please set verbose before training!\n", get_name()); 00189 } 00190 00191 SG_REF(m_helper); 00192 return m_helper; 00193 } 00194 00195 void CStructuredOutputMachine::set_verbose(bool verbose) 00196 { 00197 m_verbose = verbose; 00198 } 00199 00200 bool CStructuredOutputMachine::get_verbose() const 00201 { 00202 return m_verbose; 00203 }