SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MCLDA.h
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Kevin Hughes
00008  * Copyright (C) 2013 Kevin Hughes
00009  *
00010  * Thanks to Fernando José Iglesias García (shogun)
00011  *           and Matthieu Perrot (scikit-learn)
00012  */
00013 
00014 #ifndef _MCLDA_H__
00015 #define _MCLDA_H__
00016 
00017 #include <shogun/lib/config.h>
00018 
00019 #ifdef HAVE_EIGEN3
00020 
00021 #include <shogun/features/DotFeatures.h>
00022 #include <shogun/features/DenseFeatures.h>
00023 #include <shogun/machine/NativeMulticlassMachine.h>
00024 #include <shogun/lib/SGNDArray.h>
00025 
00026 namespace shogun
00027 {
00028 
00029 //#define DEBUG_MCLDA
00030 
00039 class CMCLDA : public CNativeMulticlassMachine
00040 {
00041     public:
00042         MACHINE_PROBLEM_TYPE(PT_MULTICLASS)
00043 
00044         
00049         CMCLDA(float64_t tolerance = 1e-4, bool store_cov = false);
00050 
00058         CMCLDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance = 1e-4, bool store_cov = false);
00059 
00060         virtual ~CMCLDA();
00061 
00067         virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
00068 
00073         inline void set_tolerance(float64_t tolerance) { m_tolerance = tolerance; }
00074 
00079         inline bool get_tolerance() { return m_tolerance; }
00080 
00085         virtual EMachineType get_classifier_type() { return CT_LDA; } // for now add to machine typers properly later
00086 
00091         virtual void set_features(CDotFeatures* feat)
00092         {
00093             if (feat->get_feature_class() != C_DENSE ||
00094                 feat->get_feature_type() != F_DREAL)
00095                 SG_ERROR("MCLDA requires SIMPLE REAL valued features\n")
00096 
00097             SG_REF(feat);
00098             SG_UNREF(m_features);
00099             m_features = feat;
00100         }
00101 
00106         virtual CDotFeatures* get_features() { SG_REF(m_features); return m_features; }
00107 
00112         virtual const char* get_name() const { return "MCLDA"; }
00113 
00120         inline SGVector< float64_t > get_mean(int32_t c) const
00121         {
00122             return SGVector< float64_t >(m_means.get_column_vector(c), m_dim, false);
00123         }
00124 
00129         inline SGMatrix< float64_t > get_cov() const
00130         {
00131             return m_cov;
00132         }
00133 
00134         protected:
00141         virtual bool train_machine(CFeatures* data = NULL);
00142 
00143     private:
00144         void init();
00145 
00146         void cleanup();
00147 
00148         private:
00150         CDotFeatures* m_features;
00151 
00153         float64_t m_tolerance;
00154 
00156         bool m_store_cov;
00157 
00159         int32_t m_num_classes;
00160 
00162         int32_t m_dim;
00163 
00167         SGMatrix< float64_t > m_cov;
00168 
00170         SGMatrix< float64_t > m_means;
00171 
00173         SGVector< float64_t > m_xbar;
00174 
00176         int32_t m_rank;
00177 
00179         SGMatrix< float64_t > m_scalings;
00180 
00182         SGMatrix< float64_t > m_coef;
00183 
00185         SGVector< float64_t > m_intercept;
00186 
00187 }; /* class MCLDA */
00188 }  /* namespace shogun */
00189 
00190 #endif /* HAVE_EIGEN3 */
00191 #endif /* _MCLDA_H__ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation