SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #include <shogun/structure/BeliefPropagation.h> 00012 #include <shogun/lib/DynamicObjectArray.h> 00013 #include <shogun/io/SGIO.h> 00014 #include <numeric> 00015 #include <algorithm> 00016 #include <functional> 00017 #include <stack> 00018 00019 using namespace shogun; 00020 00021 CBeliefPropagation::CBeliefPropagation() 00022 : CMAPInferImpl() 00023 { 00024 SG_UNSTABLE("CBeliefPropagation::CBeliefPropagation()", "\n"); 00025 } 00026 00027 CBeliefPropagation::CBeliefPropagation(CFactorGraph* fg) 00028 : CMAPInferImpl(fg) 00029 { 00030 } 00031 00032 CBeliefPropagation::~CBeliefPropagation() 00033 { 00034 } 00035 00036 float64_t CBeliefPropagation::inference(SGVector<int32_t> assignment) 00037 { 00038 SG_ERROR("%s::inference(): please use TreeMaxProduct or LoopyMaxProduct!\n", get_name()); 00039 return 0; 00040 } 00041 00042 // ----------------------------------------------------------------- 00043 00044 CTreeMaxProduct::CTreeMaxProduct() 00045 : CBeliefPropagation() 00046 { 00047 SG_UNSTABLE("CTreeMaxProduct::CTreeMaxProduct()", "\n"); 00048 00049 init(); 00050 } 00051 00052 CTreeMaxProduct::CTreeMaxProduct(CFactorGraph* fg) 00053 : CBeliefPropagation(fg) 00054 { 00055 ASSERT(m_fg != NULL); 00056 00057 init(); 00058 00059 CDisjointSet* dset = m_fg->get_disjoint_set(); 00060 bool is_connected = dset->get_connected(); 00061 SG_UNREF(dset); 00062 00063 if (!is_connected) 00064 m_fg->connect_components(); 00065 00066 get_message_order(m_msg_order, m_is_root); 00067 00068 // calculate lookup tables for forward messages 00069 // a key is unique because a tree has only one root 00070 // a var or a factor has only one edge towards root 00071 for (uint32_t mi = 0; mi < m_msg_order.size(); mi++) 00072 { 00073 if (m_msg_order[mi]->mtype == VAR_TO_FAC) // var_to_factor 00074 { 00075 // <var_id, msg_id> 00076 m_msg_map_var[m_msg_order[mi]->child] = mi; 00077 } 00078 else // factor_to_var 00079 { 00080 // <fac_id, msg_id> 00081 m_msg_map_fac[m_msg_order[mi]->child] = mi; 00082 // collect incoming msgs for each var_id 00083 m_msgset_map_var[m_msg_order[mi]->parent].insert(mi); 00084 } 00085 } 00086 00087 } 00088 00089 CTreeMaxProduct::~CTreeMaxProduct() 00090 { 00091 if (!m_msg_order.empty()) 00092 { 00093 for (std::vector<MessageEdge*>::iterator it = m_msg_order.begin(); it != m_msg_order.end(); ++it) 00094 delete *it; 00095 } 00096 } 00097 00098 void CTreeMaxProduct::init() 00099 { 00100 m_msg_order = std::vector<MessageEdge*>(m_fg->get_num_edges(), (MessageEdge*) NULL); 00101 m_is_root = std::vector<bool>(m_fg->get_cardinalities().size(), false); 00102 m_fw_msgs = std::vector< std::vector<float64_t> >(m_msg_order.size(), 00103 std::vector<float64_t>()); 00104 m_bw_msgs = std::vector< std::vector<float64_t> >(m_msg_order.size(), 00105 std::vector<float64_t>()); 00106 m_states = std::vector<int32_t>(m_fg->get_cardinalities().size(), 0); 00107 00108 m_msg_map_var = msg_map_type(); 00109 m_msg_map_fac = msg_map_type(); 00110 m_msgset_map_var = msgset_map_type(); 00111 } 00112 00113 void CTreeMaxProduct::get_message_order(std::vector<MessageEdge*>& order, 00114 std::vector<bool>& is_root) const 00115 { 00116 ASSERT(m_fg->is_acyclic_graph()); 00117 00118 // 1) pick up roots according to union process of disjoint sets 00119 CDisjointSet* dset = m_fg->get_disjoint_set(); 00120 if (!dset->get_connected()) 00121 { 00122 SG_UNREF(dset); 00123 SG_ERROR("%s::get_root_indicators(): run connect_components() first!\n", get_name()); 00124 } 00125 00126 int32_t num_vars = m_fg->get_cardinalities().size(); 00127 if (is_root.size() != (uint32_t)num_vars) 00128 is_root.resize(num_vars); 00129 00130 std::fill(is_root.begin(), is_root.end(), false); 00131 00132 for (int32_t vi = 0; vi < num_vars; vi++) 00133 is_root[dset->find_set(vi)] = true; 00134 00135 SG_UNREF(dset); 00136 ASSERT(std::accumulate(is_root.begin(), is_root.end(), 0) >= 1); 00137 00138 // 2) caculate message order 00139 // <var_id, fac_id> 00140 var_factor_map_type vf_map; 00141 CDynamicObjectArray* facs = m_fg->get_factors(); 00142 00143 for (int32_t fi = 0; fi < facs->get_num_elements(); ++fi) 00144 { 00145 CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fi)); 00146 SGVector<int32_t> vars = fac->get_variables(); 00147 for (int32_t vi = 0; vi < vars.size(); vi++) 00148 vf_map.insert(var_factor_map_type::value_type(vars[vi], fi)); 00149 00150 SG_UNREF(fac); 00151 } 00152 00153 std::stack<GraphNode*> node_stack; 00154 // init node_stack with root nodes 00155 for (uint32_t ni = 0; ni < is_root.size(); ni++) 00156 { 00157 if (is_root[ni]) 00158 { 00159 // node_id = ni, node_type = variable, parent = none 00160 node_stack.push(new GraphNode(ni, VAR_NODE, -1)); 00161 } 00162 } 00163 00164 if (order.size() != (uint32_t)(m_fg->get_num_edges())) 00165 order.resize(m_fg->get_num_edges()); 00166 00167 // find reverse order 00168 int32_t eid = m_fg->get_num_edges() - 1; 00169 while (!node_stack.empty()) 00170 { 00171 GraphNode* node = node_stack.top(); 00172 node_stack.pop(); 00173 00174 if (node->node_type == VAR_NODE) // child: factor -> parent: var 00175 { 00176 typedef var_factor_map_type::const_iterator const_iter; 00177 std::pair<const_iter, const_iter> adj_factors = vf_map.equal_range(node->node_id); 00178 for (const_iter mi = adj_factors.first; mi != adj_factors.second; ++mi) 00179 { 00180 int32_t adj_factor_id = mi->second; 00181 if (adj_factor_id == node->parent) 00182 continue; 00183 00184 order[eid--] = new MessageEdge(FAC_TO_VAR, adj_factor_id, node->node_id); 00185 // add factor node to node_stack 00186 node_stack.push(new GraphNode(adj_factor_id, FAC_NODE, node->node_id)); 00187 } 00188 } 00189 else // child: var -> parent: factor 00190 { 00191 CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(node->node_id)); 00192 SGVector<int32_t> vars = fac->get_variables(); 00193 SG_UNREF(fac); 00194 00195 for (int32_t vi = 0; vi < vars.size(); vi++) 00196 { 00197 if (vars[vi] == node->parent) 00198 continue; 00199 00200 order[eid--] = new MessageEdge(VAR_TO_FAC, vars[vi], node->node_id); 00201 // add variable node to node_stack 00202 node_stack.push(new GraphNode(vars[vi], VAR_NODE, node->node_id)); 00203 } 00204 } 00205 00206 delete node; 00207 } 00208 00209 SG_UNREF(facs); 00210 } 00211 00212 float64_t CTreeMaxProduct::inference(SGVector<int32_t> assignment) 00213 { 00214 REQUIRE(assignment.size() == m_fg->get_cardinalities().size(), 00215 "%s::inference(): the output assignment should be prepared as" 00216 "the same size as variables!\n", get_name()); 00217 00218 bottom_up_pass(); 00219 top_down_pass(); 00220 00221 for (int32_t vi = 0; vi < assignment.size(); vi++) 00222 assignment[vi] = m_states[vi]; 00223 00224 SG_DEBUG("fg.evaluate_energy(assignment) = %f\n", m_fg->evaluate_energy(assignment)); 00225 SG_DEBUG("minimized energy = %f\n", -m_map_energy); 00226 00227 return -m_map_energy; 00228 } 00229 00230 void CTreeMaxProduct::bottom_up_pass() 00231 { 00232 SG_DEBUG("\n***enter bottom_up_pass().\n"); 00233 CDynamicObjectArray* facs = m_fg->get_factors(); 00234 SGVector<int32_t> cards = m_fg->get_cardinalities(); 00235 00236 // init forward msgs to 0 00237 m_fw_msgs.resize(m_msg_order.size()); 00238 for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi) 00239 { 00240 // msg size is determined by var node of msg edge 00241 m_fw_msgs[mi].resize(cards[m_msg_order[mi]->get_var_node()]); 00242 std::fill(m_fw_msgs[mi].begin(), m_fw_msgs[mi].end(), 0); 00243 } 00244 00245 // pass msgs along the order up to root 00246 // if var -> factor 00247 // compute q_v2f 00248 // else factor -> var 00249 // compute r_f2v 00250 // where q_v2f and r_f2v are beliefs of the edge collecting from neighborhoods 00251 // by one end, which will be sent to another end, read Eq.(3.19), Eq.(3.20) 00252 // on [Nowozin et al. 2011] for more detail. 00253 for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi) 00254 { 00255 SG_DEBUG("mi = %d, mtype: %d %d -> %d\n", mi, 00256 m_msg_order[mi]->mtype, m_msg_order[mi]->child, m_msg_order[mi]->parent); 00257 00258 if (m_msg_order[mi]->mtype == VAR_TO_FAC) // var -> factor 00259 { 00260 uint32_t var_id = m_msg_order[mi]->child; 00261 const std::set<uint32_t>& msgset_var = m_msgset_map_var[var_id]; 00262 00263 // q_v2f = sum(r_f2v), i.e. sum all incoming f2v msgs 00264 for (std::set<uint32_t>::const_iterator cit = msgset_var.begin(); cit != msgset_var.end(); cit++) 00265 { 00266 std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(), 00267 m_fw_msgs[mi].begin(), 00268 m_fw_msgs[mi].begin(), 00269 std::plus<float64_t>()); 00270 } 00271 } 00272 else // factor -> var 00273 { 00274 int32_t fac_id = m_msg_order[mi]->child; 00275 int32_t var_id = m_msg_order[mi]->parent; 00276 00277 CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id)); 00278 CTableFactorType* ftype = fac->get_factor_type(); 00279 SGVector<int32_t> fvars = fac->get_variables(); 00280 SGVector<float64_t> fenrgs = fac->get_energies(); 00281 SG_UNREF(fac); 00282 00283 // find index of var_id in the factor 00284 SGVector<int32_t> fvar_set = fvars.find(var_id); 00285 ASSERT(fvar_set.vlen == 1); 00286 int32_t var_id_index = fvar_set[0]; 00287 00288 std::vector<float64_t> r_f2v(fenrgs.size(), 0); 00289 std::vector<float64_t> r_f2v_max(cards[var_id], 00290 -std::numeric_limits<float64_t>::infinity()); 00291 00292 // TODO: optimize with index_from_new_state() 00293 // marginalization 00294 // r_f2v = max(-fenrg + sum_{j!=var_id} q_v2f[adj_var_state]) 00295 for (int32_t ei = 0; ei < fenrgs.size(); ei++) 00296 { 00297 r_f2v[ei] = -fenrgs[ei]; 00298 00299 for (int32_t vi = 0; vi < fvars.size(); vi++) 00300 { 00301 if (vi == var_id_index) 00302 continue; 00303 00304 uint32_t adj_msg = m_msg_map_var[fvars[vi]]; 00305 int32_t adj_var_state = ftype->state_from_index(ei, vi); 00306 00307 r_f2v[ei] += m_fw_msgs[adj_msg][adj_var_state]; 00308 } 00309 00310 // find max value for each state of var_id 00311 int32_t var_state = ftype->state_from_index(ei, var_id_index); 00312 if (r_f2v[ei] > r_f2v_max[var_state]) 00313 r_f2v_max[var_state] = r_f2v[ei]; 00314 } 00315 00316 // in max-product, final r_f2v = r_f2v_max 00317 for (int32_t si = 0; si < cards[var_id]; si++) 00318 m_fw_msgs[mi][si] = r_f2v_max[si]; 00319 00320 SG_UNREF(ftype); 00321 } 00322 } 00323 SG_UNREF(facs); 00324 00325 // -energy = max(sum_{f} r_f2root) 00326 m_map_energy = 0; 00327 for (uint32_t ri = 0; ri < m_is_root.size(); ri++) 00328 { 00329 if (!m_is_root[ri]) 00330 continue; 00331 00332 const std::set<uint32_t>& msgset_rt = m_msgset_map_var[ri]; 00333 std::vector<float64_t> rmarg(cards[ri], 0); 00334 for (std::set<uint32_t>::const_iterator cit = msgset_rt.begin(); cit != msgset_rt.end(); cit++) 00335 { 00336 std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(), 00337 rmarg.begin(), 00338 rmarg.begin(), 00339 std::plus<float64_t>()); 00340 } 00341 00342 m_map_energy += *std::max_element(rmarg.begin(), rmarg.end()); 00343 } 00344 SG_DEBUG("***leave bottom_up_pass().\n"); 00345 } 00346 00347 void CTreeMaxProduct::top_down_pass() 00348 { 00349 SG_DEBUG("\n***enter top_down_pass().\n"); 00350 int32_t minf = std::numeric_limits<int32_t>::max(); 00351 CDynamicObjectArray* facs = m_fg->get_factors(); 00352 SGVector<int32_t> cards = m_fg->get_cardinalities(); 00353 00354 // init backward msgs to 0 00355 m_bw_msgs.resize(m_msg_order.size()); 00356 for (uint32_t mi = 0; mi < m_msg_order.size(); ++mi) 00357 { 00358 // msg size is determined by var node of msg edge 00359 m_bw_msgs[mi].resize(cards[m_msg_order[mi]->get_var_node()]); 00360 std::fill(m_bw_msgs[mi].begin(), m_bw_msgs[mi].end(), 0); 00361 } 00362 00363 // init states to max infinity 00364 m_states.resize(cards.size()); 00365 std::fill(m_states.begin(), m_states.end(), minf); 00366 00367 // infer states of roots first since marginal distributions of 00368 // root variables are ready after bottom-up pass 00369 for (uint32_t ri = 0; ri < m_is_root.size(); ri++) 00370 { 00371 if (!m_is_root[ri]) 00372 continue; 00373 00374 const std::set<uint32_t>& msgset_rt = m_msgset_map_var[ri]; 00375 std::vector<float64_t> rmarg(cards[ri], 0); 00376 for (std::set<uint32_t>::const_iterator cit = msgset_rt.begin(); cit != msgset_rt.end(); cit++) 00377 { 00378 // rmarg += m_fw_msgs[*cit] 00379 std::transform(m_fw_msgs[*cit].begin(), m_fw_msgs[*cit].end(), 00380 rmarg.begin(), 00381 rmarg.begin(), 00382 std::plus<float64_t>()); 00383 } 00384 00385 // argmax 00386 m_states[ri] = static_cast<int32_t>( 00387 std::max_element(rmarg.begin(), rmarg.end()) 00388 - rmarg.begin()); 00389 } 00390 00391 // pass msgs down to leaf 00392 // if factor <- var edge 00393 // compute q_v2f 00394 // compute marginal of f 00395 // else var <- factor edge 00396 // compute r_f2v 00397 for (int32_t mi = (int32_t)(m_msg_order.size()-1); mi >= 0; --mi) 00398 { 00399 SG_DEBUG("mi = %d, mtype: %d %d <- %d\n", mi, 00400 m_msg_order[mi]->mtype, m_msg_order[mi]->child, m_msg_order[mi]->parent); 00401 00402 if (m_msg_order[mi]->mtype == FAC_TO_VAR) // factor <- var 00403 { 00404 int32_t fac_id = m_msg_order[mi]->child; 00405 int32_t var_id = m_msg_order[mi]->parent; 00406 00407 CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id)); 00408 CTableFactorType* ftype = fac->get_factor_type(); 00409 SGVector<int32_t> fvars = fac->get_variables(); 00410 SGVector<float64_t> fenrgs = fac->get_energies(); 00411 SG_UNREF(fac); 00412 00413 // find index of var_id in the factor 00414 SGVector<int32_t> fvar_set = fvars.find(var_id); 00415 ASSERT(fvar_set.vlen == 1); 00416 int32_t var_id_index = fvar_set[0]; 00417 00418 // q_v2f = r_bw_parent2v + sum_{child!=f} r_fw_child2v 00419 // make sure the state of var_id has been inferred (factor marginalization) 00420 // s.t. this msg computation will condition on the known state 00421 ASSERT(m_states[var_id] != minf); 00422 00423 // parent msg: r_bw_parent2v 00424 if (m_is_root[var_id] == 0) 00425 { 00426 uint32_t parent_msg = m_msg_map_var[var_id]; 00427 std::fill(m_bw_msgs[mi].begin(), m_bw_msgs[mi].end(), 00428 m_bw_msgs[parent_msg][m_states[var_id]]); 00429 } 00430 00431 // siblings: sum_{child!=f} r_fw_child2v 00432 const std::set<uint32_t>& msgset_var = m_msgset_map_var[var_id]; 00433 for (std::set<uint32_t>::const_iterator cit = msgset_var.begin(); 00434 cit != msgset_var.end(); cit++) 00435 { 00436 if (m_msg_order[*cit]->child == fac_id) 00437 continue; 00438 00439 for (uint32_t xi = 0; xi < m_bw_msgs[mi].size(); xi++) 00440 m_bw_msgs[mi][xi] += m_fw_msgs[*cit][m_states[var_id]]; 00441 } 00442 00443 // m_states from maximizing marginal distributions of fac_id 00444 // mu(f) = -E(v_state) + sum_v q_v2f 00445 std::vector<float64_t> marg(fenrgs.size(), 0); 00446 for (uint32_t ei = 0; ei < marg.size(); ei++) 00447 { 00448 int32_t nei = ftype->index_from_new_state(ei, var_id_index, m_states[var_id]); 00449 marg[ei] = -fenrgs[nei]; 00450 00451 for (int32_t vi = 0; vi < fvars.size(); vi++) 00452 { 00453 if (vi == var_id_index) 00454 { 00455 int32_t var_id_state = ftype->state_from_index(ei, var_id_index); 00456 if (m_states[var_id] != minf) 00457 var_id_state = m_states[var_id]; 00458 00459 marg[ei] += m_bw_msgs[mi][var_id_state]; 00460 } 00461 else 00462 { 00463 uint32_t adj_id = fvars[vi]; 00464 uint32_t adj_msg = m_msg_map_var[adj_id]; 00465 int32_t adj_id_state = ftype->state_from_index(ei, vi); 00466 00467 marg[ei] += m_fw_msgs[adj_msg][adj_id_state]; 00468 } 00469 } 00470 } 00471 00472 int32_t ei_max = static_cast<int32_t>( 00473 std::max_element(marg.begin(), marg.end()) 00474 - marg.begin()); 00475 00476 // infer states of neiboring vars of f 00477 for (int32_t vi = 0; vi < fvars.size(); vi++) 00478 { 00479 int32_t nvar_id = fvars[vi]; 00480 // usually parent node has been inferred 00481 if (m_states[nvar_id] != minf) 00482 continue; 00483 00484 int32_t nvar_id_state = ftype->state_from_index(ei_max, vi); 00485 m_states[nvar_id] = nvar_id_state; 00486 } 00487 00488 SG_UNREF(ftype); 00489 } 00490 else // var <- factor 00491 { 00492 int32_t var_id = m_msg_order[mi]->child; 00493 int32_t fac_id = m_msg_order[mi]->parent; 00494 00495 CFactor* fac = dynamic_cast<CFactor*>(facs->get_element(fac_id)); 00496 CTableFactorType* ftype = fac->get_factor_type(); 00497 SGVector<int32_t> fvars = fac->get_variables(); 00498 SGVector<float64_t> fenrgs = fac->get_energies(); 00499 SG_UNREF(fac); 00500 00501 // find index of var_id in the factor 00502 SGVector<int32_t> fvar_set = fvars.find(var_id); 00503 ASSERT(fvar_set.vlen == 1); 00504 int32_t var_id_index = fvar_set[0]; 00505 00506 uint32_t msg_parent = m_msg_map_fac[fac_id]; 00507 int32_t var_parent = m_msg_order[msg_parent]->parent; 00508 00509 std::vector<float64_t> r_f2v(fenrgs.size()); 00510 std::vector<float64_t> r_f2v_max(cards[var_id], 00511 -std::numeric_limits<float64_t>::infinity()); 00512 00513 // r_f2v = max(-fenrg + sum_{j!=var_id} q_v2f[adj_var_state]) 00514 for (int32_t ei = 0; ei < fenrgs.size(); ei++) 00515 { 00516 r_f2v[ei] = -fenrgs[ei]; 00517 00518 for (int32_t vi = 0; vi < fvars.size(); vi++) 00519 { 00520 if (vi == var_id_index) 00521 continue; 00522 00523 if (fvars[vi] == var_parent) 00524 { 00525 ASSERT(m_states[var_parent] != minf); 00526 r_f2v[ei] += m_bw_msgs[msg_parent][m_states[var_parent]]; 00527 } 00528 else 00529 { 00530 int32_t adj_id = fvars[vi]; 00531 uint32_t adj_msg = m_msg_map_var[adj_id]; 00532 int32_t adj_var_state = ftype->state_from_index(ei, vi); 00533 00534 if (m_states[adj_id] != minf) 00535 adj_var_state = m_states[adj_id]; 00536 00537 r_f2v[ei] += m_fw_msgs[adj_msg][adj_var_state]; 00538 } 00539 } 00540 00541 // max marginalization 00542 int32_t var_id_state = ftype->state_from_index(ei, var_id_index); 00543 if (r_f2v[ei] > r_f2v_max[var_id_state]) 00544 r_f2v_max[var_id_state] = r_f2v[ei]; 00545 } 00546 00547 for (int32_t si = 0; si < cards[var_id]; si++) 00548 m_bw_msgs[mi][si] = r_f2v_max[si]; 00549 00550 SG_UNREF(ftype); 00551 } 00552 } // end for msg edge 00553 00554 SG_UNREF(facs); 00555 SG_DEBUG("***leave top_down_pass().\n"); 00556 } 00557