SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #ifndef _LINEARMULTICLASSMACHINE_H___ 00012 #define _LINEARMULTICLASSMACHINE_H___ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/features/DotFeatures.h> 00016 #include <shogun/machine/LinearMachine.h> 00017 #include <shogun/machine/MulticlassMachine.h> 00018 00019 namespace shogun 00020 { 00021 00022 class CDotFeatures; 00023 class CLinearMachine; 00024 class CMulticlassStrategy; 00025 00027 class CLinearMulticlassMachine : public CMulticlassMachine 00028 { 00029 public: 00031 CLinearMulticlassMachine() : CMulticlassMachine(), m_features(NULL) 00032 { 00033 SG_ADD((CSGObject**)&m_features, "m_features", "Feature object.", 00034 MS_NOT_AVAILABLE); 00035 } 00036 00043 CLinearMulticlassMachine(CMulticlassStrategy *strategy, CDotFeatures* features, CLinearMachine* machine, CLabels* labs) : 00044 CMulticlassMachine(strategy,(CMachine*)machine,labs), m_features(NULL) 00045 { 00046 set_features(features); 00047 SG_ADD((CSGObject**)&m_features, "m_features", "Feature object.", 00048 MS_NOT_AVAILABLE); 00049 } 00050 00052 virtual ~CLinearMulticlassMachine() 00053 { 00054 SG_UNREF(m_features); 00055 } 00056 00058 virtual const char* get_name() const 00059 { 00060 return "LinearMulticlassMachine"; 00061 } 00062 00067 void set_features(CDotFeatures* f) 00068 { 00069 SG_REF(f); 00070 SG_UNREF(m_features); 00071 m_features = f; 00072 00073 for (index_t i=0; i<m_machines->get_num_elements(); i++) 00074 { 00075 CLinearMachine* machine = (CLinearMachine* )m_machines->get_element(i); 00076 machine->set_features(f); 00077 SG_UNREF(machine); 00078 } 00079 } 00080 00085 CDotFeatures* get_features() const 00086 { 00087 SG_REF(m_features); 00088 return m_features; 00089 } 00090 00091 protected: 00092 00094 virtual bool init_machine_for_train(CFeatures* data) 00095 { 00096 if (!m_machine) 00097 SG_ERROR("No machine given in Multiclass constructor\n") 00098 00099 if (data) 00100 set_features((CDotFeatures*)data); 00101 00102 ((CLinearMachine*)m_machine)->set_features(m_features); 00103 00104 return true; 00105 } 00106 00108 virtual bool init_machines_for_apply(CFeatures* data) 00109 { 00110 if (data) 00111 set_features((CDotFeatures*)data); 00112 00113 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00114 { 00115 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i); 00116 ASSERT(m_features) 00117 ASSERT(machine) 00118 machine->set_features(m_features); 00119 SG_UNREF(machine); 00120 } 00121 00122 return true; 00123 } 00124 00126 virtual bool is_ready() 00127 { 00128 if (m_features) 00129 return true; 00130 00131 return false; 00132 } 00133 00135 virtual CMachine* get_machine_from_trained(CMachine* machine) 00136 { 00137 return new CLinearMachine((CLinearMachine*)machine); 00138 } 00139 00141 virtual int32_t get_num_rhs_vectors() 00142 { 00143 return m_features->get_num_vectors(); 00144 } 00145 00150 virtual void add_machine_subset(SGVector<index_t> subset) 00151 { 00152 /* changing the subset structure to use subset stacks. This might 00153 * have to be revised. Heiko Strathmann */ 00154 m_features->add_subset(subset); 00155 } 00156 00158 virtual void remove_machine_subset() 00159 { 00160 /* changing the subset structure to use subset stacks. This might 00161 * have to be revised. Heiko Strathmann */ 00162 m_features->remove_subset(); 00163 } 00164 00169 virtual void store_model_features() {} 00170 00171 protected: 00172 00174 CDotFeatures* m_features; 00175 }; 00176 } 00177 #endif