SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 2011-2012 Heiko Strathmann 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _MACHINE_H__ 00013 #define _MACHINE_H__ 00014 00015 #include <shogun/lib/common.h> 00016 #include <shogun/base/SGObject.h> 00017 #include <shogun/labels/Labels.h> 00018 #include <shogun/labels/BinaryLabels.h> 00019 #include <shogun/labels/RegressionLabels.h> 00020 #include <shogun/labels/MulticlassLabels.h> 00021 #include <shogun/labels/StructuredLabels.h> 00022 #include <shogun/labels/LatentLabels.h> 00023 #include <shogun/features/Features.h> 00024 00025 namespace shogun 00026 { 00027 00028 class CFeatures; 00029 class CLabels; 00030 class CMath; 00031 00033 enum EMachineType 00034 { 00035 CT_NONE = 0, 00036 CT_LIGHT = 10, 00037 CT_LIGHTONECLASS = 11, 00038 CT_LIBSVM = 20, 00039 CT_LIBSVMONECLASS=30, 00040 CT_LIBSVMMULTICLASS=40, 00041 CT_MPD = 50, 00042 CT_GPBT = 60, 00043 CT_CPLEXSVM = 70, 00044 CT_PERCEPTRON = 80, 00045 CT_KERNELPERCEPTRON = 90, 00046 CT_LDA = 100, 00047 CT_LPM = 110, 00048 CT_LPBOOST = 120, 00049 CT_KNN = 130, 00050 CT_SVMLIN=140, 00051 CT_KERNELRIDGEREGRESSION = 150, 00052 CT_GNPPSVM = 160, 00053 CT_GMNPSVM = 170, 00054 CT_SVMPERF = 200, 00055 CT_LIBSVR = 210, 00056 CT_SVRLIGHT = 220, 00057 CT_LIBLINEAR = 230, 00058 CT_KMEANS = 240, 00059 CT_HIERARCHICAL = 250, 00060 CT_SVMOCAS = 260, 00061 CT_WDSVMOCAS = 270, 00062 CT_SVMSGD = 280, 00063 CT_MKLMULTICLASS = 290, 00064 CT_MKLCLASSIFICATION = 300, 00065 CT_MKLONECLASS = 310, 00066 CT_MKLREGRESSION = 320, 00067 CT_SCATTERSVM = 330, 00068 CT_DASVM = 340, 00069 CT_LARANK = 350, 00070 CT_DASVMLINEAR = 360, 00071 CT_GAUSSIANNAIVEBAYES = 370, 00072 CT_AVERAGEDPERCEPTRON = 380, 00073 CT_SGDQN = 390, 00074 CT_CONJUGATEINDEX = 400, 00075 CT_LINEARRIDGEREGRESSION = 410, 00076 CT_LEASTSQUARESREGRESSION = 420, 00077 CT_QDA = 430, 00078 CT_NEWTONSVM = 440, 00079 CT_GAUSSIANPROCESSREGRESSION = 450, 00080 CT_LARS = 460, 00081 CT_MULTICLASS = 470, 00082 CT_DIRECTORLINEAR = 480, 00083 CT_DIRECTORKERNEL = 490, 00084 CT_LIBQPSOSVM = 500, 00085 CT_PRIMALMOSEKSOSVM = 510, 00086 CT_CCSOSVM = 520, 00087 CT_GAUSSIANPROCESSBINARY = 530, 00088 CT_GAUSSIANPROCESSMULTICLASS = 540, 00089 CT_STOCHASTICSOSVM = 550, 00090 CT_BAGGING 00091 }; 00092 00094 enum ESolverType 00095 { 00096 ST_AUTO=0, 00097 ST_CPLEX=1, 00098 ST_GLPK=2, 00099 ST_NEWTON=3, 00100 ST_DIRECT=4, 00101 ST_ELASTICNET=5, 00102 ST_BLOCK_NORM=6 00103 }; 00104 00106 enum EProblemType 00107 { 00108 PT_BINARY = 0, 00109 PT_REGRESSION = 1, 00110 PT_MULTICLASS = 2, 00111 PT_STRUCTURED = 3, 00112 PT_LATENT = 4 00113 }; 00114 00115 #define MACHINE_PROBLEM_TYPE(PT) \ 00116 \ 00119 virtual EProblemType get_machine_problem_type() const { return PT; } 00120 00138 class CMachine : public CSGObject 00139 { 00140 public: 00142 CMachine(); 00143 00145 virtual ~CMachine(); 00146 00156 virtual bool train(CFeatures* data=NULL); 00157 00164 virtual CLabels* apply(CFeatures* data=NULL); 00165 00167 virtual CBinaryLabels* apply_binary(CFeatures* data=NULL); 00169 virtual CRegressionLabels* apply_regression(CFeatures* data=NULL); 00171 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00173 virtual CStructuredLabels* apply_structured(CFeatures* data=NULL); 00175 virtual CLatentLabels* apply_latent(CFeatures* data=NULL); 00176 00181 virtual void set_labels(CLabels* lab); 00182 00187 virtual CLabels* get_labels(); 00188 00193 void set_max_train_time(float64_t t); 00194 00199 float64_t get_max_train_time(); 00200 00205 virtual EMachineType get_classifier_type(); 00206 00211 void set_solver_type(ESolverType st); 00212 00217 ESolverType get_solver_type(); 00218 00224 virtual void set_store_model_features(bool store_model); 00225 00234 virtual bool train_locked(SGVector<index_t> indices) 00235 { 00236 SG_ERROR("train_locked(SGVector<index_t>) is not yet implemented " 00237 "for %s\n", get_name()); 00238 return false; 00239 } 00240 00242 virtual float64_t apply_one(int32_t i) 00243 { 00244 SG_NOTIMPLEMENTED 00245 return 0.0; 00246 } 00247 00253 virtual CLabels* apply_locked(SGVector<index_t> indices); 00254 00256 virtual CBinaryLabels* apply_locked_binary( 00257 SGVector<index_t> indices); 00259 virtual CRegressionLabels* apply_locked_regression( 00260 SGVector<index_t> indices); 00262 virtual CMulticlassLabels* apply_locked_multiclass( 00263 SGVector<index_t> indices); 00265 virtual CStructuredLabels* apply_locked_structured( 00266 SGVector<index_t> indices); 00268 virtual CLatentLabels* apply_locked_latent( 00269 SGVector<index_t> indices); 00270 00279 virtual void data_lock(CLabels* labs, CFeatures* features); 00280 00282 virtual void post_lock(CLabels* labs, CFeatures* features) { }; 00283 00285 virtual void data_unlock(); 00286 00288 virtual bool supports_locking() const { return false; } 00289 00291 bool is_data_locked() const { return m_data_locked; } 00292 00294 virtual EProblemType get_machine_problem_type() const 00295 { 00296 SG_NOTIMPLEMENTED 00297 return PT_BINARY; 00298 } 00299 00300 virtual const char* get_name() const { return "Machine"; } 00301 00302 protected: 00313 virtual bool train_machine(CFeatures* data=NULL) 00314 { 00315 SG_ERROR("train_machine is not yet implemented for %s!\n", 00316 get_name()); 00317 return false; 00318 } 00319 00330 virtual void store_model_features() 00331 { 00332 SG_ERROR("Model storage and therefore unlocked Cross-Validation and" 00333 " Model-Selection is not supported for %s. Locked may" 00334 " work though.\n", get_name()); 00335 } 00336 00343 virtual bool is_label_valid(CLabels *lab) const 00344 { 00345 return true; 00346 } 00347 00349 virtual bool train_require_labels() const { return true; } 00350 00351 protected: 00353 float64_t m_max_train_time; 00354 00356 CLabels* m_labels; 00357 00359 ESolverType m_solver_type; 00360 00362 bool m_store_model_features; 00363 00365 bool m_data_locked; 00366 }; 00367 } 00368 #endif // _MACHINE_H__