Marsyas
0.6.0-alpha
|
00001 /* 00002 ** Copyright (C) 1998-2006 George Tzanetakis <gtzan@cs.uvic.ca> 00003 ** 00004 ** This program is free software; you can redistribute it and/or modify 00005 ** it under the terms of the GNU General Public License as published by 00006 ** the Free Software Foundation; either version 2 of the License, or 00007 ** (at your option) any later version. 00008 ** 00009 ** This program is distributed in the hope that it will be useful, 00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of 00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00012 ** GNU General Public License for more details. 00013 ** 00014 ** You should have received a copy of the GNU General Public License 00015 ** along with this program; if not, write to the Free Software 00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00017 */ 00018 00019 #include "SMO.h" 00020 #include "../common_source.h" 00021 00022 using namespace std; 00023 using namespace Marsyas; 00024 00025 SMO::SMO(mrs_string name):MarSystem("SMO",name) 00026 { 00027 //type_ = "SMO"; 00028 //name_ = name; 00029 00030 addControls(); 00031 } 00032 00033 00034 SMO::~SMO() 00035 { 00036 } 00037 00038 00039 MarSystem* 00040 SMO::clone() const 00041 { 00042 return new SMO(*this); 00043 } 00044 00045 void 00046 SMO::addControls() 00047 { 00048 addctrl("mrs_string/mode", "train", modePtr_); 00049 addctrl("mrs_natural/nLabels", 1, nlabelsPtr_); 00050 setctrlState("mrs_natural/nLabels", true); 00051 weights_.create(1); 00052 addctrl("mrs_realvec/weights", weights_, weightsPtr_); 00053 addctrl("mrs_bool/done", false, donePtr_); 00054 setctrlState("mrs_bool/done", true); 00055 00056 } 00057 00058 00059 void 00060 SMO::myUpdate(MarControlPtr sender) 00061 { 00062 (void) sender; //suppress warning of unused parameter(s) 00063 MRSDIAG("SMO.cpp - SMO:myUpdate"); 00064 00065 ctrl_onSamples_->setValue(ctrl_inSamples_, NOUPDATE); 00066 ctrl_onObservations_->setValue(2, NOUPDATE); 00067 ctrl_osrate_->setValue(ctrl_israte_, NOUPDATE); 00068 00069 00070 mrs_natural inObservations = ctrl_inObservations_->to<mrs_natural>(); 00071 // FIXME This variable is being defined but (possibly) not used. 00072 // mrs_natural nlabels = getctrl("mrs_natural/nLabels")->to<mrs_natural>(); 00073 00074 00075 mrs_natural mcols = (getctrl("mrs_realvec/weights")->to<mrs_realvec>()).getCols(); 00076 mrs_natural ncols = weights_.getCols(); 00077 00078 00079 00080 if (inObservations != mcols) 00081 { 00082 weights_.create(inObservations); 00083 updControl("mrs_realvec/weights", weights_); 00084 } 00085 00086 00087 if (inObservations != ncols) 00088 { 00089 weights_.create(inObservations); 00090 } 00091 00092 mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>(); 00093 if (mode == "predict") 00094 { 00095 weights_ = getctrl("mrs_realvec/weights")->to<mrs_realvec>(); 00096 } 00097 } 00098 00099 00100 void 00101 SMO::myProcess(realvec& in, realvec& out) 00102 { 00103 mrs_natural t,o; 00104 mrs_string mode = modePtr_->to<mrs_string>(); 00105 mrs_natural prediction = 0; 00106 mrs_real label; 00107 mrs_real thres; 00108 00109 if (mode == "train") 00110 { 00111 for (t = 0; t < inSamples_; t++) 00112 { 00113 label = in(inObservations_-1, t); 00114 out(0,t) = (mrs_real) label; 00115 out(1,t) = (mrs_real) label; 00116 } 00117 00118 weights_(0) = 0.4122; 00119 weights_(1) = -4.599; 00120 weights_(2) = -14.0203; 00121 weights_(3) = -6.2503; 00122 weights_(4) = -0.8447; 00123 weights_(5) = -2.0753; 00124 weights_(6) = 0.9826; 00125 weights_(7) = -4.1159; 00126 weights_(8) = -1.6985; 00127 weights_(9) = -1.1419; 00128 weights_(10) = 3.5605; 00129 weights_(11) = 1.9987; 00130 weights_(12) = 1.3641; 00131 weights_(13) = -6.412; 00132 weights_(14) = 7.7704; 00133 weights_(15) = 0.6565; 00134 weights_(16) = -0.3749; 00135 weights_(17) = -0.3507; 00136 weights_(18) = 2.5022; 00137 weights_(19) = 0.8658; 00138 weights_(20) = -2.6361; 00139 weights_(21) = 3.9029; 00140 weights_(22) = 0.4051; 00141 weights_(23) = -2.8185; 00142 weights_(24) = 2.4864; 00143 weights_(25) = -1.8054; 00144 weights_(26) = -2.7731; 00145 weights_(27) = 2.2423; 00146 weights_(28) = -2.1786; 00147 weights_(29) = -1.0741; 00148 weights_(30) = -0.5614; 00149 weights_(31) = -3.5967; 00150 weights_(32) = 7.7832; 00151 00152 00153 00154 00155 /* weights_(0) = -1.252; 00156 weights_(1) = 6.796; 00157 weights_(2) = -3.9419; 00158 weights_(3) = 2.3463; 00159 weights_(4) = -3.6959; 00160 weights_(5) = -4.5353; 00161 weights_(6) = -3.5343; 00162 weights_(7) = 0.0114; 00163 weights_(8) = -5.0538; 00164 weights_(9) = -2.0138; 00165 weights_(10) = -1.8438; 00166 weights_(11) = 3.16; 00167 weights_(12) = 2.1316 ; 00168 weights_(13) = 1.6142 ; 00169 weights_(14) = -4.4765 ; 00170 weights_(15) = 7.5799 ; 00171 weights_(16) = 0.9734 ; 00172 weights_(17) = 0.5425 ; 00173 weights_(18) = -0.9018 ; 00174 weights_(19) = -0.1296 ; 00175 weights_(20) = -1.1898 ; 00176 weights_(21) = 2.7628 ; 00177 weights_(22) = -2.7207 ; 00178 weights_(23) = 3.5209 ; 00179 weights_(24) = 0.8888 ; 00180 weights_(25) = -3.8638 ; 00181 weights_(26) = 2.8184 ; 00182 weights_(27) = -2.656 ; 00183 weights_(28) = -2.7921 ; 00184 weights_(29) = 1.8606 ; 00185 weights_(30) = -2.5113 ; 00186 weights_(31) = -1.3537 ; 00187 weights_(32) = -1.1434 ; 00188 weights_(33) = -4.1955 ; 00189 weights_(34) = 3.9084; 00190 */ 00191 00192 /* weights_(0) = 0.9822; 00193 weights_(1) = 0.0614; 00194 weights_(2) = 1.74; 00195 weights_(3) = -1.0346; 00196 weights_(4) = 0.8395; 00197 weights_(5) = 1.2181; 00198 weights_(6) = 0.2218; 00199 weights_(7) = 1.092; 00200 weights_(8) = 0.2186; 00201 weights_(9) = -0.1633; 00202 weights_(10) = - 1.4334; 00203 */ 00204 00205 00206 00207 /*ights_(0) = 6.3694; 00208 weights_(1) = 1.3558; 00209 weights_(2) = -1.3705; 00210 weights_(3) = -7.7337; 00211 weights_(4) = 5.7892; 00212 weights_(5) = 2.2434; 00213 weights_(6) = 4.9089; 00214 weights_(7) = 2.2042; 00215 weights_(8) = -2.0502; 00216 */ 00217 00218 00219 } 00220 00221 00222 00223 if (mode == "predict") 00224 { 00225 for (t = 0; t < inSamples_; t++) 00226 { 00227 label = in(inObservations_-1, t); 00228 thres = 0.0; 00229 for (o = 0; o < inObservations_-1; o++) 00230 { 00231 00232 thres += (weights_(o) * in(o,t)); 00233 } 00234 thres += weights_(inObservations_-1); 00235 00236 if (thres <= 0 ) 00237 { 00238 prediction = 0; 00239 } 00240 00241 else 00242 { 00243 prediction = 1; 00244 } 00245 00246 out(0,t) = (mrs_real) prediction; 00247 out(1,t) = (mrs_real) label; 00248 } 00249 00250 00251 } 00252 00253 if (donePtr_->to<mrs_bool>()) 00254 { 00255 updControl(weightsPtr_, weights_); 00256 } 00257 00258 00259 } 00260 00261 00262 00263 00264 00265 00266 00267 00268 00269 00270 00271