SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Shell Hu 00008 * Copyright (C) 2013 Shell Hu 00009 */ 00010 00011 #ifndef __FACTOR_GRAPH_MODEL_H__ 00012 #define __FACTOR_GRAPH_MODEL_H__ 00013 00014 #include <shogun/lib/SGString.h> 00015 #include <shogun/lib/DynamicObjectArray.h> 00016 #include <shogun/structure/StructuredModel.h> 00017 #include <shogun/structure/FactorType.h> 00018 #include <shogun/structure/MAPInference.h> 00019 00020 namespace shogun 00021 { 00022 00030 class CFactorGraphModel : public CStructuredModel 00031 { 00032 public: 00034 CFactorGraphModel(); 00035 00044 CFactorGraphModel(CFeatures* features, CStructuredLabels* labels, 00045 EMAPInferType inf_type = TREE_MAX_PROD, bool verbose = false); 00046 00048 ~CFactorGraphModel(); 00049 00051 virtual const char* get_name() const { return "FactorGraphModel"; } 00052 00059 void add_factor_type(CFactorType* ftype); 00060 00065 void del_factor_type(const int32_t ftype_id); 00066 00068 CDynamicObjectArray* get_factor_types() const; 00069 00074 CFactorType* get_factor_type(const int32_t ftype_id) const; 00075 00077 SGVector<int32_t> get_global_params_mapping() const; 00078 00083 SGVector<int32_t> get_params_mapping(const int32_t ftype_id); 00084 00086 SGVector<float64_t> fparams_to_w(); 00087 00092 void w_to_fparams(SGVector<float64_t> w); 00093 00106 virtual SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, CStructuredData* y); 00107 00121 virtual CResultSet* argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training = true); 00122 00130 virtual float64_t delta_loss(CStructuredData* y1, CStructuredData* y2); 00131 00136 virtual void init_training(); 00137 00149 virtual void init_primal_opt( 00150 float64_t regularization, 00151 SGMatrix< float64_t > & A, SGVector< float64_t > a, 00152 SGMatrix< float64_t > B, SGVector< float64_t > & b, 00153 SGVector< float64_t > lb, SGVector< float64_t > ub, 00154 SGMatrix < float64_t > & C); 00155 00160 virtual int32_t get_dim() const; 00161 00162 private: 00164 void init(); 00165 00166 protected: 00168 CDynamicObjectArray* m_factor_types; 00169 00171 SGVector<int32_t> m_w_map; 00172 00174 SGVector<float64_t> m_w_cache; 00175 00177 EMAPInferType m_inf_type; 00178 00180 bool m_verbose; 00181 }; 00182 00183 } 00184 00185 #endif 00186