SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #include <shogun/structure/MAPInference.h> 00012 #include <shogun/structure/BeliefPropagation.h> 00013 #include <shogun/labels/FactorGraphLabels.h> 00014 00015 using namespace shogun; 00016 00017 CMAPInference::CMAPInference() : CSGObject() 00018 { 00019 SG_UNSTABLE("CMAPInference::CMAPInference()", "\n"); 00020 00021 init(); 00022 } 00023 00024 CMAPInference::CMAPInference(CFactorGraph* fg, EMAPInferType inference_method) 00025 : CSGObject() 00026 { 00027 init(); 00028 m_fg = fg; 00029 00030 REQUIRE(fg != NULL, "%s::CMAPInference(): fg cannot be NULL!\n", get_name()); 00031 00032 switch(inference_method) 00033 { 00034 case TREE_MAX_PROD: 00035 m_infer_impl = new CTreeMaxProduct(fg); 00036 break; 00037 case LOOPY_MAX_PROD: 00038 SG_ERROR("%s::CMAPInference(): LoopyMaxProduct has not been implemented!\n", 00039 get_name()); 00040 break; 00041 case LP_RELAXATION: 00042 SG_ERROR("%s::CMAPInference(): LPRelaxation has not been implemented!\n", 00043 get_name()); 00044 break; 00045 case TRWS_MAX_PROD: 00046 SG_ERROR("%s::CMAPInference(): TRW-S has not been implemented!\n", 00047 get_name()); 00048 break; 00049 case ITER_COND_MODE: 00050 SG_ERROR("%s::CMAPInference(): ICM has not been implemented!\n", 00051 get_name()); 00052 break; 00053 case NAIVE_MEAN_FIELD: 00054 SG_ERROR("%s::CMAPInference(): NaiveMeanField has not been implemented!\n", 00055 get_name()); 00056 break; 00057 case STRUCT_MEAN_FIELD: 00058 SG_ERROR("%s::CMAPInference(): StructMeanField has not been implemented!\n", 00059 get_name()); 00060 break; 00061 default: 00062 SG_ERROR("%s::CMAPInference(): unsupported inference method!\n", 00063 get_name()); 00064 break; 00065 } 00066 00067 SG_REF(m_infer_impl); 00068 SG_REF(m_fg); 00069 } 00070 00071 CMAPInference::~CMAPInference() 00072 { 00073 SG_UNREF(m_infer_impl); 00074 SG_UNREF(m_outputs); 00075 SG_UNREF(m_fg); 00076 } 00077 00078 void CMAPInference::init() 00079 { 00080 SG_ADD((CSGObject**)&m_fg, "fg", "factor graph", MS_NOT_AVAILABLE); 00081 SG_ADD((CSGObject**)&m_outputs, "outputs", "Structured outputs", MS_NOT_AVAILABLE); 00082 SG_ADD((CSGObject**)&m_infer_impl, "infer_impl", "Inference implementation", MS_NOT_AVAILABLE); 00083 SG_ADD(&m_energy, "energy", "Minimized energy", MS_NOT_AVAILABLE); 00084 00085 m_outputs = NULL; 00086 m_infer_impl = NULL; 00087 m_fg = NULL; 00088 m_energy = 0; 00089 } 00090 00091 void CMAPInference::inference() 00092 { 00093 SGVector<int32_t> assignment(m_fg->get_num_vars()); 00094 assignment.zero(); 00095 m_energy = m_infer_impl->inference(assignment); 00096 00097 // create structured output, with default normalized hamming loss 00098 SG_UNREF(m_outputs); 00099 SGVector<float64_t> loss_weights(m_fg->get_num_vars()); 00100 SGVector<float64_t>::fill_vector(loss_weights.vector, loss_weights.vlen, 1.0 / loss_weights.vlen); 00101 m_outputs = new CFactorGraphObservation(assignment, loss_weights); // already ref() in constructor 00102 SG_REF(m_outputs); 00103 } 00104 00105 CFactorGraphObservation* CMAPInference::get_structured_outputs() const 00106 { 00107 SG_REF(m_outputs); 00108 return m_outputs; 00109 } 00110 00111 float64_t CMAPInference::get_energy() const 00112 { 00113 return m_energy; 00114 } 00115 00116 //----------------------------------------------------------------- 00117 00118 CMAPInferImpl::CMAPInferImpl() : CSGObject() 00119 { 00120 register_parameters(); 00121 } 00122 00123 CMAPInferImpl::CMAPInferImpl(CFactorGraph* fg) 00124 : CSGObject() 00125 { 00126 register_parameters(); 00127 m_fg = fg; 00128 } 00129 00130 CMAPInferImpl::~CMAPInferImpl() 00131 { 00132 } 00133 00134 void CMAPInferImpl::register_parameters() 00135 { 00136 SG_ADD((CSGObject**)&m_fg, "fg", 00137 "Factor graph pointer", MS_NOT_AVAILABLE); 00138 00139 m_fg = NULL; 00140 } 00141