SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2011 Soeren Sonnenburg 00008 * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn 00009 * Written (W) 2013 Shell Hu and Heiko Strathmann 00010 * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia 00011 */ 00012 00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00014 #include <shogun/machine/LinearMachine.h> 00015 #include <shogun/machine/KernelMachine.h> 00016 #include <shogun/machine/MulticlassMachine.h> 00017 #include <shogun/base/Parameter.h> 00018 #include <shogun/labels/MulticlassLabels.h> 00019 #include <shogun/labels/RegressionLabels.h> 00020 #include <shogun/mathematics/Statistics.h> 00021 00022 using namespace shogun; 00023 00024 CMulticlassMachine::CMulticlassMachine() 00025 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()), 00026 m_machine(NULL) 00027 { 00028 SG_REF(m_multiclass_strategy); 00029 register_parameters(); 00030 } 00031 00032 CMulticlassMachine::CMulticlassMachine( 00033 CMulticlassStrategy *strategy, 00034 CMachine* machine, CLabels* labs) 00035 : CBaseMulticlassMachine(), m_multiclass_strategy(strategy) 00036 { 00037 SG_REF(strategy); 00038 set_labels(labs); 00039 SG_REF(machine); 00040 m_machine = machine; 00041 register_parameters(); 00042 00043 if (labs) 00044 init_strategy(); 00045 } 00046 00047 CMulticlassMachine::~CMulticlassMachine() 00048 { 00049 SG_UNREF(m_multiclass_strategy); 00050 SG_UNREF(m_machine); 00051 } 00052 00053 void CMulticlassMachine::set_labels(CLabels* lab) 00054 { 00055 CMachine::set_labels(lab); 00056 if (lab) 00057 init_strategy(); 00058 } 00059 00060 void CMulticlassMachine::register_parameters() 00061 { 00062 SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE); 00063 SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE); 00064 } 00065 00066 void CMulticlassMachine::init_strategy() 00067 { 00068 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00069 m_multiclass_strategy->set_num_classes(num_classes); 00070 } 00071 00072 CBinaryLabels* CMulticlassMachine::get_submachine_outputs(int32_t i) 00073 { 00074 CMachine *machine = (CMachine*)m_machines->get_element(i); 00075 ASSERT(machine) 00076 CBinaryLabels* output = machine->apply_binary(); 00077 SG_UNREF(machine); 00078 return output; 00079 } 00080 00081 float64_t CMulticlassMachine::get_submachine_output(int32_t i, int32_t num) 00082 { 00083 CMachine *machine = get_machine(i); 00084 float64_t output = 0.0; 00085 // dirty hack 00086 if (dynamic_cast<CLinearMachine*>(machine)) 00087 output = ((CLinearMachine*)machine)->apply_one(num); 00088 if (dynamic_cast<CKernelMachine*>(machine)) 00089 output = ((CKernelMachine*)machine)->apply_one(num); 00090 SG_UNREF(machine); 00091 return output; 00092 } 00093 00094 CMulticlassLabels* CMulticlassMachine::apply_multiclass(CFeatures* data) 00095 { 00096 SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n", 00097 get_name(), data ? data->get_name() : "NULL", data); 00098 00099 CMulticlassLabels* return_labels=NULL; 00100 00101 if (data) 00102 init_machines_for_apply(data); 00103 else 00104 init_machines_for_apply(NULL); 00105 00106 if (is_ready()) 00107 { 00108 /* num vectors depends on whether data is provided */ 00109 int32_t num_vectors=data ? data->get_num_vectors() : 00110 get_num_rhs_vectors(); 00111 00112 int32_t num_machines=m_machines->get_num_elements(); 00113 if (num_machines <= 0) 00114 SG_ERROR("num_machines = %d, did you train your machine?", num_machines) 00115 00116 CMulticlassLabels* result=new CMulticlassLabels(num_vectors); 00117 00118 // if outputs are prob, only one confidence for each class 00119 int32_t num_classes=m_multiclass_strategy->get_num_classes(); 00120 EProbHeuristicType heuris = get_prob_heuris(); 00121 00122 if (heuris!=PROB_HEURIS_NONE) 00123 result->allocate_confidences_for(num_classes); 00124 else 00125 result->allocate_confidences_for(num_machines); 00126 00127 CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines); 00128 SGVector<float64_t> As(num_machines); 00129 SGVector<float64_t> Bs(num_machines); 00130 00131 for (int32_t i=0; i<num_machines; ++i) 00132 { 00133 outputs[i] = (CBinaryLabels*) get_submachine_outputs(i); 00134 00135 if (heuris==OVA_SOFTMAX) 00136 { 00137 CStatistics::SigmoidParamters params = CStatistics::fit_sigmoid(outputs[i]->get_values()); 00138 As[i] = params.a; 00139 Bs[i] = params.b; 00140 } 00141 00142 if (heuris!=PROB_HEURIS_NONE && heuris!=OVA_SOFTMAX) 00143 outputs[i]->scores_to_probabilities(0,0); 00144 } 00145 00146 SGVector<float64_t> output_for_i(num_machines); 00147 SGVector<float64_t> r_output_for_i(num_machines); 00148 if (heuris!=PROB_HEURIS_NONE) 00149 r_output_for_i.resize_vector(num_classes); 00150 00151 for (int32_t i=0; i<num_vectors; i++) 00152 { 00153 for (int32_t j=0; j<num_machines; j++) 00154 output_for_i[j] = outputs[j]->get_value(i); 00155 00156 if (heuris==PROB_HEURIS_NONE) 00157 { 00158 r_output_for_i = output_for_i; 00159 } 00160 else 00161 { 00162 if (heuris==OVA_SOFTMAX) 00163 m_multiclass_strategy->rescale_outputs(output_for_i,As,Bs); 00164 else 00165 m_multiclass_strategy->rescale_outputs(output_for_i); 00166 00167 // only first num_classes are returned 00168 for (int32_t r=0; r<num_classes; r++) 00169 r_output_for_i[r] = output_for_i[r]; 00170 00171 SG_DEBUG("%s::apply_multiclass(): sum(r_output_for_i) = %f\n", 00172 get_name(), SGVector<float64_t>::sum(r_output_for_i.vector,num_classes)); 00173 } 00174 00175 // use rescaled outputs for label decision 00176 result->set_label(i, m_multiclass_strategy->decide_label(r_output_for_i)); 00177 result->set_multiclass_confidences(i, r_output_for_i); 00178 } 00179 00180 for (int32_t i=0; i < num_machines; ++i) 00181 SG_UNREF(outputs[i]); 00182 00183 SG_FREE(outputs); 00184 00185 return_labels=result; 00186 } 00187 else 00188 SG_ERROR("Not ready") 00189 00190 00191 SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n", 00192 get_name(), data ? data->get_name() : "NULL", data); 00193 return return_labels; 00194 } 00195 00196 CMulticlassMultipleOutputLabels* CMulticlassMachine::apply_multiclass_multiple_output(CFeatures* data, int32_t n_outputs) 00197 { 00198 CMulticlassMultipleOutputLabels* return_labels=NULL; 00199 00200 if (data) 00201 init_machines_for_apply(data); 00202 else 00203 init_machines_for_apply(NULL); 00204 00205 if (is_ready()) 00206 { 00207 /* num vectors depends on whether data is provided */ 00208 int32_t num_vectors=data ? data->get_num_vectors() : 00209 get_num_rhs_vectors(); 00210 00211 int32_t num_machines=m_machines->get_num_elements(); 00212 if (num_machines <= 0) 00213 SG_ERROR("num_machines = %d, did you train your machine?", num_machines) 00214 REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available") 00215 00216 CMulticlassMultipleOutputLabels* result=new CMulticlassMultipleOutputLabels(num_vectors); 00217 CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines); 00218 00219 for (int32_t i=0; i < num_machines; ++i) 00220 outputs[i] = (CBinaryLabels*) get_submachine_outputs(i); 00221 00222 SGVector<float64_t> output_for_i(num_machines); 00223 for (int32_t i=0; i<num_vectors; i++) 00224 { 00225 for (int32_t j=0; j<num_machines; j++) 00226 output_for_i[j] = outputs[j]->get_value(i); 00227 00228 result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs)); 00229 } 00230 00231 for (int32_t i=0; i < num_machines; ++i) 00232 SG_UNREF(outputs[i]); 00233 00234 SG_FREE(outputs); 00235 00236 return_labels=result; 00237 } 00238 else 00239 SG_ERROR("Not ready") 00240 00241 return return_labels; 00242 } 00243 00244 bool CMulticlassMachine::train_machine(CFeatures* data) 00245 { 00246 ASSERT(m_multiclass_strategy) 00247 00248 if ( !data && !is_ready() ) 00249 SG_ERROR("Please provide training data.\n") 00250 else 00251 init_machine_for_train(data); 00252 00253 m_machines->reset_array(); 00254 CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors()); 00255 SG_REF(train_labels); 00256 m_machine->set_labels(train_labels); 00257 00258 m_multiclass_strategy->train_start(CLabelsFactory::to_multiclass(m_labels), train_labels); 00259 while (m_multiclass_strategy->train_has_more()) 00260 { 00261 SGVector<index_t> subset=m_multiclass_strategy->train_prepare_next(); 00262 if (subset.vlen) 00263 { 00264 train_labels->add_subset(subset); 00265 add_machine_subset(subset); 00266 } 00267 00268 m_machine->train(); 00269 m_machines->push_back(get_machine_from_trained(m_machine)); 00270 00271 if (subset.vlen) 00272 { 00273 train_labels->remove_subset(); 00274 remove_machine_subset(); 00275 } 00276 } 00277 00278 m_multiclass_strategy->train_stop(); 00279 SG_UNREF(train_labels); 00280 00281 return true; 00282 } 00283 00284 float64_t CMulticlassMachine::apply_one(int32_t vec_idx) 00285 { 00286 init_machines_for_apply(NULL); 00287 00288 ASSERT(m_machines->get_num_elements()>0) 00289 SGVector<float64_t> outputs(m_machines->get_num_elements()); 00290 00291 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00292 outputs[i] = get_submachine_output(i, vec_idx); 00293 00294 float64_t result = m_multiclass_strategy->decide_label(outputs); 00295 00296 return result; 00297 }