SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Roman Votyakov 00008 * 00009 * Based on ideas from GAUSSIAN PROCESS REGRESSION AND CLASSIFICATION Toolbox 00010 * Copyright (C) 2005-2013 by Carl Edward Rasmussen & Hannes Nickisch under the 00011 * FreeBSD License 00012 * http://www.gaussianprocess.org/gpml/code/matlab/doc/ 00013 */ 00014 00015 #ifndef _EPINFERENCEMETHOD_H_ 00016 #define _EPINFERENCEMETHOD_H_ 00017 00018 #include <shogun/lib/config.h> 00019 00020 #ifdef HAVE_EIGEN3 00021 00022 #include <shogun/machine/gp/InferenceMethod.h> 00023 00024 namespace shogun 00025 { 00026 00034 class CEPInferenceMethod : public CInferenceMethod 00035 { 00036 public: 00038 CEPInferenceMethod(); 00039 00048 CEPInferenceMethod(CKernel* kernel, CFeatures* features, CMeanFunction* mean, 00049 CLabels* labels, CLikelihoodModel* model); 00050 00051 virtual ~CEPInferenceMethod(); 00052 00057 virtual EInferenceType get_inference_type() const { return INF_EP; } 00058 00063 virtual const char* get_name() const { return "EPInferenceMethod"; } 00064 00076 virtual float64_t get_negative_log_marginal_likelihood(); 00077 00100 virtual SGVector<float64_t> get_alpha(); 00101 00116 virtual SGMatrix<float64_t> get_cholesky(); 00117 00129 virtual SGVector<float64_t> get_diagonal_vector(); 00130 00151 virtual SGVector<float64_t> get_posterior_mean(); 00152 00172 virtual SGMatrix<float64_t> get_posterior_covariance(); 00173 00178 virtual float64_t get_tolerance() const { return m_tol; } 00179 00184 virtual void set_tolerance(const float64_t tol) { m_tol=tol; } 00185 00190 virtual uint32_t get_min_sweep() const { return m_min_sweep; } 00191 00196 virtual void set_min_sweep(const uint32_t min_sweep) { m_min_sweep=min_sweep; } 00197 00202 virtual uint32_t get_max_sweep() const { return m_max_sweep; } 00203 00208 virtual void set_max_sweep(const uint32_t max_sweep) { m_max_sweep=max_sweep; } 00209 00214 virtual bool supports_binary() const 00215 { 00216 check_members(); 00217 return m_model->supports_binary(); 00218 } 00219 00221 virtual void update(); 00222 00223 protected: 00225 virtual void update_alpha(); 00226 00228 virtual void update_chol(); 00229 00231 virtual void update_approx_cov(); 00232 00234 virtual void update_approx_mean(); 00235 00237 virtual void update_negative_ml(); 00238 00242 virtual void update_deriv(); 00243 00251 virtual SGVector<float64_t> get_derivative_wrt_inference_method( 00252 const TParameter* param); 00253 00261 virtual SGVector<float64_t> get_derivative_wrt_likelihood_model( 00262 const TParameter* param); 00263 00271 virtual SGVector<float64_t> get_derivative_wrt_kernel( 00272 const TParameter* param); 00273 00281 virtual SGVector<float64_t> get_derivative_wrt_mean( 00282 const TParameter* param); 00283 00284 private: 00285 void init(); 00286 00287 private: 00289 SGVector<float64_t> m_mu; 00290 00292 SGMatrix<float64_t> m_Sigma; 00293 00295 float64_t m_nlZ; 00296 00300 SGVector<float64_t> m_tnu; 00301 00305 SGVector<float64_t> m_ttau; 00306 00308 SGVector<float64_t> m_sttau; 00309 00311 float64_t m_tol; 00312 00314 uint32_t m_min_sweep; 00315 00317 uint32_t m_max_sweep; 00318 00319 SGMatrix<float64_t> m_F; 00320 }; 00321 } 00322 #endif /* HAVE_EIGEN3 */ 00323 #endif /* _EPINFERENCEMETHOD_H_ */