SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Kevin Hughes 00008 * Copyright (C) 2013 Kevin Hughes 00009 * 00010 * Thanks to Fernando José Iglesias García (shogun) 00011 * and Matthieu Perrot (scikit-learn) 00012 */ 00013 00014 #ifndef _MCLDA_H__ 00015 #define _MCLDA_H__ 00016 00017 #include <shogun/lib/config.h> 00018 00019 #ifdef HAVE_EIGEN3 00020 00021 #include <shogun/features/DotFeatures.h> 00022 #include <shogun/features/DenseFeatures.h> 00023 #include <shogun/machine/NativeMulticlassMachine.h> 00024 #include <shogun/lib/SGNDArray.h> 00025 00026 namespace shogun 00027 { 00028 00029 //#define DEBUG_MCLDA 00030 00039 class CMCLDA : public CNativeMulticlassMachine 00040 { 00041 public: 00042 MACHINE_PROBLEM_TYPE(PT_MULTICLASS) 00043 00044 00049 CMCLDA(float64_t tolerance = 1e-4, bool store_cov = false); 00050 00058 CMCLDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance = 1e-4, bool store_cov = false); 00059 00060 virtual ~CMCLDA(); 00061 00067 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00068 00073 inline void set_tolerance(float64_t tolerance) { m_tolerance = tolerance; } 00074 00079 inline bool get_tolerance() { return m_tolerance; } 00080 00085 virtual EMachineType get_classifier_type() { return CT_LDA; } // for now add to machine typers properly later 00086 00091 virtual void set_features(CDotFeatures* feat) 00092 { 00093 if (feat->get_feature_class() != C_DENSE || 00094 feat->get_feature_type() != F_DREAL) 00095 SG_ERROR("MCLDA requires SIMPLE REAL valued features\n") 00096 00097 SG_REF(feat); 00098 SG_UNREF(m_features); 00099 m_features = feat; 00100 } 00101 00106 virtual CDotFeatures* get_features() { SG_REF(m_features); return m_features; } 00107 00112 virtual const char* get_name() const { return "MCLDA"; } 00113 00120 inline SGVector< float64_t > get_mean(int32_t c) const 00121 { 00122 return SGVector< float64_t >(m_means.get_column_vector(c), m_dim, false); 00123 } 00124 00129 inline SGMatrix< float64_t > get_cov() const 00130 { 00131 return m_cov; 00132 } 00133 00134 protected: 00141 virtual bool train_machine(CFeatures* data = NULL); 00142 00143 private: 00144 void init(); 00145 00146 void cleanup(); 00147 00148 private: 00150 CDotFeatures* m_features; 00151 00153 float64_t m_tolerance; 00154 00156 bool m_store_cov; 00157 00159 int32_t m_num_classes; 00160 00162 int32_t m_dim; 00163 00167 SGMatrix< float64_t > m_cov; 00168 00170 SGMatrix< float64_t > m_means; 00171 00173 SGVector< float64_t > m_xbar; 00174 00176 int32_t m_rank; 00177 00179 SGMatrix< float64_t > m_scalings; 00180 00182 SGMatrix< float64_t > m_coef; 00183 00185 SGVector< float64_t > m_intercept; 00186 00187 }; /* class MCLDA */ 00188 } /* namespace shogun */ 00189 00190 #endif /* HAVE_EIGEN3 */ 00191 #endif /* _MCLDA_H__ */