SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Roman Votyakov 00008 * Written (W) 2013 Heiko Strathmann 00009 * Copyright (C) 2012 Jacob Walker 00010 * Copyright (C) 2013 Roman Votyakov 00011 */ 00012 00013 #ifndef CINFERENCEMETHOD_H_ 00014 #define CINFERENCEMETHOD_H_ 00015 00016 #include <shogun/lib/config.h> 00017 00018 #ifdef HAVE_EIGEN3 00019 00020 #include <shogun/base/SGObject.h> 00021 #include <shogun/kernel/Kernel.h> 00022 #include <shogun/features/Features.h> 00023 #include <shogun/labels/Labels.h> 00024 #include <shogun/machine/gp/LikelihoodModel.h> 00025 #include <shogun/machine/gp/MeanFunction.h> 00026 #include <shogun/evaluation/DifferentiableFunction.h> 00027 00028 namespace shogun 00029 { 00030 00032 enum EInferenceType 00033 { 00034 INF_NONE=0, 00035 INF_EXACT=10, 00036 INF_FITC=20, 00037 INF_LAPLACIAN=30, 00038 INF_EP=40 00039 }; 00040 00049 class CInferenceMethod : public CDifferentiableFunction 00050 { 00051 public: 00053 CInferenceMethod(); 00054 00063 CInferenceMethod(CKernel* kernel, CFeatures* features, 00064 CMeanFunction* mean, CLabels* labels, CLikelihoodModel* model); 00065 00066 virtual ~CInferenceMethod(); 00067 00072 virtual EInferenceType get_inference_type() const { return INF_NONE; } 00073 00085 virtual float64_t get_negative_log_marginal_likelihood()=0; 00086 00122 float64_t get_marginal_likelihood_estimate(int32_t num_importance_samples=1, 00123 float64_t ridge_size=1e-15); 00124 00138 virtual CMap<TParameter*, SGVector<float64_t> >* get_negative_log_marginal_likelihood_derivatives( 00139 CMap<TParameter*, CSGObject*>* parameters); 00140 00151 virtual SGVector<float64_t> get_alpha()=0; 00152 00164 virtual SGMatrix<float64_t> get_cholesky()=0; 00165 00177 virtual SGVector<float64_t> get_diagonal_vector()=0; 00178 00194 virtual SGVector<float64_t> get_posterior_mean()=0; 00195 00211 virtual SGMatrix<float64_t> get_posterior_covariance()=0; 00212 00220 virtual CMap<TParameter*, SGVector<float64_t> >* get_gradient( 00221 CMap<TParameter*, CSGObject*>* parameters) 00222 { 00223 return get_negative_log_marginal_likelihood_derivatives(parameters); 00224 } 00225 00230 virtual SGVector<float64_t> get_value() 00231 { 00232 SGVector<float64_t> result(1); 00233 result[0]=get_negative_log_marginal_likelihood(); 00234 return result; 00235 } 00236 00241 virtual CFeatures* get_features() { SG_REF(m_features); return m_features; } 00242 00247 virtual void set_features(CFeatures* feat) 00248 { 00249 SG_REF(feat); 00250 SG_UNREF(m_features); 00251 m_features=feat; 00252 } 00253 00258 virtual CKernel* get_kernel() { SG_REF(m_kernel); return m_kernel; } 00259 00264 virtual void set_kernel(CKernel* kern) 00265 { 00266 SG_REF(kern); 00267 SG_UNREF(m_kernel); 00268 m_kernel=kern; 00269 } 00270 00275 virtual CMeanFunction* get_mean() { SG_REF(m_mean); return m_mean; } 00276 00281 virtual void set_mean(CMeanFunction* m) 00282 { 00283 SG_REF(m); 00284 SG_UNREF(m_mean); 00285 m_mean=m; 00286 } 00287 00292 virtual CLabels* get_labels() { SG_REF(m_labels); return m_labels; } 00293 00298 virtual void set_labels(CLabels* lab) 00299 { 00300 SG_REF(lab); 00301 SG_UNREF(m_labels); 00302 m_labels=lab; 00303 } 00304 00309 CLikelihoodModel* get_model() { SG_REF(m_model); return m_model; } 00310 00315 virtual void set_model(CLikelihoodModel* mod) 00316 { 00317 SG_REF(mod); 00318 SG_UNREF(m_model); 00319 m_model=mod; 00320 } 00321 00326 virtual float64_t get_scale() const { return m_scale; } 00327 00332 virtual void set_scale(float64_t scale) { m_scale=scale; } 00333 00339 virtual bool supports_regression() const { return false; } 00340 00346 virtual bool supports_binary() const { return false; } 00347 00353 virtual bool supports_multiclass() const { return false; } 00354 00356 virtual void update(); 00357 00358 protected: 00360 virtual void check_members() const; 00361 00363 virtual void update_alpha()=0; 00364 00366 virtual void update_chol()=0; 00367 00371 virtual void update_deriv()=0; 00372 00374 virtual void update_train_kernel(); 00375 00383 virtual SGVector<float64_t> get_derivative_wrt_inference_method( 00384 const TParameter* param)=0; 00385 00393 virtual SGVector<float64_t> get_derivative_wrt_likelihood_model( 00394 const TParameter* param)=0; 00395 00403 virtual SGVector<float64_t> get_derivative_wrt_kernel( 00404 const TParameter* param)=0; 00405 00413 virtual SGVector<float64_t> get_derivative_wrt_mean( 00414 const TParameter* param)=0; 00415 00419 static void* get_derivative_helper(void* p); 00420 00421 private: 00422 void init(); 00423 00424 protected: 00426 CKernel* m_kernel; 00427 00429 CMeanFunction* m_mean; 00430 00432 CLikelihoodModel* m_model; 00433 00435 CFeatures* m_features; 00436 00438 CLabels* m_labels; 00439 00441 SGVector<float64_t> m_alpha; 00442 00444 SGMatrix<float64_t> m_L; 00445 00447 float64_t m_scale; 00448 00450 SGMatrix<float64_t> m_ktrtr; 00451 }; 00452 } 00453 #endif /* HAVE_EIGEN3 */ 00454 #endif /* CINFERENCEMETHOD_H_ */