SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MAPInference.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Shell Hu
00008  * Copyright (C) 2013 Shell Hu
00009  */
00010 
00011 #include <shogun/structure/MAPInference.h>
00012 #include <shogun/structure/BeliefPropagation.h>
00013 #include <shogun/labels/FactorGraphLabels.h>
00014 
00015 using namespace shogun;
00016 
00017 CMAPInference::CMAPInference() : CSGObject()
00018 {
00019     SG_UNSTABLE("CMAPInference::CMAPInference()", "\n");
00020 
00021     init();
00022 }
00023 
00024 CMAPInference::CMAPInference(CFactorGraph* fg, EMAPInferType inference_method)
00025     : CSGObject()
00026 {
00027     init();
00028     m_fg = fg;
00029 
00030     REQUIRE(fg != NULL, "%s::CMAPInference(): fg cannot be NULL!\n", get_name());
00031 
00032     switch(inference_method)
00033     {
00034         case TREE_MAX_PROD:
00035             m_infer_impl = new CTreeMaxProduct(fg);
00036             break;
00037         case LOOPY_MAX_PROD:
00038             SG_ERROR("%s::CMAPInference(): LoopyMaxProduct has not been implemented!\n",
00039                 get_name());
00040             break;
00041         case LP_RELAXATION:
00042             SG_ERROR("%s::CMAPInference(): LPRelaxation has not been implemented!\n",
00043                 get_name());
00044             break;
00045         case TRWS_MAX_PROD:
00046             SG_ERROR("%s::CMAPInference(): TRW-S has not been implemented!\n",
00047                 get_name());
00048             break;
00049         case ITER_COND_MODE:
00050             SG_ERROR("%s::CMAPInference(): ICM has not been implemented!\n",
00051                 get_name());
00052             break;
00053         case NAIVE_MEAN_FIELD:
00054             SG_ERROR("%s::CMAPInference(): NaiveMeanField has not been implemented!\n",
00055                 get_name());
00056             break;
00057         case STRUCT_MEAN_FIELD:
00058             SG_ERROR("%s::CMAPInference(): StructMeanField has not been implemented!\n",
00059                 get_name());
00060             break;
00061         default:
00062             SG_ERROR("%s::CMAPInference(): unsupported inference method!\n",
00063                 get_name());
00064             break;
00065     }
00066 
00067     SG_REF(m_infer_impl);
00068     SG_REF(m_fg);
00069 }
00070 
00071 CMAPInference::~CMAPInference()
00072 {
00073     SG_UNREF(m_infer_impl);
00074     SG_UNREF(m_outputs);
00075     SG_UNREF(m_fg);
00076 }
00077 
00078 void CMAPInference::init()
00079 {
00080     SG_ADD((CSGObject**)&m_fg, "fg", "factor graph", MS_NOT_AVAILABLE);
00081     SG_ADD((CSGObject**)&m_outputs, "outputs", "Structured outputs", MS_NOT_AVAILABLE);
00082     SG_ADD((CSGObject**)&m_infer_impl, "infer_impl", "Inference implementation", MS_NOT_AVAILABLE);
00083     SG_ADD(&m_energy, "energy", "Minimized energy", MS_NOT_AVAILABLE);
00084 
00085     m_outputs = NULL;
00086     m_infer_impl = NULL;
00087     m_fg = NULL;
00088     m_energy = 0;
00089 }
00090 
00091 void CMAPInference::inference()
00092 {
00093     SGVector<int32_t> assignment(m_fg->get_num_vars());
00094     assignment.zero();
00095     m_energy = m_infer_impl->inference(assignment);
00096 
00097     // create structured output, with default normalized hamming loss
00098     SG_UNREF(m_outputs);
00099     SGVector<float64_t> loss_weights(m_fg->get_num_vars());
00100     SGVector<float64_t>::fill_vector(loss_weights.vector, loss_weights.vlen, 1.0 / loss_weights.vlen);
00101     m_outputs = new CFactorGraphObservation(assignment, loss_weights); // already ref() in constructor
00102     SG_REF(m_outputs);
00103 }
00104 
00105 CFactorGraphObservation* CMAPInference::get_structured_outputs() const
00106 {
00107     SG_REF(m_outputs);
00108     return m_outputs;
00109 }
00110 
00111 float64_t CMAPInference::get_energy() const
00112 {
00113     return m_energy;
00114 }
00115 
00116 //-----------------------------------------------------------------
00117 
00118 CMAPInferImpl::CMAPInferImpl() : CSGObject()
00119 {
00120     register_parameters();
00121 }
00122 
00123 CMAPInferImpl::CMAPInferImpl(CFactorGraph* fg)
00124     : CSGObject()
00125 {
00126     register_parameters();
00127     m_fg = fg;
00128 }
00129 
00130 CMAPInferImpl::~CMAPInferImpl()
00131 {
00132 }
00133 
00134 void CMAPInferImpl::register_parameters()
00135 {
00136     SG_ADD((CSGObject**)&m_fg, "fg",
00137         "Factor graph pointer", MS_NOT_AVAILABLE);
00138 
00139     m_fg = NULL;
00140 }
00141 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation