SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2011 Soeren Sonnenburg 00008 * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn 00009 * Written (W) 2013 Shell Hu and Heiko Strathmann 00010 * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia 00011 */ 00012 00013 #ifndef _MULTICLASSMACHINE_H___ 00014 #define _MULTICLASSMACHINE_H___ 00015 00016 #include <shogun/machine/BaseMulticlassMachine.h> 00017 #include <shogun/lib/DynamicObjectArray.h> 00018 #include <shogun/multiclass/MulticlassStrategy.h> 00019 #include <shogun/labels/MulticlassLabels.h> 00020 #include <shogun/labels/MulticlassMultipleOutputLabels.h> 00021 00022 namespace shogun 00023 { 00024 00025 class CFeatures; 00026 class CLabels; 00027 00029 class CMulticlassMachine : public CBaseMulticlassMachine 00030 { 00031 public: 00033 CMulticlassMachine(); 00034 00040 CMulticlassMachine(CMulticlassStrategy* strategy, CMachine* machine, CLabels* labels); 00041 00043 virtual ~CMulticlassMachine(); 00044 00049 virtual void set_labels(CLabels* lab); 00050 00057 inline bool set_machine(int32_t num, CMachine* machine) 00058 { 00059 ASSERT(num<m_machines->get_num_elements() && num>=0) 00060 if (machine != NULL && !is_acceptable_machine(machine)) 00061 SG_ERROR("Machine %s is not acceptable by %s", machine->get_name(), this->get_name()) 00062 00063 m_machines->set_element(machine, num); 00064 return true; 00065 } 00066 00072 inline CMachine* get_machine(int32_t num) const 00073 { 00074 return (CMachine*) m_machines->get_element_safe(num); 00075 } 00076 00081 virtual CBinaryLabels* get_submachine_outputs(int32_t i); 00082 00088 virtual float64_t get_submachine_output(int32_t i, int32_t num); 00089 00094 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00095 00100 virtual CMulticlassMultipleOutputLabels* apply_multiclass_multiple_output(CFeatures* data=NULL, int32_t n_outputs=5); 00101 00106 virtual float64_t apply_one(int32_t vec_idx); 00107 00112 inline CMulticlassStrategy* get_multiclass_strategy() const 00113 { 00114 SG_REF(m_multiclass_strategy); 00115 return m_multiclass_strategy; 00116 } 00117 00122 inline CRejectionStrategy* get_rejection_strategy() const 00123 { 00124 return m_multiclass_strategy->get_rejection_strategy(); 00125 } 00126 00131 inline void set_rejection_strategy(CRejectionStrategy* rejection_strategy) 00132 { 00133 m_multiclass_strategy->set_rejection_strategy(rejection_strategy); 00134 } 00135 00137 virtual const char* get_name() const 00138 { 00139 return "MulticlassMachine"; 00140 } 00141 00143 inline EProbHeuristicType get_prob_heuris() 00144 { 00145 return m_multiclass_strategy->get_prob_heuris_type(); 00146 } 00147 00151 inline void set_prob_heuris(EProbHeuristicType prob_heuris) 00152 { 00153 m_multiclass_strategy->set_prob_heuris_type(prob_heuris); 00154 } 00155 00156 protected: 00158 void init_strategy(); 00159 00161 void clear_machines(); 00162 00164 virtual bool train_machine(CFeatures* data = NULL); 00165 00167 virtual bool init_machine_for_train(CFeatures* data) = 0; 00168 00170 virtual bool init_machines_for_apply(CFeatures* data) = 0; 00171 00173 virtual bool is_ready() = 0; 00174 00176 virtual CMachine* get_machine_from_trained(CMachine* machine) = 0; 00177 00179 virtual int32_t get_num_rhs_vectors() = 0; 00180 00185 virtual void add_machine_subset(SGVector<index_t> subset) = 0; 00186 00188 virtual void remove_machine_subset() = 0; 00189 00191 virtual bool is_acceptable_machine(CMachine *machine) 00192 { 00193 return true; 00194 } 00195 00196 private: 00197 00199 void register_parameters(); 00200 00201 protected: 00203 CMulticlassStrategy *m_multiclass_strategy; 00204 00206 CMachine* m_machine; 00207 }; 00208 } 00209 #endif