SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <set> 00012 #include <map> 00013 #include <vector> 00014 #include <algorithm> 00015 00016 #include <shogun/evaluation/ClusteringEvaluation.h> 00017 #include <shogun/labels/MulticlassLabels.h> 00018 #include <shogun/mathematics/munkres.h> 00019 00020 using namespace shogun; 00021 using namespace std; 00022 00023 int32_t CClusteringEvaluation::find_match_count(SGVector<int32_t> l1, int32_t m1, SGVector<int32_t> l2, int32_t m2) 00024 { 00025 int32_t match_count=0; 00026 for (int32_t i=l1.vlen-1; i >= 0; --i) 00027 { 00028 if (l1[i] == m1 && l2[i] == m2) 00029 match_count++; 00030 } 00031 00032 return match_count; 00033 } 00034 00035 int32_t CClusteringEvaluation::find_mismatch_count(SGVector<int32_t> l1, int32_t m1, SGVector<int32_t> l2, int32_t m2) 00036 { 00037 return l1.vlen - find_match_count(l1, m1, l2, m2); 00038 } 00039 00040 void CClusteringEvaluation::best_map(CLabels* predicted, CLabels* ground_truth) 00041 { 00042 ASSERT(predicted->get_num_labels() == ground_truth->get_num_labels()) 00043 ASSERT(predicted->get_label_type() == LT_MULTICLASS) 00044 ASSERT(ground_truth->get_label_type() == LT_MULTICLASS) 00045 00046 SGVector<float64_t> label_p=((CMulticlassLabels*) predicted)->get_unique_labels(); 00047 SGVector<float64_t> label_g=((CMulticlassLabels*) ground_truth)->get_unique_labels(); 00048 00049 SGVector<int32_t> predicted_ilabels=((CMulticlassLabels*) predicted)->get_int_labels(); 00050 SGVector<int32_t> groundtruth_ilabels=((CMulticlassLabels*) ground_truth)->get_int_labels(); 00051 00052 int32_t n_class=max(label_p.vlen, label_g.vlen); 00053 SGMatrix<float64_t> G(n_class, n_class); 00054 G.zero(); 00055 00056 for (int32_t i=0; i < label_g.vlen; ++i) 00057 { 00058 for (int32_t j=0; j < label_p.vlen; ++j) 00059 { 00060 G(i, j)=find_mismatch_count(groundtruth_ilabels, static_cast<int32_t>(label_g[i]), 00061 predicted_ilabels, static_cast<int32_t>(label_p[j])); 00062 } 00063 } 00064 00065 Munkres munkres_solver(G); 00066 munkres_solver.solve(); 00067 00068 std::map<int32_t, int32_t> label_map; 00069 for (int32_t i=0; i < label_p.vlen; ++i) 00070 { 00071 for (int32_t j=0; j < label_g.vlen; ++j) 00072 { 00073 if (G(j, i) == 0) 00074 { 00075 label_map.insert(make_pair(static_cast<int32_t>(label_p[i]), 00076 static_cast<int32_t>(label_g[j]))); 00077 break; 00078 } 00079 } 00080 } 00081 00082 for (int32_t i= 0; i < predicted_ilabels.vlen; ++i) 00083 ((CMulticlassLabels*) predicted)->set_int_label(i, label_map[predicted_ilabels[i]]); 00084 }