SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #include <shogun/structure/FactorGraph.h> 00012 #include <shogun/labels/FactorGraphLabels.h> 00013 00014 using namespace shogun; 00015 00016 CFactorGraph::CFactorGraph() 00017 : CSGObject() 00018 { 00019 SG_UNSTABLE("CFactorGraph::CFactorGraph()", "\n"); 00020 00021 register_parameters(); 00022 init(); 00023 } 00024 00025 CFactorGraph::CFactorGraph(SGVector<int32_t> card) 00026 : CSGObject() 00027 { 00028 m_cardinalities = card; 00029 register_parameters(); 00030 init(); 00031 } 00032 00033 CFactorGraph::CFactorGraph(const CFactorGraph &fg) 00034 : CSGObject() 00035 { 00036 register_parameters(); 00037 m_cardinalities = fg.get_cardinalities(); 00038 // No need to unref and ref in this case 00039 m_factors = fg.get_factors(); 00040 m_datasources = fg.get_factor_data_sources(); 00041 m_dset = fg.get_disjoint_set(); 00042 m_has_cycle = !(fg.is_acyclic_graph()); 00043 m_num_edges = fg.get_num_edges(); 00044 } 00045 00046 CFactorGraph::~CFactorGraph() 00047 { 00048 SG_UNREF(m_factors); 00049 SG_UNREF(m_datasources); 00050 SG_UNREF(m_dset); 00051 00052 #ifdef USE_REFERENCE_COUNTING 00053 if (m_factors != NULL) 00054 SG_DEBUG("CFactorGraph::~CFactorGraph(): m_factors->ref_count() = %d.\n", m_factors->ref_count()); 00055 00056 if (m_datasources != NULL) 00057 SG_DEBUG("CFactorGraph::~CFactorGraph(): m_datasources->ref_count() = %d.\n", m_datasources->ref_count()); 00058 00059 SG_DEBUG("CFactorGraph::~CFactorGraph(): this->ref_count() = %d.\n", this->ref_count()); 00060 #endif 00061 } 00062 00063 void CFactorGraph::register_parameters() 00064 { 00065 SG_ADD(&m_cardinalities, "cardinalities", "Cardinalities", MS_NOT_AVAILABLE); 00066 SG_ADD((CSGObject**)&m_factors, "factors", "Factors", MS_NOT_AVAILABLE); 00067 SG_ADD((CSGObject**)&m_datasources, "datasources", "Factor data sources", MS_NOT_AVAILABLE); 00068 SG_ADD((CSGObject**)&m_dset, "dset", "Disjoint set", MS_NOT_AVAILABLE); 00069 SG_ADD(&m_has_cycle, "has_cycle", "Whether has circle in graph", MS_NOT_AVAILABLE); 00070 SG_ADD(&m_num_edges, "num_edges", "Number of edges", MS_NOT_AVAILABLE); 00071 } 00072 00073 void CFactorGraph::init() 00074 { 00075 m_has_cycle = false; 00076 m_num_edges = 0; 00077 m_factors = NULL; 00078 m_datasources = NULL; 00079 m_factors = new CDynamicObjectArray(); 00080 m_datasources = new CDynamicObjectArray(); 00081 00082 #ifdef USE_REFERENCE_COUNTING 00083 if (m_factors != NULL) 00084 SG_DEBUG("CFactorGraph::init(): m_factors->ref_count() = %d.\n", m_factors->ref_count()); 00085 #endif 00086 00087 // NOTE m_cards cannot be empty 00088 m_dset = new CDisjointSet(m_cardinalities.size()); 00089 00090 SG_REF(m_factors); 00091 SG_REF(m_datasources); 00092 SG_REF(m_dset); 00093 } 00094 00095 void CFactorGraph::add_factor(CFactor* factor) 00096 { 00097 m_factors->push_back(factor); 00098 m_num_edges += factor->get_variables().size(); 00099 00100 // graph structure changed after adding factors 00101 if (m_dset->get_connected()) 00102 m_dset->set_connected(false); 00103 } 00104 00105 void CFactorGraph::add_data_source(CFactorDataSource* datasource) 00106 { 00107 m_datasources->push_back(datasource); 00108 } 00109 00110 CDynamicObjectArray* CFactorGraph::get_factors() const 00111 { 00112 SG_REF(m_factors); 00113 return m_factors; 00114 } 00115 00116 CDynamicObjectArray* CFactorGraph::get_factor_data_sources() const 00117 { 00118 SG_REF(m_datasources); 00119 return m_datasources; 00120 } 00121 00122 int32_t CFactorGraph::get_num_factors() const 00123 { 00124 return m_factors->get_num_elements(); 00125 } 00126 00127 SGVector<int32_t> CFactorGraph::get_cardinalities() const 00128 { 00129 return m_cardinalities; 00130 } 00131 00132 void CFactorGraph::set_cardinalities(SGVector<int32_t> cards) 00133 { 00134 m_cardinalities = cards.clone(); 00135 } 00136 00137 CDisjointSet* CFactorGraph::get_disjoint_set() const 00138 { 00139 SG_REF(m_dset); 00140 return m_dset; 00141 } 00142 00143 int32_t CFactorGraph::get_num_edges() const 00144 { 00145 return m_num_edges; 00146 } 00147 00148 int32_t CFactorGraph::get_num_vars() const 00149 { 00150 return m_cardinalities.size(); 00151 } 00152 00153 void CFactorGraph::compute_energies() 00154 { 00155 for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi) 00156 { 00157 CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi)); 00158 fac->compute_energies(); 00159 SG_UNREF(fac); 00160 } 00161 } 00162 00163 float64_t CFactorGraph::evaluate_energy(const SGVector<int32_t> state) const 00164 { 00165 ASSERT(state.size() == m_cardinalities.size()); 00166 00167 float64_t energy = 0.0; 00168 for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi) 00169 { 00170 CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi)); 00171 energy += fac->evaluate_energy(state); 00172 SG_UNREF(fac); 00173 } 00174 return energy; 00175 } 00176 00177 float64_t CFactorGraph::evaluate_energy(const CFactorGraphObservation* obs) const 00178 { 00179 return evaluate_energy(obs->get_data()); 00180 } 00181 00182 SGVector<float64_t> CFactorGraph::evaluate_energies() const 00183 { 00184 int num_assig = 1; 00185 SGVector<int32_t> cumprod_cards(m_cardinalities.size()); 00186 for (int32_t n = 0; n < m_cardinalities.size(); ++n) 00187 { 00188 cumprod_cards[n] = num_assig; 00189 num_assig *= m_cardinalities[n]; 00190 } 00191 00192 SGVector<float64_t> etable(num_assig); 00193 for (int32_t ei = 0; ei < num_assig; ++ei) 00194 { 00195 SGVector<int32_t> assig(m_cardinalities.size()); 00196 for (int32_t vi = 0; vi < m_cardinalities.size(); ++vi) 00197 assig[vi] = (ei / cumprod_cards[vi]) % m_cardinalities[vi]; 00198 00199 etable[ei] = evaluate_energy(assig); 00200 00201 for (int32_t vi = 0; vi < m_cardinalities.size(); ++vi) 00202 SG_SPRINT("%d ", assig[vi]); 00203 00204 SG_SPRINT("| %f\n", etable[ei]); 00205 } 00206 00207 return etable; 00208 } 00209 00210 void CFactorGraph::connect_components() 00211 { 00212 if (m_dset->get_connected()) 00213 return; 00214 00215 // need to be reset once factor graph is updated 00216 m_dset->make_sets(); 00217 bool flag = false; 00218 00219 for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi) 00220 { 00221 CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi)); 00222 SGVector<int32_t> vars = fac->get_variables(); 00223 00224 int32_t r0 = m_dset->find_set(vars[0]); 00225 for (int32_t vi = 1; vi < vars.size(); vi++) 00226 { 00227 // for two nodes in a factor, should be an edge between them 00228 // but this time link() isn't performed, if they are linked already 00229 // means there is another path connected them, so cycle detected 00230 int32_t ri = m_dset->find_set(vars[vi]); 00231 00232 if (r0 == ri) 00233 { 00234 flag = true; 00235 continue; 00236 } 00237 00238 r0 = m_dset->link_set(r0, ri); 00239 } 00240 00241 SG_UNREF(fac); 00242 } 00243 m_has_cycle = flag; 00244 m_dset->set_connected(true); 00245 } 00246 00247 bool CFactorGraph::is_acyclic_graph() const 00248 { 00249 return !m_has_cycle; 00250 } 00251 00252 bool CFactorGraph::is_connected_graph() const 00253 { 00254 return (m_dset->get_num_sets() == 1); 00255 } 00256 00257 bool CFactorGraph::is_tree_graph() const 00258 { 00259 return (m_has_cycle == false && m_dset->get_num_sets() == 1); 00260 } 00261 00262 void CFactorGraph::loss_augmentation(CFactorGraphObservation* gt) 00263 { 00264 loss_augmentation(gt->get_data(), gt->get_loss_weights()); 00265 } 00266 00267 void CFactorGraph::loss_augmentation(SGVector<int32_t> states_gt, SGVector<float64_t> loss) 00268 { 00269 if (loss.size() == 0) 00270 { 00271 loss.resize_vector(states_gt.size()); 00272 SGVector<float64_t>::fill_vector(loss.vector, loss.vlen, 1.0 / states_gt.size()); 00273 } 00274 00275 int32_t num_vars = states_gt.size(); 00276 ASSERT(num_vars == loss.size()); 00277 00278 SGVector<int32_t> var_flags(num_vars); 00279 var_flags.zero(); 00280 00281 // augment loss to incorrect states in the first factor containing the variable 00282 // since += L_i for each variable if it takes wrong state ever 00283 // TODO: augment unary factors 00284 for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi) 00285 { 00286 CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi)); 00287 SGVector<int32_t> vars = fac->get_variables(); 00288 for (int32_t vi = 0; vi < vars.size(); vi++) 00289 { 00290 int32_t vv = vars[vi]; 00291 ASSERT(vv < num_vars); 00292 if (var_flags[vv]) 00293 continue; 00294 00295 SGVector<float64_t> energies = fac->get_energies(); 00296 for (int32_t ei = 0; ei < energies.size(); ei++) 00297 { 00298 CTableFactorType* ftype = fac->get_factor_type(); 00299 int32_t vstate = ftype->state_from_index(ei, vi); 00300 SG_UNREF(ftype); 00301 00302 if (states_gt[vv] == vstate) 00303 continue; 00304 00305 // -delta(y_n, y_i_n) 00306 fac->set_energy(ei, energies[ei] - loss[vv]); 00307 } 00308 00309 var_flags[vv] = 1; 00310 } 00311 00312 SG_UNREF(fac); 00313 } 00314 00315 // make sure all variables have been checked 00316 int32_t min_var = SGVector<int32_t>::min(var_flags.vector, var_flags.vlen); 00317 ASSERT(min_var == 1); 00318 } 00319