Marsyas
0.6.0-alpha
|
00001 /* 00002 ** Copyright (C) 1998-2006 George Tzanetakis <gtzan@cs.uvic.ca> 00003 ** 00004 ** This program is free software; you can redistribute it and/or modify 00005 ** it under the terms of the GNU General Public License as published by 00006 ** the Free Software Foundation; either version 2 of the License, or 00007 ** (at your option) any later version. 00008 ** 00009 ** This program is distributed in the hope that it will be useful, 00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of 00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00012 ** GNU General Public License for more details. 00013 ** 00014 ** You should have received a copy of the GNU General Public License 00015 ** along with this program; if not, write to the Free Software 00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00017 */ 00018 00019 #include "../common_source.h" 00020 #include "SVMClassifier.h" 00021 #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) 00022 00023 using namespace std; 00024 using namespace Marsyas; 00025 00026 SVMClassifier::SVMClassifier(mrs_string name) : 00027 MarSystem("SVMClassifier", name) { 00028 training_ = true; 00029 was_training_ = false; 00030 trained_ = false; 00031 kernel_ = LINEAR; 00032 svm_ = C_SVC; 00033 svm_model_ = NULL; 00034 num_nodes = 0; 00035 svm_prob_.y = NULL; 00036 svm_prob_.x = NULL; 00037 00038 addControls(); 00039 } 00040 00041 SVMClassifier::SVMClassifier(const SVMClassifier& a) : 00042 MarSystem(a) { 00043 training_ = true; 00044 was_training_ = false; 00045 trained_ = false; 00046 kernel_ = LINEAR; 00047 svm_ = C_SVC; 00048 svm_model_ = NULL; 00049 num_nodes = 0; 00050 svm_prob_.y = NULL; 00051 svm_prob_.x = NULL; 00052 00053 ctrl_nClasses_ = getctrl("mrs_natural/nClasses"); 00054 ctrl_sv_coef_ = getctrl("mrs_realvec/sv_coef"); 00055 ctrl_sv_indices_ = getctrl("mrs_realvec/sv_indices"); 00056 ctrl_SV_ = getctrl("mrs_realvec/SV"); 00057 ctrl_rho_ = getctrl("mrs_realvec/rho"); 00058 ctrl_probA_ = getctrl("mrs_realvec/probA"); 00059 ctrl_probB_ = getctrl("mrs_realvec/probB"); 00060 ctrl_label_ = getctrl("mrs_realvec/label"); 00061 ctrl_nSV_ = getctrl("mrs_realvec/nSV"); 00062 ctrl_nr_class_ = getctrl("mrs_natural/nr_class"); 00063 ctrl_weight_ = getctrl("mrs_realvec/weight"); 00064 ctrl_weight_label_ = getctrl("mrs_realvec/weight_label"); 00065 ctrl_minimums_ = getctrl("mrs_realvec/minimums"); 00066 ctrl_maximums_ = getctrl("mrs_realvec/maximums"); 00067 ctrl_mode_ = getctrl("mrs_string/mode"); 00068 ctrl_l_ = getctrl("mrs_natural/l"); 00069 ctrl_svm_ = getctrl("mrs_string/svm"); 00070 ctrl_kernel_ = getctrl("mrs_string/kernel"); 00071 ctrl_degree_ = getctrl("mrs_natural/degree"); 00072 ctrl_gamma_ = getctrl("mrs_natural/gamma"); 00073 ctrl_coef0_ = getctrl("mrs_natural/coef0"); 00074 ctrl_nu_ = getctrl("mrs_real/nu"); 00075 ctrl_cache_size_ = getctrl("mrs_natural/cache_size"); 00076 ctrl_C_ = getctrl("mrs_real/C"); 00077 ctrl_eps_ = getctrl("mrs_real/eps"); 00078 ctrl_p_ = getctrl("mrs_real/p"); 00079 ctrl_shrinking_ = getctrl("mrs_bool/shrinking"); 00080 ctrl_probability_ = getctrl("mrs_bool/probability"); 00081 ctrl_nr_weight_ = getctrl("mrs_natural/nr_weight"); 00082 ctrl_classPerms_ = getctrl("mrs_realvec/classPerms"); 00083 00084 } 00085 00086 SVMClassifier::~SVMClassifier() { 00087 ensure_free_svm_model(); 00088 ensure_free_svm_prob_xy(); 00089 } 00090 00091 void SVMClassifier::ensure_free_svm_model() { 00092 if (svm_model_ != NULL) 00093 { 00094 svm_free_and_destroy_model(&svm_model_); 00095 } 00096 } 00097 00098 void SVMClassifier::ensure_free_svm_prob_xy() { 00099 if (svm_prob_.x != NULL) { 00100 for (int i=0; i < num_svm_prob_x; ++i) { 00101 if (svm_prob_.x[i] != NULL) { 00102 delete [] svm_prob_.x[i]; 00103 svm_prob_.x[i] = NULL; 00104 } 00105 } 00106 delete [] svm_prob_.x; 00107 svm_prob_.x = NULL; 00108 } 00109 if (svm_prob_.y != NULL) { 00110 delete [] svm_prob_.y; 00111 svm_prob_.y = NULL; 00112 } 00113 } 00114 00115 MarSystem* SVMClassifier::clone() const { 00116 return new SVMClassifier(*this); 00117 } 00118 00119 void SVMClassifier::addControls() { 00120 addctrl("mrs_string/mode", "train", ctrl_mode_); 00121 setctrlState("mrs_string/mode", true); 00122 00123 addctrl("mrs_natural/nClasses", 1, ctrl_nClasses_); 00124 setctrlState("mrs_natural/nClasses", true); 00125 00126 addctrl("mrs_realvec/minimums", realvec(), ctrl_minimums_); 00127 addctrl("mrs_realvec/maximums", realvec(), ctrl_maximums_); 00128 addctrl("mrs_realvec/sv_coef", realvec(), ctrl_sv_coef_); 00129 addctrl("mrs_realvec/sv_indices", realvec(), ctrl_sv_indices_); 00130 addctrl("mrs_realvec/SV", realvec(), ctrl_SV_); 00131 addctrl("mrs_realvec/rho", realvec(), ctrl_rho_); 00132 addctrl("mrs_realvec/probA", realvec(), ctrl_probA_); 00133 addctrl("mrs_realvec/probB", realvec(), ctrl_probB_); 00134 addctrl("mrs_realvec/label", realvec(), ctrl_label_); 00135 addctrl("mrs_realvec/nSV", realvec(), ctrl_nSV_); 00136 addctrl("mrs_natural/nr_class", (mrs_natural)0, ctrl_nr_class_); 00137 addctrl("mrs_natural/l", (mrs_natural)0, ctrl_l_); 00138 addctrl("mrs_realvec/weight_label", realvec(), ctrl_weight_label_); 00139 addctrl("mrs_realvec/weight", realvec(), ctrl_weight_); 00140 addctrl("mrs_string/svm", "C_SVC", ctrl_svm_); 00141 setctrlState("mrs_string/svm", true); 00142 addctrl("mrs_string/kernel", "LINEAR", ctrl_kernel_);; 00143 setctrlState("mrs_string/kernel", true); 00144 addctrl("mrs_natural/degree", (mrs_natural)3, ctrl_degree_); 00145 addctrl("mrs_natural/gamma", (mrs_natural)4, ctrl_gamma_); 00146 addctrl("mrs_natural/coef0", (mrs_natural)0, ctrl_coef0_); 00147 addctrl("mrs_real/nu", (mrs_real)0.5, ctrl_nu_); 00148 addctrl("mrs_natural/cache_size", (mrs_natural)100, ctrl_cache_size_); 00149 addctrl("mrs_real/C", (mrs_real)1.0, ctrl_C_); 00150 addctrl("mrs_real/eps", (mrs_real)0.001, ctrl_eps_); 00151 addctrl("mrs_real/p", (mrs_real)0.1, ctrl_p_); 00152 addctrl("mrs_bool/shrinking", true, ctrl_shrinking_); 00153 addctrl("mrs_bool/probability", true, ctrl_probability_); 00154 addctrl("mrs_natural/nr_weight", (mrs_natural)0, ctrl_nr_weight_); 00155 addctrl("mrs_realvec/classPerms", realvec(), ctrl_classPerms_); 00156 00157 // turn off for regression 00158 addctrl("mrs_bool/output_classPerms", true); 00159 } 00160 00161 void SVMClassifier::myUpdate(MarControlPtr sender) { 00162 (void) sender; //suppress warning of unused parameter(s) 00163 MRSDIAG("SVMClassifier.cpp - SVMClassifier:myUpdate"); 00164 00165 ctrl_onSamples_->setValue(ctrl_inSamples_, NOUPDATE); 00166 mrs_natural nClasses = getctrl("mrs_natural/nClasses")->to<mrs_natural>(); 00167 ctrl_onObservations_->setValue(2 + nClasses, NOUPDATE); 00168 00169 if (ctrl_mode_->to<mrs_string>() == "train") { 00170 training_ = true; 00171 } else if (ctrl_mode_->to<mrs_string>() == "predict") { 00172 training_ = false; 00173 } else { 00174 cerr << "SVMClassifier.cpp, mode not supported"<<endl; 00175 exit(1); 00176 } 00177 00178 00179 if (ctrl_svm_->to<mrs_string>() == "C_SVC") 00180 svm_ = C_SVC; 00181 else if (ctrl_svm_->to<mrs_string>() == "ONE_CLASS") 00182 svm_ = ONE_CLASS; 00183 else if (ctrl_svm_->to<mrs_string>() == "EPSILON_SVR") 00184 svm_ = EPSILON_SVR; 00185 else if (ctrl_svm_->to<mrs_string>() == "NU_SVR") 00186 svm_ = NU_SVR; 00187 else 00188 { 00189 cerr << "SVMClassifier.cpp, SVM not supported"<<endl; 00190 exit(1); 00191 } 00192 00193 if (ctrl_kernel_->to<mrs_string>() == "LINEAR") 00194 kernel_ = LINEAR; 00195 else if (ctrl_kernel_->to<mrs_string>() == "POLY") 00196 kernel_ = POLY; 00197 else if (ctrl_kernel_->to<mrs_string>() == "RBF") 00198 kernel_ = RBF; 00199 else if (ctrl_kernel_->to<mrs_string>() == "SIGMOID") 00200 kernel_ = SIGMOID; 00201 else if (ctrl_kernel_->to<mrs_string>() == "PRECOMPUTED") 00202 kernel_ = PRECOMPUTED; 00203 else 00204 { 00205 cerr << "SVMClassifier.cpp, kernel not supported"<<endl; 00206 exit(1); 00207 } 00208 00209 if (!training_) { 00210 if (!trained_ && was_training_) { 00211 00212 // When network is switched from "train" mode to "predict", 00213 // we process the data instances which have been stored 00214 // in WekaData and pass them onto libsvm classes for actual training. 00215 00216 svm_param_.svm_type = svm_; 00217 svm_param_.kernel_type = kernel_; 00218 svm_param_.degree = ctrl_degree_->to<mrs_natural>(); 00219 svm_param_.gamma = ctrl_gamma_->to<mrs_natural>(); 00220 svm_param_.coef0 = ctrl_coef0_->to<mrs_natural>(); 00221 svm_param_.nu = ctrl_nu_->to<mrs_real>(); 00222 svm_param_.cache_size = ctrl_cache_size_->to<mrs_natural>(); 00223 svm_param_.C = ctrl_C_->to<mrs_real>(); 00224 svm_param_.eps = ctrl_eps_->to<mrs_real>(); 00225 svm_param_.p = ctrl_p_->to<mrs_real>(); 00226 svm_param_.shrinking = ctrl_shrinking_->to<mrs_bool>(); 00227 svm_param_.probability = ctrl_probability_->to<mrs_bool>(); 00228 svm_param_.nr_weight = ctrl_nr_weight_->to<mrs_natural>(); 00229 00230 if (svm_param_.nr_weight) { 00231 svm_param_.weight_label = Malloc(int,svm_param_.nr_weight); 00232 svm_param_.weight = Malloc(double,svm_param_.nr_weight); 00233 for (int i=0; i < svm_param_.nr_weight-1; ++i) { 00234 svm_param_.weight_label[i] 00235 = (int)ctrl_weight_label_->to<realvec>()(i); 00236 svm_param_.weight[i] 00237 = (double)ctrl_weight_->to<realvec>()(i); 00238 } 00239 } else { 00240 svm_param_.weight_label = NULL; 00241 svm_param_.weight = NULL; 00242 } 00243 00244 // normalize data 00245 instances_.NormMaxMin(); 00246 00247 // transfer training data instances into svm_problem 00248 mrs_natural nInstances = instances_.getRows(); 00249 svm_prob_.l = nInstances; 00250 00251 ensure_free_svm_prob_xy(); 00252 svm_prob_.y = new double[svm_prob_.l]; 00253 svm_prob_.x = new svm_node*[nInstances]; 00254 num_svm_prob_x = nInstances; 00255 00256 for (int i=0; i < nInstances; ++i) { 00257 svm_prob_.x[i] = NULL; 00258 } 00259 int l; 00260 mrs_bool seen; 00261 00262 for (int i=0; i < nInstances; ++i) 00263 { 00264 // set class (as number) for each of the instances 00265 l = instances_.GetClass(i); 00266 svm_prob_.y[i] = l; 00267 00268 // store all distinct classes in classPerms_ 00269 seen = false; 00270 for (size_t j=0; j < classPerms_.size(); j++) 00271 { 00272 if (l == classPerms_[j]) 00273 seen = true; 00274 } 00275 if (!seen) 00276 classPerms_.push_back(l); 00277 } 00278 00279 00280 { 00281 MarControlAccessor acc_classPerms(ctrl_classPerms_); 00282 realvec& classPerms = acc_classPerms.to<mrs_realvec>(); // ? 00283 classPerms.create((mrs_natural)classPerms_.size()); 00284 00285 for (mrs_natural i=0; i < (mrs_natural)classPerms_.size(); ++i) 00286 { 00287 classPerms(i) = classPerms_[i]; 00288 } 00289 } 00290 00291 // load each instance data into svm_nodes and store in svm_problem 00292 for (int i=0; i < nInstances; ++i) { 00293 svm_prob_.x[i] = new svm_node[inObservations_]; 00294 for (int j=0; j < inObservations_; j++) { 00295 if (j < inObservations_ -1) { 00296 svm_prob_.x[i][j].index = j+1; 00297 svm_prob_.x[i][j].value = instances_.at(i)->at(j); 00298 } else { 00299 svm_prob_.x[i][j].index = -1; 00300 svm_prob_.x[i][j].value = 0.0; 00301 } 00302 } 00303 } 00304 00305 const char *error_msg; 00306 error_msg = svm_check_parameter(&svm_prob_, &svm_param_); 00307 if (error_msg) { 00308 cerr << "SVMClassifier.cpp libsvm error: " << error_msg 00309 << endl; 00310 exit(1); 00311 } 00312 00313 ensure_free_svm_model(); 00314 svm_model_ = svm_train(&svm_prob_, &svm_param_); 00315 00316 trained_ = true; 00317 00318 MRSDEBUG ("SVMCLassifier train ... done"); 00319 MRSDEBUG ("svm_model_->nr_class = " << svm_model_->nr_class); 00320 MRSDEBUG ("svm_model_->l = " << svm_model_->l); 00321 MRSDEBUG ("svm_model_->free_sv = " << svm_model_->free_sv); 00322 MRSDEBUG ("svm_model_->SV = " << svm_model_->SV); 00323 00324 int n = 0; 00325 00326 00328 ctrl_minimums_->setValue(instances_.GetMinimums(), NOUPDATE); 00329 ctrl_maximums_->setValue(instances_.GetMaximums(), NOUPDATE); 00330 00331 00333 MarControlAccessor acc_sv_coef(ctrl_sv_coef_, NOUPDATE); 00334 realvec& sv_coef = acc_sv_coef.to<mrs_realvec>(); 00335 MarControlAccessor acc_sv_indices(ctrl_sv_indices_, NOUPDATE); 00336 realvec& sv_indices = acc_sv_indices.to<mrs_realvec>(); 00337 MarControlAccessor acc_SV(ctrl_SV_, NOUPDATE); 00338 realvec& SV = acc_SV.to<mrs_realvec>(); 00339 n = svm_model_->l; 00340 sv_coef.stretch(svm_model_->nr_class-1,n); 00341 sv_indices.stretch(n); 00342 SV.stretch(n, (inObservations_-1)); 00343 00344 for (int i=0; i<n; ++i) { 00345 for (int j=0; j<svm_model_->nr_class-1; j++) // for every class 00346 sv_coef(j, i)=svm_model_->sv_coef[j][i]; // copy coeff to sv_coef MarControl 00347 const svm_node *p = svm_model_->SV[i]; 00348 int ind = 0; 00349 while (p->index != -1) { // for every observation in the vector 00350 SV(i, ind)=p->value; // copy to SV MarControl 00351 p++; 00352 ind++; 00353 } 00354 } 00355 00357 00358 // rho 00359 { 00360 MarControlAccessor acc_rho(ctrl_rho_, NOUPDATE); 00361 realvec& rho = acc_rho.to<mrs_realvec>(); 00362 n = svm_model_->nr_class*(svm_model_->nr_class-1)/2; 00363 rho.stretch(n); 00364 for (int i=0; i<n; ++i) 00365 rho(i)=svm_model_->rho[i]; 00366 } 00367 00368 // probA 00369 if (svm_model_->probA) { 00370 MarControlAccessor acc_probA(ctrl_probA_, NOUPDATE); 00371 realvec& probA = acc_probA.to<mrs_realvec>(); 00372 n = svm_model_->nr_class*(svm_model_->nr_class-1)/2; 00373 probA.stretch(n); 00374 for (int i=0; i<n; ++i) 00375 probA(i)=svm_model_->probA[i]; 00376 } 00377 00378 // probB 00379 if (svm_model_->probB) { 00380 MarControlAccessor acc_probB(ctrl_probB_, NOUPDATE); 00381 realvec& probB = acc_probB.to<mrs_realvec>(); 00382 n = svm_model_->nr_class*(svm_model_->nr_class-1)/2; 00383 probB.stretch(n+1); 00384 for (int i=0; i<n; ++i) 00385 probB(i)=svm_model_->probB[i]; 00386 } 00387 00388 // label 00389 if (svm_model_->label) { 00390 MarControlAccessor acc_label(ctrl_label_, NOUPDATE); 00391 realvec& label = acc_label.to<mrs_realvec>(); 00392 n = svm_model_->nr_class; 00393 label.stretch(n); 00394 for (int i=0; i<n; ++i) 00395 label(i)=svm_model_->label[i]; 00396 } 00397 00398 // nSV 00399 if (svm_model_->nSV) { 00400 MarControlAccessor acc_nSV(ctrl_nSV_, NOUPDATE); 00401 realvec& nSV = acc_nSV.to<mrs_realvec>(); 00402 n = svm_model_->nr_class; 00403 nSV.stretch(n); 00404 for (int i=0; i<n; ++i) 00405 nSV(i)=svm_model_->nSV[i]; 00406 } 00407 00408 // if (svm_model_->nSV) { 00409 // n = svm_model_->nr_class; 00410 // realvec nSV(n); 00411 // for (int i=0; i<n; ++i) 00412 // nSV(i)=svm_model_->nSV[i]; 00413 // 00414 // ctrl_nSV_->setValue(nSV, NOUPDATE); 00415 // } 00416 00417 // nr_class 00418 ctrl_nr_class_->setValue(svm_model_->nr_class, NOUPDATE); 00419 00420 // l 00421 ctrl_l_->setValue(svm_model_->l, NOUPDATE); 00422 } 00423 } 00424 } 00425 00426 00427 00428 00429 00430 void SVMClassifier::myProcess(realvec& in, realvec& out) 00431 { 00432 00433 if (training_) { 00434 // Training here means simply inserting all input vectors to 00435 // the WekaData instances_. The actual training of libsvm classes 00436 // happens when we update the mode control to "predict". 00437 00438 if (!was_training_) { 00439 instances_.Create(inObservations_); 00440 trained_ = false; 00441 00442 } 00443 00444 instances_.Append(in); 00445 out(0,0) = in(inObservations_-1, 0); 00446 out(1,0) = in(inObservations_-1, 0); 00447 00448 00449 } else { // predict 00450 00451 if (!trained_) { 00452 if (was_training_) { 00453 ; 00454 } else { 00455 // Init libsvm structures and load data from 00456 // network controls into libsvm in cased they had been stored 00457 00458 svm_prob_.y = NULL; 00459 svm_prob_.x = NULL; 00460 svm_model_ = Malloc(svm_model,1); 00461 svm_model_->param.svm_type = svm_; 00462 svm_model_->param.weight_label = NULL; 00463 svm_model_->param.weight = NULL; 00464 svm_model_->param.kernel_type = kernel_; 00465 svm_model_->param.degree = ctrl_degree_->to<mrs_natural>(); 00466 svm_model_->param.gamma = ctrl_gamma_->to<mrs_natural>(); 00467 svm_model_->param.coef0 = ctrl_coef0_->to<mrs_natural>(); 00468 svm_model_->param.nu = ctrl_nu_->to<mrs_real>(); 00469 svm_model_->param.cache_size = ctrl_cache_size_->to<mrs_natural>(); 00470 svm_model_->param.C = ctrl_C_->to<mrs_real>(); 00471 svm_model_->param.eps = ctrl_eps_->to<mrs_real>(); 00472 svm_model_->param.p = ctrl_p_->to<mrs_real>(); 00473 svm_model_->param.shrinking = ctrl_shrinking_->to<mrs_bool>(); 00474 svm_model_->param.probability = ctrl_probability_->to<mrs_bool>(); 00475 svm_model_->param.nr_weight = ctrl_nr_weight_->to<mrs_natural>(); 00476 00477 { 00478 MarControlAccessor acc_classPerms(ctrl_classPerms_); 00479 realvec& classPerms = acc_classPerms.to<mrs_realvec>(); 00480 classPerms_.clear(); 00481 for (mrs_natural i=0; i < classPerms.getSize(); ++i) 00482 { 00483 classPerms_.push_back((mrs_natural)classPerms(i)); 00484 } 00485 } 00486 00487 MRSDEBUG ("svm_model_->param.svm_type = " << svm_model_->param.svm_type); 00488 MRSDEBUG ("svm_model_->param.kernel_type = " << svm_model_->param.kernel_type); 00489 MRSDEBUG ("svm_model_->param.degree = " << svm_model_->param.degree); 00490 MRSDEBUG ("svm_model_->param.gamma = " << svm_model_->param.gamma); 00491 MRSDEBUG ("svm_model_->param.coef0 = " << svm_model_->param.coef0); 00492 MRSDEBUG ("svm_model_->param.nu = " << svm_model_->param.nu); 00493 MRSDEBUG ("svm_model_->param.cache_size = " << svm_model_->param.cache_size); 00494 MRSDEBUG ("svm_model_->param.C = " << svm_model_->param.C); 00495 MRSDEBUG ("svm_model_->param.eps = " << svm_model_->param.eps); 00496 MRSDEBUG ("svm_model_->param.p = " << svm_model_->param.p); 00497 MRSDEBUG ("svm_model_->param.shrinking = " << svm_model_->param.shrinking); 00498 MRSDEBUG ("svm_model_->param.probability = " << svm_model_->param.probability); 00499 MRSDEBUG ("svm_model_->param.nr_weight = " << svm_model_->param.nr_weight); 00500 MRSDEBUG ("svm_model_->param.weight_label = " << svm_model_->param.weight_label); 00501 MRSDEBUG ("svm_model_->param.weight = " << svm_model_->param.weight); 00502 00503 int n = ctrl_nr_class_->to<mrs_natural>(); 00504 int l = ctrl_l_->to<mrs_natural>(); 00505 int m = n*(n-1)/2; 00506 00507 svm_model_->nr_class = n; 00508 svm_model_->l = l; 00509 00510 MRSDEBUG ("svm_model_->nr_class = " << svm_model_->nr_class); 00511 MRSDEBUG ("svm_model_->l = " << svm_model_->l); 00512 00513 00514 00516 00517 if (ctrl_rho_->to<realvec>().getSize()) //rho 00518 { 00519 svm_model_->rho = Malloc(double,m); 00520 for (int i=0; i<m; ++i) 00521 svm_model_->rho[i]= ctrl_rho_->to<realvec>()(i); 00522 } 00523 else 00524 svm_model_->rho = NULL; 00525 00526 if (ctrl_probA_->to<realvec>().getSize()) //probA 00527 { 00528 svm_model_->probA = Malloc(double,m); 00529 for (int i=0; i<m; ++i) 00530 svm_model_->probA[i] = ctrl_probA_->to<realvec>()(i); 00531 } 00532 else 00533 svm_model_->probA = NULL; 00534 00535 if (ctrl_probB_->to<realvec>().getSize()) //probB 00536 { 00537 svm_model_->probB = Malloc(double,m); 00538 for (int i=0; i<m; ++i) 00539 svm_model_->probB[i]=ctrl_probB_->to<realvec>()(i); 00540 } 00541 else 00542 svm_model_->probB = NULL; 00543 00544 if (ctrl_label_->to<realvec>().getSize()) //label 00545 { 00546 svm_model_->label = Malloc(int,n); 00547 for (int i=0; i<n; ++i) 00548 svm_model_->label[i] 00549 = (int)ctrl_label_->to<realvec>()(i); 00550 } 00551 else 00552 svm_model_->label = NULL; 00553 00554 if (ctrl_nSV_->to<realvec>().getSize()) //nr_sv 00555 { 00556 svm_model_->nSV = Malloc(int,n); 00557 for (int i=0; i<n; ++i) 00558 svm_model_->nSV[i]=(int)ctrl_nSV_->to<realvec>()(i); 00559 } 00560 else 00561 svm_model_->nSV = NULL; 00562 00563 #ifdef MARSYAS_LOG_DEBUGS 00564 if (svm_model_->rho) { 00565 MRSDEBUG("svm_model_->rho ="); 00566 for (int i=0; i<m; ++i) 00567 cout << " "<< svm_model_->rho[i]; 00568 cout << endl;; 00569 } 00570 00571 if (svm_model_->probA) { 00572 MRSDEBUG("svm_model_->probA ="); 00573 for (int i=0; i<m; ++i) 00574 cout << " " << svm_model_->probA[i]; 00575 cout << endl;; 00576 } 00577 00578 if (svm_model_->probB) { 00579 MRSDEBUG("svm_model_->probB ="); 00580 for (int i=0; i<m; ++i) 00581 cout << " " << svm_model_->probB[i]; 00582 cout << endl;; 00583 } 00584 00585 if (svm_model_->label) { 00586 MRSDEBUG("svm_model_->label ="); 00587 for (int i=0; i<n; ++i) 00588 cout << " " << svm_model_->label[i]; 00589 cout << endl;; 00590 } 00591 00592 if (svm_model_->nSV) { 00593 MRSDEBUG("svm_model_->nSV ="); 00594 for (int i=0; i<n; ++i) 00595 cout << " " << svm_model_->nSV[i]; 00596 cout << endl;; 00597 } 00598 00599 #endif 00600 --n; 00601 m = ctrl_SV_->to<realvec>().getCols(); 00602 00603 if (ctrl_sv_coef_->to<realvec>().getSize()) // sv_coef 00604 { 00605 svm_model_->sv_coef = Malloc(double *,n); 00606 for (int i=0; i<n; ++i) 00607 svm_model_->sv_coef[i] = Malloc(double,l); 00608 for (int i=0; i<l; ++i) 00609 for (int k=0; k<n; k++) 00610 svm_model_->sv_coef[k][i] 00611 =ctrl_sv_coef_->to<realvec>()(k, i); 00612 } 00613 svm_model_->sv_indices = Malloc(int, n); 00614 for (int i=0; i<l; i++) { 00615 // FIXME: interface with libsvm 00616 } 00617 00618 if (ctrl_SV_->to<realvec>().getSize()) // SV 00619 { 00620 svm_model_->SV = Malloc(svm_node*,l); 00621 svm_node *x_space=NULL; 00622 if (l>0) { 00623 x_space = Malloc(svm_node, 2*m*l); 00624 num_nodes++; 00625 } 00626 int j=0; 00627 for (int i=0; i<l; ++i) { 00628 svm_model_->SV[i] = &x_space[j]; 00629 for (int k = 0; k < m; k++) { 00630 x_space[j].index = k+1; 00631 x_space[j].value = ctrl_SV_->to<realvec>()(i, k); 00632 ++j; 00633 } 00634 x_space[j++].index = -1; 00635 } 00636 } 00637 00638 #ifdef MARSYAS_LOG_DEBUGS 00639 00640 MRSDEBUG ("svm_model_->SV = "); 00641 00642 { 00643 for (int i=0; i<l; ++i) { 00644 for (int j=0; j<n; j++) { 00645 if (svm_model_->sv_coef) 00646 cout <<svm_model_->sv_coef[j][i] << " "; 00647 if (svm_model_->SV) { 00648 const svm_node *p = svm_model_->SV[i]; 00649 while (p->index != -1) { 00650 cout << p->index << ":"; 00651 cout << p->value << " "; 00652 p++; 00653 } 00654 } 00655 cout << endl;; 00656 } 00657 } 00658 } 00659 #endif 00660 svm_model_->free_sv = 1; 00661 MRSDEBUG ("svm_model_->free_sv = " << svm_model_->free_sv); 00662 trained_ = true; 00663 00664 #ifdef MARSYAS_LOG_DEBUGS 00665 { 00666 MRSDEBUG ("mini/maxi : "); 00667 realvec mini = ctrl_minimums_->to<mrs_realvec>(); 00668 realvec maxi = ctrl_maximums_->to<mrs_realvec>(); 00669 for (int i=0; i<inObservations_ -1; ++i) 00670 cout << "mini(" << i << ")" << mini(i) << " maxi(" 00671 << i << ")" << maxi(i) << endl; 00672 } 00673 #endif 00674 } // if(!was_training) 00675 }// if(!trained) 00676 00677 00679 00680 struct svm_node* xv = new svm_node[inObservations_]; 00681 double* probs = new double[svm_model_->nr_class]; 00682 00683 // Get the minimum and maximum values to which the 00684 // training set was normalized. 00685 realvec mini = ctrl_minimums_->to<mrs_realvec>(); 00686 realvec maxi = ctrl_maximums_->to<mrs_realvec>(); 00687 00688 // Scale our input the same as our trainingset 00689 for (int i=0; i<inObservations_ -1; ++i) 00690 in(i, 0) = (in(i, 0) - mini(i)) / (maxi(i) - mini(i)); 00691 00692 // Copy our input to an SV structure 00693 for (int j=0; j < inObservations_; ++j) { 00694 if (j < inObservations_ -1) { 00695 xv[j].index = j+1; 00696 xv[j].value = in(j, 0); 00697 } else { 00698 // The last index value in the SV is always set to -1 00699 xv[j].index = -1; 00700 xv[j].value = 0.0; 00701 } 00702 } 00703 00704 double prediction = 0.0; 00705 00706 if (ctrl_probability_->to<mrs_bool>()) 00707 prediction = svm_predict_probability(svm_model_, xv, probs); 00708 else 00709 prediction = svm_predict(svm_model_, xv); 00710 00711 00712 // Output 00713 if (getctrl("mrs_bool/output_classPerms")->isTrue()) { 00714 for (int i=0; i < svm_model_->nr_class; ++i) { 00715 out(2 + classPerms_[i], 0) = probs[i]; 00716 } 00717 } 00718 00719 out(0,0) = (mrs_real)prediction; 00720 out(1,0) = in(inObservations_-1,0); 00721 00722 00723 // Cleanup 00724 delete [] xv; 00725 delete [] probs; 00726 } 00727 was_training_ = training_; 00728 }