SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
FactorGraph.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Shell Hu
00008  * Copyright (C) 2013 Shell Hu
00009  */
00010 
00011 #include <shogun/structure/FactorGraph.h>
00012 #include <shogun/labels/FactorGraphLabels.h>
00013 
00014 using namespace shogun;
00015 
00016 CFactorGraph::CFactorGraph()
00017     : CSGObject()
00018 {
00019     SG_UNSTABLE("CFactorGraph::CFactorGraph()", "\n");
00020 
00021     register_parameters();
00022     init();
00023 }
00024 
00025 CFactorGraph::CFactorGraph(SGVector<int32_t> card)
00026     : CSGObject()
00027 {
00028     m_cardinalities = card;
00029     register_parameters();
00030     init();
00031 }
00032 
00033 CFactorGraph::CFactorGraph(const CFactorGraph &fg)
00034     : CSGObject()
00035 {
00036     register_parameters();
00037     m_cardinalities = fg.get_cardinalities();
00038     // No need to unref and ref in this case
00039     m_factors = fg.get_factors();
00040     m_datasources = fg.get_factor_data_sources();
00041     m_dset = fg.get_disjoint_set();
00042     m_has_cycle = !(fg.is_acyclic_graph());
00043     m_num_edges = fg.get_num_edges();
00044 }
00045 
00046 CFactorGraph::~CFactorGraph()
00047 {
00048     SG_UNREF(m_factors);
00049     SG_UNREF(m_datasources);
00050     SG_UNREF(m_dset);
00051 
00052 #ifdef USE_REFERENCE_COUNTING
00053     if (m_factors != NULL)
00054         SG_DEBUG("CFactorGraph::~CFactorGraph(): m_factors->ref_count() = %d.\n", m_factors->ref_count());
00055 
00056     if (m_datasources != NULL)
00057         SG_DEBUG("CFactorGraph::~CFactorGraph(): m_datasources->ref_count() = %d.\n", m_datasources->ref_count());
00058 
00059     SG_DEBUG("CFactorGraph::~CFactorGraph(): this->ref_count() = %d.\n", this->ref_count());
00060 #endif
00061 }
00062 
00063 void CFactorGraph::register_parameters()
00064 {
00065     SG_ADD(&m_cardinalities, "cardinalities", "Cardinalities", MS_NOT_AVAILABLE);
00066     SG_ADD((CSGObject**)&m_factors, "factors", "Factors", MS_NOT_AVAILABLE);
00067     SG_ADD((CSGObject**)&m_datasources, "datasources", "Factor data sources", MS_NOT_AVAILABLE);
00068     SG_ADD((CSGObject**)&m_dset, "dset", "Disjoint set", MS_NOT_AVAILABLE);
00069     SG_ADD(&m_has_cycle, "has_cycle", "Whether has circle in graph", MS_NOT_AVAILABLE);
00070     SG_ADD(&m_num_edges, "num_edges", "Number of edges", MS_NOT_AVAILABLE);
00071 }
00072 
00073 void CFactorGraph::init()
00074 {
00075     m_has_cycle = false;
00076     m_num_edges = 0;
00077     m_factors = NULL;
00078     m_datasources = NULL;
00079     m_factors = new CDynamicObjectArray();
00080     m_datasources = new CDynamicObjectArray();
00081 
00082 #ifdef USE_REFERENCE_COUNTING
00083     if (m_factors != NULL)
00084         SG_DEBUG("CFactorGraph::init(): m_factors->ref_count() = %d.\n", m_factors->ref_count());
00085 #endif
00086 
00087     // NOTE m_cards cannot be empty
00088     m_dset = new CDisjointSet(m_cardinalities.size());
00089 
00090     SG_REF(m_factors);
00091     SG_REF(m_datasources);
00092     SG_REF(m_dset);
00093 }
00094 
00095 void CFactorGraph::add_factor(CFactor* factor)
00096 {
00097     m_factors->push_back(factor);
00098     m_num_edges += factor->get_variables().size();
00099 
00100     // graph structure changed after adding factors
00101     if (m_dset->get_connected())
00102         m_dset->set_connected(false);
00103 }
00104 
00105 void CFactorGraph::add_data_source(CFactorDataSource* datasource)
00106 {
00107     m_datasources->push_back(datasource);
00108 }
00109 
00110 CDynamicObjectArray* CFactorGraph::get_factors() const
00111 {
00112     SG_REF(m_factors);
00113     return m_factors;
00114 }
00115 
00116 CDynamicObjectArray* CFactorGraph::get_factor_data_sources() const
00117 {
00118     SG_REF(m_datasources);
00119     return m_datasources;
00120 }
00121 
00122 int32_t CFactorGraph::get_num_factors() const
00123 {
00124     return m_factors->get_num_elements();
00125 }
00126 
00127 SGVector<int32_t> CFactorGraph::get_cardinalities() const
00128 {
00129     return m_cardinalities;
00130 }
00131 
00132 void CFactorGraph::set_cardinalities(SGVector<int32_t> cards)
00133 {
00134     m_cardinalities = cards.clone();
00135 }
00136 
00137 CDisjointSet* CFactorGraph::get_disjoint_set() const
00138 {
00139     SG_REF(m_dset);
00140     return m_dset;
00141 }
00142 
00143 int32_t CFactorGraph::get_num_edges() const
00144 {
00145     return m_num_edges;
00146 }
00147 
00148 int32_t CFactorGraph::get_num_vars() const
00149 {
00150     return m_cardinalities.size();
00151 }
00152 
00153 void CFactorGraph::compute_energies()
00154 {
00155     for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi)
00156     {
00157         CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi));
00158         fac->compute_energies();
00159         SG_UNREF(fac);
00160     }
00161 }
00162 
00163 float64_t CFactorGraph::evaluate_energy(const SGVector<int32_t> state) const
00164 {
00165     ASSERT(state.size() == m_cardinalities.size());
00166 
00167     float64_t energy = 0.0;
00168     for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi)
00169     {
00170         CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi));
00171         energy += fac->evaluate_energy(state);
00172         SG_UNREF(fac);
00173     }
00174     return energy;
00175 }
00176 
00177 float64_t CFactorGraph::evaluate_energy(const CFactorGraphObservation* obs) const
00178 {
00179     return evaluate_energy(obs->get_data());
00180 }
00181 
00182 SGVector<float64_t> CFactorGraph::evaluate_energies() const
00183 {
00184     int num_assig = 1;
00185     SGVector<int32_t> cumprod_cards(m_cardinalities.size());
00186     for (int32_t n = 0; n < m_cardinalities.size(); ++n)
00187     {
00188         cumprod_cards[n] = num_assig;
00189         num_assig *= m_cardinalities[n];
00190     }
00191 
00192     SGVector<float64_t> etable(num_assig);
00193     for (int32_t ei = 0; ei < num_assig; ++ei)
00194     {
00195         SGVector<int32_t> assig(m_cardinalities.size());
00196         for (int32_t vi = 0; vi < m_cardinalities.size(); ++vi)
00197             assig[vi] = (ei / cumprod_cards[vi]) % m_cardinalities[vi];
00198 
00199         etable[ei] = evaluate_energy(assig);
00200 
00201         for (int32_t vi = 0; vi < m_cardinalities.size(); ++vi)
00202             SG_SPRINT("%d ", assig[vi]);
00203 
00204         SG_SPRINT("| %f\n", etable[ei]);
00205     }
00206 
00207     return etable;
00208 }
00209 
00210 void CFactorGraph::connect_components()
00211 {
00212     if (m_dset->get_connected())
00213         return;
00214 
00215     // need to be reset once factor graph is updated
00216     m_dset->make_sets();
00217     bool flag = false;
00218 
00219     for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi)
00220     {
00221         CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi));
00222         SGVector<int32_t> vars = fac->get_variables();
00223 
00224         int32_t r0 = m_dset->find_set(vars[0]);
00225         for (int32_t vi = 1; vi < vars.size(); vi++)
00226         {
00227             // for two nodes in a factor, should be an edge between them
00228             // but this time link() isn't performed, if they are linked already
00229             // means there is another path connected them, so cycle detected
00230             int32_t ri = m_dset->find_set(vars[vi]);
00231 
00232             if (r0 == ri)
00233             {
00234                 flag = true;
00235                 continue;
00236             }
00237 
00238             r0 = m_dset->link_set(r0, ri);
00239         }
00240 
00241         SG_UNREF(fac);
00242     }
00243     m_has_cycle = flag;
00244     m_dset->set_connected(true);
00245 }
00246 
00247 bool CFactorGraph::is_acyclic_graph() const
00248 {
00249     return !m_has_cycle;
00250 }
00251 
00252 bool CFactorGraph::is_connected_graph() const
00253 {
00254     return (m_dset->get_num_sets() == 1);
00255 }
00256 
00257 bool CFactorGraph::is_tree_graph() const
00258 {
00259     return (m_has_cycle == false && m_dset->get_num_sets() == 1);
00260 }
00261 
00262 void CFactorGraph::loss_augmentation(CFactorGraphObservation* gt)
00263 {
00264     loss_augmentation(gt->get_data(), gt->get_loss_weights());
00265 }
00266 
00267 void CFactorGraph::loss_augmentation(SGVector<int32_t> states_gt, SGVector<float64_t> loss)
00268 {
00269     if (loss.size() == 0)
00270     {
00271         loss.resize_vector(states_gt.size());
00272         SGVector<float64_t>::fill_vector(loss.vector, loss.vlen, 1.0 / states_gt.size());
00273     }
00274 
00275     int32_t num_vars = states_gt.size();
00276     ASSERT(num_vars == loss.size());
00277 
00278     SGVector<int32_t> var_flags(num_vars);
00279     var_flags.zero();
00280 
00281     // augment loss to incorrect states in the first factor containing the variable
00282     // since += L_i for each variable if it takes wrong state ever
00283     // TODO: augment unary factors
00284     for (int32_t fi = 0; fi < m_factors->get_num_elements(); ++fi)
00285     {
00286         CFactor* fac = dynamic_cast<CFactor*>(m_factors->get_element(fi));
00287         SGVector<int32_t> vars = fac->get_variables();
00288         for (int32_t vi = 0; vi < vars.size(); vi++)
00289         {
00290             int32_t vv = vars[vi];
00291             ASSERT(vv < num_vars);
00292             if (var_flags[vv])
00293                 continue;
00294 
00295             SGVector<float64_t> energies = fac->get_energies();
00296             for (int32_t ei = 0; ei < energies.size(); ei++)
00297             {
00298                 CTableFactorType* ftype = fac->get_factor_type();
00299                 int32_t vstate = ftype->state_from_index(ei, vi);
00300                 SG_UNREF(ftype);
00301 
00302                 if (states_gt[vv] == vstate)
00303                     continue;
00304 
00305                 // -delta(y_n, y_i_n)
00306                 fac->set_energy(ei, energies[ei] - loss[vv]);
00307             }
00308 
00309             var_flags[vv] = 1;
00310         }
00311 
00312         SG_UNREF(fac);
00313     }
00314 
00315     // make sure all variables have been checked
00316     int32_t min_var = SGVector<int32_t>::min(var_flags.vector, var_flags.vlen);
00317     ASSERT(min_var == 1);
00318 }
00319 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation