SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MulticlassOneVsRestStrategy.h
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Written (W) 2013 Shell Hu and Heiko Strathmann
00009  * Copyright (C) 2012 Chiyuan Zhang
00010  */
00011 
00012 #ifndef MULTICLASSONEVSRESTSTRATEGY_H__
00013 #define MULTICLASSONEVSRESTSTRATEGY_H__
00014 
00015 #include <shogun/multiclass/MulticlassStrategy.h>
00016 
00017 namespace shogun
00018 {
00019 
00031 class CMulticlassOneVsRestStrategy: public CMulticlassStrategy
00032 {
00033 public:
00035     CMulticlassOneVsRestStrategy();
00036 
00040     CMulticlassOneVsRestStrategy(EProbHeuristicType prob_heuris);
00041 
00043     virtual ~CMulticlassOneVsRestStrategy() {}
00044 
00046     virtual void train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels)
00047     {
00048         CMulticlassStrategy::train_start(orig_labels, train_labels);
00049     }
00050 
00052     virtual bool train_has_more()
00053     {
00054         return m_train_iter < m_num_classes;
00055     }
00056 
00060     virtual SGVector<int32_t> train_prepare_next();
00061 
00065     virtual int32_t decide_label(SGVector<float64_t> outputs);
00066 
00071     virtual SGVector<index_t> decide_label_multiple_output(SGVector<float64_t> outputs, int32_t n_outputs);
00072 
00075     virtual int32_t get_num_machines()
00076     {
00077         return m_num_classes;
00078     }
00079 
00081     virtual const char* get_name() const
00082     {
00083         return "MulticlassOneVsRestStrategy";
00084     };
00085 
00089     virtual void rescale_outputs(SGVector<float64_t> outputs);
00090 
00097     virtual void rescale_outputs(SGVector<float64_t> outputs,
00098             const SGVector<float64_t> As, const SGVector<float64_t> Bs);
00099 
00100 protected:
00104     void rescale_heuris_norm(SGVector<float64_t> outputs);
00105 
00111     void rescale_heuris_softmax(SGVector<float64_t> outputs,
00112             const SGVector<float64_t> As, const SGVector<float64_t> Bs);
00113 
00114 };
00115 
00116 } // namespace shogun
00117 
00118 #endif /* end of include guard: MULTICLASSONEVSRESTSTRATEGY_H__ */
00119 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation