Marsyas
0.6.0-alpha
|
00001 /* 00002 ** Copyright (C) 1998-2010 George Tzanetakis <gtzan@cs.uvic.ca> 00003 ** 00004 ** This program is free software; you can redistribute it and/or modify 00005 ** it under the terms of the GNU General Public License as published by 00006 ** the Free Software Foundation; either version 2 of the License, or 00007 ** (at your option) any later version. 00008 ** 00009 ** This program is distributed in the hope that it will be useful, 00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of 00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00012 ** GNU General Public License for more details. 00013 ** 00014 ** You should have received a copy of the GNU General Public License 00015 ** along with this program; if not, write to the Free Software 00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00017 */ 00018 00024 #include "OneRClassifier.h" 00025 #include "../common_source.h" 00026 #include <cstddef> 00027 00028 using std::ostringstream; 00029 using std::cout; 00030 using std::endl; 00031 using std::vector; 00032 using std::size_t; 00033 00034 using namespace Marsyas; 00035 00036 OneRClassifier::OneRClassifier(const mrs_string name) : MarSystem("OneRClassifier", name) 00037 { 00038 addControls(); 00039 rule_ = NULL; 00040 lastModePredict_ = false; 00041 } 00042 00043 //Only thing needing destroying is the current rule. 00044 OneRClassifier::~OneRClassifier() 00045 { 00046 if(rule_ != NULL) 00047 delete rule_; 00048 } 00049 00050 MarSystem *OneRClassifier::clone() const 00051 { 00052 return new OneRClassifier(*this); 00053 } 00054 00055 void 00056 OneRClassifier::addControls() 00057 { 00058 addctrl("mrs_string/mode", "train"); 00059 addctrl("mrs_natural/nClasses", 1); 00060 setctrlState("mrs_natural/nClasses", true); 00061 } 00062 00063 void 00064 OneRClassifier::myUpdate(MarControlPtr sender) 00065 { 00066 (void) sender; //suppress warning of unused parameter(s) 00067 MRSDIAG("OneRClassifier.cpp - OneRClassifier:myUpdate"); 00068 ctrl_onSamples_->setValue(ctrl_inSamples_, NOUPDATE); 00069 setctrl("mrs_natural/onObservations", 2); 00070 ctrl_osrate_->setValue(ctrl_israte_->to<mrs_real>()); 00071 ctrl_onObsNames_->setValue("OneRClassifier_" 00072 + ctrl_inObsNames_->to<mrs_string>() , NOUPDATE); 00073 } 00074 00075 void 00076 OneRClassifier::myProcess(realvec& in, realvec& out) 00077 { 00078 cout << "OneRClassifier::myProcess" << endl; 00079 cout << "in.getCols() = " << in.getCols() << endl; 00080 cout << "in.getRows() = " << in.getRows() << endl; 00081 //get the current mode, either train of predict mode 00082 bool trainMode = (getctrl("mrs_string/mode")->to<mrs_string>() == "train"); 00083 row_.stretch(in.getRows()); 00084 if (trainMode) 00085 { 00086 if(lastModePredict_ || instances_.getCols()<=0) 00087 { 00088 mrs_natural nAttributes = getctrl("mrs_natural/inObservations")->to<mrs_natural>(); 00089 cout << "nAttributes = " << nAttributes << endl; 00090 instances_.Create(nAttributes); 00091 } 00092 00093 lastModePredict_ = false; 00094 00095 //get the incoming data and append it to the data table 00096 for (mrs_natural ii=0; ii< inSamples_; ++ii) 00097 { 00098 mrs_real label = in(inObservations_-1, ii); 00099 instances_.Append(in); 00100 out(0,ii) = label; 00101 out(1,ii) = label; 00102 }//for t 00103 }//if 00104 else 00105 { //predict mode 00106 00107 cout << "OneRClassifier::predict" << endl; 00108 if(!lastModePredict_) 00109 { 00110 //get the number of class labels and build the classifier 00111 mrs_natural nAttributes = getctrl("mrs_natural/inObservations")->to<mrs_natural>(); 00112 cout << "BUILD nAttributes = " << nAttributes << endl; 00113 Build(nAttributes); 00114 }//if 00115 lastModePredict_ = true; 00116 cout << "After lastModePredict" << endl; 00117 00118 00119 //foreach row of predict data, extract the actual class, then call the 00120 //classifier predict method. Output the actual and predicted classes. 00121 for (mrs_natural ii=0; ii<inSamples_; ++ii) 00122 { 00123 //extract the actual class 00124 mrs_natural label = (mrs_natural)in(inObservations_-1, ii); 00125 00126 //invoke the classifier predict method to predict the class 00127 in.getCol(ii,row_); 00128 mrs_natural prediction = Predict(row_); 00129 cout << "PREDICTION = " << prediction << endl; 00130 cout << "row_ " << row_ << endl; 00131 00132 //and output actual/predicted classes 00133 out(0,ii) = (mrs_real)prediction; 00134 out(1,ii) = (mrs_real)label; 00135 }//for t 00136 }//if 00137 00138 }//myProcess 00139 00140 //Create a new rule for this attribute. 00141 //Sorts the data table on this attribute and executes the OneR algorithm. 00142 OneRClassifier::OneRRule *OneRClassifier::newRule(mrs_natural attr, mrs_natural nClasses) 00143 { 00144 //create the counting variables 00145 vector<mrs_natural> classifications(instances_.size()); 00146 vector<mrs_real> breakpoints(instances_.size()); 00147 vector<mrs_natural> counts(nClasses); 00148 00149 //set correct count to 0 00150 mrs_natural correct = 0; 00151 mrs_natural lastInstance = (mrs_natural) instances_.size(); 00152 00153 //Sort the data table for this attribute 00154 instances_.Sort(attr); 00155 00156 mrs_natural ii = 0; 00157 mrs_natural cl = 0; //index of next bucket to create 00158 mrs_natural it = 0; 00159 00160 //scan thru all rows in table 00161 while(ii < lastInstance) 00162 { 00163 //zero the current counts 00164 for(mrs_natural jj=0; jj<(mrs_natural)counts.size(); jj++) counts[jj]=0; 00165 do 00166 { //fill it until is has enough of the majority class 00167 it = instances_.GetClass(++ii); 00168 counts[it]++; 00169 } while(counts[it] < minBucketSize_ && ii < lastInstance); 00170 00171 //while class remains the same, keep on filling 00172 while(ii < lastInstance && instances_.GetClass(ii) == it) 00173 { 00174 counts[it]++; 00175 ++ii; 00176 }//while 00177 00178 //keep on while attr value is the same 00179 while(ii < lastInstance && instances_.at(ii-1)->at(attr) == instances_.at(ii)->at(attr)) 00180 { 00181 mrs_natural index = instances_.GetClass(ii++); 00182 counts[index]++; 00183 }//while 00184 00185 for(mrs_natural jj=0; jj<nClasses; jj++) 00186 { 00187 if(counts[jj] > counts[it]) 00188 { 00189 it = jj; 00190 }//if 00191 }//for jj 00192 00193 if(cl > 0) 00194 { //can we coalesce with previous class? 00195 if(counts[classifications[cl-1]] == counts[it]) 00196 it = classifications[cl-1]; 00197 00198 if(it == classifications[cl-1]) 00199 cl--; 00200 }//if 00201 00202 correct += counts[it]; 00203 classifications[cl] = it; 00204 00205 if(ii < lastInstance) 00206 breakpoints[cl] = (((instances_.at(ii-1)->at(attr) + instances_.at(ii)->at(attr)) / 2.0)); 00207 00208 cl++; 00209 }//while 00210 00211 //create a new rule with cl branches 00212 OneRRule *rule = new OneRRule(attr, cl, correct); 00213 for(mrs_natural vv=0; vv<cl; vv++) 00214 { 00215 rule->getClassifications()[vv] = classifications[vv]; 00216 if(vv < (cl-1)) 00217 rule->getBreakpoints()[vv] = breakpoints[vv]; 00218 00219 }//for vv 00220 00221 return rule; 00222 }//newRule 00223 00224 //Build the classifier from the data table 00225 void 00226 OneRClassifier::Build(mrs_natural nClasses) 00227 { 00228 //make sure any previous rule is out 00229 if(rule_!=NULL) 00230 delete rule_; 00231 rule_ = NULL; 00232 00233 //scan through all the attributes(columns) of the table 00234 for(mrs_natural enu = 0; enu < instances_.getCols()-1; enu++) 00235 { 00236 //construct a new rule for this attribute 00237 OneRClassifier::OneRRule *r = newRule(enu, nClasses); 00238 00239 //if a current rule does not exist or this new rule is better, replace old rule 00240 if(!rule_ || r->getCorrect() > rule_->getCorrect()) 00241 { 00242 if(rule_!=NULL) 00243 delete rule_; 00244 00245 rule_ = r; 00246 }//if 00247 }//for enu 00248 }//Build 00249 00250 //Predict a class given a row of attribute data 00251 mrs_natural 00252 OneRClassifier::Predict(const realvec& in) 00253 { 00254 mrs_natural vv = 0; 00255 mrs_real instValue = in(rule_->getAttr()); 00256 00257 //find the breakpoint whose value exceeds the attribute value. 00258 while(vv < rule_->getnBreaks()-1 && instValue >= rule_->getBreakpoints()[vv]) 00259 { 00260 vv++; 00261 }//while 00262 00263 //return the class for this prediction. 00264 return rule_->getClassifications()[vv]; 00265 }//Predict