SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #ifndef __FACTORGRAPH_H__ 00012 #define __FACTORGRAPH_H__ 00013 00014 #include <shogun/lib/DynamicObjectArray.h> 00015 #include <shogun/lib/SGVector.h> 00016 #include <shogun/structure/Factor.h> 00017 #include <shogun/labels/FactorGraphLabels.h> 00018 #include <shogun/structure/DisjointSet.h> 00019 00020 namespace shogun 00021 { 00022 00025 class CFactorGraph : public CSGObject 00026 { 00027 00028 public: 00029 CFactorGraph(); 00030 00035 CFactorGraph(const SGVector<int32_t> card); 00036 00041 CFactorGraph(const CFactorGraph &fg); 00042 00044 ~CFactorGraph(); 00045 00047 virtual const char* get_name() const { return "FactorGraph"; } 00048 00053 void add_factor(CFactor* factor); 00054 00059 void add_data_source(CFactorDataSource* datasource); 00060 00062 CDynamicObjectArray* get_factors() const; 00063 00065 CDynamicObjectArray* get_factor_data_sources() const; 00066 00068 int32_t get_num_factors() const; 00069 00071 SGVector<int32_t> get_cardinalities() const; 00072 00077 void set_cardinalities(SGVector<int32_t> cards); 00078 00080 void compute_energies(); 00081 00086 float64_t evaluate_energy(const SGVector<int32_t> state) const; 00087 00092 float64_t evaluate_energy(const CFactorGraphObservation* obs) const; 00093 00095 SGVector<float64_t> evaluate_energies() const; 00096 00098 CDisjointSet* get_disjoint_set() const; 00099 00101 int32_t get_num_edges() const; 00102 00104 int32_t get_num_vars() const; 00105 00110 void connect_components(); 00111 00113 bool is_acyclic_graph() const; 00114 00116 bool is_connected_graph() const; 00117 00119 bool is_tree_graph() const; 00120 00125 virtual void loss_augmentation(CFactorGraphObservation* gt); 00126 00132 virtual void loss_augmentation(SGVector<int32_t> states_gt, \ 00133 SGVector<float64_t> loss = SGVector<float64_t>()); 00134 00135 private: 00137 void register_parameters(); 00138 00140 void init(); 00141 00142 protected: 00143 // TODO: FactorNode, VariableNode, such that they have IDs 00144 00146 SGVector<int32_t> m_cardinalities; 00147 00149 CDynamicObjectArray* m_factors; 00150 00152 CDynamicObjectArray* m_datasources; 00153 00155 CDisjointSet* m_dset; 00156 00158 bool m_has_cycle; 00159 00161 int32_t m_num_edges; 00162 00163 }; 00164 00165 } 00166 00167 #endif 00168