SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
EPInferenceMethod.h
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Roman Votyakov
00008  *
00009  * Based on ideas from GAUSSIAN PROCESS REGRESSION AND CLASSIFICATION Toolbox
00010  * Copyright (C) 2005-2013 by Carl Edward Rasmussen & Hannes Nickisch under the
00011  * FreeBSD License
00012  * http://www.gaussianprocess.org/gpml/code/matlab/doc/
00013  */
00014 
00015 #ifndef _EPINFERENCEMETHOD_H_
00016 #define _EPINFERENCEMETHOD_H_
00017 
00018 #include <shogun/lib/config.h>
00019 
00020 #ifdef HAVE_EIGEN3
00021 
00022 #include <shogun/machine/gp/InferenceMethod.h>
00023 
00024 namespace shogun
00025 {
00026 
00034 class CEPInferenceMethod : public CInferenceMethod
00035 {
00036 public:
00038     CEPInferenceMethod();
00039 
00048     CEPInferenceMethod(CKernel* kernel, CFeatures* features, CMeanFunction* mean,
00049             CLabels* labels, CLikelihoodModel* model);
00050 
00051     virtual ~CEPInferenceMethod();
00052 
00057     virtual EInferenceType get_inference_type() const { return INF_EP; }
00058 
00063     virtual const char* get_name() const { return "EPInferenceMethod"; }
00064 
00076     virtual float64_t get_negative_log_marginal_likelihood();
00077 
00100     virtual SGVector<float64_t> get_alpha();
00101 
00116     virtual SGMatrix<float64_t> get_cholesky();
00117 
00129     virtual SGVector<float64_t> get_diagonal_vector();
00130 
00151     virtual SGVector<float64_t> get_posterior_mean();
00152 
00172     virtual SGMatrix<float64_t> get_posterior_covariance();
00173 
00178     virtual float64_t get_tolerance() const { return m_tol; }
00179 
00184     virtual void set_tolerance(const float64_t tol) { m_tol=tol; }
00185 
00190     virtual uint32_t get_min_sweep() const { return m_min_sweep; }
00191 
00196     virtual void set_min_sweep(const uint32_t min_sweep) { m_min_sweep=min_sweep; }
00197 
00202     virtual uint32_t get_max_sweep() const { return m_max_sweep; }
00203 
00208     virtual void set_max_sweep(const uint32_t max_sweep) { m_max_sweep=max_sweep; }
00209 
00214     virtual bool supports_binary() const
00215     {
00216         check_members();
00217         return m_model->supports_binary();
00218     }
00219 
00221     virtual void update();
00222 
00223 protected:
00225     virtual void update_alpha();
00226 
00228     virtual void update_chol();
00229 
00231     virtual void update_approx_cov();
00232 
00234     virtual void update_approx_mean();
00235 
00237     virtual void update_negative_ml();
00238 
00242     virtual void update_deriv();
00243 
00251     virtual SGVector<float64_t> get_derivative_wrt_inference_method(
00252             const TParameter* param);
00253 
00261     virtual SGVector<float64_t> get_derivative_wrt_likelihood_model(
00262             const TParameter* param);
00263 
00271     virtual SGVector<float64_t> get_derivative_wrt_kernel(
00272             const TParameter* param);
00273 
00281     virtual SGVector<float64_t> get_derivative_wrt_mean(
00282             const TParameter* param);
00283 
00284 private:
00285     void init();
00286 
00287 private:
00289     SGVector<float64_t> m_mu;
00290 
00292     SGMatrix<float64_t> m_Sigma;
00293 
00295     float64_t m_nlZ;
00296 
00300     SGVector<float64_t> m_tnu;
00301 
00305     SGVector<float64_t> m_ttau;
00306 
00308     SGVector<float64_t> m_sttau;
00309 
00311     float64_t m_tol;
00312 
00314     uint32_t m_min_sweep;
00315 
00317     uint32_t m_max_sweep;
00318 
00319     SGMatrix<float64_t> m_F;
00320 };
00321 }
00322 #endif /* HAVE_EIGEN3 */
00323 #endif /* _EPINFERENCEMETHOD_H_ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation