SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann 00008 */ 00009 00010 #include <shogun/evaluation/CrossValidationMulticlassStorage.h> 00011 #include <shogun/evaluation/ROCEvaluation.h> 00012 #include <shogun/evaluation/PRCEvaluation.h> 00013 #include <shogun/evaluation/MulticlassAccuracy.h> 00014 00015 using namespace shogun; 00016 00017 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) : 00018 CCrossValidationOutput() 00019 { 00020 m_initialized = false; 00021 m_compute_ROC = compute_ROC; 00022 m_compute_PRC = compute_PRC; 00023 m_compute_conf_matrices = compute_conf_matrices; 00024 m_pred_labels = NULL; 00025 m_true_labels = NULL; 00026 m_num_classes = 0; 00027 m_binary_evaluations = new CDynamicObjectArray(); 00028 00029 m_fold_ROC_graphs=NULL; 00030 m_conf_matrices=NULL; 00031 } 00032 00033 00034 CCrossValidationMulticlassStorage::~CCrossValidationMulticlassStorage() 00035 { 00036 if (m_compute_ROC) 00037 { 00038 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) 00039 m_fold_ROC_graphs[i].~SGMatrix<float64_t>(); 00040 00041 SG_FREE(m_fold_ROC_graphs); 00042 } 00043 00044 if (m_compute_PRC) 00045 { 00046 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) 00047 m_fold_PRC_graphs[i].~SGMatrix<float64_t>(); 00048 00049 SG_FREE(m_fold_PRC_graphs); 00050 } 00051 00052 if (m_compute_conf_matrices) 00053 { 00054 for (int32_t i=0; i<m_num_folds*m_num_runs; i++) 00055 m_conf_matrices[i].~SGMatrix<int32_t>(); 00056 00057 SG_FREE(m_conf_matrices); 00058 } 00059 00060 SG_UNREF(m_binary_evaluations); 00061 }; 00062 00063 00064 void CCrossValidationMulticlassStorage::post_init() 00065 { 00066 if (m_initialized) 00067 SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n") 00068 00069 if (m_compute_ROC) 00070 { 00071 SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes) 00072 m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes); 00073 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) 00074 new (&m_fold_ROC_graphs[i]) SGMatrix<float64_t>(); 00075 } 00076 00077 if (m_compute_PRC) 00078 { 00079 SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes) 00080 m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes); 00081 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) 00082 new (&m_fold_PRC_graphs[i]) SGMatrix<float64_t>(); 00083 } 00084 00085 if (m_binary_evaluations->get_num_elements()) 00086 m_evaluations_results = SGVector<float64_t>(m_num_folds*m_num_runs*m_num_classes*m_binary_evaluations->get_num_elements()); 00087 00088 m_accuracies = SGVector<float64_t>(m_num_folds*m_num_runs); 00089 00090 if (m_compute_conf_matrices) 00091 { 00092 m_conf_matrices = SG_MALLOC(SGMatrix<int32_t>, m_num_folds*m_num_runs); 00093 for (int32_t i=0; i<m_num_folds*m_num_runs; i++) 00094 new (&m_conf_matrices[i]) SGMatrix<int32_t>(); 00095 } 00096 00097 m_initialized = true; 00098 } 00099 00100 void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels) 00101 { 00102 ASSERT((CMulticlassLabels*)labels) 00103 m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes(); 00104 } 00105 00106 void CCrossValidationMulticlassStorage::post_update_results() 00107 { 00108 CROCEvaluation eval_ROC; 00109 CPRCEvaluation eval_PRC; 00110 int32_t n_evals = m_binary_evaluations->get_num_elements(); 00111 for (int32_t c=0; c<m_num_classes; c++) 00112 { 00113 SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c) 00114 CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c); 00115 CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c); 00116 if (m_compute_ROC) 00117 { 00118 eval_ROC.evaluate(pred_labels_binary, true_labels_binary); 00119 m_fold_ROC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = 00120 eval_ROC.get_ROC(); 00121 } 00122 if (m_compute_PRC) 00123 { 00124 eval_PRC.evaluate(pred_labels_binary, true_labels_binary); 00125 m_fold_PRC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = 00126 eval_PRC.get_PRC(); 00127 } 00128 00129 for (int32_t i=0; i<n_evals; i++) 00130 { 00131 CBinaryClassEvaluation* evaluator = (CBinaryClassEvaluation*)m_binary_evaluations->get_element_safe(i); 00132 m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] = 00133 evaluator->evaluate(pred_labels_binary, true_labels_binary); 00134 SG_UNREF(evaluator); 00135 } 00136 00137 SG_UNREF(pred_labels_binary); 00138 SG_UNREF(true_labels_binary); 00139 } 00140 CMulticlassAccuracy accuracy; 00141 00142 m_accuracies[m_current_run_index*m_num_folds+m_current_fold_index] = accuracy.evaluate(m_pred_labels, m_true_labels); 00143 00144 if (m_compute_conf_matrices) 00145 { 00146 m_conf_matrices[m_current_run_index*m_num_folds+m_current_fold_index] = CMulticlassAccuracy::get_confusion_matrix(m_pred_labels, m_true_labels); 00147 } 00148 } 00149 00150 void CCrossValidationMulticlassStorage::update_test_result(CLabels* results, const char* prefix) 00151 { 00152 m_pred_labels = (CMulticlassLabels*)results; 00153 } 00154 00155 void CCrossValidationMulticlassStorage::update_test_true_result(CLabels* results, const char* prefix) 00156 { 00157 m_true_labels = (CMulticlassLabels*)results; 00158 } 00159