SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #ifndef __FACTOR_RELATED_H__ 00012 #define __FACTOR_RELATED_H__ 00013 00014 #include <shogun/base/SGObject.h> 00015 #include <shogun/lib/SGVector.h> 00016 #include <shogun/lib/SGSparseVector.h> 00017 #include <shogun/structure/FactorType.h> 00018 00019 namespace shogun 00020 { 00021 00025 class CFactorDataSource : public CSGObject 00026 { 00027 public: 00029 CFactorDataSource(); 00030 00035 CFactorDataSource(SGVector<float64_t> dense); 00036 00041 CFactorDataSource(SGSparseVector<float64_t> sparse); 00042 00044 virtual ~CFactorDataSource(); 00045 00047 virtual const char* get_name() const { return "FactorDataSource"; } 00048 00050 virtual bool is_sparse() const; 00051 00053 virtual SGVector<float64_t> get_data() const; 00054 00056 virtual SGSparseVector<float64_t> get_data_sparse() const; 00057 00062 virtual void set_data(SGVector<float64_t> dense); 00063 00069 virtual void set_data_sparse(SGSparseVectorEntry<float64_t>* sparse, int32_t dlen); 00070 00071 private: 00073 void init(); 00074 00075 private: 00077 SGVector<float64_t> m_dense; 00078 00080 SGSparseVector<float64_t> m_sparse; 00081 }; 00082 00087 class CFactor : public CSGObject 00088 { 00089 public: 00091 CFactor(); 00092 00099 CFactor(CTableFactorType* ftype, SGVector<int32_t> var_index, SGVector<float64_t> data); 00100 00107 CFactor(CTableFactorType* ftype, SGVector<int32_t> var_index, 00108 SGSparseVector<float64_t> data_sparse); 00109 00116 CFactor(CTableFactorType* ftype, SGVector<int32_t> var_index, 00117 CFactorDataSource* data_source); 00118 00120 virtual ~CFactor(); 00121 00123 virtual const char* get_name() const { return "Factor"; } 00124 00126 CTableFactorType* get_factor_type() const; 00127 00132 void set_factor_type(CTableFactorType* ftype); 00133 00135 const SGVector<int32_t> get_variables() const; 00136 00141 void set_variables(SGVector<int32_t> vars); 00142 00144 const SGVector<int32_t> get_cardinalities() const; 00145 00147 SGVector<float64_t> get_data() const; 00148 00150 SGSparseVector<float64_t> get_data_sparse() const; 00151 00156 void set_data(SGVector<float64_t> data_dense); 00157 00163 void set_data_sparse(SGSparseVectorEntry<float64_t>* data_sparse, int32_t dlen); 00164 00166 bool is_data_dependent() const; 00167 00169 bool is_data_sparse() const; 00170 00174 SGVector<float64_t> get_energies() const; 00175 00180 float64_t get_energy(int32_t index) const; 00181 00185 void set_energies(SGVector<float64_t> ft_energies); 00186 00191 void set_energy(int32_t ei, float64_t value); 00192 00197 float64_t evaluate_energy(const SGVector<int32_t> state) const; 00198 00200 void compute_energies(); 00201 00208 void compute_gradients(const SGVector<float64_t> marginals, 00209 SGVector<float64_t>& parameter_gradient, double mult = 1.0) const; 00210 00211 protected: 00213 CTableFactorType* m_factor_type; 00214 00216 SGVector<int32_t> m_var_index; 00217 00219 SGVector<float64_t> m_energies; 00220 00222 CFactorDataSource* m_data_source; 00223 00225 SGVector<float64_t> m_data; 00226 00228 SGSparseVector<float64_t> m_data_sparse; 00229 00231 bool m_is_data_dep; 00232 00233 private: 00235 void init(); 00236 }; 00237 00238 } 00239 00240 #endif 00241