SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Written (W) 2012 Heiko Strathmann 00009 * 00010 */ 00011 00012 #ifndef __CROSSVALIDATIONOUTPUT_H_ 00013 #define __CROSSVALIDATIONOUTPUT_H_ 00014 00015 #include <shogun/base/SGObject.h> 00016 #include <shogun/lib/SGVector.h> 00017 00018 namespace shogun 00019 { 00020 00021 class CMachine; 00022 class CLabels; 00023 class CEvaluation; 00024 00041 class CCrossValidationOutput: public CSGObject 00042 { 00043 public: 00044 00046 CCrossValidationOutput() : CSGObject() 00047 { 00048 m_current_run_index=0; 00049 m_current_fold_index=0; 00050 m_num_runs=0; 00051 m_num_folds=0; 00052 } 00053 00055 virtual ~CCrossValidationOutput() {} 00056 00058 virtual const char* get_name() const=0; 00059 00065 virtual void init_num_runs(index_t num_runs, const char* prefix="") 00066 { 00067 m_num_runs=num_runs; 00068 } 00069 00074 virtual void init_num_folds(index_t num_folds, const char* prefix="") 00075 { 00076 m_num_folds=num_folds; 00077 } 00078 00082 virtual void init_expose_labels(CLabels* labels) { } 00083 00085 virtual void post_init() { } 00086 00092 virtual void update_run_index(index_t run_index, 00093 const char* prefix="") 00094 { 00095 m_current_run_index=run_index; 00096 } 00097 00103 virtual void update_fold_index(index_t fold_index, 00104 const char* prefix="") 00105 { 00106 m_current_fold_index=fold_index; 00107 } 00108 00114 virtual void update_train_indices(SGVector<index_t> indices, 00115 const char* prefix="") {} 00116 00122 virtual void update_test_indices(SGVector<index_t> indices, 00123 const char* prefix="") {} 00124 00130 virtual void update_trained_machine(CMachine* machine, 00131 const char* prefix="") {} 00132 00138 virtual void update_test_result(CLabels* results, 00139 const char* prefix="") {} 00140 00146 virtual void update_test_true_result(CLabels* results, 00147 const char* prefix="") {} 00148 00151 virtual void post_update_results() {} 00152 00158 virtual void update_evaluation_result(float64_t result, 00159 const char* prefix="") {} 00160 00161 protected: 00163 index_t m_current_run_index; 00164 00166 index_t m_current_fold_index; 00167 00169 index_t m_num_runs; 00170 00172 index_t m_num_folds; 00173 }; 00174 00175 } 00176 00177 #endif /* __CROSSVALIDATIONOUTPUT_H_ */