SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
ROCEvaluation.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Sergey Lisitsyn
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/mathematics/Math.h>
00013 
00014 using namespace shogun;
00015 
00016 CROCEvaluation::~CROCEvaluation()
00017 {
00018 }
00019 
00020 float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00021 {
00022     return evaluate_roc(predicted,ground_truth);
00023 }
00024 
00025 float64_t CROCEvaluation::evaluate_roc(CLabels* predicted, CLabels* ground_truth)
00026 {
00027     ASSERT(predicted && ground_truth)
00028     ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels())
00029     ASSERT(predicted->get_label_type()==LT_BINARY)
00030     ASSERT(ground_truth->get_label_type()==LT_BINARY)
00031     ground_truth->ensure_valid();
00032 
00033     // assume threshold as negative infinity
00034     float64_t threshold = CMath::ALMOST_NEG_INFTY;
00035     // false positive rate
00036     float64_t fp = 0.0;
00037     // true positive rate
00038     float64_t tp=0.0;
00039 
00040     int32_t i;
00041     // total number of positive labels in predicted
00042     int32_t pos_count=0;
00043     int32_t neg_count=0;
00044 
00045     // initialize number of labels and labels
00046     SGVector<float64_t> orig_labels(predicted->get_num_labels());
00047     int32_t length = orig_labels.vlen;
00048     for (i=0; i<length; i++)
00049         orig_labels[i] = predicted->get_value(i);
00050     float64_t* labels = SGVector<float64_t>::clone_vector(orig_labels.vector, length);
00051 
00052     // get sorted indexes
00053     SGVector<int32_t> idxs(length);
00054     for(i=0; i<length; i++)
00055         idxs[i] = i;
00056 
00057     CMath::qsort_backward_index(labels,idxs.vector,idxs.vlen);
00058 
00059     // number of different predicted labels
00060     int32_t diff_count=1;
00061 
00062     // get number of different labels
00063     for (i=0; i<length-1; i++)
00064     {
00065         if (labels[i] != labels[i+1])
00066             diff_count++;
00067     }
00068 
00069     SG_FREE(labels);
00070 
00071     // initialize graph and auROC
00072     m_ROC_graph = SGMatrix<float64_t>(2,diff_count+1);
00073     m_thresholds = SGVector<float64_t>(length);
00074     m_auROC = 0.0;
00075 
00076     // get total numbers of positive and negative labels
00077     for(i=0; i<length; i++)
00078     {
00079         if (ground_truth->get_value(i) >= 0)
00080             pos_count++;
00081         else
00082             neg_count++;
00083     }
00084 
00085     // assure both number of positive and negative examples is >0
00086     REQUIRE(pos_count>0, "%s::evaluate_roc(): Number of positive labels is "
00087             "zero, ROC fails!\n", get_name());
00088     REQUIRE(neg_count>0, "%s::evaluate_roc(): Number of negative labels is "
00089             "zero, ROC fails!\n", get_name());
00090 
00091     int32_t j = 0;
00092     float64_t label;
00093 
00094     // create ROC curve and calculate auROC
00095     for(i=0; i<length; i++)
00096     {
00097         label = predicted->get_value(idxs[i]);
00098 
00099         if (label != threshold)
00100         {
00101             threshold = label;
00102             m_ROC_graph[2*j] = fp/neg_count;
00103             m_ROC_graph[2*j+1] = tp/pos_count;
00104             j++;
00105         }
00106 
00107         m_thresholds[i]=threshold;
00108 
00109         if (ground_truth->get_value(idxs[i]) > 0)
00110             tp+=1.0;
00111         else
00112             fp+=1.0;
00113     }
00114 
00115     // add (1,1) to ROC curve
00116     m_ROC_graph[2*diff_count] = 1.0;
00117     m_ROC_graph[2*diff_count+1] = 1.0;
00118 
00119     // calc auROC using area under curve
00120     m_auROC = CMath::area_under_curve(m_ROC_graph.matrix,diff_count+1,false);
00121 
00122     m_computed = true;
00123 
00124     return m_auROC;
00125 }
00126 
00127 SGMatrix<float64_t> CROCEvaluation::get_ROC()
00128 {
00129     if (!m_computed)
00130         SG_ERROR("Uninitialized, please call evaluate first")
00131 
00132     return m_ROC_graph;
00133 }
00134 
00135 SGVector<float64_t> CROCEvaluation::get_thresholds()
00136 {
00137     if (!m_computed)
00138         SG_ERROR("Uninitialized, please call evaluate first")
00139 
00140     return m_thresholds;
00141 }
00142 
00143 float64_t CROCEvaluation::get_auROC()
00144 {
00145     if (!m_computed)
00146             SG_ERROR("Uninitialized, please call evaluate first")
00147 
00148     return m_auROC;
00149 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation