SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
BeliefPropagation.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Shell Hu
00008  * Copyright (C) 2013 Shell Hu
00009  */
00010 
00011 #include <shogun/structure/BeliefPropagation.h>
00012 #include <shogun/lib/DynamicObjectArray.h>
00013 #include <shogun/io/SGIO.h>
00014 #include <numeric>
00015 #include <algorithm>
00016 #include <functional>
00017 #include <stack>
00018 
00019 using namespace shogun;
00020 
00021 CBeliefPropagation::CBeliefPropagation()
00022     : CMAPInferImpl()
00023 {
00024     SG_UNSTABLE("CBeliefPropagation::CBeliefPropagation()", "\n");
00025 }
00026 
00027 CBeliefPropagation::CBeliefPropagation(CFactorGraph* fg)
00028     : CMAPInferImpl(fg)
00029 {
00030 }
00031 
00032 CBeliefPropagation::~CBeliefPropagation()
00033 {
00034 }
00035 
00036 float64_t CBeliefPropagation::inference(SGVector<int32_t> assignment)
00037 {
00038     SG_ERROR("%s::inference(): please use TreeMaxProduct or LoopyMaxProduct!\n", get_name());
00039     return 0;
00040 }
00041 
00042 // -----------------------------------------------------------------
00043 
00044 CTreeMaxProduct::CTreeMaxProduct()
00045     : CBeliefPropagation()
00046 {
00047     SG_UNSTABLE("CTreeMaxProduct::CTreeMaxProduct()", "\n");
00048 
00049     init();
00050 }
00051 
00052 CTreeMaxProduct::CTreeMaxProduct(CFactorGraph* fg)
00053     : CBeliefPropagation(fg)
00054 {
00055     ASSERT(m_fg != NULL);
00056 
00057     init();
00058 
00059     CDisjointSet* dset = m_fg->get_disjoint_set();
00060     bool is_connected = dset->get_connected();
00061     SG_UNREF(dset);
00062 
00063     if (!is_connected)
00064         m_fg->connect_components();
00065 
00066     get_message_order(m_msg_order, m_is_root);
00067 
00068     // calculate lookup tables for forward messages
00069     // a key is unique because a tree has only one root
00070     // a var or a factor has only one edge towards root
00071     for (uint32_t mi = 0; mi < m_msg_order.size(); mi++)
00072     {
00073         if (m_msg_order[mi]->mtype == VAR_TO_FAC) // var_to_factor
00074         {
00075             // <var_id, msg_id>
00076             m_msg_map_var[m_msg_order[mi]->child] = mi;
00077         }
00078         else // factor_to_var
00079         {
00080             // <fac_id, msg_id>
00081             m_msg_map_fac[m_msg_order[mi]->child] = mi;
00082             // collect incoming msgs for each var_id
00083             m_msgset_map_var[m_msg_order[mi]->parent].insert(mi);
00084         }
00085     }
00086 
00087 }
00088 
00089 CTreeMaxProduct::~CTreeMaxProduct()
00090 {
00091     if (!m_msg_order.empty())
00092     {
00093         for (std::vector<MessageEdge*>::iterator it = m_msg_order.begin(); it != m_msg_order.end(); ++it)
00094             delete *it;
00095     }
00096 }
00097 
00098 void CTreeMaxProduct::init()
00099 {
00100     m_msg_order = std::vector<MessageEdge*>(m_fg->get_num_edges(), (MessageEdge*) NULL);
00101     m_is_root = std::vector<bool>(m_fg->get_cardinalities().size(), false);
00102     m_fw_msgs = std::vector< std::vector<float64_t> >(m_msg_order.size(),
00103             std::vector<float64_t>());
00104     m_bw_msgs = std::vector< std::vector<float64_t> >(m_msg_order.size(),
00105             std::vector<float64_t>());
00106     m_states = std::vector<int32_t>(m_fg->get_cardinalities().size(), 0);
00107 
00108     m_msg_map_var = msg_map_type();
00109     m_msg_map_fac = msg_map_type();
00110     m_msgset_map_var = msgset_map_type();
00111 }
00112 
00113 void CTreeMaxProduct::get_message_order(std::vector<MessageEdge*>& order,
00114     std::vector<bool>& is_root) const
00115 {
00116     ASSERT(m_fg->is_acyclic_graph());
00117 
00118     // 1) pick up roots according to union process of disjoint sets
00119     CDisjointSet* dset = m_fg->get_disjoint_set();
00120     if (!dset->get_connected())
00121     {
00122         SG_UNREF(dset);
00123         SG_ERROR("%s::get_root_indicators(): run connect_components() first!\n", get_name());
00124     }
00125 
00126     int32_t num_vars = m_fg->get_cardinalities().size();
00127     if (is_root.size() != (uint32_t)num_vars)
00128         is_root.resize(num_vars);
00129 
00130     std::fill(is_root.begin(), is_root.end(), false);
00131 
00132     for (int32_t vi = 0; vi < num_vars; vi++)
00133         is_root[dset->find_set(vi)] = true;
00134 
00135     SG_UNREF(dset);
00136     ASSERT(std::accumulate(is_root.begin(), is_root.end(), 0) >= 1);
00137 
00138     // 2) caculate message order
00139     // <var_id, fac_id>
00140     var_factor_map_type vf_map;
00141     CDynamicObjectArray* facs = m_fg->get_factors();
00142 
00143     for (int32_t fi = 0; fi < facs->get_num_elements(); ++fi)
00144     {
00145         CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fi));
00146         SGVector<int32_t> vars = fac->get_variables();
00147         for (int32_t vi = 0; vi < vars.size(); vi++)
00148             vf_map.insert(var_factor_map_type::value_type(vars[vi], fi));
00149 
00150         SG_UNREF(fac);
00151     }
00152 
00153     std::stack<GraphNode*> node_stack;
00154     // init node_stack with root nodes
00155     for (uint32_t ni = 0; ni < is_root.size(); ni++)
00156     {
00157         if (is_root[ni])
00158         {
00159             // node_id = ni, node_type = variable, parent = none
00160             node_stack.push(new GraphNode(ni, VAR_NODE, -1));
00161         }
00162     }
00163 
00164     if (order.size() != (uint32_t)(m_fg->get_num_edges()))
00165         order.resize(m_fg->get_num_edges());
00166 
00167     // find reverse order
00168     int32_t eid = m_fg->get_num_edges() - 1;
00169     while (!node_stack.empty())
00170     {
00171         GraphNode* node = node_stack.top();
00172         node_stack.pop();
00173 
00174         if (node->node_type == VAR_NODE) // child: factor -> parent: var
00175         {
00176             typedef var_factor_map_type::const_iterator const_iter;
00177             std::pair<const_iter, const_iter> adj_factors = vf_map.equal_range(node->node_id);
00178             for (const_iter mi = adj_factors.first; mi != adj_factors.second; ++mi)
00179             {
00180                 int32_t adj_factor_id = mi->second;
00181                 if (adj_factor_id == node->parent)
00182                     continue;
00183 
00184                 order[eid--] = new MessageEdge(FAC_TO_VAR, adj_factor_id, node->node_id);
00185                 // add factor node to node_stack
00186                 node_stack.push(new GraphNode(adj_factor_id, FAC_NODE, node->node_id));
00187             }
00188         }
00189         else // child: var -> parent: factor
00190         {
00191             CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(node->node_id));
00192             SGVector<int32_t> vars = fac->get_variables();
00193             SG_UNREF(fac);
00194 
00195             for (int32_t vi = 0; vi < vars.size(); vi++)
00196             {
00197                 if (vars[vi] == node->parent)
00198                     continue;
00199 
00200                 order[eid--] = new MessageEdge(VAR_TO_FAC, vars[vi], node->node_id);
00201                 // add variable node to node_stack
00202                 node_stack.push(new GraphNode(vars[vi], VAR_NODE, node->node_id));
00203             }
00204         }
00205 
00206         delete node;
00207     }
00208 
00209     SG_UNREF(facs);
00210 }
00211 
00212 float64_t CTreeMaxProduct::inference(SGVector<int32_t> assignment)
00213 {
00214     REQUIRE(assignment.size() == m_fg->get_cardinalities().size(),
00215         "%s::inference(): the output assignment should be prepared as"
00216         "the same size as variables!\n", get_name());
00217 
00218     bottom_up_pass();
00219     top_down_pass();
00220 
00221     for (int32_t vi = 0; vi < assignment.size(); vi++)
00222         assignment[vi] = m_states[vi];
00223 
00224     SG_DEBUG("fg.evaluate_energy(assignment) = %f\n", m_fg->evaluate_energy(assignment));
00225     SG_DEBUG("minimized energy = %f\n", -m_map_energy);
00226 
00227     return -m_map_energy;
00228 }
00229 
00230 void CTreeMaxProduct::bottom_up_pass()
00231 {
00232     SG_DEBUG("\n***enter bottom_up_pass().\n");
00233     CDynamicObjectArray* facs = m_fg->get_factors();
00234     SGVector<int32_t> cards = m_fg->get_cardinalities();
00235 
00236     // init forward msgs to 0
00237     m_fw_msgs.resize(m_msg_order.size());
00238     for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
00239     {
00240         // msg size is determined by var node of msg edge
00241         m_fw_msgs[mi].resize(cards[m_msg_order[mi]->get_var_node()]);
00242         std::fill(m_fw_msgs[mi].begin(), m_fw_msgs[mi].end(), 0);
00243     }
00244 
00245     // pass msgs along the order up to root
00246     // if var -> factor
00247     //   compute q_v2f
00248     // else factor -> var
00249     //   compute r_f2v
00250     // where q_v2f and r_f2v are beliefs of the edge collecting from neighborhoods
00251     // by one end, which will be sent to another end, read Eq.(3.19), Eq.(3.20)
00252     // on [Nowozin et al. 2011] for more detail.
00253     for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
00254     {
00255         SG_DEBUG("mi = %d, mtype: %d %d -> %d\n", mi,
00256             m_msg_order[mi]->mtype, m_msg_order[mi]->child, m_msg_order[mi]->parent);
00257 
00258         if (m_msg_order[mi]->mtype == VAR_TO_FAC) // var -> factor
00259         {
00260             uint32_t var_id = m_msg_order[mi]->child;
00261             const std::set<uint32_t>& msgset_var = m_msgset_map_var[var_id];
00262 
00263             // q_v2f = sum(r_f2v), i.e. sum all incoming f2v msgs
00264             for (std::set<uint32_t>::const_iterator cit = msgset_var.begin(); cit != msgset_var.end(); cit++)
00265             {
00266                 std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
00267                     m_fw_msgs[mi].begin(),
00268                     m_fw_msgs[mi].begin(),
00269                     std::plus<float64_t>());
00270             }
00271         }
00272         else // factor -> var
00273         {
00274             int32_t fac_id = m_msg_order[mi]->child;
00275             int32_t var_id = m_msg_order[mi]->parent;
00276 
00277             CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id));
00278             CTableFactorType* ftype = fac->get_factor_type();
00279             SGVector<int32_t> fvars = fac->get_variables();
00280             SGVector<float64_t> fenrgs = fac->get_energies();
00281             SG_UNREF(fac);
00282 
00283             // find index of var_id in the factor
00284             SGVector<int32_t> fvar_set = fvars.find(var_id);
00285             ASSERT(fvar_set.vlen == 1);
00286             int32_t var_id_index = fvar_set[0];
00287 
00288             std::vector<float64_t> r_f2v(fenrgs.size(), 0);
00289             std::vector<float64_t> r_f2v_max(cards[var_id],
00290                 -std::numeric_limits<float64_t>::infinity());
00291 
00292             // TODO: optimize with index_from_new_state()
00293             // marginalization
00294             // r_f2v = max(-fenrg + sum_{j!=var_id} q_v2f[adj_var_state])
00295             for (int32_t ei = 0; ei < fenrgs.size(); ei++)
00296             {
00297                 r_f2v[ei] = -fenrgs[ei];
00298 
00299                 for (int32_t vi = 0; vi < fvars.size(); vi++)
00300                 {
00301                     if (vi == var_id_index)
00302                         continue;
00303 
00304                     uint32_t adj_msg = m_msg_map_var[fvars[vi]];
00305                     int32_t adj_var_state = ftype->state_from_index(ei, vi);
00306 
00307                     r_f2v[ei] += m_fw_msgs[adj_msg][adj_var_state];
00308                 }
00309 
00310                 // find max value for each state of var_id
00311                 int32_t var_state = ftype->state_from_index(ei, var_id_index);
00312                 if (r_f2v[ei] > r_f2v_max[var_state])
00313                     r_f2v_max[var_state] = r_f2v[ei];
00314             }
00315 
00316             // in max-product, final r_f2v = r_f2v_max
00317             for (int32_t si = 0; si < cards[var_id]; si++)
00318                 m_fw_msgs[mi][si] = r_f2v_max[si];
00319 
00320             SG_UNREF(ftype);
00321         }
00322     }
00323     SG_UNREF(facs);
00324 
00325     // -energy = max(sum_{f} r_f2root)
00326     m_map_energy = 0;
00327     for (uint32_t ri = 0; ri < m_is_root.size(); ri++)
00328     {
00329         if (!m_is_root[ri])
00330             continue;
00331 
00332         const std::set<uint32_t>& msgset_rt = m_msgset_map_var[ri];
00333         std::vector<float64_t> rmarg(cards[ri], 0);
00334         for (std::set<uint32_t>::const_iterator cit = msgset_rt.begin(); cit != msgset_rt.end(); cit++)
00335         {
00336             std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
00337                 rmarg.begin(),
00338                 rmarg.begin(),
00339                 std::plus<float64_t>());
00340         }
00341 
00342         m_map_energy += *std::max_element(rmarg.begin(), rmarg.end());
00343     }
00344     SG_DEBUG("***leave bottom_up_pass().\n");
00345 }
00346 
00347 void CTreeMaxProduct::top_down_pass()
00348 {
00349     SG_DEBUG("\n***enter top_down_pass().\n");
00350     int32_t minf = std::numeric_limits<int32_t>::max();
00351     CDynamicObjectArray* facs = m_fg->get_factors();
00352     SGVector<int32_t> cards = m_fg->get_cardinalities();
00353 
00354     // init backward msgs to 0
00355     m_bw_msgs.resize(m_msg_order.size());
00356     for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi)
00357     {
00358         // msg size is determined by var node of msg edge
00359         m_bw_msgs[mi].resize(cards[m_msg_order[mi]->get_var_node()]);
00360         std::fill(m_bw_msgs[mi].begin(), m_bw_msgs[mi].end(), 0);
00361     }
00362 
00363     // init states to max infinity
00364     m_states.resize(cards.size());
00365     std::fill(m_states.begin(), m_states.end(), minf);
00366 
00367     // infer states of roots first since marginal distributions of
00368     // root variables are ready after bottom-up pass
00369     for (uint32_t ri = 0; ri < m_is_root.size(); ri++)
00370     {
00371         if (!m_is_root[ri])
00372             continue;
00373 
00374         const std::set<uint32_t>& msgset_rt = m_msgset_map_var[ri];
00375         std::vector<float64_t> rmarg(cards[ri], 0);
00376         for (std::set<uint32_t>::const_iterator cit = msgset_rt.begin(); cit != msgset_rt.end(); cit++)
00377         {
00378             // rmarg += m_fw_msgs[*cit]
00379             std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(),
00380                 rmarg.begin(),
00381                 rmarg.begin(),
00382                 std::plus<float64_t>());
00383         }
00384 
00385         // argmax
00386         m_states[ri] = static_cast<int32_t>(
00387             std::max_element(rmarg.begin(), rmarg.end())
00388             - rmarg.begin());
00389     }
00390 
00391     // pass msgs down to leaf
00392     // if factor <- var edge
00393     //   compute q_v2f
00394     //   compute marginal of f
00395     // else var <- factor edge
00396     //   compute r_f2v
00397     for (int32_t mi = (int32_t)(m_msg_order.size()-1); mi >= 0; --mi)
00398     {
00399         SG_DEBUG("mi = %d, mtype: %d %d <- %d\n", mi,
00400             m_msg_order[mi]->mtype, m_msg_order[mi]->child, m_msg_order[mi]->parent);
00401 
00402         if (m_msg_order[mi]->mtype == FAC_TO_VAR) // factor <- var
00403         {
00404             int32_t fac_id = m_msg_order[mi]->child;
00405             int32_t var_id = m_msg_order[mi]->parent;
00406 
00407             CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id));
00408             CTableFactorType* ftype = fac->get_factor_type();
00409             SGVector<int32_t> fvars = fac->get_variables();
00410             SGVector<float64_t> fenrgs = fac->get_energies();
00411             SG_UNREF(fac);
00412 
00413             // find index of var_id in the factor
00414             SGVector<int32_t> fvar_set = fvars.find(var_id);
00415             ASSERT(fvar_set.vlen == 1);
00416             int32_t var_id_index = fvar_set[0];
00417 
00418             // q_v2f = r_bw_parent2v + sum_{child!=f} r_fw_child2v
00419             // make sure the state of var_id has been inferred (factor marginalization)
00420             // s.t. this msg computation will condition on the known state
00421             ASSERT(m_states[var_id] != minf);
00422 
00423             // parent msg: r_bw_parent2v
00424             if (m_is_root[var_id] == 0)
00425             {
00426                 uint32_t parent_msg = m_msg_map_var[var_id];
00427                 std::fill(m_bw_msgs[mi].begin(), m_bw_msgs[mi].end(),
00428                     m_bw_msgs[parent_msg][m_states[var_id]]);
00429             }
00430 
00431             // siblings: sum_{child!=f} r_fw_child2v
00432             const std::set<uint32_t>& msgset_var = m_msgset_map_var[var_id];
00433             for (std::set<uint32_t>::const_iterator cit = msgset_var.begin();
00434                 cit != msgset_var.end(); cit++)
00435             {
00436                 if (m_msg_order[*cit]->child == fac_id)
00437                     continue;
00438 
00439                 for (uint32_t xi = 0; xi < m_bw_msgs[mi].size(); xi++)
00440                     m_bw_msgs[mi][xi] += m_fw_msgs[*cit][m_states[var_id]];
00441             }
00442 
00443             // m_states from maximizing marginal distributions of fac_id
00444             // mu(f) = -E(v_state) + sum_v q_v2f
00445             std::vector<float64_t> marg(fenrgs.size(), 0);
00446             for (uint32_t ei = 0; ei < marg.size(); ei++)
00447             {
00448                 int32_t nei = ftype->index_from_new_state(ei, var_id_index, m_states[var_id]);
00449                 marg[ei] = -fenrgs[nei];
00450 
00451                 for (int32_t vi = 0; vi < fvars.size(); vi++)
00452                 {
00453                     if (vi == var_id_index)
00454                     {
00455                         int32_t var_id_state = ftype->state_from_index(ei, var_id_index);
00456                         if (m_states[var_id] != minf)
00457                             var_id_state = m_states[var_id];
00458 
00459                         marg[ei] += m_bw_msgs[mi][var_id_state];
00460                     }
00461                     else
00462                     {
00463                         uint32_t adj_id = fvars[vi];
00464                         uint32_t adj_msg = m_msg_map_var[adj_id];
00465                         int32_t adj_id_state = ftype->state_from_index(ei, vi);
00466 
00467                         marg[ei] += m_fw_msgs[adj_msg][adj_id_state];
00468                     }
00469                 }
00470             }
00471 
00472             int32_t ei_max = static_cast<int32_t>(
00473                 std::max_element(marg.begin(), marg.end())
00474                 - marg.begin());
00475 
00476             // infer states of neiboring vars of f
00477             for (int32_t vi = 0; vi < fvars.size(); vi++)
00478             {
00479                 int32_t nvar_id = fvars[vi];
00480                 // usually parent node has been inferred
00481                 if (m_states[nvar_id] != minf)
00482                     continue;
00483 
00484                 int32_t nvar_id_state = ftype->state_from_index(ei_max, vi);
00485                 m_states[nvar_id] = nvar_id_state;
00486             }
00487 
00488             SG_UNREF(ftype);
00489         }
00490         else // var <- factor
00491         {
00492             int32_t var_id = m_msg_order[mi]->child;
00493             int32_t fac_id = m_msg_order[mi]->parent;
00494 
00495             CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id));
00496             CTableFactorType* ftype = fac->get_factor_type();
00497             SGVector<int32_t> fvars = fac->get_variables();
00498             SGVector<float64_t> fenrgs = fac->get_energies();
00499             SG_UNREF(fac);
00500 
00501             // find index of var_id in the factor
00502             SGVector<int32_t> fvar_set = fvars.find(var_id);
00503             ASSERT(fvar_set.vlen == 1);
00504             int32_t var_id_index = fvar_set[0];
00505 
00506             uint32_t msg_parent = m_msg_map_fac[fac_id];
00507             int32_t var_parent = m_msg_order[msg_parent]->parent;
00508 
00509             std::vector<float64_t> r_f2v(fenrgs.size());
00510             std::vector<float64_t> r_f2v_max(cards[var_id],
00511                 -std::numeric_limits<float64_t>::infinity());
00512 
00513             // r_f2v = max(-fenrg + sum_{j!=var_id} q_v2f[adj_var_state])
00514             for (int32_t ei = 0; ei < fenrgs.size(); ei++)
00515             {
00516                 r_f2v[ei] = -fenrgs[ei];
00517 
00518                 for (int32_t vi = 0; vi < fvars.size(); vi++)
00519                 {
00520                     if (vi == var_id_index)
00521                         continue;
00522 
00523                     if (fvars[vi] == var_parent)
00524                     {
00525                         ASSERT(m_states[var_parent] != minf);
00526                         r_f2v[ei] += m_bw_msgs[msg_parent][m_states[var_parent]];
00527                     }
00528                     else
00529                     {
00530                         int32_t adj_id = fvars[vi];
00531                         uint32_t adj_msg = m_msg_map_var[adj_id];
00532                         int32_t adj_var_state = ftype->state_from_index(ei, vi);
00533 
00534                         if (m_states[adj_id] != minf)
00535                             adj_var_state = m_states[adj_id];
00536 
00537                         r_f2v[ei] += m_fw_msgs[adj_msg][adj_var_state];
00538                     }
00539                 }
00540 
00541                 // max marginalization
00542                 int32_t var_id_state = ftype->state_from_index(ei, var_id_index);
00543                 if (r_f2v[ei] > r_f2v_max[var_id_state])
00544                     r_f2v_max[var_id_state] = r_f2v[ei];
00545             }
00546 
00547             for (int32_t si = 0; si < cards[var_id]; si++)
00548                 m_bw_msgs[mi][si] = r_f2v_max[si];
00549 
00550             SG_UNREF(ftype);
00551         }
00552     } // end for msg edge
00553 
00554     SG_UNREF(facs);
00555     SG_DEBUG("***leave top_down_pass().\n");
00556 }
00557 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation