SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Heiko Strathmann, Sergey Lisitsyn 00008 * 00009 */ 00010 00011 #ifndef CROSSVALIDATIONMULTICLASSSTORAGE_H_ 00012 #define CROSSVALIDATIONMULTICLASSSTORAGE_H_ 00013 00014 #include <shogun/evaluation/CrossValidationOutput.h> 00015 #include <shogun/evaluation/BinaryClassEvaluation.h> 00016 #include <shogun/labels/MulticlassLabels.h> 00017 #include <shogun/lib/SGMatrix.h> 00018 #include <shogun/lib/DynamicObjectArray.h> 00019 00020 namespace shogun 00021 { 00022 00023 class CMachine; 00024 class CLabels; 00025 class CEvaluation; 00026 00031 class CCrossValidationMulticlassStorage: public CCrossValidationOutput 00032 { 00033 public: 00034 00040 CCrossValidationMulticlassStorage(bool compute_ROC=true, bool compute_PRC=false, bool compute_conf_matrices=false); 00041 00043 virtual ~CCrossValidationMulticlassStorage(); 00044 00052 SGMatrix<float64_t> get_fold_ROC(int32_t run, int32_t fold, int32_t c) 00053 { 00054 ASSERT(0<=run) 00055 ASSERT(run<m_num_runs) 00056 ASSERT(0<=fold) 00057 ASSERT(fold<m_num_folds) 00058 ASSERT(0<=c) 00059 ASSERT(c<m_num_classes) 00060 REQUIRE(m_compute_ROC, "ROC computation was not enabled\n") 00061 return m_fold_ROC_graphs[run*m_num_folds*m_num_classes+fold*m_num_classes+c]; 00062 } 00063 00071 SGMatrix<float64_t> get_fold_PRC(int32_t run, int32_t fold, int32_t c) 00072 { 00073 ASSERT(0<=run) 00074 ASSERT(run<m_num_runs) 00075 ASSERT(0<=fold) 00076 ASSERT(fold<m_num_folds) 00077 ASSERT(0<=c) 00078 ASSERT(c<m_num_classes) 00079 REQUIRE(m_compute_PRC, "PRC computation was not enabled\n") 00080 return m_fold_PRC_graphs[run*m_num_folds*m_num_classes+fold*m_num_classes+c]; 00081 } 00082 00087 void append_binary_evaluation(CBinaryClassEvaluation* evaluation) 00088 { 00089 m_binary_evaluations->push_back(evaluation); 00090 } 00091 00096 CBinaryClassEvaluation* get_binary_evaluation(int32_t idx) 00097 { 00098 return (CBinaryClassEvaluation*)m_binary_evaluations->get_element_safe(idx); 00099 } 00100 00108 float64_t get_fold_evaluation_result(int32_t run, int32_t fold, int32_t c, int32_t e) 00109 { 00110 ASSERT(0<=run) 00111 ASSERT(run<m_num_runs) 00112 ASSERT(0<=fold) 00113 ASSERT(fold<m_num_folds) 00114 ASSERT(0<=c) 00115 ASSERT(c<m_num_classes) 00116 ASSERT(0<=e) 00117 int32_t n_evals = m_binary_evaluations->get_num_elements(); 00118 ASSERT(e<n_evals) 00119 return m_evaluations_results[run*m_num_folds*m_num_classes*n_evals+fold*m_num_classes*n_evals+c*n_evals+e]; 00120 } 00121 00126 float64_t get_fold_accuracy(int32_t run, int32_t fold) 00127 { 00128 ASSERT(0<=run) 00129 ASSERT(run<m_num_runs) 00130 ASSERT(0<=fold) 00131 ASSERT(fold<m_num_folds) 00132 return m_accuracies[run*m_num_folds+fold]; 00133 } 00134 00139 SGMatrix<int32_t> get_fold_conf_matrix(int32_t run, int32_t fold) 00140 { 00141 ASSERT(0<=run) 00142 ASSERT(run<m_num_runs) 00143 ASSERT(0<=fold) 00144 ASSERT(fold<m_num_folds) 00145 REQUIRE(m_compute_conf_matrices, "Confusion matrices computation was not enabled\n") 00146 return m_conf_matrices[run*m_num_folds+fold]; 00147 } 00148 00150 virtual void post_init(); 00151 00153 virtual void post_update_results(); 00154 00158 virtual void init_expose_labels(CLabels* labels); 00159 00165 virtual void update_test_result(CLabels* results, 00166 const char* prefix=""); 00167 00173 virtual void update_test_true_result(CLabels* results, 00174 const char* prefix=""); 00175 00177 virtual const char* get_name() const { return "CrossValidationMulticlassStorage"; } 00178 00179 protected: 00180 00182 bool m_initialized; 00183 00185 CDynamicObjectArray* m_binary_evaluations; 00186 00188 SGVector<float64_t> m_evaluations_results; 00189 00191 SGVector<float64_t> m_accuracies; 00192 00194 bool m_compute_ROC; 00195 00197 SGMatrix<float64_t>* m_fold_ROC_graphs; 00198 00200 bool m_compute_PRC; 00201 00203 SGMatrix<float64_t>* m_fold_PRC_graphs; 00204 00206 bool m_compute_conf_matrices; 00207 00209 SGMatrix<int32_t>* m_conf_matrices; 00210 00212 CMulticlassLabels* m_pred_labels; 00213 00215 CMulticlassLabels* m_true_labels; 00216 00218 int32_t m_num_classes; 00219 00220 }; 00221 00222 } 00223 00224 #endif /* CROSSVALIDATIONMULTICLASSSTORAGE_H_ */