SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
StructuredOutputMachine.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Shell Hu
00008  * Written (W) 2013 Thoralf Klein
00009  * Written (W) 2012 Fernando José Iglesias García
00010  * Copyright (C) 2012 Fernando José Iglesias García
00011  */
00012 
00013 #include <shogun/machine/StructuredOutputMachine.h>
00014 
00015 using namespace shogun;
00016 
00017 CStructuredOutputMachine::CStructuredOutputMachine()
00018 : CMachine(), m_model(NULL), m_surrogate_loss(NULL)
00019 {
00020     register_parameters();
00021 }
00022 
00023 CStructuredOutputMachine::CStructuredOutputMachine(
00024         CStructuredModel*  model,
00025         CStructuredLabels* labs)
00026 : CMachine(), m_model(model), m_surrogate_loss(NULL)
00027 {
00028     SG_REF(m_model);
00029     set_labels(labs);
00030     register_parameters();
00031 }
00032 
00033 CStructuredOutputMachine::~CStructuredOutputMachine()
00034 {
00035     SG_UNREF(m_model);
00036     SG_UNREF(m_surrogate_loss);
00037     SG_UNREF(m_helper);
00038 }
00039 
00040 void CStructuredOutputMachine::set_model(CStructuredModel* model)
00041 {
00042     SG_REF(model);
00043     SG_UNREF(m_model);
00044     m_model = model;
00045 }
00046 
00047 CStructuredModel* CStructuredOutputMachine::get_model() const
00048 {
00049     SG_REF(m_model);
00050     return m_model;
00051 }
00052 
00053 void CStructuredOutputMachine::register_parameters()
00054 {
00055     SG_ADD((CSGObject**)&m_model, "m_model", "Structured model", MS_NOT_AVAILABLE);
00056     SG_ADD((CSGObject**)&m_surrogate_loss, "m_surrogate_loss", "Surrogate loss", MS_NOT_AVAILABLE);
00057     SG_ADD(&m_verbose, "verbose", "Verbosity flag", MS_NOT_AVAILABLE);
00058     SG_ADD((CSGObject**)&m_helper, "helper", "Training helper", MS_NOT_AVAILABLE);
00059 
00060     m_verbose = false;
00061     m_helper = NULL;
00062 }
00063 
00064 void CStructuredOutputMachine::set_labels(CLabels* lab)
00065 {
00066     CMachine::set_labels(lab);
00067     REQUIRE(m_model != NULL, "please call set_model() before set_labels()\n");
00068     m_model->set_labels(CLabelsFactory::to_structured(lab));
00069 }
00070 
00071 void CStructuredOutputMachine::set_features(CFeatures* f)
00072 {
00073     m_model->set_features(f);
00074 }
00075 
00076 CFeatures* CStructuredOutputMachine::get_features() const
00077 {
00078     return m_model->get_features();
00079 }
00080 
00081 void CStructuredOutputMachine::set_surrogate_loss(CLossFunction* loss)
00082 {
00083     SG_REF(loss);
00084     SG_UNREF(m_surrogate_loss);
00085     m_surrogate_loss = loss;
00086 }
00087 
00088 CLossFunction* CStructuredOutputMachine::get_surrogate_loss() const
00089 {
00090     SG_REF(m_surrogate_loss);
00091     return m_surrogate_loss;
00092 }
00093 
00094 float64_t CStructuredOutputMachine::risk_nslack_margin_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info)
00095 {
00096     int32_t dim = m_model->get_dim();
00097 
00098     int32_t from=0, to=0;
00099     CFeatures* features = get_features();
00100     if (info)
00101     {
00102         from = info->m_from;
00103         to = (info->m_N == 0) ? features->get_num_vectors() : from+info->m_N;
00104     }
00105     else
00106     {
00107         from = 0;
00108         to = features->get_num_vectors();
00109     }
00110     SG_UNREF(features);
00111 
00112     float64_t R = 0.0;
00113     for (int32_t i=0; i<dim; i++)
00114         subgrad[i] = 0;
00115 
00116     for (int32_t i=from; i<to; i++)
00117     {
00118         CResultSet* result = m_model->argmax(SGVector<float64_t>(W,dim,false), i, true);
00119         SGVector<float64_t> psi_pred = result->psi_pred;
00120         SGVector<float64_t> psi_truth = result->psi_truth;
00121         SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, 1.0, psi_pred.vector, dim);
00122         SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, -1.0, psi_truth.vector, dim);
00123         R += result->score;
00124         SG_UNREF(result);
00125     }
00126 
00127     return R;
00128 }
00129 
00130 float64_t CStructuredOutputMachine::risk_nslack_slack_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info)
00131 {
00132     SG_ERROR("%s::risk_nslack_slack_rescale() has not been implemented!\n", get_name());
00133     return 0.0;
00134 }
00135 
00136 float64_t CStructuredOutputMachine::risk_1slack_margin_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info)
00137 {
00138     SG_ERROR("%s::risk_1slack_margin_rescale() has not been implemented!\n", get_name());
00139     return 0.0;
00140 }
00141 
00142 float64_t CStructuredOutputMachine::risk_1slack_slack_rescale(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info)
00143 {
00144     SG_ERROR("%s::risk_1slack_slack_rescale() has not been implemented!\n", get_name());
00145     return 0.0;
00146 }
00147 
00148 float64_t CStructuredOutputMachine::risk_customized_formulation(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info)
00149 {
00150     SG_ERROR("%s::risk_customized_formulation() has not been implemented!\n", get_name());
00151     return 0.0;
00152 }
00153 
00154 float64_t CStructuredOutputMachine::risk(float64_t* subgrad, float64_t* W,
00155         TMultipleCPinfo* info, EStructRiskType rtype)
00156 {
00157     float64_t ret = 0.0;
00158     switch(rtype)
00159     {
00160         case N_SLACK_MARGIN_RESCALING:
00161             ret = risk_nslack_margin_rescale(subgrad, W, info);
00162             break;
00163         case N_SLACK_SLACK_RESCALING:
00164             ret = risk_nslack_slack_rescale(subgrad, W, info);
00165             break;
00166         case ONE_SLACK_MARGIN_RESCALING:
00167             ret = risk_1slack_margin_rescale(subgrad, W, info);
00168             break;
00169         case ONE_SLACK_SLACK_RESCALING:
00170             ret = risk_1slack_slack_rescale(subgrad, W, info);
00171             break;
00172         case CUSTOMIZED_RISK:
00173             ret = risk_customized_formulation(subgrad, W, info);
00174             break;
00175         default:
00176             SG_ERROR("%s::risk(): cannot recognize the risk type!\n", get_name());
00177             ret = -1;
00178             break;
00179     }
00180     return ret;
00181 }
00182 
00183 CSOSVMHelper* CStructuredOutputMachine::get_helper() const
00184 {
00185     if (m_helper == NULL)
00186     {
00187         SG_ERROR("%s::get_helper(): no helper has been created!"
00188             "Please set verbose before training!\n", get_name());
00189     }
00190 
00191     SG_REF(m_helper);
00192     return m_helper;
00193 }
00194 
00195 void CStructuredOutputMachine::set_verbose(bool verbose)
00196 {
00197     m_verbose = verbose;
00198 }
00199 
00200 bool CStructuredOutputMachine::get_verbose() const
00201 {
00202     return m_verbose;
00203 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation