SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Written (W) 2013 Shell Hu and Heiko Strathmann 00009 * Copyright (C) 2012 Chiyuan Zhang 00010 */ 00011 00012 #include <shogun/multiclass/MulticlassStrategy.h> 00013 00014 namespace shogun 00015 { 00016 00028 class CMulticlassOneVsOneStrategy: public CMulticlassStrategy 00029 { 00030 public: 00032 CMulticlassOneVsOneStrategy(); 00033 00037 CMulticlassOneVsOneStrategy(EProbHeuristicType prob_heuris); 00038 00040 virtual ~CMulticlassOneVsOneStrategy() {} 00041 00043 virtual void train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels); 00044 00046 virtual bool train_has_more(); 00047 00051 virtual SGVector<int32_t> train_prepare_next(); 00052 00056 virtual int32_t decide_label(SGVector<float64_t> outputs); 00057 00060 virtual int32_t get_num_machines() 00061 { 00062 return m_num_classes*(m_num_classes-1)/2; 00063 } 00064 00066 virtual const char* get_name() const 00067 { 00068 return "MulticlassOneVsOneStrategy"; 00069 }; 00070 00075 virtual void rescale_outputs(SGVector<float64_t> outputs); 00076 00081 void set_num_classes(int32_t num_classes) 00082 { 00083 CMulticlassStrategy::set_num_classes(num_classes); 00084 m_num_machines = m_num_classes*(m_num_classes-1)/2; 00085 } 00086 00087 protected: 00093 void rescale_heuris_price(SGVector<float64_t> outputs, 00094 const SGVector<int32_t> indx1, const SGVector<int32_t> indx2); 00095 00101 void rescale_heuris_hastie(SGVector<float64_t> outputs, 00102 const SGVector<int32_t> indx1, const SGVector<int32_t> indx2); 00103 00109 void rescale_heuris_hamamura(SGVector<float64_t> outputs, 00110 const SGVector<int32_t> indx1, const SGVector<int32_t> indx2); 00111 00112 private: 00114 void register_parameters(); 00115 00116 protected: 00117 int32_t m_num_machines; 00118 int32_t m_train_pair_idx_1; 00119 int32_t m_train_pair_idx_2; 00120 SGVector<int32_t> m_num_samples; 00121 }; 00122 00123 } // namespace shogun