SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Sergey Lisitsyn 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/ROCEvaluation.h> 00012 #include <shogun/mathematics/Math.h> 00013 00014 using namespace shogun; 00015 00016 CROCEvaluation::~CROCEvaluation() 00017 { 00018 } 00019 00020 float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth) 00021 { 00022 return evaluate_roc(predicted,ground_truth); 00023 } 00024 00025 float64_t CROCEvaluation::evaluate_roc(CLabels* predicted, CLabels* ground_truth) 00026 { 00027 ASSERT(predicted && ground_truth) 00028 ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels()) 00029 ASSERT(predicted->get_label_type()==LT_BINARY) 00030 ASSERT(ground_truth->get_label_type()==LT_BINARY) 00031 ground_truth->ensure_valid(); 00032 00033 // assume threshold as negative infinity 00034 float64_t threshold = CMath::ALMOST_NEG_INFTY; 00035 // false positive rate 00036 float64_t fp = 0.0; 00037 // true positive rate 00038 float64_t tp=0.0; 00039 00040 int32_t i; 00041 // total number of positive labels in predicted 00042 int32_t pos_count=0; 00043 int32_t neg_count=0; 00044 00045 // initialize number of labels and labels 00046 SGVector<float64_t> orig_labels(predicted->get_num_labels()); 00047 int32_t length = orig_labels.vlen; 00048 for (i=0; i<length; i++) 00049 orig_labels[i] = predicted->get_value(i); 00050 float64_t* labels = SGVector<float64_t>::clone_vector(orig_labels.vector, length); 00051 00052 // get sorted indexes 00053 SGVector<int32_t> idxs(length); 00054 for(i=0; i<length; i++) 00055 idxs[i] = i; 00056 00057 CMath::qsort_backward_index(labels,idxs.vector,idxs.vlen); 00058 00059 // number of different predicted labels 00060 int32_t diff_count=1; 00061 00062 // get number of different labels 00063 for (i=0; i<length-1; i++) 00064 { 00065 if (labels[i] != labels[i+1]) 00066 diff_count++; 00067 } 00068 00069 SG_FREE(labels); 00070 00071 // initialize graph and auROC 00072 m_ROC_graph = SGMatrix<float64_t>(2,diff_count+1); 00073 m_thresholds = SGVector<float64_t>(length); 00074 m_auROC = 0.0; 00075 00076 // get total numbers of positive and negative labels 00077 for(i=0; i<length; i++) 00078 { 00079 if (ground_truth->get_value(i) >= 0) 00080 pos_count++; 00081 else 00082 neg_count++; 00083 } 00084 00085 // assure both number of positive and negative examples is >0 00086 REQUIRE(pos_count>0, "%s::evaluate_roc(): Number of positive labels is " 00087 "zero, ROC fails!\n", get_name()); 00088 REQUIRE(neg_count>0, "%s::evaluate_roc(): Number of negative labels is " 00089 "zero, ROC fails!\n", get_name()); 00090 00091 int32_t j = 0; 00092 float64_t label; 00093 00094 // create ROC curve and calculate auROC 00095 for(i=0; i<length; i++) 00096 { 00097 label = predicted->get_value(idxs[i]); 00098 00099 if (label != threshold) 00100 { 00101 threshold = label; 00102 m_ROC_graph[2*j] = fp/neg_count; 00103 m_ROC_graph[2*j+1] = tp/pos_count; 00104 j++; 00105 } 00106 00107 m_thresholds[i]=threshold; 00108 00109 if (ground_truth->get_value(idxs[i]) > 0) 00110 tp+=1.0; 00111 else 00112 fp+=1.0; 00113 } 00114 00115 // add (1,1) to ROC curve 00116 m_ROC_graph[2*diff_count] = 1.0; 00117 m_ROC_graph[2*diff_count+1] = 1.0; 00118 00119 // calc auROC using area under curve 00120 m_auROC = CMath::area_under_curve(m_ROC_graph.matrix,diff_count+1,false); 00121 00122 m_computed = true; 00123 00124 return m_auROC; 00125 } 00126 00127 SGMatrix<float64_t> CROCEvaluation::get_ROC() 00128 { 00129 if (!m_computed) 00130 SG_ERROR("Uninitialized, please call evaluate first") 00131 00132 return m_ROC_graph; 00133 } 00134 00135 SGVector<float64_t> CROCEvaluation::get_thresholds() 00136 { 00137 if (!m_computed) 00138 SG_ERROR("Uninitialized, please call evaluate first") 00139 00140 return m_thresholds; 00141 } 00142 00143 float64_t CROCEvaluation::get_auROC() 00144 { 00145 if (!m_computed) 00146 SG_ERROR("Uninitialized, please call evaluate first") 00147 00148 return m_auROC; 00149 }