SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #ifndef __FACTOR_TYPE_H__ 00012 #define __FACTOR_TYPE_H__ 00013 00014 #include <shogun/base/SGObject.h> 00015 #include <shogun/lib/SGVector.h> 00016 00017 namespace shogun 00018 { 00019 00022 class CFactorType : public CSGObject 00023 { 00024 public: 00026 CFactorType(); 00027 00039 CFactorType(int32_t id, SGVector<int32_t> card, SGVector<float64_t> w); 00040 00042 virtual ~CFactorType(); 00043 00045 virtual const char* get_name() const { return "FactorType"; } 00046 00048 virtual int32_t get_type_id() const; 00049 00054 virtual void set_type_id(int32_t id); 00055 00057 virtual SGVector<float64_t> get_w(); 00058 00060 virtual const SGVector<float64_t> get_w() const; 00061 00066 void set_w(SGVector<float64_t> w); 00067 00069 virtual int32_t get_w_dim() const; 00070 00072 virtual const SGVector<int32_t> get_cardinalities() const; 00073 00078 virtual void set_cardinalities(SGVector<int32_t> cards); 00079 00081 virtual int32_t get_num_vars(); 00082 00084 virtual int32_t get_num_assignments() const; 00085 00087 virtual bool is_table() const { return false; } 00088 00089 protected: 00091 void init_card(); 00092 00093 private: 00095 void init(); 00096 00097 protected: 00099 int32_t m_type_id; 00100 00102 SGVector<int32_t> m_cards; 00103 00105 SGVector<int32_t> m_cumprod_cards; 00106 00108 int32_t m_num_assignments; 00109 00111 int32_t m_data_size; 00112 00114 SGVector<float64_t> m_w; 00115 }; 00116 00120 class CTableFactorType : public CFactorType 00121 { 00122 public: 00124 CTableFactorType(); 00125 00132 CTableFactorType(int32_t id, SGVector<int32_t> card, SGVector<float64_t> w); 00133 00135 virtual ~CTableFactorType(); 00136 00138 virtual const char* get_name() const { return "TableFactorType"; } 00139 00141 virtual bool is_table() const { return true; } 00142 00149 int32_t state_from_index(int32_t ei, int32_t var_index) const; 00150 00156 SGVector<int32_t> assignment_from_index(int32_t ei) const; 00157 00166 int32_t index_from_new_state(int32_t old_ei, int32_t var_index, int32_t var_state) const; 00167 00173 int32_t index_from_assignment(const SGVector<int32_t> assig) const; 00174 00181 int32_t index_from_universe_assignment(const SGVector<int32_t> assig, 00182 const SGVector<int32_t> var_index) const; 00183 00189 virtual void compute_energies(const SGVector<float64_t> factor_data, 00190 SGVector<float64_t>& energies) const; 00191 00197 virtual void compute_energies(const SGSparseVector<float64_t> factor_data_sparse, 00198 SGVector<float64_t>& energies) const; 00199 00207 virtual void compute_gradients(const SGVector<float64_t> factor_data, 00208 const SGVector<float64_t> marginals, 00209 SGVector<float64_t>& parameter_gradient, double mult) const; 00210 00218 virtual void compute_gradients(const SGSparseVector<float64_t> factor_data_sparse, 00219 const SGVector<float64_t> marginals, 00220 SGVector<float64_t>& parameter_gradient, double mult) const; 00221 00222 }; 00223 00224 } 00225 00226 #endif 00227