SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
ClusteringMutualInformation.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <shogun/lib/SGVector.h>
00012 #include <shogun/labels/MulticlassLabels.h>
00013 #include <shogun/evaluation/ClusteringMutualInformation.h>
00014 
00015 using namespace shogun;
00016 
00017 float64_t CClusteringMutualInformation::evaluate(CLabels* predicted, CLabels* ground_truth)
00018 {
00019     ASSERT(predicted && ground_truth)
00020     ASSERT(predicted->get_label_type() == LT_MULTICLASS)
00021     ASSERT(ground_truth->get_label_type() == LT_MULTICLASS)
00022     SGVector<float64_t> label_p=((CMulticlassLabels*) predicted)->get_unique_labels();
00023     SGVector<float64_t> label_g=((CMulticlassLabels*) ground_truth)->get_unique_labels();
00024 
00025     if (label_p.vlen != label_g.vlen)
00026         SG_ERROR("Number of classes are different\n")
00027     index_t n_class=label_p.vlen;
00028     float64_t n_label=predicted->get_num_labels();
00029 
00030     SGVector<int32_t> ilabels_p=((CMulticlassLabels*) predicted)->get_int_labels();
00031     SGVector<int32_t> ilabels_g=((CMulticlassLabels*) ground_truth)->get_int_labels();
00032 
00033     SGMatrix<float64_t> G(n_class, n_class);
00034     for (index_t i=0; i < n_class; ++i)
00035     {
00036         for (index_t j=0; j < n_class; ++j)
00037             G(i, j)=find_match_count(ilabels_g, label_g[i],
00038                 ilabels_p, label_p[j])/n_label;
00039     }
00040 
00041     SGVector<float64_t> G_rowsum(n_class);
00042     G_rowsum.zero();
00043     SGVector<float64_t> G_colsum(n_class);
00044     G_colsum.zero();
00045     for (index_t i=0; i < n_class; ++i)
00046     {
00047         for (index_t j=0; j < n_class; ++j)
00048         {
00049             G_rowsum[i] += G(i, j);
00050             G_colsum[i] += G(j, i);
00051         }
00052     }
00053 
00054     float64_t mutual_info = 0;
00055     for (index_t i=0; i < n_class; ++i)
00056     {
00057         for (index_t j=0; j < n_class; ++j)
00058         {
00059             if (G(i, j) != 0)
00060                 mutual_info += G(i, j) * log(G(i,j) /
00061                     (G_rowsum[i]*G_colsum[j]))/log(2.);
00062         }
00063     }
00064 
00065     float64_t entropy_p = 0;
00066     float64_t entropy_g = 0;
00067     for (index_t i=0; i < n_class; ++i)
00068     {
00069         entropy_g += -G_rowsum[i] * log(G_rowsum[i])/log(2.);
00070         entropy_p += -G_colsum[i] * log(G_colsum[i])/log(2.);
00071     }
00072 
00073     return mutual_info / CMath::max(entropy_g, entropy_p);
00074 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation