SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Fernando J. Iglesias Garcia 00008 * Copyright (C) 2013 Fernando J. Iglesias Garcia 00009 */ 00010 00011 #ifndef LMNN_H_ 00012 #define LMNN_H_ 00013 00014 #include <shogun/lib/config.h> 00015 00016 #ifdef HAVE_EIGEN3 00017 #ifdef HAVE_LAPACK 00018 00019 #include <shogun/base/SGObject.h> 00020 #include <shogun/distance/CustomMahalanobisDistance.h> 00021 #include <shogun/features/DenseFeatures.h> 00022 #include <shogun/labels/MulticlassLabels.h> 00023 #include <shogun/lib/SGMatrix.h> 00024 00025 namespace shogun 00026 { 00027 00028 // Forward declaration 00029 class CLMNNStatistics; 00030 00038 class CLMNN : public CSGObject 00039 { 00040 public: 00042 CLMNN(); 00043 00050 CLMNN(CDenseFeatures<float64_t>* features, CMulticlassLabels* labels, int32_t k); 00051 00053 virtual ~CLMNN(); 00054 00056 virtual const char* get_name() const; 00057 00065 void train(SGMatrix<float64_t> init_transform=SGMatrix<float64_t>()); 00066 00071 SGMatrix<float64_t> get_linear_transform() const; 00072 00079 CCustomMahalanobisDistance* get_distance() const; 00080 00085 int32_t get_k() const; 00086 00091 void set_k(const int32_t k); 00092 00097 float64_t get_regularization() const; 00098 00103 void set_regularization(const float64_t regularization); 00104 00109 float64_t get_stepsize() const; 00110 00115 void set_stepsize(const float64_t stepsize); 00116 00121 float64_t get_stepsize_threshold() const; 00122 00127 void set_stepsize_threshold(const float64_t stepsize_threshold); 00128 00133 uint32_t get_maxiter() const; 00134 00139 void set_maxiter(const uint32_t maxiter); 00140 00145 uint32_t get_correction() const; 00146 00151 void set_correction(const uint32_t correction); 00152 00157 float64_t get_obj_threshold() const; 00158 00163 void set_obj_threshold(const float64_t obj_threshold); 00164 00169 bool get_diagonal() const; 00170 00175 void set_diagonal(const bool diagonal); 00176 00181 CLMNNStatistics* get_statistics() const; 00182 00183 private: 00185 void init(); 00186 00187 private: 00189 SGMatrix<float64_t> m_linear_transform; 00190 00192 CFeatures* m_features; 00193 00195 CLabels* m_labels; 00196 00201 float64_t m_regularization; 00202 00204 int32_t m_k; 00205 00210 float64_t m_stepsize; 00211 00217 float64_t m_stepsize_threshold; 00218 00220 uint32_t m_maxiter; 00221 00226 uint32_t m_correction; 00227 00234 float64_t m_obj_threshold; 00235 00240 bool m_diagonal; 00241 00243 CLMNNStatistics* m_statistics; 00244 00245 }; /* class CLMNN */ 00246 00251 class CLMNNStatistics : public CSGObject 00252 { 00253 public: 00255 CLMNNStatistics(); 00256 00258 virtual ~CLMNNStatistics(); 00259 00261 virtual const char* get_name() const; 00262 00269 void resize(int32_t size); 00270 00281 void set(index_t iter, float64_t obj_iter, float64_t stepsize_iter, uint32_t num_impostors_iter); 00282 00283 private: 00285 void init(); 00286 00287 public: 00289 SGVector<float64_t> obj; 00290 00292 SGVector<float64_t> stepsize; 00293 00295 SGVector<uint32_t> num_impostors; 00296 }; 00297 00298 } /* namespace shogun */ 00299 00300 #endif /* HAVE_LAPACK */ 00301 #endif /* HAVE_EIGEN3 */ 00302 00303 #endif /* LMNN_H_ */