SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Roman Votyakov 00008 * Written (W) 2013 Heiko Strathmann 00009 * Copyright (C) 2012 Jacob Walker 00010 * Copyright (C) 2013 Roman Votyakov 00011 */ 00012 00013 #ifndef CLIKELIHOODMODEL_H_ 00014 #define CLIKELIHOODMODEL_H_ 00015 00016 #include <shogun/base/SGObject.h> 00017 #include <shogun/labels/Labels.h> 00018 00019 namespace shogun 00020 { 00021 00023 enum ELikelihoodModelType 00024 { 00025 LT_NONE=0, 00026 LT_GAUSSIAN=10, 00027 LT_STUDENTST=20, 00028 LT_LOGIT=30, 00029 LT_PROBIT=40 00030 }; 00031 00037 class CLikelihoodModel : public CSGObject 00038 { 00039 public: 00041 CLikelihoodModel(); 00042 00043 virtual ~CLikelihoodModel(); 00044 00072 virtual SGVector<float64_t> get_predictive_log_probabilities( 00073 SGVector<float64_t> mu, SGVector<float64_t> s2, 00074 const CLabels* lab=NULL); 00075 00090 virtual SGVector<float64_t> get_predictive_means(SGVector<float64_t> mu, 00091 SGVector<float64_t> s2, const CLabels* lab=NULL) const=0; 00092 00107 virtual SGVector<float64_t> get_predictive_variances(SGVector<float64_t> mu, 00108 SGVector<float64_t> s2, const CLabels* lab=NULL) const=0; 00109 00114 virtual ELikelihoodModelType get_model_type() const { return LT_NONE; } 00115 00127 virtual SGVector<float64_t> get_log_probability_f(const CLabels* lab, 00128 SGVector<float64_t> func) const=0; 00129 00142 virtual SGVector<float64_t> get_log_probability_fmatrix(const CLabels* lab, 00143 SGMatrix<float64_t> F) const; 00144 00155 virtual SGVector<float64_t> get_log_probability_derivative_f( 00156 const CLabels* lab, SGVector<float64_t> func, index_t i) const=0; 00157 00167 virtual SGVector<float64_t> get_first_derivative(const CLabels* lab, 00168 SGVector<float64_t> func, const TParameter* param) const 00169 { 00170 SG_ERROR("Can't compute derivative wrt %s parameter\n", param->m_name) 00171 return SGVector<float64_t>(); 00172 } 00173 00174 00185 virtual SGVector<float64_t> get_second_derivative(const CLabels* lab, 00186 SGVector<float64_t> func, const TParameter* param) const 00187 { 00188 SG_ERROR("Can't compute derivative wrt %s parameter\n", param->m_name) 00189 return SGVector<float64_t>(); 00190 } 00191 00202 virtual SGVector<float64_t> get_third_derivative(const CLabels* lab, 00203 SGVector<float64_t> func, const TParameter* param) const 00204 { 00205 SG_ERROR("Can't compute derivative wrt %s parameter\n", param->m_name) 00206 return SGVector<float64_t>(); 00207 } 00208 00225 virtual SGVector<float64_t> get_log_zeroth_moments(SGVector<float64_t> mu, 00226 SGVector<float64_t> s2, const CLabels* lab) const=0; 00227 00242 virtual float64_t get_first_moment(SGVector<float64_t> mu, 00243 SGVector<float64_t> s2, const CLabels* lab, index_t i) const=0; 00244 00258 virtual SGVector<float64_t> get_first_moments(SGVector<float64_t> mu, 00259 SGVector<float64_t> s2, const CLabels* lab) const; 00260 00275 virtual float64_t get_second_moment(SGVector<float64_t> mu, 00276 SGVector<float64_t> s2, const CLabels* lab, index_t i) const=0; 00277 00291 virtual SGVector<float64_t> get_second_moments(SGVector<float64_t> mu, 00292 SGVector<float64_t> s2, const CLabels* lab) const; 00293 00298 virtual bool supports_regression() const { return false; } 00299 00304 virtual bool supports_binary() const { return false; } 00305 00310 virtual bool supports_multiclass() const { return false; } 00311 }; 00312 } 00313 #endif /* CLIKELIHOODMODEL_H_ */