SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Sergey Lisitsyn 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/MulticlassAccuracy.h> 00012 #include <shogun/labels/Labels.h> 00013 #include <shogun/labels/MulticlassLabels.h> 00014 #include <shogun/mathematics/Math.h> 00015 00016 using namespace shogun; 00017 00018 float64_t CMulticlassAccuracy::evaluate(CLabels* predicted, CLabels* ground_truth) 00019 { 00020 ASSERT(predicted && ground_truth) 00021 ASSERT(predicted->get_num_labels() == ground_truth->get_num_labels()) 00022 ASSERT(predicted->get_label_type() == LT_MULTICLASS) 00023 ASSERT(ground_truth->get_label_type() == LT_MULTICLASS) 00024 int32_t length = predicted->get_num_labels(); 00025 int32_t correct = 0; 00026 if (m_ignore_rejects) 00027 { 00028 for (int32_t i=0; i<length; i++) 00029 { 00030 if (((CMulticlassLabels*) predicted)->get_int_label(i)==((CMulticlassLabels*) ground_truth)->get_int_label(i)) 00031 correct++; 00032 } 00033 return ((float64_t)correct)/length; 00034 } 00035 else 00036 { 00037 int32_t total = length; 00038 for (int32_t i=0; i<length; i++) 00039 { 00040 int32_t predicted_label = ((CMulticlassLabels*) predicted)->get_int_label(i); 00041 00042 if (predicted_label==((CMulticlassLabels*) predicted)->REJECTION_LABEL) 00043 total--; 00044 else if (predicted_label==((CMulticlassLabels*) ground_truth)->get_int_label(i)) 00045 correct++; 00046 } 00047 m_rejects_num = length-total; 00048 SG_DEBUG("correct=%d, total=%d, rejected=%d\n",correct,total,length-total) 00049 return ((float64_t)correct)/total; 00050 } 00051 return 0.0; 00052 } 00053 00054 SGMatrix<int32_t> CMulticlassAccuracy::get_confusion_matrix(CLabels* predicted, CLabels* ground_truth) 00055 { 00056 ASSERT(predicted->get_num_labels() == ground_truth->get_num_labels()) 00057 int32_t length = ground_truth->get_num_labels(); 00058 int32_t num_classes = ((CMulticlassLabels*) ground_truth)->get_num_classes(); 00059 SGMatrix<int32_t> confusion_matrix(num_classes, num_classes); 00060 memset(confusion_matrix.matrix,0,sizeof(int32_t)*num_classes*num_classes); 00061 for (int32_t i=0; i<length; i++) 00062 { 00063 int32_t predicted_label = ((CMulticlassLabels*) predicted)->get_int_label(i); 00064 int32_t ground_truth_label = ((CMulticlassLabels*) ground_truth)->get_int_label(i); 00065 00066 if (predicted_label==((CMulticlassLabels*) predicted)->REJECTION_LABEL) 00067 continue; 00068 00069 confusion_matrix[predicted_label*num_classes+ground_truth_label]++; 00070 } 00071 return confusion_matrix; 00072 } 00073