SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Sergey Lisitsyn 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #ifndef CONTINGENCYTABLEEVALUATION_H_ 00012 #define CONTINGENCYTABLEEVALUATION_H_ 00013 00014 #include <shogun/evaluation/BinaryClassEvaluation.h> 00015 #include <shogun/labels/Labels.h> 00016 #include <shogun/mathematics/Math.h> 00017 #include <shogun/io/SGIO.h> 00018 00019 namespace shogun 00020 { 00021 00022 class CLabels; 00023 00025 enum EContingencyTableMeasureType 00026 { 00027 ACCURACY = 0, 00028 ERROR_RATE = 10, 00029 BAL = 20, 00030 WRACC = 30, 00031 F1 = 40, 00032 CROSS_CORRELATION = 50, 00033 RECALL = 60, 00034 PRECISION = 70, 00035 SPECIFICITY = 80, 00036 CUSTOM = 999 00037 }; 00038 00070 class CContingencyTableEvaluation: public CBinaryClassEvaluation 00071 { 00072 00073 public: 00074 00076 CContingencyTableEvaluation() : 00077 CBinaryClassEvaluation(), m_type(ACCURACY), m_computed(false) {}; 00078 00082 CContingencyTableEvaluation(EContingencyTableMeasureType type) : 00083 CBinaryClassEvaluation(), m_type(type), m_computed(false) {}; 00084 00086 virtual ~CContingencyTableEvaluation() {}; 00087 00093 virtual float64_t evaluate(CLabels* predicted, CLabels* ground_truth); 00094 00095 virtual EEvaluationDirection get_evaluation_direction() const; 00096 00098 virtual const char* get_name() const 00099 { 00100 return "ContingencyTableEvaluation"; 00101 } 00102 00106 inline float64_t get_accuracy() const 00107 { 00108 if (!m_computed) 00109 SG_ERROR("Uninitialized, please call evaluate first") 00110 00111 return (m_TP+m_TN)/m_N; 00112 }; 00113 00117 inline float64_t get_error_rate() const 00118 { 00119 if (!m_computed) 00120 SG_ERROR("Uninitialized, please call evaluate first") 00121 00122 return (m_FP + m_FN)/m_N; 00123 }; 00124 00128 inline float64_t get_BAL() const 00129 { 00130 if (!m_computed) 00131 SG_ERROR("Uninitialized, please call evaluate first") 00132 00133 return 0.5*(m_FN/(m_FN + m_TP) + m_FP/(m_FP + m_TN)); 00134 }; 00135 00139 inline float64_t get_WRACC() const 00140 { 00141 if (!m_computed) 00142 SG_ERROR("Uninitialized, please call evaluate first") 00143 00144 return m_TP/(m_FN + m_TP) - m_FP/(m_FP + m_TN); 00145 }; 00146 00150 inline float64_t get_F1() const 00151 { 00152 if (!m_computed) 00153 SG_ERROR("Uninitialized, please call evaluate first") 00154 00155 return (2*m_TP)/(2*m_TP + m_FP + m_FN); 00156 }; 00157 00161 inline float64_t get_cross_correlation() const 00162 { 00163 if (!m_computed) 00164 SG_ERROR("Uninitialized, please call evaluate first") 00165 00166 return (m_TP*m_TN-m_FP*m_FN)/CMath::sqrt((m_TP+m_FP)*(m_TP+m_FN)*(m_TN+m_FP)*(m_TN+m_FN)); 00167 }; 00168 00172 inline float64_t get_recall() const 00173 { 00174 if (!m_computed) 00175 SG_ERROR("Uninitialized, please call evaluate first") 00176 00177 return m_TP/(m_TP+m_FN); 00178 }; 00179 00183 inline float64_t get_precision() const 00184 { 00185 if (!m_computed) 00186 SG_ERROR("Uninitialized, please call evaluate first") 00187 00188 return m_TP/(m_TP+m_FP); 00189 }; 00190 00194 inline float64_t get_specificity() const 00195 { 00196 if (!m_computed) 00197 SG_ERROR("Uninitialized, please call evaluate first") 00198 00199 return m_TN/(m_TN+m_FP); 00200 }; 00201 00205 float64_t get_TP() const 00206 { 00207 return m_TP; 00208 } 00212 float64_t get_FP() const 00213 { 00214 return m_FP; 00215 } 00219 float64_t get_TN() const 00220 { 00221 return m_TN; 00222 } 00226 float64_t get_FN() const 00227 { 00228 return m_FN; 00229 } 00230 00234 virtual float64_t get_custom_score() 00235 { 00236 SG_NOTIMPLEMENTED 00237 return 0.0; 00238 } 00239 00243 virtual EEvaluationDirection get_custom_direction() const 00244 { 00245 SG_NOTIMPLEMENTED 00246 return ED_MAXIMIZE; 00247 } 00248 00249 protected: 00250 00252 void compute_scores(CBinaryLabels* predicted, CBinaryLabels* ground_truth); 00253 00255 EContingencyTableMeasureType m_type; 00256 00258 bool m_computed; 00259 00261 int32_t m_N; 00262 00264 float64_t m_TP; 00265 00267 float64_t m_FP; 00268 00270 float64_t m_TN; 00271 00273 float64_t m_FN; 00274 }; 00275 00285 class CAccuracyMeasure: public CContingencyTableEvaluation 00286 { 00287 public: 00288 /* constructor */ 00289 CAccuracyMeasure() : CContingencyTableEvaluation(ACCURACY) {}; 00290 /* virtual destructor */ 00291 virtual ~CAccuracyMeasure() {}; 00292 /* name */ 00293 virtual const char* get_name() const { return "AccuracyMeasure"; }; 00294 }; 00295 00305 class CErrorRateMeasure: public CContingencyTableEvaluation 00306 { 00307 public: 00308 /* constructor */ 00309 CErrorRateMeasure() : CContingencyTableEvaluation(ERROR_RATE) {}; 00310 /* virtual destructor */ 00311 virtual ~CErrorRateMeasure() {}; 00312 /* name */ 00313 virtual const char* get_name() const { return "ErrorRateMeasure"; }; 00314 }; 00315 00325 class CBALMeasure: public CContingencyTableEvaluation 00326 { 00327 public: 00328 /* constructor */ 00329 CBALMeasure() : CContingencyTableEvaluation(BAL) {}; 00330 /* virtual destructor */ 00331 virtual ~CBALMeasure() {}; 00332 /* name */ 00333 virtual const char* get_name() const { return "BALMeasure"; }; 00334 }; 00335 00345 class CWRACCMeasure: public CContingencyTableEvaluation 00346 { 00347 public: 00348 /* constructor */ 00349 CWRACCMeasure() : CContingencyTableEvaluation(WRACC) {}; 00350 /* virtual destructor */ 00351 virtual ~CWRACCMeasure() {}; 00352 /* name */ 00353 virtual const char* get_name() const { return "WRACCMeasure"; }; 00354 }; 00355 00365 class CF1Measure: public CContingencyTableEvaluation 00366 { 00367 public: 00368 /* constructor */ 00369 CF1Measure() : CContingencyTableEvaluation(F1) {}; 00370 /* virtual destructor */ 00371 virtual ~CF1Measure() {}; 00372 /* name */ 00373 virtual const char* get_name() const { return "F1Measure"; }; 00374 }; 00375 00385 class CCrossCorrelationMeasure: public CContingencyTableEvaluation 00386 { 00387 public: 00388 /* constructor */ 00389 CCrossCorrelationMeasure() : CContingencyTableEvaluation(CROSS_CORRELATION) {}; 00390 /* virtual destructor */ 00391 virtual ~CCrossCorrelationMeasure() {}; 00392 /* name */ 00393 virtual const char* get_name() const { return "CrossCorrelationMeasure"; }; 00394 }; 00395 00405 class CRecallMeasure: public CContingencyTableEvaluation 00406 { 00407 public: 00408 /* constructor */ 00409 CRecallMeasure() : CContingencyTableEvaluation(RECALL) {}; 00410 /* virtual destructor */ 00411 virtual ~CRecallMeasure() {}; 00412 /* name */ 00413 virtual const char* get_name() const { return "RecallMeasure"; }; 00414 }; 00415 00425 class CPrecisionMeasure: public CContingencyTableEvaluation 00426 { 00427 public: 00428 /* constructor */ 00429 CPrecisionMeasure() : CContingencyTableEvaluation(PRECISION) {}; 00430 /* virtual destructor */ 00431 virtual ~CPrecisionMeasure() {}; 00432 /* name */ 00433 virtual const char* get_name() const { return "PrecisionMeasure"; }; 00434 }; 00435 00445 class CSpecificityMeasure: public CContingencyTableEvaluation 00446 { 00447 public: 00448 /* constructor */ 00449 CSpecificityMeasure() : CContingencyTableEvaluation(SPECIFICITY) {}; 00450 /* virtual destructor */ 00451 virtual ~CSpecificityMeasure() {}; 00452 /* name */ 00453 virtual const char* get_name() const { return "SpecificityMeasure"; }; 00454 }; 00455 } 00456 #endif /* CONTINGENCYTABLEEVALUATION_H_ */