Marsyas  0.6.0-alpha
/usr/src/RPM/BUILD/marsyas-0.6.0/src/marsyas/marsystems/KNNClassifier.cpp
Go to the documentation of this file.
00001 /*
00002 ** Copyright (C) 1998-2006 George Tzanetakis <gtzan@cs.princeton.edu>
00003 **
00004 ** This program is free software; you can redistribute it and/or modify
00005 ** it under the terms of the GNU General Public License as published by
00006 ** the Free Software Foundation; either version 2 of the License, or
00007 ** (at your option) any later version.
00008 **
00009 ** This program is distributed in the hope that it will be useful,
00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00012 ** GNU General Public License for more details.
00013 **
00014 ** You should have received a copy of the GNU General Public License
00015 ** along with this program; if not, write to the Free Software
00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00017 */
00018 
00019 #include "KNNClassifier.h"
00020 #include "../common_source.h"
00021 
00022 using std::ostringstream;
00023 using namespace Marsyas;
00024 
00025 KNNClassifier::KNNClassifier(mrs_string name):MarSystem("KNNClassifier",name)
00026 {
00027   prev_mode_ = "predict";
00028   addControls();
00029 }
00030 
00031 
00032 KNNClassifier::~KNNClassifier()
00033 {
00034 }
00035 
00036 
00037 MarSystem*
00038 KNNClassifier::clone() const
00039 {
00040   return new KNNClassifier(*this);
00041 }
00042 
00043 void
00044 KNNClassifier::addControls()
00045 {
00046   addctrl("mrs_string/mode", "train");
00047   addctrl("mrs_natural/nLabels", 1);
00048   setctrlState("mrs_natural/nLabels", true);
00049   trainSet_.create((mrs_natural)1,(mrs_natural)1);
00050   addctrl("mrs_natural/grow", 1);
00051   addctrl("mrs_natural/k", 1);
00052   k_ = 1;
00053   addctrl("mrs_realvec/trainSet", trainSet_);
00054   addctrl("mrs_natural/nPoints", 0);
00055   addctrl("mrs_bool/done", false);
00056   addctrl("mrs_natural/nPredictions", 1);
00057   setctrlState("mrs_natural/nPredictions", true);
00058   setctrlState("mrs_bool/done", true);
00059 }
00060 
00061 
00062 void
00063 KNNClassifier::myUpdate(MarControlPtr sender)
00064 {
00065   (void) sender;  //suppress warning of unused parameter(s)
00066   MRSDIAG("KNNClassifier.cpp - KNNClassifier:myUpdate");
00067 
00068   nPredictions_ = getctrl("mrs_natural/nPredictions")->to<mrs_natural>();
00069   setctrl("mrs_natural/onSamples", getctrl("mrs_natural/inSamples"));
00070   setctrl("mrs_natural/onObservations", (mrs_natural) nPredictions_ + 1);
00071   setctrl("mrs_real/osrate", getctrl("mrs_real/israte"));
00072 
00073   inObservations_ = getctrl("mrs_natural/inObservations")->to<mrs_natural>();
00074   grow_ = getctrl("mrs_natural/grow")->to<mrs_natural>();
00075   nPoints_ = getctrl("mrs_natural/nPoints")->to<mrs_natural>();
00076   k_ = getctrl("mrs_natural/k")->to<mrs_natural>();
00077   mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>();
00078 
00079   if (mode == "train")
00080   {
00081     if (inObservations_ != trainSet_.getCols())
00082     {
00083       trainSet_.stretch(1, getctrl("mrs_natural/inObservations")->to<mrs_natural>());
00084       setctrl("mrs_realvec/trainSet", trainSet_);
00085     }
00086   }
00087 
00088 
00089 
00090   if (mode == "predict")
00091   {
00092     trainSet_.create(getctrl("mrs_realvec/trainSet")->to<mrs_realvec>().getRows(),
00093                      getctrl("mrs_realvec/trainSet")->to<mrs_realvec>().getCols());
00094     trainSet_ = getctrl("mrs_realvec/trainSet")->to<mrs_realvec>();
00095   }
00096 
00097 
00098   if (getctrl("mrs_bool/done")->to<mrs_bool>())
00099   {
00100     setctrl("mrs_bool/done", false);
00101     setctrl("mrs_realvec/trainSet", trainSet_);
00102   }
00103 }
00104 
00105 
00106 void
00107 KNNClassifier::myProcess(realvec& in, realvec& out)
00108 {
00109   //checkFlow(in,out);
00110 
00111   mrs_real v;
00112   mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>();
00113   mrs_real label;
00114   mrs_natural nlabels = getctrl("mrs_natural/nLabels")->to<mrs_natural>();
00115   mrs_natural prediction;
00116   int x, y;
00117   int p;
00118   mrs_natural o,t;
00119 
00120   if ((prev_mode_ == "predict")&&(mode == "train"))
00121   {
00122 
00123     // reset
00124     for (p = 0; p < nPoints_; p++)
00125     {
00126       for (o=0; o < inObservations_-1; o++)
00127         trainSet_(p,o) = 0.0;
00128     }
00129     nPoints_ = 0;
00130   }
00131 
00132 
00133   if (mode == "train")
00134   {
00135     for (t = 0; t < inSamples_; t++)
00136     {
00137       label = in(inObservations_-1, t);
00138 
00139       if (nPoints_ == grow_)
00140       {
00141 
00142         // expontentially stretch trainSet_ realvec
00143         grow_ = 2*grow_;
00144         trainSet_.stretch(grow_, inObservations_);
00145         updControl("mrs_natural/grow", grow_);
00146       }
00147 
00148       for (o=0; o < inObservations_; o++)
00149       {
00150         // store all observations for instance t
00151         trainSet_(nPoints_,o) = in(o,t);
00152       }
00153       out(0,t) = label;
00154       out(1,t) = label;
00155 
00156       // update number of points
00157       nPoints_= nPoints_ +1;
00158       updControl("mrs_natural/nPoints", nPoints_);
00159     }
00160   }
00161 
00162 
00163 
00164   if (mode == "predict")
00165   {
00166 
00167     // Calculate Distances for each Point
00168     for (t = 0; t < inSamples_; t++)
00169     {
00170       label = in(inObservations_-1, t);
00171 
00172       realvec Distance;
00173       Distance.create(nPoints_);
00174 
00175       realvec kMin;
00176       kMin.create(k_,2);
00177 
00178       realvec kSmallest;
00179       kSmallest.create(nlabels);
00180 
00181       for (p = 0; p < nPoints_; p++)
00182       {
00183         mrs_real sum = 0;
00184         for (o=0; o < inObservations_-1; o++)
00185         {
00186           v = in(o,t);
00187           v = (v - trainSet_(p,o));
00188           sum += v*v;
00189         }
00190         Distance(p) = sum;
00191       }
00192 
00193 
00194       // Find k smallest distances
00195 
00196       // initialize kMin RealVec
00197       mrs_real kmaxV = Distance(0); // max value initialization
00198       int kmaxI = 0; // max Index initialization
00199 
00200       for (x=0; x < k_; x++)
00201       {
00202         kMin(x, 0) = Distance(0); // Distance Value
00203         kMin(x, 1) = 0; // Label
00204       }
00205 
00206 
00207       for (y=0; y < nPoints_; y++)
00208       {
00209 
00210         if (Distance(y) < kmaxV)
00211         {
00212           mrs_real kmaxV_t = 0.0;
00213           int kmaxI_t = 1;
00214 
00215           kMin(kmaxI,0) = Distance(y); // value
00216           kMin(kmaxI,1) = trainSet_(y, inObservations_-1); // label
00217 
00218           // Now find Max Value in kMin RealVec
00219           for (x=0; x < k_; x++)
00220           {
00221             kmaxV_t = kMin(0,0);
00222             kmaxI_t = 0;
00223             if (kMin(x,0) > kmaxV_t)
00224             {
00225               kmaxV_t = kMin(x,0);
00226               kmaxI_t = x;
00227             }
00228           }
00229           kmaxV = kmaxV_t;
00230           kmaxI = kmaxI_t;
00231         }
00232       }
00233 
00234 
00235       // Now find biggest number of values for label in kMin
00236       for (x=0; x< k_; x++)
00237       {
00238         // based on the label of the min point, incement kSmallest
00239         kSmallest((int)kMin(x, 1))++;
00240       }
00241 
00242 
00243       // find largest value in kSmallest
00244       mrs_real max = kSmallest(0);
00245       int maxI = 0;
00246       for (x=0; x<nlabels; x++)
00247       {
00248         if (kSmallest(x) > max)
00249         {
00250           max = kSmallest(x);
00251           maxI = x;
00252         }
00253       }
00254       prediction = maxI;
00255       out(0,t) = (mrs_real)prediction;
00256       if (nPredictions_ >= 1)
00257         for (x=0; x < nPredictions_; x++)
00258           out(x,t) = kMin(x,1);
00259 
00260       out(onObservations_-1,t) = label;
00261     }
00262   }
00263 
00264 
00265 
00266   prev_mode_ = mode;
00267 }
00268 
00269 
00270 
00271 
00272 
00273 
00274 
00275 
00276 
00277