SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011-2012 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/CrossValidation.h> 00012 #include <shogun/machine/Machine.h> 00013 #include <shogun/evaluation/Evaluation.h> 00014 #include <shogun/evaluation/SplittingStrategy.h> 00015 #include <shogun/base/Parameter.h> 00016 #include <shogun/base/ParameterMap.h> 00017 #include <shogun/mathematics/Statistics.h> 00018 #include <shogun/evaluation/CrossValidationOutput.h> 00019 #include <shogun/lib/List.h> 00020 00021 using namespace shogun; 00022 00023 CCrossValidation::CCrossValidation() : CMachineEvaluation() 00024 { 00025 init(); 00026 } 00027 00028 CCrossValidation::CCrossValidation(CMachine* machine, CFeatures* features, 00029 CLabels* labels, CSplittingStrategy* splitting_strategy, 00030 CEvaluation* evaluation_criterion, bool autolock) : 00031 CMachineEvaluation(machine, features, labels, splitting_strategy, 00032 evaluation_criterion, autolock) 00033 { 00034 init(); 00035 } 00036 00037 CCrossValidation::CCrossValidation(CMachine* machine, CLabels* labels, 00038 CSplittingStrategy* splitting_strategy, 00039 CEvaluation* evaluation_criterion, bool autolock) : 00040 CMachineEvaluation(machine, labels, splitting_strategy, evaluation_criterion, 00041 autolock) 00042 { 00043 init(); 00044 } 00045 00046 CCrossValidation::~CCrossValidation() 00047 { 00048 SG_UNREF(m_xval_outputs); 00049 } 00050 00051 void CCrossValidation::init() 00052 { 00053 m_num_runs=1; 00054 m_conf_int_alpha=0; 00055 00056 /* do reference counting for output objects */ 00057 m_xval_outputs=new CList(true); 00058 00059 SG_ADD(&m_num_runs, "num_runs", "Number of repetitions", 00060 MS_NOT_AVAILABLE); 00061 SG_ADD(&m_conf_int_alpha, "conf_int_alpha", "alpha-value " 00062 "of confidence interval", MS_NOT_AVAILABLE); 00063 SG_ADD((CSGObject**)&m_xval_outputs, "m_xval_outputs", "List of output " 00064 "classes for intermediade cross-validation results", 00065 MS_NOT_AVAILABLE); 00066 } 00067 00068 CEvaluationResult* CCrossValidation::evaluate() 00069 { 00070 SG_DEBUG("entering %s::evaluate()\n", get_name()) 00071 00072 REQUIRE(m_machine, "%s::evaluate() is only possible if a machine is " 00073 "attached\n", get_name()); 00074 00075 REQUIRE(m_features, "%s::evaluate() is only possible if features are " 00076 "attached\n", get_name()); 00077 00078 REQUIRE(m_labels, "%s::evaluate() is only possible if labels are " 00079 "attached\n", get_name()); 00080 00081 /* if for some reason the do_unlock_frag is set, unlock */ 00082 if (m_do_unlock) 00083 { 00084 m_machine->data_unlock(); 00085 m_do_unlock=false; 00086 } 00087 00088 /* set labels in any case (no locking needs this) */ 00089 m_machine->set_labels(m_labels); 00090 00091 if (m_autolock) 00092 { 00093 /* if machine supports locking try to do so */ 00094 if (m_machine->supports_locking()) 00095 { 00096 /* only lock if machine is not yet locked */ 00097 if (!m_machine->is_data_locked()) 00098 { 00099 m_machine->data_lock(m_labels, m_features); 00100 m_do_unlock=true; 00101 } 00102 } 00103 else 00104 { 00105 SG_WARNING("%s does not support locking. Autolocking is skipped. " 00106 "Set autolock flag to false to get rid of warning.\n", 00107 m_machine->get_name()); 00108 } 00109 } 00110 00111 SGVector<float64_t> results(m_num_runs); 00112 00113 /* evtl. update xvalidation output class */ 00114 CCrossValidationOutput* current=(CCrossValidationOutput*) 00115 m_xval_outputs->get_first_element(); 00116 while (current) 00117 { 00118 current->init_num_runs(m_num_runs); 00119 current->init_num_folds(m_splitting_strategy->get_num_subsets()); 00120 current->init_expose_labels(m_labels); 00121 current->post_init(); 00122 SG_UNREF(current); 00123 current=(CCrossValidationOutput*) 00124 m_xval_outputs->get_next_element(); 00125 } 00126 00127 /* perform all the x-val runs */ 00128 SG_DEBUG("starting %d runs of cross-validation\n", m_num_runs) 00129 for (index_t i=0; i <m_num_runs; ++i) 00130 { 00131 00132 /* evtl. update xvalidation output class */ 00133 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element(); 00134 while (current) 00135 { 00136 current->update_run_index(i); 00137 SG_UNREF(current); 00138 current=(CCrossValidationOutput*) 00139 m_xval_outputs->get_next_element(); 00140 } 00141 00142 SG_DEBUG("entering cross-validation run %d \n", i) 00143 results[i]=evaluate_one_run(); 00144 SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i]) 00145 } 00146 00147 /* construct evaluation result */ 00148 CCrossValidationResult* result = new CCrossValidationResult(); 00149 result->has_conf_int=m_conf_int_alpha != 0; 00150 result->conf_int_alpha=m_conf_int_alpha; 00151 00152 if (result->has_conf_int) 00153 { 00154 result->conf_int_alpha=m_conf_int_alpha; 00155 result->mean=CStatistics::confidence_intervals_mean(results, 00156 result->conf_int_alpha, result->conf_int_low, result->conf_int_up); 00157 } 00158 else 00159 { 00160 result->mean=CStatistics::mean(results); 00161 result->conf_int_low=0; 00162 result->conf_int_up=0; 00163 } 00164 00165 /* unlock machine if it was locked in this method */ 00166 if (m_machine->is_data_locked() && m_do_unlock) 00167 { 00168 m_machine->data_unlock(); 00169 m_do_unlock=false; 00170 } 00171 00172 SG_DEBUG("leaving %s::evaluate()\n", get_name()) 00173 00174 SG_REF(result); 00175 return result; 00176 } 00177 00178 void CCrossValidation::set_conf_int_alpha(float64_t conf_int_alpha) 00179 { 00180 if (conf_int_alpha <0 || conf_int_alpha>= 1) { 00181 SG_ERROR("%f is an illegal alpha-value for confidence interval of " 00182 "cross-validation\n", conf_int_alpha); 00183 } 00184 00185 if (m_num_runs==1) 00186 { 00187 SG_WARNING("Confidence interval for Cross-Validation only possible" 00188 " when number of runs is >1, ignoring.\n"); 00189 } 00190 else 00191 m_conf_int_alpha=conf_int_alpha; 00192 } 00193 00194 void CCrossValidation::set_num_runs(int32_t num_runs) 00195 { 00196 if (num_runs <1) 00197 SG_ERROR("%d is an illegal number of repetitions\n", num_runs) 00198 00199 m_num_runs=num_runs; 00200 } 00201 00202 float64_t CCrossValidation::evaluate_one_run() 00203 { 00204 SG_DEBUG("entering %s::evaluate_one_run()\n", get_name()) 00205 index_t num_subsets=m_splitting_strategy->get_num_subsets(); 00206 00207 SG_DEBUG("building index sets for %d-fold cross-validation\n", num_subsets) 00208 00209 /* build index sets */ 00210 m_splitting_strategy->build_subsets(); 00211 00212 /* results array */ 00213 SGVector<float64_t> results(num_subsets); 00214 00215 /* different behavior whether data is locked or not */ 00216 if (m_machine->is_data_locked()) 00217 { 00218 SG_DEBUG("starting locked evaluation\n", get_name()) 00219 /* do actual cross-validation */ 00220 for (index_t i=0; i <num_subsets; ++i) 00221 { 00222 /* evtl. update xvalidation output class */ 00223 CCrossValidationOutput* current=(CCrossValidationOutput*) 00224 m_xval_outputs->get_first_element(); 00225 while (current) 00226 { 00227 current->update_fold_index(i); 00228 SG_UNREF(current); 00229 current=(CCrossValidationOutput*) 00230 m_xval_outputs->get_next_element(); 00231 } 00232 00233 /* index subset for training, will be freed below */ 00234 SGVector<index_t> inverse_subset_indices = 00235 m_splitting_strategy->generate_subset_inverse(i); 00236 00237 /* train machine on training features */ 00238 m_machine->train_locked(inverse_subset_indices); 00239 00240 /* feature subset for testing */ 00241 SGVector<index_t> subset_indices = 00242 m_splitting_strategy->generate_subset_indices(i); 00243 00244 /* evtl. update xvalidation output class */ 00245 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element(); 00246 while (current) 00247 { 00248 current->update_train_indices(inverse_subset_indices, "\t"); 00249 current->update_trained_machine(m_machine, "\t"); 00250 SG_UNREF(current); 00251 current=(CCrossValidationOutput*) 00252 m_xval_outputs->get_next_element(); 00253 } 00254 00255 /* produce output for desired indices */ 00256 CLabels* result_labels=m_machine->apply_locked(subset_indices); 00257 SG_REF(result_labels); 00258 00259 /* set subset for testing labels */ 00260 m_labels->add_subset(subset_indices); 00261 00262 /* evaluate against own labels */ 00263 m_evaluation_criterion->set_indices(subset_indices); 00264 results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels); 00265 00266 /* evtl. update xvalidation output class */ 00267 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element(); 00268 while (current) 00269 { 00270 current->update_test_indices(subset_indices, "\t"); 00271 current->update_test_result(result_labels, "\t"); 00272 current->update_test_true_result(m_labels, "\t"); 00273 current->post_update_results(); 00274 current->update_evaluation_result(results[i], "\t"); 00275 SG_UNREF(current); 00276 current=(CCrossValidationOutput*) 00277 m_xval_outputs->get_next_element(); 00278 } 00279 00280 /* remove subset to prevent side effects */ 00281 m_labels->remove_subset(); 00282 00283 /* clean up */ 00284 SG_UNREF(result_labels); 00285 00286 SG_DEBUG("done locked evaluation\n", get_name()) 00287 } 00288 } 00289 else 00290 { 00291 SG_DEBUG("starting unlocked evaluation\n", get_name()) 00292 /* tell machine to store model internally 00293 * (otherwise changing subset of features will kaboom the classifier) */ 00294 m_machine->set_store_model_features(true); 00295 00296 /* do actual cross-validation */ 00297 for (index_t i=0; i <num_subsets; ++i) 00298 { 00299 /* evtl. update xvalidation output class */ 00300 CCrossValidationOutput* current=(CCrossValidationOutput*) 00301 m_xval_outputs->get_first_element(); 00302 while (current) 00303 { 00304 current->update_fold_index(i); 00305 SG_UNREF(current); 00306 current=(CCrossValidationOutput*) 00307 m_xval_outputs->get_next_element(); 00308 } 00309 00310 /* set feature subset for training */ 00311 SGVector<index_t> inverse_subset_indices= 00312 m_splitting_strategy->generate_subset_inverse(i); 00313 m_features->add_subset(inverse_subset_indices); 00314 for (index_t p=0; p<m_features->get_num_preprocessors(); p++) 00315 { 00316 CPreprocessor* preprocessor = m_features->get_preprocessor(p); 00317 preprocessor->init(m_features); 00318 SG_UNREF(preprocessor); 00319 } 00320 00321 /* set label subset for training */ 00322 m_labels->add_subset(inverse_subset_indices); 00323 00324 SG_DEBUG("training set %d:\n", i) 00325 if (io->get_loglevel()==MSG_DEBUG) 00326 { 00327 SGVector<index_t>::display_vector(inverse_subset_indices.vector, 00328 inverse_subset_indices.vlen, "training indices"); 00329 } 00330 00331 /* train machine on training features and remove subset */ 00332 SG_DEBUG("starting training\n") 00333 m_machine->train(m_features); 00334 SG_DEBUG("finished training\n") 00335 00336 /* evtl. update xvalidation output class */ 00337 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element(); 00338 while (current) 00339 { 00340 current->update_train_indices(inverse_subset_indices, "\t"); 00341 current->update_trained_machine(m_machine, "\t"); 00342 SG_UNREF(current); 00343 current=(CCrossValidationOutput*) 00344 m_xval_outputs->get_next_element(); 00345 } 00346 00347 m_features->remove_subset(); 00348 m_labels->remove_subset(); 00349 00350 /* set feature subset for testing (subset method that stores pointer) */ 00351 SGVector<index_t> subset_indices = 00352 m_splitting_strategy->generate_subset_indices(i); 00353 m_features->add_subset(subset_indices); 00354 00355 /* set label subset for testing */ 00356 m_labels->add_subset(subset_indices); 00357 00358 SG_DEBUG("test set %d:\n", i) 00359 if (io->get_loglevel()==MSG_DEBUG) 00360 { 00361 SGVector<index_t>::display_vector(subset_indices.vector, 00362 subset_indices.vlen, "test indices"); 00363 } 00364 00365 /* apply machine to test features and remove subset */ 00366 SG_DEBUG("starting evaluation\n") 00367 SG_DEBUG("%p\n", m_features) 00368 CLabels* result_labels=m_machine->apply(m_features); 00369 SG_DEBUG("finished evaluation\n") 00370 m_features->remove_subset(); 00371 SG_REF(result_labels); 00372 00373 /* evaluate */ 00374 results[i]=m_evaluation_criterion->evaluate(result_labels, m_labels); 00375 SG_DEBUG("result on fold %d is %f\n", i, results[i]) 00376 00377 /* evtl. update xvalidation output class */ 00378 current=(CCrossValidationOutput*)m_xval_outputs->get_first_element(); 00379 while (current) 00380 { 00381 current->update_test_indices(subset_indices, "\t"); 00382 current->update_test_result(result_labels, "\t"); 00383 current->update_test_true_result(m_labels, "\t"); 00384 current->post_update_results(); 00385 current->update_evaluation_result(results[i], "\t"); 00386 SG_UNREF(current); 00387 current=(CCrossValidationOutput*) 00388 m_xval_outputs->get_next_element(); 00389 } 00390 00391 /* clean up, remove subsets */ 00392 SG_UNREF(result_labels); 00393 m_labels->remove_subset(); 00394 } 00395 00396 SG_DEBUG("done unlocked evaluation\n", get_name()) 00397 } 00398 00399 /* build arithmetic mean of results */ 00400 float64_t mean=CStatistics::mean(results); 00401 00402 SG_DEBUG("leaving %s::evaluate_one_run()\n", get_name()) 00403 return mean; 00404 } 00405 00406 void CCrossValidation::add_cross_validation_output( 00407 CCrossValidationOutput* cross_validation_output) 00408 { 00409 m_xval_outputs->append_element(cross_validation_output); 00410 }