SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Roman Votyakov 00008 * Copyright (C) 2012 Jacob Walker 00009 * Copyright (C) 2013 Roman Votyakov 00010 * 00011 * Code adapted from Gaussian Process Machine Learning Toolbox 00012 * http://www.gaussianprocess.org/gpml/code/matlab/doc/ 00013 */ 00014 00015 #ifndef CFITCINFERENCEMETHOD_H_ 00016 #define CFITCINFERENCEMETHOD_H_ 00017 00018 #include <shogun/lib/config.h> 00019 00020 #ifdef HAVE_EIGEN3 00021 00022 #include <shogun/machine/gp/InferenceMethod.h> 00023 00024 namespace shogun 00025 { 00026 00040 class CFITCInferenceMethod: public CInferenceMethod 00041 { 00042 public: 00044 CFITCInferenceMethod(); 00045 00055 CFITCInferenceMethod(CKernel* kernel, CFeatures* features, 00056 CMeanFunction* mean, CLabels* labels, CLikelihoodModel* model, 00057 CFeatures* latent_features); 00058 00059 virtual ~CFITCInferenceMethod(); 00060 00065 virtual EInferenceType get_inference_type() const { return INF_FITC; } 00066 00071 virtual const char* get_name() const { return "FITCInferenceMethod"; } 00072 00078 static CFITCInferenceMethod* obtain_from_generic(CInferenceMethod* inference); 00079 00084 virtual void set_latent_features(CFeatures* feat) 00085 { 00086 SG_REF(feat); 00087 SG_UNREF(m_latent_features); 00088 m_latent_features=feat; 00089 } 00090 00095 virtual CFeatures* get_latent_features() 00096 { 00097 SG_REF(m_latent_features); 00098 return m_latent_features; 00099 } 00100 00112 virtual float64_t get_negative_log_marginal_likelihood(); 00113 00124 virtual SGVector<float64_t> get_alpha(); 00125 00137 virtual SGMatrix<float64_t> get_cholesky(); 00138 00150 virtual SGVector<float64_t> get_diagonal_vector(); 00151 00167 virtual SGVector<float64_t> get_posterior_mean(); 00168 00184 virtual SGMatrix<float64_t> get_posterior_covariance(); 00185 00190 virtual bool supports_regression() const 00191 { 00192 check_members(); 00193 return m_model->supports_regression(); 00194 } 00195 00197 virtual void update(); 00198 00199 protected: 00201 virtual void check_members() const; 00202 00204 virtual void update_alpha(); 00205 00207 virtual void update_chol(); 00208 00210 virtual void update_train_kernel(); 00211 00215 virtual void update_deriv(); 00216 00224 virtual SGVector<float64_t> get_derivative_wrt_inference_method( 00225 const TParameter* param); 00226 00234 virtual SGVector<float64_t> get_derivative_wrt_likelihood_model( 00235 const TParameter* param); 00236 00244 virtual SGVector<float64_t> get_derivative_wrt_kernel( 00245 const TParameter* param); 00246 00254 virtual SGVector<float64_t> get_derivative_wrt_mean( 00255 const TParameter* param); 00256 00257 private: 00258 void init(); 00259 00260 private: 00262 CFeatures* m_latent_features; 00263 00265 float64_t m_ind_noise; 00266 00268 SGMatrix<float64_t> m_chol_uu; 00269 00271 SGMatrix<float64_t> m_chol_utr; 00272 00274 SGMatrix<float64_t> m_kuu; 00275 00277 SGMatrix<float64_t> m_ktru; 00278 00282 SGVector<float64_t> m_dg; 00283 00285 SGVector<float64_t> m_r; 00286 00288 SGVector<float64_t> m_be; 00289 00290 SGVector<float64_t> m_al; 00291 00292 SGMatrix<float64_t> m_B; 00293 00294 SGVector<float64_t> m_w; 00295 00296 SGMatrix<float64_t> m_W; 00297 }; 00298 } 00299 #endif /* HAVE_EIGEN3 */ 00300 #endif /* CFITCINFERENCEMETHOD_H_ */