Marsyas
0.6.0-alpha
|
00001 /* 00002 ** Copyright (C) 1998-2010 George Tzanetakis <gtzan@cs.princeton.edu> 00003 ** 00004 ** This program is free software; you can redistribute it and/or modify 00005 ** it under the terms of the GNU General Public License as published by 00006 ** the Free Software Foundation; either version 2 of the License, or 00007 ** (at your option) any later version. 00008 ** 00009 ** This program is distributed in the hope that it will be useful, 00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of 00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00012 ** GNU General Public License for more details. 00013 ** 00014 ** You should have received a copy of the GNU General Public License 00015 ** along with this program; if not, write to the Free Software 00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00017 */ 00018 00036 #include "ClassificationReport.h" 00037 #include "../common_source.h" 00038 00039 00040 00041 using std::ostringstream; 00042 using std::vector; 00043 using std::cout; 00044 using std::endl; 00045 00046 00047 using namespace Marsyas; 00048 00049 ClassificationReport::ClassificationReport(mrs_string name) : MarSystem("ClassificationReport", name) 00050 { 00051 regCorr.sumClass = 0; 00052 regCorr.sumSqrClass = 0; 00053 regCorr.sumClassPredicted = 0; 00054 regCorr.sumPredicted = 0; 00055 regCorr.sumSqrPredicted = 0; 00056 regCorr.withClass = 0; 00057 addControls(); 00058 } 00059 00060 00061 ClassificationReport::~ClassificationReport() 00062 { 00063 } 00064 00065 MarSystem *ClassificationReport::clone() const 00066 { 00067 return new ClassificationReport(*this); 00068 } 00069 00070 void ClassificationReport::addControls() 00071 { 00072 addctrl("mrs_string/mode", "train"); 00073 setctrlState("mrs_string/mode", true); 00074 addctrl("mrs_natural/nClasses", 2); 00075 setctrlState("mrs_natural/nClasses", true); 00076 addctrl("mrs_string/classNames", "Music,Speech"); 00077 setctrlState("mrs_string/classNames", true); 00078 addctrl("mrs_bool/done", false); 00079 addctrl("mrs_bool/regression", false); 00080 } 00081 00082 void ClassificationReport::myUpdate(MarControlPtr sender) 00083 { 00084 (void) sender; //suppress warning of unused parameter(s) 00085 MRSDIAG("ClassificationReport.cpp - ClassificationReport:myUpdate"); 00086 00087 setctrl("mrs_natural/onSamples", getctrl("mrs_natural/inSamples")); 00088 setctrl("mrs_natural/onObservations", (mrs_natural)2); 00089 setctrl("mrs_real/osrate", getctrl("mrs_real/israte")); 00090 00091 mrs_natural nClasses = getctrl("mrs_natural/nClasses")->to<mrs_natural>(); 00092 if (confusionMatrix.getRows() != nClasses) 00093 { 00094 confusionMatrix.create(nClasses, nClasses); 00095 }//if 00096 classNames = getctrl("mrs_string/classNames")->to<mrs_string>(); 00097 00098 }//myUpdate 00099 00100 void ClassificationReport::myProcess(realvec& in, realvec& out) 00101 { 00102 00103 static int count = 0; 00104 00105 00106 mrs_natural t; 00107 mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>(); 00108 00109 //modified this code to check the done flag-dale 00110 bool done = getctrl("mrs_bool/done")->to<mrs_bool>(); 00111 00112 00113 00114 if ((mode == "train") && !done) 00115 { 00116 for (t=0; t < inSamples_; t++) 00117 { 00118 mrs_real label = in(inObservations_-1, t); 00119 out(0,t) = label; 00120 out(1,t) = label; 00121 }//for t 00122 }//if train 00123 else if ((mode == "predict") && !done) 00124 { 00125 count++; 00126 00127 for (t=0; t < inSamples_; t++) 00128 { 00129 if (getctrl("mrs_bool/regression")->isTrue()) { 00130 mrs_real prediction = in(0, t); //prediction 00131 mrs_real actual = in(1, t); //actual 00132 //cout<<prediction<<'\t'<<actual<<endl; 00133 regCorr.sumClass += actual; 00134 regCorr.sumSqrClass += actual*actual; 00135 regCorr.sumClassPredicted += actual*prediction; 00136 regCorr.sumPredicted += prediction; 00137 regCorr.sumSqrPredicted += prediction*prediction; 00138 regCorr.withClass += 1.0; 00139 out(0,t) = prediction; 00140 out(1,t) = actual; 00141 } else { 00142 //swapped the x and y values-dale 00143 mrs_natural prediction = (mrs_natural)in(0, t); //prediction 00144 mrs_natural actual = (mrs_natural)in(1, t); //actual 00145 00146 confusionMatrix(actual,prediction)++; 00147 //cout << "(y,x) (" << y << ","<< x << ")"<< endl; 00148 00149 out(0,t) = prediction; 00150 out(1,t) = actual; 00151 } 00152 } 00153 } 00154 else if(mode == "report" || done) 00155 { 00156 ostringstream stream; 00157 00158 if (getctrl("mrs_bool/regression")->isTrue()) { 00159 00160 mrs_real varActual = regCorr.sumSqrClass - 00161 (regCorr.sumClass*regCorr.sumClass) / 00162 regCorr.withClass; 00163 mrs_real varPredicted = regCorr.sumSqrPredicted - 00164 (regCorr.sumPredicted*regCorr.sumPredicted) / 00165 regCorr.withClass; 00166 mrs_real varProd = regCorr.sumClassPredicted - 00167 (regCorr.sumClass*regCorr.sumPredicted) / 00168 regCorr.withClass; 00169 00170 mrs_real correlation; 00171 if (varActual * varPredicted <= 0) { 00172 correlation = 0.0; 00173 } else { 00174 correlation = varProd / sqrt(varActual*varPredicted); 00175 } 00176 00177 mrs_real meanAbsoluteError = 0.0; 00178 mrs_real rootMeanSquaredError = 0.0; 00179 mrs_real relativeAbsoluteError = 0.0; 00180 mrs_real rootRelativeSquaredError = 0.0; 00181 mrs_real instances = 0; 00182 stream << "=== ClassificationReport ===" << endl << endl; 00183 stream << "Correlation coefficient" << "\t\t\t" << correlation << "\t" << endl; 00184 stream << "Mean absolute error" << "\t\t\t" << meanAbsoluteError << endl; 00185 stream << "Root mean squared error" << "\t\t\t" << rootMeanSquaredError << endl; 00186 stream << "Relative absolute error" << "\t\t\t" << relativeAbsoluteError << endl; 00187 stream << "Root relative squared error" << "\t\t" << rootRelativeSquaredError << endl; 00188 stream << "Total Number of Instances" << "\t\t" << instances << endl << endl; 00189 } else { 00190 00191 summaryStatistics stats = computeSummaryStatistics(confusionMatrix); 00192 stream << "=== ClassificationReport ===" << endl << endl; 00193 00194 stream << "Correctly Classified Instances" << "\t\t" << stats.correctInstances << "\t"; 00195 stream << (((mrs_real)stats.correctInstances / (mrs_real)stats.instances)*100.0); 00196 stream << " %" << endl; 00197 00198 stream << "Incorrectly Classified Instances" << "\t" << (stats.instances - stats.correctInstances) << "\t"; 00199 stream << (((mrs_real)(stats.instances - stats.correctInstances) / (mrs_real)stats.instances)*100.0); 00200 stream << " %" << endl; 00201 00202 stream << "Kappa statistic" << "\t\t\t\t" << stats.kappa << "\t" << endl; 00203 stream << "Mean absolute error" << "\t\t\t" << stats.meanAbsoluteError << endl; 00204 stream << "Root mean squared error" << "\t\t\t" << stats.rootMeanSquaredError << endl; 00205 stream << "Relative absolute error" << "\t\t\t" << stats.relativeAbsoluteError << endl; 00206 stream << "Root relative squared error" << "\t\t" << stats.rootRelativeSquaredError << endl; 00207 stream << "Total Number of Instances" << "\t\t" << stats.instances << endl << endl; 00208 00209 stream << "=== Confusion Matrix ==="; 00210 stream << endl; stream << endl; 00211 00212 if(!classNames.size()) 00213 classNames = ","; 00214 00215 mrs_string::size_type from = 0; 00216 mrs_string::size_type to = classNames.find(","); 00217 00218 mrs_natural correct = 0; 00219 mrs_natural total = 0; 00220 for (mrs_natural x = 0; x<confusionMatrix.getCols(); x++) 00221 stream << "\t" << (char)(x+'a'); 00222 stream << "\t" << "<-- classified as"; 00223 stream << endl; 00224 00225 for(mrs_natural y = 0; y<confusionMatrix.getRows(); y++) 00226 { 00227 for(mrs_natural x = 0; x<confusionMatrix.getCols(); x++) 00228 { 00229 mrs_natural value = (mrs_natural)confusionMatrix(y, x); 00230 total += value; 00231 if(x == y) 00232 correct += value; 00233 00234 stream << "\t" << value; 00235 }//for x 00236 stream << "\t" << "| "; 00237 if(from < classNames.size()) 00238 { 00239 stream << (char)(y+'a') << " = " << classNames.substr(from, to - from); 00240 from = to + 1; 00241 to = classNames.find(",", from); 00242 if(to == mrs_string::npos) 00243 to = classNames.size(); 00244 }//if 00245 stream << endl; 00246 }//for y 00247 stream << (total > 0 ? correct * 100 / total: 0) << "% classified correctly (" << correct << "/" << total << ")" << endl; 00248 } 00249 00250 MrsLog::mrsMessage(stream); 00251 00252 updControl("mrs_bool/done", true); 00253 }//if done 00254 }//myProcess 00255 00256 summaryStatistics ClassificationReport::computeSummaryStatistics(const realvec& mat) 00257 { 00258 MRSASSERT(mat.getCols()==mat.getRows()); 00259 00260 summaryStatistics stats; 00261 00262 mrs_natural size = mat.getCols(); 00263 00264 vector<mrs_natural>rowSums(size); 00265 for(int ii=0; ii<size; ++ii) rowSums[ii] = 0; 00266 vector<mrs_natural>colSums(size); 00267 for(int ii=0; ii<size; ++ii) colSums[ii] = 0; 00268 mrs_natural diagonalSum = 0; 00269 00270 mrs_natural instanceCount = 0; 00271 for(mrs_natural row=0; row<size; row++) 00272 { 00273 for(mrs_natural col=0; col<size; col++) 00274 { 00275 mrs_natural num = (mrs_natural)mat(row,col); 00276 instanceCount += num; 00277 00278 rowSums[row] += num; 00279 colSums[col] += num; 00280 00281 if(row==col) 00282 diagonalSum += num; 00283 } 00284 } 00285 //printf("row1 sum:%d\n",rowSums[0]); 00286 //printf("row2 sum:%d\n",rowSums[1]); 00287 //printf("col1 sum:%d\n",colSums[0]); 00288 //printf("col2 sum:%d\n",colSums[1]); 00289 //printf("diagonal sum:%d\n",diagonalSum); 00290 //printf("instanceCount:%d\n",instanceCount); 00291 00292 mrs_natural N = instanceCount; 00293 mrs_natural N2 = (N*N); 00294 stats.instances = instanceCount; 00295 stats.correctInstances = diagonalSum; 00296 00297 mrs_natural sum = 0; 00298 for(mrs_natural ii=0; ii<size; ++ii) 00299 { 00300 sum += (rowSums[ii] * colSums[ii]); 00301 } 00302 mrs_real PE = (mrs_real)sum / (mrs_real)N2; 00303 mrs_real PA = (mrs_real)diagonalSum / (mrs_real)N; 00304 stats.kappa = (PA - PE) / (1.0 - PE); 00305 00306 mrs_natural not_diagonal_sum = instanceCount - diagonalSum; 00307 mrs_real MeanAbsoluteError = (mrs_real)not_diagonal_sum / (mrs_real)instanceCount; 00308 //printf("MeanAbsoluteError:%f\n",MeanAbsoluteError); 00309 stats.meanAbsoluteError = MeanAbsoluteError; 00310 00311 mrs_real RootMeanSquaredError = sqrt(MeanAbsoluteError); 00312 //printf("RootMeanSquaredError:%f\n",RootMeanSquaredError); 00313 stats.rootMeanSquaredError = RootMeanSquaredError; 00314 00315 mrs_real RelativeAbsoluteError = (MeanAbsoluteError / 0.5) * 100.0; 00316 //printf("RelativeAbsoluteError:%f%%\n",RelativeAbsoluteError); 00317 stats.relativeAbsoluteError = RelativeAbsoluteError; 00318 00319 mrs_real RootRelativeSquaredError = (RootMeanSquaredError / (0.5)) * 100.0; 00320 //printf("RootRelativeSquaredError:%f%%\n",RootRelativeSquaredError); 00321 stats.rootRelativeSquaredError = RootRelativeSquaredError; 00322 00323 return stats; 00324 }//computeSummaryStatistics