SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011-2012 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #ifndef __CROSSVALIDATION_H_ 00012 #define __CROSSVALIDATION_H_ 00013 00014 #include <shogun/evaluation/EvaluationResult.h> 00015 #include <shogun/evaluation/MachineEvaluation.h> 00016 00017 namespace shogun 00018 { 00019 00020 class CMachineEvaluation; 00021 class CCrossValidationOutput; 00022 class CList; 00023 00029 class CCrossValidationResult : public CEvaluationResult 00030 { 00031 public: 00032 CCrossValidationResult() 00033 { 00034 SG_ADD(&mean, "mean", "Mean of results", MS_NOT_AVAILABLE); 00035 SG_ADD(&has_conf_int, "has_conf_int", "Has confidence intervals?", 00036 MS_NOT_AVAILABLE); 00037 SG_ADD(&conf_int_low, "conf_int_low", "Lower confidence bound", 00038 MS_NOT_AVAILABLE); 00039 SG_ADD(&conf_int_up, "conf_int_up", "Upper confidence bound", 00040 MS_NOT_AVAILABLE); 00041 00042 SG_ADD(&conf_int_alpha, "conf_int_alpha", 00043 "Alpha of confidence interval", MS_NOT_AVAILABLE); 00044 00045 mean = 0; 00046 has_conf_int = 0; 00047 conf_int_low = 0; 00048 conf_int_up = 0; 00049 conf_int_alpha = 0; 00050 } 00051 00056 virtual EEvaluationResultType get_result_type() const 00057 { 00058 return CROSSVALIDATION_RESULT; 00059 } 00060 00066 virtual const char* get_name() const { return "CrossValidationResult"; } 00067 00072 static CCrossValidationResult* obtain_from_generic( 00073 CEvaluationResult* eval_result) 00074 { 00075 if (!eval_result) 00076 return NULL; 00077 00078 REQUIRE(eval_result->get_result_type()==CROSSVALIDATION_RESULT, 00079 "CrossValidationResult::obtain_from_generic(): argument is" 00080 "of wrong type!\n"); 00081 00082 SG_REF(eval_result); 00083 return (CCrossValidationResult*) eval_result; 00084 } 00085 00087 virtual void print_result() 00088 { 00089 if (has_conf_int) 00090 { 00091 SG_SPRINT("[%f,%f] with alpha=%f, mean=%f\n", conf_int_low, 00092 conf_int_up, conf_int_alpha, mean); 00093 } 00094 else 00095 SG_SPRINT("%f\n", mean) 00096 } 00097 00098 public: 00100 float64_t mean; 00102 bool has_conf_int; 00104 float64_t conf_int_low; 00106 float64_t conf_int_up; 00108 float64_t conf_int_alpha; 00109 00110 }; 00111 00137 class CCrossValidation: public CMachineEvaluation 00138 { 00139 public: 00141 CCrossValidation(); 00142 00151 CCrossValidation(CMachine* machine, CFeatures* features, CLabels* labels, 00152 CSplittingStrategy* splitting_strategy, 00153 CEvaluation* evaluation_criterion, bool autolock=true); 00154 00162 CCrossValidation(CMachine* machine, CLabels* labels, 00163 CSplittingStrategy* splitting_strategy, 00164 CEvaluation* evaluation_criterion, bool autolock=true); 00165 00167 virtual ~CCrossValidation(); 00168 00170 void set_num_runs(int32_t num_runs); 00171 00173 void set_conf_int_alpha(float64_t m_conf_int_alpha); 00174 00176 virtual CEvaluationResult* evaluate(); 00177 00183 void add_cross_validation_output( 00184 CCrossValidationOutput* cross_validation_output); 00185 00187 virtual const char* get_name() const 00188 { 00189 return "CrossValidation"; 00190 } 00191 00192 private: 00193 void init(); 00194 00195 protected: 00204 virtual float64_t evaluate_one_run(); 00205 00207 int32_t m_num_runs; 00209 float64_t m_conf_int_alpha; 00210 00212 CList* m_xval_outputs; 00213 }; 00214 00215 } 00216 00217 #endif /* __CROSSVALIDATION_H_ */