SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Roman Votyakov 00008 * Copyright (C) 2012 Jacob Walker 00009 * Copyright (C) 2013 Roman Votyakov 00010 * 00011 * Code adapted from Gaussian Process Machine Learning Toolbox 00012 * http://www.gaussianprocess.org/gpml/code/matlab/doc/ 00013 */ 00014 00015 #ifndef CLAPLACIANINFERENCEMETHOD_H_ 00016 #define CLAPLACIANINFERENCEMETHOD_H_ 00017 00018 #include <shogun/lib/config.h> 00019 00020 #ifdef HAVE_EIGEN3 00021 00022 #include <shogun/machine/gp/InferenceMethod.h> 00023 00024 namespace shogun 00025 { 00026 00042 class CLaplacianInferenceMethod: public CInferenceMethod 00043 { 00044 public: 00046 CLaplacianInferenceMethod(); 00047 00056 CLaplacianInferenceMethod(CKernel* kernel, CFeatures* features, 00057 CMeanFunction* mean, CLabels* labels, CLikelihoodModel* model); 00058 00059 virtual ~CLaplacianInferenceMethod(); 00060 00065 virtual EInferenceType get_inference_type() const { return INF_LAPLACIAN; } 00066 00071 virtual const char* get_name() const { return "LaplacianInferenceMethod"; } 00072 00084 virtual float64_t get_negative_log_marginal_likelihood(); 00085 00096 virtual SGVector<float64_t> get_alpha(); 00097 00109 virtual SGMatrix<float64_t> get_cholesky(); 00110 00122 virtual SGVector<float64_t> get_diagonal_vector(); 00123 00136 virtual SGVector<float64_t> get_posterior_mean(); 00137 00156 virtual SGMatrix<float64_t> get_posterior_covariance(); 00157 00162 virtual float64_t get_newton_tolerance() { return m_tolerance; } 00163 00168 virtual void set_newton_tolerance(float64_t tol) { m_tolerance=tol; } 00169 00174 virtual int32_t get_newton_iterations() { return m_iter; } 00175 00180 virtual void set_newton_iterations(int32_t iter) { m_iter=iter; } 00181 00186 virtual float64_t get_minimization_tolerance() { return m_opt_tolerance; } 00187 00192 virtual void set_minimization_tolerance(float64_t tol) { m_opt_tolerance=tol; } 00193 00198 virtual float64_t get_minimization_max() { return m_opt_max; } 00199 00204 virtual void set_minimization_max(float64_t max) { m_opt_max=max; } 00205 00210 virtual bool supports_regression() const 00211 { 00212 check_members(); 00213 return m_model->supports_regression(); 00214 } 00215 00220 virtual bool supports_binary() const 00221 { 00222 check_members(); 00223 return m_model->supports_binary(); 00224 } 00225 00227 virtual void update(); 00228 00229 protected: 00231 virtual void update_alpha(); 00232 00234 virtual void update_chol(); 00235 00237 virtual void update_approx_cov(); 00238 00242 virtual void update_deriv(); 00243 00251 virtual SGVector<float64_t> get_derivative_wrt_inference_method( 00252 const TParameter* param); 00253 00261 virtual SGVector<float64_t> get_derivative_wrt_likelihood_model( 00262 const TParameter* param); 00263 00271 virtual SGVector<float64_t> get_derivative_wrt_kernel( 00272 const TParameter* param); 00273 00281 virtual SGVector<float64_t> get_derivative_wrt_mean( 00282 const TParameter* param); 00283 00284 private: 00285 void init(); 00286 00287 private: 00289 float64_t m_tolerance; 00290 00292 index_t m_iter; 00293 00295 float64_t m_opt_tolerance; 00296 00298 float64_t m_opt_max; 00299 00301 SGVector<float64_t> m_mu; 00302 00304 SGMatrix<float64_t> m_Sigma; 00305 00307 SGVector<float64_t> W; 00308 00310 SGVector<float64_t> sW; 00311 00313 SGVector<float64_t> dlp; 00314 00316 SGVector<float64_t> d2lp; 00317 00319 SGVector<float64_t> d3lp; 00320 00321 SGVector<float64_t> m_dfhat; 00322 00323 SGMatrix<float64_t> m_Z; 00324 00325 SGVector<float64_t> m_g; 00326 }; 00327 } 00328 #endif /* HAVE_EIGEN3 */ 00329 #endif /* CLAPLACIANINFERENCEMETHOD_H_ */