SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Roman Votyakov 00008 * Copyright (C) 2012 Jacob Walker 00009 * Copyright (C) 2013 Roman Votyakov 00010 */ 00011 00012 #ifndef CEXACTINFERENCEMETHOD_H_ 00013 #define CEXACTINFERENCEMETHOD_H_ 00014 00015 #include <shogun/lib/config.h> 00016 00017 #ifdef HAVE_EIGEN3 00018 00019 #include <shogun/machine/gp/InferenceMethod.h> 00020 00021 namespace shogun 00022 { 00023 00047 class CExactInferenceMethod: public CInferenceMethod 00048 { 00049 public: 00051 CExactInferenceMethod(); 00052 00061 CExactInferenceMethod(CKernel* kernel, CFeatures* features, 00062 CMeanFunction* mean, CLabels* labels, CLikelihoodModel* model); 00063 00064 virtual ~CExactInferenceMethod(); 00065 00070 virtual EInferenceType get_inference_type() const { return INF_EXACT; } 00071 00076 virtual const char* get_name() const { return "ExactInferenceMethod"; } 00077 00089 virtual float64_t get_negative_log_marginal_likelihood(); 00090 00101 virtual SGVector<float64_t> get_alpha(); 00102 00114 virtual SGMatrix<float64_t> get_cholesky(); 00115 00127 virtual SGVector<float64_t> get_diagonal_vector(); 00128 00138 virtual SGVector<float64_t> get_posterior_mean(); 00139 00149 virtual SGMatrix<float64_t> get_posterior_covariance(); 00150 00155 virtual bool supports_regression() const 00156 { 00157 check_members(); 00158 return m_model->supports_regression(); 00159 } 00160 00162 virtual void update(); 00163 00164 protected: 00166 virtual void check_members() const; 00167 00169 virtual void update_alpha(); 00170 00172 virtual void update_chol(); 00173 00175 virtual void update_mean(); 00176 00178 virtual void update_cov(); 00179 00183 virtual void update_deriv(); 00184 00192 virtual SGVector<float64_t> get_derivative_wrt_inference_method( 00193 const TParameter* param); 00194 00202 virtual SGVector<float64_t> get_derivative_wrt_likelihood_model( 00203 const TParameter* param); 00204 00212 virtual SGVector<float64_t> get_derivative_wrt_kernel( 00213 const TParameter* param); 00214 00222 virtual SGVector<float64_t> get_derivative_wrt_mean( 00223 const TParameter* param); 00224 00225 private: 00227 SGMatrix<float64_t> m_Sigma; 00228 00230 SGVector<float64_t> m_mu; 00231 00232 SGMatrix<float64_t> m_Q; 00233 }; 00234 } 00235 #endif /* HAVE_EIGEN3 */ 00236 #endif /* CEXACTINFERENCEMETHOD_H_ */