SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
LMNN.h
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Fernando J. Iglesias Garcia
00008  * Copyright (C) 2013 Fernando J. Iglesias Garcia
00009  */
00010 
00011 #ifndef LMNN_H_
00012 #define LMNN_H_
00013 
00014 #include <shogun/lib/config.h>
00015 
00016 #ifdef HAVE_EIGEN3
00017 #ifdef HAVE_LAPACK
00018 
00019 #include <shogun/base/SGObject.h>
00020 #include <shogun/distance/CustomMahalanobisDistance.h>
00021 #include <shogun/features/DenseFeatures.h>
00022 #include <shogun/labels/MulticlassLabels.h>
00023 #include <shogun/lib/SGMatrix.h>
00024 
00025 namespace shogun
00026 {
00027 
00028 // Forward declaration
00029 class CLMNNStatistics;
00030 
00038 class CLMNN : public CSGObject
00039 {
00040     public:
00042         CLMNN();
00043 
00050         CLMNN(CDenseFeatures<float64_t>* features, CMulticlassLabels* labels, int32_t k);
00051 
00053         virtual ~CLMNN();
00054 
00056         virtual const char* get_name() const;
00057 
00065         void train(SGMatrix<float64_t> init_transform=SGMatrix<float64_t>());
00066 
00071         SGMatrix<float64_t> get_linear_transform() const;
00072 
00079         CCustomMahalanobisDistance* get_distance() const;
00080 
00085         int32_t get_k() const;
00086 
00091         void set_k(const int32_t k);
00092 
00097         float64_t get_regularization() const;
00098 
00103         void set_regularization(const float64_t regularization);
00104 
00109         float64_t get_stepsize() const;
00110 
00115         void set_stepsize(const float64_t stepsize);
00116 
00121         float64_t get_stepsize_threshold() const;
00122 
00127         void set_stepsize_threshold(const float64_t stepsize_threshold);
00128 
00133         uint32_t get_maxiter() const;
00134 
00139         void set_maxiter(const uint32_t maxiter);
00140 
00145         uint32_t get_correction() const;
00146 
00151         void set_correction(const uint32_t correction);
00152 
00157         float64_t get_obj_threshold() const;
00158 
00163         void set_obj_threshold(const float64_t obj_threshold);
00164 
00169         bool get_diagonal() const;
00170 
00175         void set_diagonal(const bool diagonal);
00176 
00181         CLMNNStatistics* get_statistics() const;
00182 
00183     private:
00185         void init();
00186 
00187     private:
00189         SGMatrix<float64_t> m_linear_transform;
00190 
00192         CFeatures* m_features;
00193 
00195         CLabels* m_labels;
00196 
00201         float64_t m_regularization;
00202 
00204         int32_t m_k;
00205 
00210         float64_t m_stepsize;
00211 
00217         float64_t m_stepsize_threshold;
00218 
00220         uint32_t m_maxiter;
00221 
00226         uint32_t m_correction;
00227 
00234         float64_t m_obj_threshold;
00235 
00240         bool m_diagonal;
00241 
00243         CLMNNStatistics* m_statistics;
00244 
00245 }; /* class CLMNN */
00246 
00251 class CLMNNStatistics : public CSGObject
00252 {
00253     public:
00255         CLMNNStatistics();
00256 
00258         virtual ~CLMNNStatistics();
00259 
00261         virtual const char* get_name() const;
00262 
00269         void resize(int32_t size);
00270 
00281         void set(index_t iter, float64_t obj_iter, float64_t stepsize_iter, uint32_t num_impostors_iter);
00282 
00283     private:
00285         void init();
00286 
00287     public:
00289         SGVector<float64_t> obj;
00290 
00292         SGVector<float64_t> stepsize;
00293 
00295         SGVector<uint32_t> num_impostors;
00296 };
00297 
00298 } /* namespace shogun */
00299 
00300 #endif /* HAVE_LAPACK */
00301 #endif /* HAVE_EIGEN3 */
00302 
00303 #endif /* LMNN_H_ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation