SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Written (W) 2013 Shell Hu and Heiko Strathmann 00009 * Copyright (C) 2012 Chiyuan Zhang 00010 */ 00011 00012 #ifndef MULTICLASSSTRATEGY_H__ 00013 #define MULTICLASSSTRATEGY_H__ 00014 00015 #include <shogun/base/SGObject.h> 00016 #include <shogun/labels/BinaryLabels.h> 00017 #include <shogun/labels/MulticlassLabels.h> 00018 #include <shogun/multiclass/RejectionStrategy.h> 00019 #include <shogun/mathematics/Statistics.h> 00020 00021 namespace shogun 00022 { 00023 00034 enum EProbHeuristicType 00035 { 00036 PROB_HEURIS_NONE = 0, 00037 OVA_NORM = 1, 00038 OVA_SOFTMAX = 2, 00039 OVO_PRICE = 3, 00040 OVO_HASTIE = 4, 00041 OVO_HAMAMURA = 5 00042 }; 00043 00047 class CMulticlassStrategy: public CSGObject 00048 { 00049 public: 00051 CMulticlassStrategy(); 00052 00056 CMulticlassStrategy(EProbHeuristicType prob_heuris); 00057 00059 virtual ~CMulticlassStrategy() {} 00060 00062 virtual const char* get_name() const 00063 { 00064 return "MulticlassStrategy"; 00065 }; 00066 00068 void set_num_classes(int32_t num_classes) 00069 { 00070 m_num_classes = num_classes; 00071 } 00072 00074 int32_t get_num_classes() const 00075 { 00076 return m_num_classes; 00077 } 00078 00080 CRejectionStrategy *get_rejection_strategy() 00081 { 00082 SG_REF(m_rejection_strategy); 00083 return m_rejection_strategy; 00084 } 00085 00087 void set_rejection_strategy(CRejectionStrategy *rejection_strategy) 00088 { 00089 SG_REF(rejection_strategy); 00090 SG_UNREF(m_rejection_strategy); 00091 m_rejection_strategy = rejection_strategy; 00092 } 00093 00095 virtual void train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels); 00096 00098 virtual bool train_has_more()=0; 00099 00103 virtual SGVector<int32_t> train_prepare_next(); 00104 00106 virtual void train_stop(); 00107 00111 virtual int32_t decide_label(SGVector<float64_t> outputs)=0; 00112 00117 virtual SGVector<index_t> decide_label_multiple_output(SGVector<float64_t> outputs, int32_t n_outputs) 00118 { 00119 SG_NOTIMPLEMENTED 00120 return SGVector<index_t>(); 00121 } 00122 00125 virtual int32_t get_num_machines()=0; 00126 00128 EProbHeuristicType get_prob_heuris_type() 00129 { 00130 return m_prob_heuris; 00131 } 00132 00136 void set_prob_heuris_type(EProbHeuristicType prob_heuris) 00137 { 00138 m_prob_heuris = prob_heuris; 00139 } 00140 00146 virtual void rescale_outputs(SGVector<float64_t> outputs) 00147 { 00148 SG_NOTIMPLEMENTED 00149 } 00150 00158 virtual void rescale_outputs(SGVector<float64_t> outputs, 00159 const SGVector<float64_t> As, const SGVector<float64_t> Bs) 00160 { 00161 SG_NOTIMPLEMENTED 00162 } 00163 00164 private: 00166 void init(); 00167 00168 protected: 00169 00170 CRejectionStrategy* m_rejection_strategy; 00171 CBinaryLabels *m_train_labels; 00172 CMulticlassLabels *m_orig_labels; 00173 int32_t m_train_iter; 00174 int32_t m_num_classes; 00175 EProbHeuristicType m_prob_heuris; 00176 }; 00177 00178 } // namespace shogun 00179 00180 #endif /* end of include guard: MULTICLASSSTRATEGY_H__ */ 00181