SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Viktor Gal 00008 * Copyright (C) 2013 Viktor Gal 00009 */ 00010 00011 #include <shogun/machine/BaggingMachine.h> 00012 #include <shogun/base/Parameter.h> 00013 00014 using namespace shogun; 00015 00016 CBaggingMachine::CBaggingMachine() 00017 : CMachine() 00018 { 00019 init(); 00020 register_parameters(); 00021 } 00022 00023 CBaggingMachine::CBaggingMachine(CFeatures* features, CLabels* labels) 00024 : CMachine() 00025 { 00026 init(); 00027 register_parameters(); 00028 00029 set_labels(labels); 00030 00031 SG_REF(features); 00032 m_features = features; 00033 } 00034 00035 CBaggingMachine::~CBaggingMachine() 00036 { 00037 SG_UNREF(m_machine); 00038 SG_UNREF(m_features); 00039 SG_UNREF(m_combination_rule); 00040 SG_UNREF(m_bags); 00041 SG_UNREF(m_oob_indices); 00042 } 00043 00044 CBinaryLabels* CBaggingMachine::apply_binary(CFeatures* data) 00045 { 00046 SGVector<float64_t> combined_vector = apply_get_outputs(data); 00047 00048 CBinaryLabels* pred = new CBinaryLabels(combined_vector); 00049 return pred; 00050 } 00051 00052 CMulticlassLabels* CBaggingMachine::apply_multiclass(CFeatures* data) 00053 { 00054 SGVector<float64_t> combined_vector = apply_get_outputs(data); 00055 00056 CMulticlassLabels* pred = new CMulticlassLabels(combined_vector); 00057 return pred; 00058 } 00059 00060 CRegressionLabels* CBaggingMachine::apply_regression(CFeatures* data) 00061 { 00062 SGVector<float64_t> combined_vector = apply_get_outputs(data); 00063 00064 CRegressionLabels* pred = new CRegressionLabels(combined_vector); 00065 00066 return pred; 00067 } 00068 00069 SGVector<float64_t> CBaggingMachine::apply_get_outputs(CFeatures* data) 00070 { 00071 ASSERT(data != NULL); 00072 REQUIRE(m_combination_rule != NULL, "Combination rule is not set!"); 00073 ASSERT(m_num_bags == m_bags->get_num_elements()); 00074 00075 SGMatrix<float64_t> output(data->get_num_vectors(), m_num_bags); 00076 output.zero(); 00077 00078 #pragma omp parallel for num_threads(parallel->get_num_threads()) 00079 for (int32_t i = 0; i < m_num_bags; ++i) 00080 { 00081 CMachine* m = dynamic_cast<CMachine*>(m_bags->get_element(i)); 00082 CLabels* l = m->apply(data); 00083 SGVector<float64_t> lv = l->get_values(); 00084 float64_t* bag_results = output.get_column_vector(i); 00085 memcpy(bag_results, lv.vector, lv.vlen*sizeof(float64_t)); 00086 00087 SG_UNREF(l); 00088 SG_UNREF(m); 00089 } 00090 00091 SGVector<float64_t> combined = m_combination_rule->combine(output); 00092 00093 return combined; 00094 } 00095 00096 bool CBaggingMachine::train_machine(CFeatures* data) 00097 { 00098 REQUIRE(m_machine != NULL, "Machine is not set!"); 00099 REQUIRE(m_bag_size > 0, "Bag size is not set!"); 00100 REQUIRE(m_num_bags > 0, "Number of bag is not set!"); 00101 00102 if (data) 00103 { 00104 SG_REF(data); 00105 SG_UNREF(m_features); 00106 m_features = data; 00107 00108 ASSERT(m_features->get_num_vectors() == m_labels->get_num_labels()); 00109 } 00110 00111 // bag size << number of feature vector 00112 ASSERT(m_bag_size < m_features->get_num_vectors()); 00113 00114 // clear the array, if previously trained 00115 m_bags->reset_array(); 00116 00117 // reset the oob index vector 00118 m_all_oob_idx = SGVector<bool>(m_features->get_num_vectors()); 00119 m_all_oob_idx.zero(); 00120 00121 SG_UNREF(m_oob_indices); 00122 m_oob_indices = new CDynamicObjectArray(); 00123 00124 /* 00125 TODO: enable multi-threaded learning. This requires views support 00126 on CFeatures 00127 #pragma omp parallel for num_threads(parallel->get_num_threads()) 00128 */ 00129 for (int32_t i = 0; i < m_num_bags; ++i) 00130 { 00131 CMachine* c = dynamic_cast<CMachine*>(m_machine->clone()); 00132 ASSERT(c != NULL); 00133 SGVector<index_t> idx(m_bag_size); 00134 idx.random(0, m_features->get_num_vectors()-1); 00135 m_labels->add_subset(idx); 00136 /* TODO: 00137 if it's a binary labeling ensure that 00138 there's always samples of both classes 00139 if ((m_labels->get_label_type() == LT_BINARY)) 00140 { 00141 while (true) { 00142 if (!m_labels->ensure_valid()) { 00143 m_labels->remove_subset(); 00144 idx.random(0, m_features->get_num_vectors()); 00145 m_labels->add_subset(idx); 00146 continue; 00147 } 00148 break; 00149 } 00150 } 00151 */ 00152 m_features->add_subset(idx); 00153 c->set_labels(m_labels); 00154 c->train(m_features); 00155 m_features->remove_subset(); 00156 m_labels->remove_subset(); 00157 00158 // get out of bag indexes 00159 CDynamicArray<index_t>* oob = get_oob_indices(idx); 00160 m_oob_indices->push_back(oob); 00161 00162 // add trained machine to bag array 00163 m_bags->append_element(c); 00164 } 00165 00166 return true; 00167 } 00168 00169 void CBaggingMachine::register_parameters() 00170 { 00171 SG_ADD((CSGObject**)&m_features, "features", "Train features for bagging", 00172 MS_NOT_AVAILABLE); 00173 SG_ADD(&m_num_bags, "num_bags", "Number of bags", MS_AVAILABLE); 00174 SG_ADD(&m_bag_size, "bag_size", "Number of vectors per bag", MS_AVAILABLE); 00175 SG_ADD((CSGObject**)&m_bags, "bags", "Bags array", MS_NOT_AVAILABLE); 00176 SG_ADD((CSGObject**)&m_combination_rule, "combination_rule", 00177 "Combination rule to use for aggregating", MS_AVAILABLE); 00178 SG_ADD(&m_all_oob_idx, "all_oob_idx", "Indices of all oob vectors", 00179 MS_NOT_AVAILABLE); 00180 SG_ADD((CSGObject**)&m_oob_indices, "oob_indices", 00181 "OOB indices for each machine", MS_NOT_AVAILABLE); 00182 } 00183 00184 void CBaggingMachine::set_num_bags(int32_t num_bags) 00185 { 00186 m_num_bags = num_bags; 00187 } 00188 00189 int32_t CBaggingMachine::get_num_bags() const 00190 { 00191 return m_num_bags; 00192 } 00193 00194 void CBaggingMachine::set_bag_size(int32_t bag_size) 00195 { 00196 m_bag_size = bag_size; 00197 } 00198 00199 int32_t CBaggingMachine::get_bag_size() const 00200 { 00201 return m_bag_size; 00202 } 00203 00204 CMachine* CBaggingMachine::get_machine() const 00205 { 00206 SG_REF(m_machine); 00207 return m_machine; 00208 } 00209 00210 void CBaggingMachine::set_machine(CMachine* machine) 00211 { 00212 SG_REF(machine); 00213 SG_UNREF(m_machine); 00214 m_machine = machine; 00215 } 00216 00217 void CBaggingMachine::init() 00218 { 00219 m_bags = new CDynamicObjectArray(); 00220 m_machine = NULL; 00221 m_features = NULL; 00222 m_combination_rule = NULL; 00223 m_labels = NULL; 00224 m_num_bags = 0; 00225 m_bag_size = 0; 00226 m_all_oob_idx = SGVector<bool>(); 00227 m_oob_indices = NULL; 00228 } 00229 00230 void CBaggingMachine::set_combination_rule(CCombinationRule* rule) 00231 { 00232 SG_REF(rule); 00233 SG_UNREF(m_combination_rule); 00234 m_combination_rule = rule; 00235 } 00236 00237 CCombinationRule* CBaggingMachine::get_combination_rule() const 00238 { 00239 SG_REF(m_combination_rule); 00240 return m_combination_rule; 00241 } 00242 00243 float64_t CBaggingMachine::get_oob_error(CEvaluation* eval) const 00244 { 00245 REQUIRE(m_combination_rule != NULL, "Combination rule is not set!"); 00246 REQUIRE(m_bags->get_num_elements() > 0, "BaggingMachine is not trained!"); 00247 00248 SGMatrix<float64_t> output(m_features->get_num_vectors(), m_bags->get_num_elements()); 00249 if (m_labels->get_label_type() == LT_REGRESSION) 00250 output.zero(); 00251 else 00252 output.set_const(NAN); 00253 00254 /* TODO: add parallel support of applying the OOBs 00255 only possible when add_subset is thread-safe 00256 #pragma omp parallel for num_threads(parallel->get_num_threads()) 00257 */ 00258 for (index_t i = 0; i < m_bags->get_num_elements(); i++) 00259 { 00260 CMachine* m = dynamic_cast<CMachine*>(m_bags->get_element(i)); 00261 CDynamicArray<index_t>* current_oob 00262 = dynamic_cast<CDynamicArray<index_t>*>(m_oob_indices->get_element(i)); 00263 00264 SGVector<index_t> oob(current_oob->get_array(), current_oob->get_num_elements(), false); 00265 oob.display_vector(); 00266 m_features->add_subset(oob); 00267 00268 CLabels* l = m->apply(m_features); 00269 SGVector<float64_t> lv = l->get_values(); 00270 00271 // assign the values in the matrix (NAN) that are in-bag! 00272 for (index_t j = 0; j < oob.vlen; j++) 00273 output(oob[j], i) = lv[j]; 00274 00275 m_features->remove_subset(); 00276 SG_UNREF(current_oob); 00277 SG_UNREF(m); 00278 SG_UNREF(l); 00279 } 00280 output.display_matrix(); 00281 00282 DynArray<index_t> idx; 00283 for (index_t i = 0; i < m_features->get_num_vectors(); i++) 00284 { 00285 if (m_all_oob_idx[i]) 00286 idx.push_back(i); 00287 } 00288 00289 SGVector<float64_t> combined = m_combination_rule->combine(output); 00290 CLabels* predicted = NULL; 00291 switch (m_labels->get_label_type()) 00292 { 00293 case LT_BINARY: 00294 predicted = new CBinaryLabels(combined); 00295 break; 00296 00297 case LT_MULTICLASS: 00298 predicted = new CMulticlassLabels(combined); 00299 break; 00300 00301 case LT_REGRESSION: 00302 predicted = new CRegressionLabels(combined); 00303 break; 00304 00305 default: 00306 SG_ERROR("Unsupported label type\n"); 00307 } 00308 00309 m_labels->add_subset(SGVector<index_t>(idx.get_array(), idx.get_num_elements(), false)); 00310 float64_t res = eval->evaluate(predicted, m_labels); 00311 m_labels->remove_subset(); 00312 00313 return res; 00314 } 00315 00316 CDynamicArray<index_t>* CBaggingMachine::get_oob_indices(const SGVector<index_t>& in_bag) 00317 { 00318 SGVector<bool> out_of_bag(m_features->get_num_vectors()); 00319 out_of_bag.set_const(true); 00320 00321 // mark the ones that are in_bag 00322 index_t oob_count = m_features->get_num_vectors(); 00323 for (index_t i = 0; i < in_bag.vlen; i++) 00324 { 00325 if (out_of_bag[in_bag[i]]) 00326 { 00327 out_of_bag[in_bag[i]] = false; 00328 oob_count--; 00329 } 00330 } 00331 00332 CDynamicArray<index_t>* oob = new CDynamicArray<index_t>(); 00333 // store the indicies of vectors that are out of the bag 00334 for (index_t i = 0; i < out_of_bag.vlen; i++) 00335 { 00336 if (out_of_bag[i]) 00337 { 00338 oob->push_back(i); 00339 m_all_oob_idx[i] = true; 00340 } 00341 } 00342 00343 return oob; 00344 } 00345