SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
CrossValidation.h
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011-2012 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #ifndef __CROSSVALIDATION_H_
00012 #define __CROSSVALIDATION_H_
00013 
00014 #include <shogun/evaluation/EvaluationResult.h>
00015 #include <shogun/evaluation/MachineEvaluation.h>
00016 
00017 namespace shogun
00018 {
00019 
00020 class CMachineEvaluation;
00021 class CCrossValidationOutput;
00022 class CList;
00023 
00029 class CCrossValidationResult : public CEvaluationResult
00030 {
00031     public:
00032         CCrossValidationResult()
00033         {
00034             SG_ADD(&mean, "mean", "Mean of results", MS_NOT_AVAILABLE);
00035             SG_ADD(&has_conf_int, "has_conf_int", "Has confidence intervals?",
00036                     MS_NOT_AVAILABLE);
00037             SG_ADD(&conf_int_low, "conf_int_low", "Lower confidence bound",
00038                     MS_NOT_AVAILABLE);
00039             SG_ADD(&conf_int_up, "conf_int_up", "Upper confidence bound",
00040                         MS_NOT_AVAILABLE);
00041 
00042             SG_ADD(&conf_int_alpha, "conf_int_alpha",
00043                     "Alpha of confidence interval", MS_NOT_AVAILABLE);
00044 
00045             mean = 0;
00046             has_conf_int = 0;
00047             conf_int_low = 0;
00048             conf_int_up = 0;
00049             conf_int_alpha = 0;
00050         }
00051 
00056         virtual EEvaluationResultType get_result_type() const
00057         {
00058             return CROSSVALIDATION_RESULT;
00059         }
00060 
00066         virtual const char* get_name() const { return "CrossValidationResult"; }
00067 
00072         static CCrossValidationResult* obtain_from_generic(
00073                 CEvaluationResult* eval_result)
00074         {
00075             if (!eval_result)
00076                 return NULL;
00077 
00078             REQUIRE(eval_result->get_result_type()==CROSSVALIDATION_RESULT,
00079                     "CrossValidationResult::obtain_from_generic(): argument is"
00080                     "of wrong type!\n");
00081 
00082             SG_REF(eval_result);
00083             return (CCrossValidationResult*) eval_result;
00084         }
00085 
00087         virtual void print_result()
00088         {
00089             if (has_conf_int)
00090             {
00091                 SG_SPRINT("[%f,%f] with alpha=%f, mean=%f\n", conf_int_low,
00092                         conf_int_up, conf_int_alpha, mean);
00093             }
00094             else
00095                 SG_SPRINT("%f\n", mean)
00096         }
00097 
00098     public:
00100         float64_t mean;
00102         bool has_conf_int;
00104         float64_t conf_int_low;
00106         float64_t conf_int_up;
00108         float64_t conf_int_alpha;
00109 
00110 };
00111 
00137 class CCrossValidation: public CMachineEvaluation
00138 {
00139 public:
00141     CCrossValidation();
00142 
00151     CCrossValidation(CMachine* machine, CFeatures* features, CLabels* labels,
00152             CSplittingStrategy* splitting_strategy,
00153             CEvaluation* evaluation_criterion, bool autolock=true);
00154 
00162     CCrossValidation(CMachine* machine, CLabels* labels,
00163             CSplittingStrategy* splitting_strategy,
00164             CEvaluation* evaluation_criterion, bool autolock=true);
00165 
00167     virtual ~CCrossValidation();
00168 
00170     void set_num_runs(int32_t num_runs);
00171 
00173     void set_conf_int_alpha(float64_t m_conf_int_alpha);
00174 
00176     virtual CEvaluationResult* evaluate();
00177 
00183     void add_cross_validation_output(
00184             CCrossValidationOutput* cross_validation_output);
00185 
00187     virtual const char* get_name() const
00188     {
00189         return "CrossValidation";
00190     }
00191 
00192 private:
00193     void init();
00194 
00195 protected:
00204     virtual float64_t evaluate_one_run();
00205 
00207     int32_t m_num_runs;
00209     float64_t m_conf_int_alpha;
00210 
00212     CList* m_xval_outputs;
00213 };
00214 
00215 }
00216 
00217 #endif /* __CROSSVALIDATION_H_ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation