SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _LINEARHMM_H__ 00013 #define _LINEARHMM_H__ 00014 00015 #include <shogun/features/StringFeatures.h> 00016 #include <shogun/labels/Labels.h> 00017 #include <shogun/distributions/Distribution.h> 00018 00019 namespace shogun 00020 { 00039 class CLinearHMM : public CDistribution 00040 { 00041 public: 00043 CLinearHMM(); 00044 00049 CLinearHMM(CStringFeatures<uint16_t>* f); 00050 00056 CLinearHMM(int32_t p_num_features, int32_t p_num_symbols); 00057 00058 virtual ~CLinearHMM(); 00059 00068 virtual bool train(CFeatures* data=NULL); 00069 00077 bool train( 00078 const int32_t* indizes, int32_t num_indizes, 00079 float64_t pseudo_count); 00080 00087 float64_t get_log_likelihood_example(uint16_t* vector, int32_t len); 00088 00095 float64_t get_likelihood_example(uint16_t* vector, int32_t len); 00096 00102 float64_t get_likelihood_example(int32_t num_example); 00103 00109 virtual float64_t get_log_likelihood_example(int32_t num_example); 00110 00117 virtual float64_t get_log_derivative( 00118 int32_t num_param, int32_t num_example); 00119 00126 virtual float64_t get_log_derivative_obsolete( 00127 uint16_t obs, int32_t pos) 00128 { 00129 return 1.0/transition_probs[pos*num_symbols+obs]; 00130 } 00131 00138 virtual float64_t get_derivative_obsolete( 00139 uint16_t* vector, int32_t len, int32_t pos) 00140 { 00141 ASSERT(pos<len) 00142 return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]]; 00143 } 00144 00149 virtual int32_t get_sequence_length() { return sequence_length; } 00150 00155 virtual int32_t get_num_symbols() { return num_symbols; } 00156 00161 virtual int32_t get_num_model_parameters() { return num_params; } 00162 00169 virtual float64_t get_positional_log_parameter( 00170 uint16_t obs, int32_t position) 00171 { 00172 return log_transition_probs[position*num_symbols+obs]; 00173 } 00174 00180 virtual float64_t get_log_model_parameter(int32_t num_param) 00181 { 00182 ASSERT(log_transition_probs) 00183 ASSERT(num_param<num_params) 00184 00185 return log_transition_probs[num_param]; 00186 } 00187 00192 virtual SGVector<float64_t> get_log_transition_probs(); 00193 00199 virtual bool set_log_transition_probs(const SGVector<float64_t> probs); 00200 00205 virtual SGVector<float64_t> get_transition_probs(); 00206 00212 virtual bool set_transition_probs(const SGVector<float64_t> probs); 00213 00215 virtual const char* get_name() const { return "LinearHMM"; } 00216 00217 protected: 00218 virtual void load_serializable_post() throw (ShogunException); 00219 00220 private: 00221 void init(); 00222 00223 protected: 00225 int32_t sequence_length; 00227 int32_t num_symbols; 00229 int32_t num_params; 00231 float64_t* transition_probs; 00233 float64_t* log_transition_probs; 00234 }; 00235 } 00236 #endif