Marsyas  0.6.0-alpha
/usr/src/RPM/BUILD/marsyas-0.6.0/src/marsyas/marsystems/SVMLinearClassifier.cpp
Go to the documentation of this file.
00001 /*
00002 ** Copyright (C) 1998-2006 George Tzanetakis <gtzan@cs.uvic.ca>
00003 **
00004 ** This program is free software; you can redistribute it and/or modify
00005 ** it under the terms of the GNU General Public License as published by
00006 ** the Free Software Foundation; either version 2 of the License, or
00007 ** (at your option) any later version.
00008 **
00009 ** This program is distributed in the hope that it will be useful,
00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00012 ** GNU General Public License for more details.
00013 **
00014 ** You should have received a copy of the GNU General Public License
00015 ** along with this program; if not, write to the Free Software
00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00017 */
00018 
00019 #include "../common_source.h"
00020 #include "SVMLinearClassifier.h"
00021 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
00022 
00023 using namespace std;
00024 using namespace Marsyas;
00025 
00026 SVMClassifier::SVMClassifier(mrs_string name) :
00027   MarSystem("SVMClassifier", name) {
00028   training_ = true;
00029   was_training_ = false;
00030   trained_ = false;
00031   kernel_ = LINEAR;
00032   svm_ = C_SVC;
00033   svm_model_ = NULL;
00034   num_nodes = 0;
00035   svm_prob_.y = NULL;
00036   svm_prob_.x = NULL;
00037 
00038   addControls();
00039 }
00040 
00041 SVMClassifier::SVMClassifier(const SVMClassifier& a) :
00042   MarSystem(a) {
00043   training_ = true;
00044   was_training_ = false;
00045   trained_ = false;
00046   kernel_ = LINEAR;
00047   svm_ = C_SVC;
00048   svm_model_ = NULL;
00049   num_nodes = 0;
00050   svm_prob_.y = NULL;
00051   svm_prob_.x = NULL;
00052 
00053   ctrl_nClasses_ = getctrl("mrs_natural/nClasses");
00054   ctrl_sv_coef_ = getctrl("mrs_realvec/sv_coef");
00055   ctrl_sv_indices_ = getctrl("mrs_realvec/sv_indices");
00056   ctrl_SV_ = getctrl("mrs_realvec/SV");
00057   ctrl_rho_ = getctrl("mrs_realvec/rho");
00058   ctrl_probA_ = getctrl("mrs_realvec/probA");
00059   ctrl_probB_ = getctrl("mrs_realvec/probB");
00060   ctrl_label_ = getctrl("mrs_realvec/label");
00061   ctrl_nSV_ = getctrl("mrs_realvec/nSV");
00062   ctrl_nr_class_ = getctrl("mrs_natural/nr_class");
00063   ctrl_weight_ = getctrl("mrs_realvec/weight");
00064   ctrl_weight_label_ = getctrl("mrs_realvec/weight_label");
00065   ctrl_minimums_ = getctrl("mrs_realvec/minimums");
00066   ctrl_maximums_ = getctrl("mrs_realvec/maximums");
00067   ctrl_mode_ = getctrl("mrs_string/mode");
00068   ctrl_l_ = getctrl("mrs_natural/l");
00069   ctrl_svm_ = getctrl("mrs_string/svm");
00070   ctrl_kernel_ = getctrl("mrs_string/kernel");
00071   ctrl_degree_ = getctrl("mrs_natural/degree");
00072   ctrl_gamma_ = getctrl("mrs_natural/gamma");
00073   ctrl_coef0_ = getctrl("mrs_natural/coef0");
00074   ctrl_nu_ = getctrl("mrs_real/nu");
00075   ctrl_cache_size_ = getctrl("mrs_natural/cache_size");
00076   ctrl_C_ = getctrl("mrs_real/C");
00077   ctrl_eps_ = getctrl("mrs_real/eps");
00078   ctrl_p_ = getctrl("mrs_real/p");
00079   ctrl_shrinking_ = getctrl("mrs_bool/shrinking");
00080   ctrl_probability_ = getctrl("mrs_bool/probability");
00081   ctrl_nr_weight_ = getctrl("mrs_natural/nr_weight");
00082   ctrl_classPerms_ = getctrl("mrs_realvec/classPerms");
00083 
00084 }
00085 
00086 SVMClassifier::~SVMClassifier() {
00087   ensure_free_svm_model();
00088   ensure_free_svm_prob_xy();
00089 }
00090 
00091 void SVMClassifier::ensure_free_svm_model() {
00092   if (svm_model_ != NULL)
00093   {
00094     svm_free_and_destroy_model(&svm_model_);
00095   }
00096 }
00097 
00098 void SVMClassifier::ensure_free_svm_prob_xy() {
00099   if (svm_prob_.x != NULL) {
00100     for (int i=0; i < num_svm_prob_x; ++i) {
00101       if (svm_prob_.x[i] != NULL) {
00102         delete [] svm_prob_.x[i];
00103         svm_prob_.x[i] = NULL;
00104       }
00105     }
00106     delete [] svm_prob_.x;
00107     svm_prob_.x = NULL;
00108   }
00109   if (svm_prob_.y != NULL) {
00110     delete [] svm_prob_.y;
00111     svm_prob_.y = NULL;
00112   }
00113 }
00114 
00115 MarSystem* SVMClassifier::clone() const {
00116   return new SVMClassifier(*this);
00117 }
00118 
00119 void SVMClassifier::addControls() {
00120   addctrl("mrs_string/mode", "train", ctrl_mode_);
00121   setctrlState("mrs_string/mode", true);
00122 
00123   addctrl("mrs_natural/nClasses", 1, ctrl_nClasses_);
00124   setctrlState("mrs_natural/nClasses", true);
00125 
00126   addctrl("mrs_realvec/minimums", realvec(), ctrl_minimums_);
00127   addctrl("mrs_realvec/maximums", realvec(), ctrl_maximums_);
00128   addctrl("mrs_realvec/sv_coef", realvec(), ctrl_sv_coef_);
00129   addctrl("mrs_realvec/sv_indices", realvec(), ctrl_sv_indices_);
00130   addctrl("mrs_realvec/SV", realvec(), ctrl_SV_);
00131   addctrl("mrs_realvec/rho", realvec(), ctrl_rho_);
00132   addctrl("mrs_realvec/probA", realvec(), ctrl_probA_);
00133   addctrl("mrs_realvec/probB", realvec(), ctrl_probB_);
00134   addctrl("mrs_realvec/label", realvec(), ctrl_label_);
00135   addctrl("mrs_realvec/nSV", realvec(), ctrl_nSV_);
00136   addctrl("mrs_natural/nr_class", (mrs_natural)0, ctrl_nr_class_);
00137   addctrl("mrs_natural/l", (mrs_natural)0, ctrl_l_);
00138   addctrl("mrs_realvec/weight_label", realvec(), ctrl_weight_label_);
00139   addctrl("mrs_realvec/weight", realvec(), ctrl_weight_);
00140   addctrl("mrs_string/svm", "C_SVC", ctrl_svm_);
00141   setctrlState("mrs_string/svm", true);
00142   addctrl("mrs_string/kernel", "LINEAR", ctrl_kernel_);;
00143   setctrlState("mrs_string/kernel", true);
00144   addctrl("mrs_natural/degree", (mrs_natural)3, ctrl_degree_);
00145   addctrl("mrs_natural/gamma", (mrs_natural)4, ctrl_gamma_);
00146   addctrl("mrs_natural/coef0", (mrs_natural)0, ctrl_coef0_);
00147   addctrl("mrs_real/nu", (mrs_real)0.5, ctrl_nu_);
00148   addctrl("mrs_natural/cache_size", (mrs_natural)100, ctrl_cache_size_);
00149   addctrl("mrs_real/C", (mrs_real)1.0, ctrl_C_);
00150   addctrl("mrs_real/eps", (mrs_real)0.001, ctrl_eps_);
00151   addctrl("mrs_real/p", (mrs_real)0.1, ctrl_p_);
00152   addctrl("mrs_bool/shrinking", true, ctrl_shrinking_);
00153   addctrl("mrs_bool/probability", true, ctrl_probability_);
00154   addctrl("mrs_natural/nr_weight", (mrs_natural)0, ctrl_nr_weight_);
00155   addctrl("mrs_realvec/classPerms", realvec(), ctrl_classPerms_);
00156 
00157   // turn off for regression
00158   addctrl("mrs_bool/output_classPerms", true);
00159 }
00160 
00161 void SVMClassifier::myUpdate(MarControlPtr sender) {
00162   (void) sender;  //suppress warning of unused parameter(s)
00163   MRSDIAG("SVMClassifier.cpp - SVMClassifier:myUpdate");
00164 
00165   ctrl_onSamples_->setValue(ctrl_inSamples_, NOUPDATE);
00166   mrs_natural nClasses = getctrl("mrs_natural/nClasses")->to<mrs_natural>();
00167   ctrl_onObservations_->setValue(2 + nClasses, NOUPDATE);
00168 
00169   if (ctrl_mode_->to<mrs_string>() == "train") {
00170     training_ = true;
00171   } else if (ctrl_mode_->to<mrs_string>() == "predict") {
00172     training_ = false;
00173   } else {
00174     cerr << "SVMClassifier.cpp, mode not supported"<<endl;
00175     exit(1);
00176   }
00177 
00178 
00179   if (ctrl_svm_->to<mrs_string>() == "C_SVC")
00180     svm_ = C_SVC;
00181   else if (ctrl_svm_->to<mrs_string>() == "ONE_CLASS")
00182     svm_ = ONE_CLASS;
00183   else if (ctrl_svm_->to<mrs_string>() == "EPSILON_SVR")
00184     svm_ = EPSILON_SVR;
00185   else if (ctrl_svm_->to<mrs_string>() == "NU_SVR")
00186     svm_ = NU_SVR;
00187   else
00188   {
00189     cerr << "SVMClassifier.cpp, SVM not supported"<<endl;
00190     exit(1);
00191   }
00192 
00193   if (ctrl_kernel_->to<mrs_string>() == "LINEAR")
00194     kernel_ = LINEAR;
00195   else if (ctrl_kernel_->to<mrs_string>() == "POLY")
00196     kernel_ = POLY;
00197   else if (ctrl_kernel_->to<mrs_string>() == "RBF")
00198     kernel_ = RBF;
00199   else if (ctrl_kernel_->to<mrs_string>() == "SIGMOID")
00200     kernel_ = SIGMOID;
00201   else if (ctrl_kernel_->to<mrs_string>() == "PRECOMPUTED")
00202     kernel_ = PRECOMPUTED;
00203   else
00204   {
00205     cerr << "SVMClassifier.cpp, kernel not supported"<<endl;
00206     exit(1);
00207   }
00208 
00209   if (!training_) {
00210     if (!trained_ && was_training_) {
00211 
00212       // When network is switched from "train" mode to "predict",
00213       // we process the data instances which have been stored
00214       // in WekaData and pass them onto libsvm classes for actual training.
00215 
00216       svm_param_.svm_type = svm_;
00217       svm_param_.kernel_type = kernel_;
00218       svm_param_.degree = ctrl_degree_->to<mrs_natural>();
00219       svm_param_.gamma = ctrl_gamma_->to<mrs_natural>();
00220       svm_param_.coef0 = ctrl_coef0_->to<mrs_natural>();
00221       svm_param_.nu = ctrl_nu_->to<mrs_real>();
00222       svm_param_.cache_size = ctrl_cache_size_->to<mrs_natural>();
00223       svm_param_.C = ctrl_C_->to<mrs_real>();
00224       svm_param_.eps = ctrl_eps_->to<mrs_real>();
00225       svm_param_.p = ctrl_p_->to<mrs_real>();
00226       svm_param_.shrinking = ctrl_shrinking_->to<mrs_bool>();
00227       svm_param_.probability = ctrl_probability_->to<mrs_bool>();
00228       svm_param_.nr_weight = ctrl_nr_weight_->to<mrs_natural>();
00229 
00230       if (svm_param_.nr_weight) {
00231         svm_param_.weight_label = Malloc(int,svm_param_.nr_weight);
00232         svm_param_.weight = Malloc(double,svm_param_.nr_weight);
00233         for (int i=0; i < svm_param_.nr_weight-1; ++i) {
00234           svm_param_.weight_label[i]
00235           = (int)ctrl_weight_label_->to<realvec>()(i);
00236           svm_param_.weight[i]
00237           = (double)ctrl_weight_->to<realvec>()(i);
00238         }
00239       } else {
00240         svm_param_.weight_label = NULL;
00241         svm_param_.weight = NULL;
00242       }
00243 
00244       // normalize data
00245       instances_.NormMaxMin();
00246 
00247       // transfer training data instances into svm_problem
00248       mrs_natural nInstances = instances_.getRows();
00249       svm_prob_.l = nInstances;
00250 
00251       ensure_free_svm_prob_xy();
00252       svm_prob_.y = new double[svm_prob_.l];
00253       svm_prob_.x = new svm_node*[nInstances];
00254       num_svm_prob_x = nInstances;
00255 
00256       for (int i=0; i < nInstances; ++i) {
00257         svm_prob_.x[i] = NULL;
00258       }
00259       int l;
00260       mrs_bool seen;
00261 
00262       for (int i=0; i < nInstances; ++i)
00263       {
00264         // set class (as number) for each of the instances
00265         l = instances_.GetClass(i);
00266         svm_prob_.y[i] = l;
00267 
00268         // store all distinct classes in classPerms_
00269         seen = false;
00270         for (size_t j=0; j < classPerms_.size(); j++)
00271         {
00272           if (l == classPerms_[j])
00273             seen = true;
00274         }
00275         if (!seen)
00276           classPerms_.push_back(l);
00277       }
00278 
00279 
00280       {
00281         MarControlAccessor acc_classPerms(ctrl_classPerms_);
00282         realvec& classPerms = acc_classPerms.to<mrs_realvec>(); // ?
00283         classPerms.create(classPerms_.size());
00284 
00285         for (size_t i=0; i < classPerms_.size(); ++i)
00286         {
00287           classPerms(i) = classPerms_[i];
00288         }
00289       }
00290 
00291       // load each instance data into svm_nodes and store in svm_problem
00292       for (int i=0; i < nInstances; ++i) {
00293         svm_prob_.x[i] = new svm_node[inObservations_];
00294         for (int j=0; j < inObservations_; j++) {
00295           if (j < inObservations_ -1) {
00296             svm_prob_.x[i][j].index = j+1;
00297             svm_prob_.x[i][j].value = instances_.at(i)->at(j);
00298           } else {
00299             svm_prob_.x[i][j].index = -1;
00300             svm_prob_.x[i][j].value = 0.0;
00301           }
00302         }
00303       }
00304 
00305       const char *error_msg;
00306       error_msg = svm_check_parameter(&svm_prob_, &svm_param_);
00307       if (error_msg) {
00308         cerr << "SVMClassifier.cpp libsvm error: " << error_msg
00309              << endl;
00310         exit(1);
00311       }
00312 
00313       ensure_free_svm_model();
00314       svm_model_ = svm_train(&svm_prob_, &svm_param_);
00315 
00316       trained_ = true;
00317 
00318       MRSDEBUG ("SVMCLassifier train ... done");
00319       MRSDEBUG ("svm_model_->nr_class = " << svm_model_->nr_class);
00320       MRSDEBUG ("svm_model_->l = " << svm_model_->l);
00321       MRSDEBUG ("svm_model_->free_sv = " << svm_model_->free_sv);
00322       MRSDEBUG ("svm_model_->SV = " << svm_model_->SV);
00323 
00324       int n = 0;
00325 
00326 
00328       ctrl_minimums_->setValue(instances_.GetMinimums(), NOUPDATE);
00329       ctrl_maximums_->setValue(instances_.GetMaximums(), NOUPDATE);
00330 
00331 
00333       MarControlAccessor acc_sv_coef(ctrl_sv_coef_, NOUPDATE);
00334       realvec& sv_coef = acc_sv_coef.to<mrs_realvec>();
00335       MarControlAccessor acc_sv_indices(ctrl_sv_indices_, NOUPDATE);
00336       realvec& sv_indices = acc_sv_indices.to<mrs_realvec>();
00337       MarControlAccessor acc_SV(ctrl_SV_, NOUPDATE);
00338       realvec& SV = acc_SV.to<mrs_realvec>();
00339       n = svm_model_->l;
00340       sv_coef.stretch(svm_model_->nr_class-1,n);
00341       sv_indices.stretch(n);
00342       SV.stretch(n, (inObservations_-1));
00343 
00344       for (int i=0; i<n; ++i) {
00345         for (int j=0; j<svm_model_->nr_class-1; j++) // for every class
00346           sv_coef(j, i)=svm_model_->sv_coef[j][i]; // copy coeff to sv_coef MarControl
00347         const svm_node *p = svm_model_->SV[i];
00348         int ind = 0;
00349         while (p->index != -1) { // for every observation in the vector
00350           SV(i, ind)=p->value; // copy to SV MarControl
00351           p++;
00352           ind++;
00353         }
00354       }
00355 
00357 
00358       // rho
00359       {
00360         MarControlAccessor acc_rho(ctrl_rho_, NOUPDATE);
00361         realvec& rho = acc_rho.to<mrs_realvec>();
00362         n = svm_model_->nr_class*(svm_model_->nr_class-1)/2;
00363         rho.stretch(n);
00364         for (int i=0; i<n; ++i)
00365           rho(i)=svm_model_->rho[i];
00366       }
00367 
00368       // probA
00369       if (svm_model_->probA) {
00370         MarControlAccessor acc_probA(ctrl_probA_, NOUPDATE);
00371         realvec& probA = acc_probA.to<mrs_realvec>();
00372         n = svm_model_->nr_class*(svm_model_->nr_class-1)/2;
00373         probA.stretch(n);
00374         for (int i=0; i<n; ++i)
00375           probA(i)=svm_model_->probA[i];
00376       }
00377 
00378       // probB
00379       if (svm_model_->probB) {
00380         MarControlAccessor acc_probB(ctrl_probB_, NOUPDATE);
00381         realvec& probB = acc_probB.to<mrs_realvec>();
00382         n = svm_model_->nr_class*(svm_model_->nr_class-1)/2;
00383         probB.stretch(n+1);
00384         for (int i=0; i<n; ++i)
00385           probB(i)=svm_model_->probB[i];
00386       }
00387 
00388       // label
00389       if (svm_model_->label) {
00390         MarControlAccessor acc_label(ctrl_label_, NOUPDATE);
00391         realvec& label = acc_label.to<mrs_realvec>();
00392         n = svm_model_->nr_class;
00393         label.stretch(n);
00394         for (int i=0; i<n; ++i)
00395           label(i)=svm_model_->label[i];
00396       }
00397 
00398       // nSV
00399       if (svm_model_->nSV) {
00400         MarControlAccessor acc_nSV(ctrl_nSV_, NOUPDATE);
00401         realvec& nSV = acc_nSV.to<mrs_realvec>();
00402         n = svm_model_->nr_class;
00403         nSV.stretch(n);
00404         for (int i=0; i<n; ++i)
00405           nSV(i)=svm_model_->nSV[i];
00406       }
00407 
00408 //          if (svm_model_->nSV) {
00409 //              n = svm_model_->nr_class;
00410 //              realvec nSV(n);
00411 //              for (int i=0; i<n; ++i)
00412 //                  nSV(i)=svm_model_->nSV[i];
00413 //
00414 //              ctrl_nSV_->setValue(nSV, NOUPDATE);
00415 //          }
00416 
00417       // nr_class
00418       ctrl_nr_class_->setValue(svm_model_->nr_class, NOUPDATE);
00419 
00420       // l
00421       ctrl_l_->setValue(svm_model_->l, NOUPDATE);
00422     }
00423   }
00424 }
00425 
00426 
00427 
00428 
00429 
00430 void SVMClassifier::myProcess(realvec& in, realvec& out)
00431 {
00432 
00433   if (training_) {
00434     // Training here means simply inserting all input vectors to
00435     // the WekaData instances_. The actual training of libsvm classes
00436     // happens when we update the mode control to "predict".
00437 
00438     if (!was_training_) {
00439       instances_.Create(inObservations_);
00440       trained_ = false;
00441 
00442     }
00443 
00444     instances_.Append(in);
00445     out(0,0) = in(inObservations_-1, 0);
00446     out(1,0) = in(inObservations_-1, 0);
00447 
00448 
00449   } else {  // predict
00450 
00451     if (!trained_) {
00452       if (was_training_) {
00453         ;
00454       } else {
00455         // Init libsvm structures and load data from
00456         // network controls into libsvm in cased they had been stored
00457 
00458         svm_prob_.y = NULL;
00459         svm_prob_.x = NULL;
00460         svm_model_ = Malloc(svm_model,1);
00461         svm_model_->param.svm_type = svm_;
00462         svm_model_->param.weight_label = NULL;
00463         svm_model_->param.weight = NULL;
00464         svm_model_->param.kernel_type = kernel_;
00465         svm_model_->param.degree = ctrl_degree_->to<mrs_natural>();
00466         svm_model_->param.gamma = ctrl_gamma_->to<mrs_natural>();
00467         svm_model_->param.coef0 = ctrl_coef0_->to<mrs_natural>();
00468         svm_model_->param.nu = ctrl_nu_->to<mrs_real>();
00469         svm_model_->param.cache_size = ctrl_cache_size_->to<mrs_natural>();
00470         svm_model_->param.C = ctrl_C_->to<mrs_real>();
00471         svm_model_->param.eps = ctrl_eps_->to<mrs_real>();
00472         svm_model_->param.p = ctrl_p_->to<mrs_real>();
00473         svm_model_->param.shrinking = ctrl_shrinking_->to<mrs_bool>();
00474         svm_model_->param.probability = ctrl_probability_->to<mrs_bool>();
00475         svm_model_->param.nr_weight = ctrl_nr_weight_->to<mrs_natural>();
00476 
00477         {
00478           MarControlAccessor acc_classPerms(ctrl_classPerms_);
00479           realvec& classPerms = acc_classPerms.to<mrs_realvec>();
00480           classPerms_.clear();
00481           for (mrs_natural i=0; i < classPerms.getSize(); ++i)
00482           {
00483             classPerms_.push_back((mrs_natural)classPerms(i));
00484           }
00485         }
00486 
00487         MRSDEBUG ("svm_model_->param.svm_type = " << svm_model_->param.svm_type);
00488         MRSDEBUG ("svm_model_->param.kernel_type = " << svm_model_->param.kernel_type);
00489         MRSDEBUG ("svm_model_->param.degree = " << svm_model_->param.degree);
00490         MRSDEBUG ("svm_model_->param.gamma = " << svm_model_->param.gamma);
00491         MRSDEBUG ("svm_model_->param.coef0 = " << svm_model_->param.coef0);
00492         MRSDEBUG ("svm_model_->param.nu = " << svm_model_->param.nu);
00493         MRSDEBUG ("svm_model_->param.cache_size = " << svm_model_->param.cache_size);
00494         MRSDEBUG ("svm_model_->param.C = " << svm_model_->param.C);
00495         MRSDEBUG ("svm_model_->param.eps = " << svm_model_->param.eps);
00496         MRSDEBUG ("svm_model_->param.p = " << svm_model_->param.p);
00497         MRSDEBUG ("svm_model_->param.shrinking = " << svm_model_->param.shrinking);
00498         MRSDEBUG ("svm_model_->param.probability = " << svm_model_->param.probability);
00499         MRSDEBUG ("svm_model_->param.nr_weight = " << svm_model_->param.nr_weight);
00500         MRSDEBUG ("svm_model_->param.weight_label = " << svm_model_->param.weight_label);
00501         MRSDEBUG ("svm_model_->param.weight = " << svm_model_->param.weight);
00502 
00503         int n = ctrl_nr_class_->to<mrs_natural>();
00504         int l = ctrl_l_->to<mrs_natural>();
00505         int m = n*(n-1)/2;
00506 
00507         svm_model_->nr_class = n;
00508         svm_model_->l = l;
00509 
00510         MRSDEBUG ("svm_model_->nr_class = " << svm_model_->nr_class);
00511         MRSDEBUG ("svm_model_->l = " << svm_model_->l);
00512 
00513 
00514 
00516 
00517         if (ctrl_rho_->to<realvec>().getSize()) //rho
00518         {
00519           svm_model_->rho = Malloc(double,m);
00520           for (int i=0; i<m; ++i)
00521             svm_model_->rho[i]= ctrl_rho_->to<realvec>()(i);
00522         }
00523         else
00524           svm_model_->rho = NULL;
00525 
00526         if (ctrl_probA_->to<realvec>().getSize()) //probA
00527         {
00528           svm_model_->probA = Malloc(double,m);
00529           for (int i=0; i<m; ++i)
00530             svm_model_->probA[i] = ctrl_probA_->to<realvec>()(i);
00531         }
00532         else
00533           svm_model_->probA = NULL;
00534 
00535         if (ctrl_probB_->to<realvec>().getSize()) //probB
00536         {
00537           svm_model_->probB = Malloc(double,m);
00538           for (int i=0; i<m; ++i)
00539             svm_model_->probB[i]=ctrl_probB_->to<realvec>()(i);
00540         }
00541         else
00542           svm_model_->probB = NULL;
00543 
00544         if (ctrl_label_->to<realvec>().getSize()) //label
00545         {
00546           svm_model_->label = Malloc(int,n);
00547           for (int i=0; i<n; ++i)
00548             svm_model_->label[i]
00549             = (int)ctrl_label_->to<realvec>()(i);
00550         }
00551         else
00552           svm_model_->label = NULL;
00553 
00554         if (ctrl_nSV_->to<realvec>().getSize()) //nr_sv
00555         {
00556           svm_model_->nSV = Malloc(int,n);
00557           for (int i=0; i<n; ++i)
00558             svm_model_->nSV[i]=(int)ctrl_nSV_->to<realvec>()(i);
00559         }
00560         else
00561           svm_model_->nSV = NULL;
00562 
00563 #ifdef MARSYAS_LOG_DEBUGS
00564         if (svm_model_->rho) {
00565           MRSDEBUG("svm_model_->rho =");
00566           for (int i=0; i<m; ++i)
00567             cout << " "<< svm_model_->rho[i];
00568           cout << endl;;
00569         }
00570 
00571         if (svm_model_->probA) {
00572           MRSDEBUG("svm_model_->probA =");
00573           for (int i=0; i<m; ++i)
00574             cout << " " << svm_model_->probA[i];
00575           cout << endl;;
00576         }
00577 
00578         if (svm_model_->probB) {
00579           MRSDEBUG("svm_model_->probB =");
00580           for (int i=0; i<m; ++i)
00581             cout << " " << svm_model_->probB[i];
00582           cout << endl;;
00583         }
00584 
00585         if (svm_model_->label) {
00586           MRSDEBUG("svm_model_->label =");
00587           for (int i=0; i<n; ++i)
00588             cout << " " << svm_model_->label[i];
00589           cout << endl;;
00590         }
00591 
00592         if (svm_model_->nSV) {
00593           MRSDEBUG("svm_model_->nSV =");
00594           for (int i=0; i<n; ++i)
00595             cout << " " << svm_model_->nSV[i];
00596           cout << endl;;
00597         }
00598 
00599 #endif
00600         --n;
00601         m = ctrl_SV_->to<realvec>().getCols();
00602 
00603         if (ctrl_sv_coef_->to<realvec>().getSize()) // sv_coef
00604         {
00605           svm_model_->sv_coef = Malloc(double *,n);
00606           for (int i=0; i<n; ++i)
00607             svm_model_->sv_coef[i] = Malloc(double,l);
00608           for (int i=0; i<l; ++i)
00609             for (int k=0; k<n; k++)
00610               svm_model_->sv_coef[k][i]
00611               =ctrl_sv_coef_->to<realvec>()(k, i);
00612         }
00613         svm_model_->sv_indices = Malloc(int, n);
00614         for (int i=0; i<l; i++) {
00615           // FIXME: interface with libsvm
00616         }
00617 
00618         if (ctrl_SV_->to<realvec>().getSize()) // SV
00619         {
00620           svm_model_->SV = Malloc(svm_node*,l);
00621           svm_node *x_space=NULL;
00622           if (l>0) {
00623             x_space = Malloc(svm_node, 2*m*l);
00624             num_nodes++;
00625           }
00626           int j=0;
00627           for (int i=0; i<l; ++i) {
00628             svm_model_->SV[i] = &x_space[j];
00629             for (int k = 0; k < m; k++) {
00630               x_space[j].index = k+1;
00631               x_space[j].value = ctrl_SV_->to<realvec>()(i, k);
00632               ++j;
00633             }
00634             x_space[j++].index = -1;
00635           }
00636         }
00637 
00638 #ifdef MARSYAS_LOG_DEBUGS
00639 
00640         MRSDEBUG ("svm_model_->SV = ");
00641 
00642         {
00643           for (int i=0; i<l; ++i) {
00644             for (int j=0; j<n; j++) {
00645               if (svm_model_->sv_coef)
00646                 cout <<svm_model_->sv_coef[j][i] << " ";
00647               if (svm_model_->SV) {
00648                 const svm_node *p = svm_model_->SV[i];
00649                 while (p->index != -1) {
00650                   cout << p->index << ":";
00651                   cout << p->value << " ";
00652                   p++;
00653                 }
00654               }
00655               cout << endl;;
00656             }
00657           }
00658         }
00659 #endif
00660         svm_model_->free_sv = 1;
00661         MRSDEBUG ("svm_model_->free_sv = " << svm_model_->free_sv);
00662         trained_ = true;
00663 
00664 #ifdef MARSYAS_LOG_DEBUGS
00665         {
00666           MRSDEBUG ("mini/maxi : ");
00667           realvec mini = ctrl_minimums_->to<mrs_realvec>();
00668           realvec maxi = ctrl_maximums_->to<mrs_realvec>();
00669           for (int i=0; i<inObservations_ -1; ++i)
00670             cout << "mini(" << i << ")" << mini(i) << "  maxi("
00671                  << i << ")" << maxi(i) << endl;
00672         }
00673 #endif
00674       } // if(!was_training)
00675     }// if(!trained)
00676 
00677 
00679 
00680     struct svm_node* xv = new svm_node[inObservations_];
00681     double* probs = new double[svm_model_->nr_class];
00682 
00683     // Get the minimum and maximum values to which the
00684     // training set was normalized.
00685     realvec mini = ctrl_minimums_->to<mrs_realvec>();
00686     realvec maxi = ctrl_maximums_->to<mrs_realvec>();
00687 
00688     // Scale our input the same as our trainingset
00689     for (int i=0; i<inObservations_ -1; ++i)
00690       in(i, 0) = (in(i, 0) - mini(i)) / (maxi(i) - mini(i));
00691 
00692     // Copy our input to an SV structure
00693     for (int j=0; j < inObservations_; ++j) {
00694       if (j < inObservations_ -1) {
00695         xv[j].index = j+1;
00696         xv[j].value = in(j, 0);
00697       } else {
00698         // The last index value in the SV is always set to -1
00699         xv[j].index = -1;
00700         xv[j].value = 0.0;
00701       }
00702     }
00703 
00704     double prediction = 0.0;
00705 
00706     if (ctrl_probability_->to<mrs_bool>())
00707       prediction = svm_predict_probability(svm_model_, xv, probs);
00708     else
00709       prediction = svm_predict(svm_model_, xv);
00710 
00711 
00712     // Output
00713     if (getctrl("mrs_bool/output_classPerms")->isTrue()) {
00714       for (int i=0; i < svm_model_->nr_class; ++i) {
00715         out(2 + classPerms_[i], 0) = probs[i];
00716       }
00717     }
00718 
00719     out(0,0) = (mrs_real)prediction;
00720     out(1,0) = in(inObservations_-1,0);
00721 
00722 
00723     // Cleanup
00724     delete [] xv;
00725     delete [] probs;
00726   }
00727   was_training_ = training_;
00728 }