Marsyas
0.6.0-alpha
|
00001 /* 00002 ** Copyright (C) 1998-2006 George Tzanetakis <gtzan@cs.princeton.edu> 00003 ** 00004 ** This program is free software; you can redistribute it and/or modify 00005 ** it under the terms of the GNU General Public License as published by 00006 ** the Free Software Foundation; either version 2 of the License, or 00007 ** (at your option) any later version. 00008 ** 00009 ** This program is distributed in the hope that it will be useful, 00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of 00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00012 ** GNU General Public License for more details. 00013 ** 00014 ** You should have received a copy of the GNU General Public License 00015 ** along with this program; if not, write to the Free Software 00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00017 */ 00018 00019 #include "KNNClassifier.h" 00020 #include "../common_source.h" 00021 00022 using std::ostringstream; 00023 using namespace Marsyas; 00024 00025 KNNClassifier::KNNClassifier(mrs_string name):MarSystem("KNNClassifier",name) 00026 { 00027 prev_mode_ = "predict"; 00028 addControls(); 00029 } 00030 00031 00032 KNNClassifier::~KNNClassifier() 00033 { 00034 } 00035 00036 00037 MarSystem* 00038 KNNClassifier::clone() const 00039 { 00040 return new KNNClassifier(*this); 00041 } 00042 00043 void 00044 KNNClassifier::addControls() 00045 { 00046 addctrl("mrs_string/mode", "train"); 00047 addctrl("mrs_natural/nLabels", 1); 00048 setctrlState("mrs_natural/nLabels", true); 00049 trainSet_.create((mrs_natural)1,(mrs_natural)1); 00050 addctrl("mrs_natural/grow", 1); 00051 addctrl("mrs_natural/k", 1); 00052 k_ = 1; 00053 addctrl("mrs_realvec/trainSet", trainSet_); 00054 addctrl("mrs_natural/nPoints", 0); 00055 addctrl("mrs_bool/done", false); 00056 addctrl("mrs_natural/nPredictions", 1); 00057 setctrlState("mrs_natural/nPredictions", true); 00058 setctrlState("mrs_bool/done", true); 00059 } 00060 00061 00062 void 00063 KNNClassifier::myUpdate(MarControlPtr sender) 00064 { 00065 (void) sender; //suppress warning of unused parameter(s) 00066 MRSDIAG("KNNClassifier.cpp - KNNClassifier:myUpdate"); 00067 00068 nPredictions_ = getctrl("mrs_natural/nPredictions")->to<mrs_natural>(); 00069 setctrl("mrs_natural/onSamples", getctrl("mrs_natural/inSamples")); 00070 setctrl("mrs_natural/onObservations", (mrs_natural) nPredictions_ + 1); 00071 setctrl("mrs_real/osrate", getctrl("mrs_real/israte")); 00072 00073 inObservations_ = getctrl("mrs_natural/inObservations")->to<mrs_natural>(); 00074 grow_ = getctrl("mrs_natural/grow")->to<mrs_natural>(); 00075 nPoints_ = getctrl("mrs_natural/nPoints")->to<mrs_natural>(); 00076 k_ = getctrl("mrs_natural/k")->to<mrs_natural>(); 00077 mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>(); 00078 00079 if (mode == "train") 00080 { 00081 if (inObservations_ != trainSet_.getCols()) 00082 { 00083 trainSet_.stretch(1, getctrl("mrs_natural/inObservations")->to<mrs_natural>()); 00084 setctrl("mrs_realvec/trainSet", trainSet_); 00085 } 00086 } 00087 00088 00089 00090 if (mode == "predict") 00091 { 00092 trainSet_.create(getctrl("mrs_realvec/trainSet")->to<mrs_realvec>().getRows(), 00093 getctrl("mrs_realvec/trainSet")->to<mrs_realvec>().getCols()); 00094 trainSet_ = getctrl("mrs_realvec/trainSet")->to<mrs_realvec>(); 00095 } 00096 00097 00098 if (getctrl("mrs_bool/done")->to<mrs_bool>()) 00099 { 00100 setctrl("mrs_bool/done", false); 00101 setctrl("mrs_realvec/trainSet", trainSet_); 00102 } 00103 } 00104 00105 00106 void 00107 KNNClassifier::myProcess(realvec& in, realvec& out) 00108 { 00109 //checkFlow(in,out); 00110 00111 mrs_real v; 00112 mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>(); 00113 mrs_real label; 00114 mrs_natural nlabels = getctrl("mrs_natural/nLabels")->to<mrs_natural>(); 00115 mrs_natural prediction; 00116 int x, y; 00117 int p; 00118 mrs_natural o,t; 00119 00120 if ((prev_mode_ == "predict")&&(mode == "train")) 00121 { 00122 00123 // reset 00124 for (p = 0; p < nPoints_; p++) 00125 { 00126 for (o=0; o < inObservations_-1; o++) 00127 trainSet_(p,o) = 0.0; 00128 } 00129 nPoints_ = 0; 00130 } 00131 00132 00133 if (mode == "train") 00134 { 00135 for (t = 0; t < inSamples_; t++) 00136 { 00137 label = in(inObservations_-1, t); 00138 00139 if (nPoints_ == grow_) 00140 { 00141 00142 // expontentially stretch trainSet_ realvec 00143 grow_ = 2*grow_; 00144 trainSet_.stretch(grow_, inObservations_); 00145 updControl("mrs_natural/grow", grow_); 00146 } 00147 00148 for (o=0; o < inObservations_; o++) 00149 { 00150 // store all observations for instance t 00151 trainSet_(nPoints_,o) = in(o,t); 00152 } 00153 out(0,t) = label; 00154 out(1,t) = label; 00155 00156 // update number of points 00157 nPoints_= nPoints_ +1; 00158 updControl("mrs_natural/nPoints", nPoints_); 00159 } 00160 } 00161 00162 00163 00164 if (mode == "predict") 00165 { 00166 00167 // Calculate Distances for each Point 00168 for (t = 0; t < inSamples_; t++) 00169 { 00170 label = in(inObservations_-1, t); 00171 00172 realvec Distance; 00173 Distance.create(nPoints_); 00174 00175 realvec kMin; 00176 kMin.create(k_,2); 00177 00178 realvec kSmallest; 00179 kSmallest.create(nlabels); 00180 00181 for (p = 0; p < nPoints_; p++) 00182 { 00183 mrs_real sum = 0; 00184 for (o=0; o < inObservations_-1; o++) 00185 { 00186 v = in(o,t); 00187 v = (v - trainSet_(p,o)); 00188 sum += v*v; 00189 } 00190 Distance(p) = sum; 00191 } 00192 00193 00194 // Find k smallest distances 00195 00196 // initialize kMin RealVec 00197 mrs_real kmaxV = Distance(0); // max value initialization 00198 int kmaxI = 0; // max Index initialization 00199 00200 for (x=0; x < k_; x++) 00201 { 00202 kMin(x, 0) = Distance(0); // Distance Value 00203 kMin(x, 1) = 0; // Label 00204 } 00205 00206 00207 for (y=0; y < nPoints_; y++) 00208 { 00209 00210 if (Distance(y) < kmaxV) 00211 { 00212 mrs_real kmaxV_t = 0.0; 00213 int kmaxI_t = 1; 00214 00215 kMin(kmaxI,0) = Distance(y); // value 00216 kMin(kmaxI,1) = trainSet_(y, inObservations_-1); // label 00217 00218 // Now find Max Value in kMin RealVec 00219 for (x=0; x < k_; x++) 00220 { 00221 kmaxV_t = kMin(0,0); 00222 kmaxI_t = 0; 00223 if (kMin(x,0) > kmaxV_t) 00224 { 00225 kmaxV_t = kMin(x,0); 00226 kmaxI_t = x; 00227 } 00228 } 00229 kmaxV = kmaxV_t; 00230 kmaxI = kmaxI_t; 00231 } 00232 } 00233 00234 00235 // Now find biggest number of values for label in kMin 00236 for (x=0; x< k_; x++) 00237 { 00238 // based on the label of the min point, incement kSmallest 00239 kSmallest((int)kMin(x, 1))++; 00240 } 00241 00242 00243 // find largest value in kSmallest 00244 mrs_real max = kSmallest(0); 00245 int maxI = 0; 00246 for (x=0; x<nlabels; x++) 00247 { 00248 if (kSmallest(x) > max) 00249 { 00250 max = kSmallest(x); 00251 maxI = x; 00252 } 00253 } 00254 prediction = maxI; 00255 out(0,t) = (mrs_real)prediction; 00256 if (nPredictions_ >= 1) 00257 for (x=0; x < nPredictions_; x++) 00258 out(x,t) = kMin(x,1); 00259 00260 out(onObservations_-1,t) = label; 00261 } 00262 } 00263 00264 00265 00266 prev_mode_ = mode; 00267 } 00268 00269 00270 00271 00272 00273 00274 00275 00276 00277