Marsyas  0.6.0-alpha
/usr/src/RPM/BUILD/marsyas-0.6.0/src/marsyas/marsystems/ClassificationReport.cpp
Go to the documentation of this file.
00001 /*
00002 ** Copyright (C) 1998-2010 George Tzanetakis <gtzan@cs.princeton.edu>
00003 **
00004 ** This program is free software; you can redistribute it and/or modify
00005 ** it under the terms of the GNU General Public License as published by
00006 ** the Free Software Foundation; either version 2 of the License, or
00007 ** (at your option) any later version.
00008 **
00009 ** This program is distributed in the hope that it will be useful,
00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00012 ** GNU General Public License for more details.
00013 **
00014 ** You should have received a copy of the GNU General Public License
00015 ** along with this program; if not, write to the Free Software
00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00017 */
00018 
00036 #include "ClassificationReport.h"
00037 #include "../common_source.h"
00038 
00039 
00040 
00041 using std::ostringstream;
00042 using std::vector;
00043 using std::cout;
00044 using std::endl;
00045 
00046 
00047 using namespace Marsyas;
00048 
00049 ClassificationReport::ClassificationReport(mrs_string name) : MarSystem("ClassificationReport", name)
00050 {
00051   regCorr.sumClass = 0;
00052   regCorr.sumSqrClass = 0;
00053   regCorr.sumClassPredicted = 0;
00054   regCorr.sumPredicted = 0;
00055   regCorr.sumSqrPredicted = 0;
00056   regCorr.withClass = 0;
00057   addControls();
00058 }
00059 
00060 
00061 ClassificationReport::~ClassificationReport()
00062 {
00063 }
00064 
00065 MarSystem *ClassificationReport::clone() const
00066 {
00067   return new ClassificationReport(*this);
00068 }
00069 
00070 void ClassificationReport::addControls()
00071 {
00072   addctrl("mrs_string/mode", "train");
00073   setctrlState("mrs_string/mode", true);
00074   addctrl("mrs_natural/nClasses", 2);
00075   setctrlState("mrs_natural/nClasses", true);
00076   addctrl("mrs_string/classNames", "Music,Speech");
00077   setctrlState("mrs_string/classNames", true);
00078   addctrl("mrs_bool/done", false);
00079   addctrl("mrs_bool/regression", false);
00080 }
00081 
00082 void ClassificationReport::myUpdate(MarControlPtr sender)
00083 {
00084   (void) sender;  //suppress warning of unused parameter(s)
00085   MRSDIAG("ClassificationReport.cpp - ClassificationReport:myUpdate");
00086 
00087   setctrl("mrs_natural/onSamples", getctrl("mrs_natural/inSamples"));
00088   setctrl("mrs_natural/onObservations", (mrs_natural)2);
00089   setctrl("mrs_real/osrate", getctrl("mrs_real/israte"));
00090 
00091   mrs_natural nClasses = getctrl("mrs_natural/nClasses")->to<mrs_natural>();
00092   if (confusionMatrix.getRows() != nClasses)
00093   {
00094     confusionMatrix.create(nClasses, nClasses);
00095   }//if
00096   classNames = getctrl("mrs_string/classNames")->to<mrs_string>();
00097 
00098 }//myUpdate
00099 
00100 void ClassificationReport::myProcess(realvec& in, realvec& out)
00101 {
00102 
00103   static int count = 0;
00104 
00105 
00106   mrs_natural t;
00107   mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>();
00108 
00109   //modified this code to check the done flag-dale
00110   bool done = getctrl("mrs_bool/done")->to<mrs_bool>();
00111 
00112 
00113 
00114   if ((mode == "train") && !done)
00115   {
00116     for (t=0; t < inSamples_; t++)
00117     {
00118       mrs_real label = in(inObservations_-1, t);
00119       out(0,t) = label;
00120       out(1,t) = label;
00121     }//for t
00122   }//if train
00123   else if ((mode == "predict") && !done)
00124   {
00125     count++;
00126 
00127     for (t=0; t < inSamples_; t++)
00128     {
00129       if (getctrl("mrs_bool/regression")->isTrue()) {
00130         mrs_real prediction = in(0, t); //prediction
00131         mrs_real actual = in(1, t); //actual
00132         //cout<<prediction<<'\t'<<actual<<endl;
00133         regCorr.sumClass += actual;
00134         regCorr.sumSqrClass += actual*actual;
00135         regCorr.sumClassPredicted += actual*prediction;
00136         regCorr.sumPredicted += prediction;
00137         regCorr.sumSqrPredicted += prediction*prediction;
00138         regCorr.withClass += 1.0;
00139         out(0,t) = prediction;
00140         out(1,t) = actual;
00141       } else {
00142         //swapped the x and y values-dale
00143         mrs_natural prediction = (mrs_natural)in(0, t); //prediction
00144         mrs_natural actual = (mrs_natural)in(1, t); //actual
00145 
00146         confusionMatrix(actual,prediction)++;
00147         //cout << "(y,x) (" << y << ","<< x << ")"<< endl;
00148 
00149         out(0,t) = prediction;
00150         out(1,t) = actual;
00151       }
00152     }
00153   }
00154   else if(mode == "report" || done)
00155   {
00156     ostringstream stream;
00157 
00158     if (getctrl("mrs_bool/regression")->isTrue()) {
00159 
00160       mrs_real varActual = regCorr.sumSqrClass -
00161                            (regCorr.sumClass*regCorr.sumClass) /
00162                            regCorr.withClass;
00163       mrs_real varPredicted = regCorr.sumSqrPredicted -
00164                               (regCorr.sumPredicted*regCorr.sumPredicted) /
00165                               regCorr.withClass;
00166       mrs_real varProd = regCorr.sumClassPredicted -
00167                          (regCorr.sumClass*regCorr.sumPredicted) /
00168                          regCorr.withClass;
00169 
00170       mrs_real correlation;
00171       if (varActual * varPredicted <= 0) {
00172         correlation = 0.0;
00173       } else {
00174         correlation = varProd / sqrt(varActual*varPredicted);
00175       }
00176 
00177       mrs_real meanAbsoluteError = 0.0;
00178       mrs_real rootMeanSquaredError = 0.0;
00179       mrs_real relativeAbsoluteError = 0.0;
00180       mrs_real rootRelativeSquaredError = 0.0;
00181       mrs_real instances = 0;
00182       stream << "=== ClassificationReport ===" << endl << endl;
00183       stream << "Correlation coefficient" << "\t\t\t" << correlation << "\t" << endl;
00184       stream << "Mean absolute error" << "\t\t\t" << meanAbsoluteError << endl;
00185       stream << "Root mean squared error" << "\t\t\t" << rootMeanSquaredError << endl;
00186       stream << "Relative absolute error" << "\t\t\t" << relativeAbsoluteError << endl;
00187       stream << "Root relative squared error" << "\t\t" << rootRelativeSquaredError << endl;
00188       stream << "Total Number of Instances" << "\t\t" << instances << endl << endl;
00189     } else {
00190 
00191       summaryStatistics stats = computeSummaryStatistics(confusionMatrix);
00192       stream << "=== ClassificationReport ===" << endl << endl;
00193 
00194       stream << "Correctly Classified Instances" << "\t\t" << stats.correctInstances << "\t";
00195       stream << (((mrs_real)stats.correctInstances / (mrs_real)stats.instances)*100.0);
00196       stream << " %" << endl;
00197 
00198       stream << "Incorrectly Classified Instances" << "\t" << (stats.instances - stats.correctInstances) << "\t";
00199       stream << (((mrs_real)(stats.instances - stats.correctInstances) / (mrs_real)stats.instances)*100.0);
00200       stream << " %" << endl;
00201 
00202       stream << "Kappa statistic" << "\t\t\t\t" << stats.kappa << "\t" << endl;
00203       stream << "Mean absolute error" << "\t\t\t" << stats.meanAbsoluteError << endl;
00204       stream << "Root mean squared error" << "\t\t\t" << stats.rootMeanSquaredError << endl;
00205       stream << "Relative absolute error" << "\t\t\t" << stats.relativeAbsoluteError << endl;
00206       stream << "Root relative squared error" << "\t\t" << stats.rootRelativeSquaredError << endl;
00207       stream << "Total Number of Instances" << "\t\t" << stats.instances << endl << endl;
00208 
00209       stream << "=== Confusion Matrix ===";
00210       stream << endl; stream << endl;
00211 
00212       if(!classNames.size())
00213         classNames = ",";
00214 
00215       mrs_string::size_type from = 0;
00216       mrs_string::size_type to = classNames.find(",");
00217 
00218       mrs_natural correct = 0;
00219       mrs_natural total = 0;
00220       for (mrs_natural x = 0; x<confusionMatrix.getCols(); x++)
00221         stream << "\t" << (char)(x+'a');
00222       stream << "\t" << "<-- classified as";
00223       stream << endl;
00224 
00225       for(mrs_natural y = 0; y<confusionMatrix.getRows(); y++)
00226       {
00227         for(mrs_natural x = 0; x<confusionMatrix.getCols(); x++)
00228         {
00229           mrs_natural value = (mrs_natural)confusionMatrix(y, x);
00230           total += value;
00231           if(x == y)
00232             correct += value;
00233 
00234           stream << "\t" << value;
00235         }//for x
00236         stream << "\t" << "| ";
00237         if(from < classNames.size())
00238         {
00239           stream << (char)(y+'a') << " = " << classNames.substr(from, to - from);
00240           from = to + 1;
00241           to = classNames.find(",", from);
00242           if(to == mrs_string::npos)
00243             to = classNames.size();
00244         }//if
00245         stream << endl;
00246       }//for y
00247       stream << (total > 0 ? correct * 100 / total: 0) << "% classified correctly (" << correct << "/" << total << ")" << endl;
00248     }
00249 
00250     MrsLog::mrsMessage(stream);
00251 
00252     updControl("mrs_bool/done", true);
00253   }//if done
00254 }//myProcess
00255 
00256 summaryStatistics ClassificationReport::computeSummaryStatistics(const realvec& mat)
00257 {
00258   MRSASSERT(mat.getCols()==mat.getRows());
00259 
00260   summaryStatistics stats;
00261 
00262   mrs_natural size = mat.getCols();
00263 
00264   vector<mrs_natural>rowSums(size);
00265   for(int ii=0; ii<size; ++ii) rowSums[ii] = 0;
00266   vector<mrs_natural>colSums(size);
00267   for(int ii=0; ii<size; ++ii) colSums[ii] = 0;
00268   mrs_natural diagonalSum = 0;
00269 
00270   mrs_natural instanceCount = 0;
00271   for(mrs_natural row=0; row<size; row++)
00272   {
00273     for(mrs_natural col=0; col<size; col++)
00274     {
00275       mrs_natural num = (mrs_natural)mat(row,col);
00276       instanceCount += num;
00277 
00278       rowSums[row] += num;
00279       colSums[col] += num;
00280 
00281       if(row==col)
00282         diagonalSum += num;
00283     }
00284   }
00285   //printf("row1 sum:%d\n",rowSums[0]);
00286   //printf("row2 sum:%d\n",rowSums[1]);
00287   //printf("col1 sum:%d\n",colSums[0]);
00288   //printf("col2 sum:%d\n",colSums[1]);
00289   //printf("diagonal sum:%d\n",diagonalSum);
00290   //printf("instanceCount:%d\n",instanceCount);
00291 
00292   mrs_natural N = instanceCount;
00293   mrs_natural N2 = (N*N);
00294   stats.instances = instanceCount;
00295   stats.correctInstances = diagonalSum;
00296 
00297   mrs_natural sum = 0;
00298   for(mrs_natural ii=0; ii<size; ++ii)
00299   {
00300     sum += (rowSums[ii] * colSums[ii]);
00301   }
00302   mrs_real PE = (mrs_real)sum / (mrs_real)N2;
00303   mrs_real PA = (mrs_real)diagonalSum / (mrs_real)N;
00304   stats.kappa = (PA - PE) / (1.0 - PE);
00305 
00306   mrs_natural not_diagonal_sum = instanceCount - diagonalSum;
00307   mrs_real MeanAbsoluteError = (mrs_real)not_diagonal_sum / (mrs_real)instanceCount;
00308   //printf("MeanAbsoluteError:%f\n",MeanAbsoluteError);
00309   stats.meanAbsoluteError = MeanAbsoluteError;
00310 
00311   mrs_real RootMeanSquaredError = sqrt(MeanAbsoluteError);
00312   //printf("RootMeanSquaredError:%f\n",RootMeanSquaredError);
00313   stats.rootMeanSquaredError = RootMeanSquaredError;
00314 
00315   mrs_real RelativeAbsoluteError = (MeanAbsoluteError / 0.5) * 100.0;
00316   //printf("RelativeAbsoluteError:%f%%\n",RelativeAbsoluteError);
00317   stats.relativeAbsoluteError = RelativeAbsoluteError;
00318 
00319   mrs_real RootRelativeSquaredError = (RootMeanSquaredError / (0.5)) * 100.0;
00320   //printf("RootRelativeSquaredError:%f%%\n",RootRelativeSquaredError);
00321   stats.rootRelativeSquaredError = RootRelativeSquaredError;
00322 
00323   return stats;
00324 }//computeSummaryStatistics