SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #ifndef __BELIEF_PROPAGATION_H__ 00012 #define __BELIEF_PROPAGATION_H__ 00013 00014 #include <shogun/lib/SGVector.h> 00015 #include <shogun/structure/FactorGraph.h> 00016 #include <shogun/structure/MAPInference.h> 00017 00018 #include <vector> 00019 #include <set> 00020 00021 #ifdef HAVE_STD_UNORDERED_MAP 00022 #include <unordered_map> 00023 #else 00024 #include <tr1/unordered_map> 00025 #endif 00026 00027 #ifndef DOXYGEN_SHOULD_SKIP_THIS 00028 00029 namespace shogun 00030 { 00031 #define IGNORE_IN_CLASSLIST 00032 00033 enum ENodeType 00034 { 00035 VAR_NODE = 0, 00036 FAC_NODE = 1 00037 }; 00038 00039 enum EEdgeType 00040 { 00041 VAR_TO_FAC = 0, 00042 FAC_TO_VAR = 1 00043 }; 00044 00045 struct GraphNode 00046 { 00047 int32_t node_id; 00048 ENodeType node_type; // 1 var, 0 factor 00049 int32_t parent; // where came from 00050 00051 GraphNode(int32_t id, ENodeType type, int32_t pa) 00052 : node_id(id), node_type(type), parent(pa) { } 00053 ~GraphNode() { } 00054 }; 00055 00056 struct MessageEdge 00057 { 00058 EEdgeType mtype; // 1 var_to_factor, 0 factor_to_var 00059 int32_t child; 00060 int32_t parent; 00061 00062 MessageEdge(EEdgeType type, int32_t ch, int32_t pa) 00063 : mtype(type), child(ch), parent(pa) { } 00064 00065 ~MessageEdge() { } 00066 00067 inline int32_t get_var_node() 00068 { 00069 return mtype == VAR_TO_FAC ? child : parent; 00070 } 00071 00072 inline int32_t get_factor_node() 00073 { 00074 return mtype == VAR_TO_FAC ? parent : child; 00075 } 00076 }; 00077 00079 IGNORE_IN_CLASSLIST class CBeliefPropagation : public CMAPInferImpl 00080 { 00081 public: 00082 CBeliefPropagation(); 00083 CBeliefPropagation(CFactorGraph* fg); 00084 00085 virtual ~CBeliefPropagation(); 00086 00088 virtual const char* get_name() const { return "BeliefPropagation"; } 00089 00090 virtual float64_t inference(SGVector<int32_t> assignment); 00091 00092 protected: 00093 float64_t m_map_energy; 00094 }; 00095 00104 IGNORE_IN_CLASSLIST class CTreeMaxProduct : public CBeliefPropagation 00105 { 00106 #ifdef HAVE_STD_UNORDERED_MAP 00107 typedef std::unordered_map<uint32_t, uint32_t> msg_map_type; 00108 typedef std::unordered_map<uint32_t, std::set<uint32_t> > msgset_map_type; 00109 typedef std::unordered_multimap<int32_t, int32_t> var_factor_map_type; 00110 #else 00111 typedef std::tr1::unordered_map<uint32_t, uint32_t> msg_map_type; 00112 typedef std::tr1::unordered_map<uint32_t, std::set<uint32_t> > msgset_map_type; 00113 typedef std::tr1::unordered_multimap<int32_t, int32_t> var_factor_map_type; 00114 #endif 00115 00116 public: 00117 CTreeMaxProduct(); 00118 CTreeMaxProduct(CFactorGraph* fg); 00119 00120 virtual ~CTreeMaxProduct(); 00121 00123 virtual const char* get_name() const { return "TreeMaxProduct"; } 00124 00125 virtual float64_t inference(SGVector<int32_t> assignment); 00126 00127 protected: 00128 void bottom_up_pass(); 00129 void top_down_pass(); 00130 void get_message_order(std::vector<MessageEdge*>& order, std::vector<bool>& is_root) const; 00131 00132 private: 00133 void init(); 00134 00135 private: 00136 std::vector<MessageEdge*> m_msg_order; 00137 std::vector<bool> m_is_root; 00138 std::vector< std::vector<float64_t> > m_fw_msgs; 00139 std::vector< std::vector<float64_t> > m_bw_msgs; 00140 std::vector<int32_t> m_states; 00141 00142 msg_map_type m_msg_map_var; 00143 msg_map_type m_msg_map_fac; 00144 msgset_map_type m_msgset_map_var; 00145 }; 00146 00147 } 00148 00149 #endif /* DOXYGEN_SHOULD_SKIP_THIS */ 00150 00151 #endif