Marsyas  0.6.0-alpha
/usr/src/RPM/BUILD/marsyas-0.6.0/src/marsyas/marsystems/BeatPhase.cpp
Go to the documentation of this file.
00001 /*
00002 ** Copyright (C) 1998-2010 George Tzanetakis <gtzan@cs.uvic.ca>
00003 **
00004 ** This program is free software; you can redistribute it and/or modify
00005 ** it under the terms of the GNU General Public License as published by
00006 ** the Free Software Foundation; either version 2 of the License, or
00007 ** (at your option) any later version.
00008 **
00009 ** This program is distributed in the hope that it will be useful,
00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00012 ** GNU General Public License for more details.
00013 **
00014 ** You should have received a copy of the GNU General Public License
00015 ** along with this program; if not, write to the Free Software
00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00017 */
00018 
00019 #include "BeatPhase.h"
00020 #include "../common_source.h"
00021 
00022 using std::ostringstream;
00023 using std::cout;
00024 using std::endl;
00025 
00026 using namespace Marsyas;
00027 
00028 BeatPhase::BeatPhase(mrs_string name):MarSystem("BeatPhase", name)
00029 {
00030   addControls();
00031   sampleCount_ = 0;
00032   current_beat_location_ = 0.0;
00033   pinSamples_ = 0;
00034 
00035 
00036 }
00037 
00038 BeatPhase::BeatPhase(const BeatPhase& a) : MarSystem(a)
00039 {
00040   ctrl_tempo_candidates_ = getctrl("mrs_realvec/tempo_candidates");
00041 
00042   ctrl_tempos_ = getctrl("mrs_realvec/tempos");
00043   ctrl_temposcores_ = getctrl("mrs_realvec/tempo_scores");
00044   ctrl_phase_tempo_ = getctrl("mrs_real/phase_tempo");
00045   ctrl_ground_truth_tempo_ = getctrl("mrs_real/ground_truth_tempo");
00046   ctrl_beats_ = getctrl("mrs_realvec/beats");
00047   ctrl_bhopSize_ = getctrl("mrs_natural/bhopSize");
00048   ctrl_bwinSize_ = getctrl("mrs_natural/bwinSize");
00049   ctrl_timeDomain_ = getctrl("mrs_realvec/timeDomain");
00050   ctrl_nCandidates_ = getctrl("mrs_natural/nCandidates");
00051   ctrl_beatOutput_ = getctrl("mrs_realvec/beatOutput");
00052   ctrl_factor_ = getctrl("mrs_real/factor");
00053 
00054   sampleCount_ = 0;
00055   current_beat_location_ = 0.0;
00056   pinSamples_ = 0;
00057 
00058 }
00059 
00060 BeatPhase::~BeatPhase()
00061 {
00062 }
00063 
00064 MarSystem*
00065 BeatPhase::clone() const
00066 {
00067   return new BeatPhase(*this);
00068 }
00069 
00070 void
00071 BeatPhase::addControls()
00072 {
00073   mrs_natural nCandidates = 8;
00074 
00075   //Add specific controls needed by this MarSystem.
00076   addctrl("mrs_realvec/tempo_candidates", realvec(nCandidates), ctrl_tempo_candidates_);
00077   addctrl("mrs_realvec/tempos", realvec(nCandidates), ctrl_tempos_);
00078   addctrl("mrs_realvec/tempo_scores", realvec(nCandidates), ctrl_temposcores_);
00079 
00080   addctrl("mrs_real/phase_tempo", 100.0, ctrl_phase_tempo_);
00081   addctrl("mrs_real/ground_truth_tempo", 100.0, ctrl_ground_truth_tempo_);
00082   addctrl("mrs_realvec/beats", realvec(), ctrl_beats_);
00083   addctrl("mrs_natural/bhopSize", 64, ctrl_bhopSize_);
00084   addctrl("mrs_natural/bwinSize", 1024, ctrl_bwinSize_);
00085   addctrl("mrs_realvec/timeDomain", realvec(), ctrl_timeDomain_);
00086   addctrl("mrs_natural/nCandidates", nCandidates, ctrl_nCandidates_);
00087   setctrlState("mrs_natural/nCandidates", true);
00088   addctrl("mrs_realvec/beatOutput", realvec(), ctrl_beatOutput_);
00089   addctrl("mrs_real/factor", 4.0, ctrl_factor_);
00090 }
00091 
00092 void
00093 BeatPhase::myUpdate(MarControlPtr sender)
00094 {
00095   // no need to do anything BeatPhase-specific in myUpdate
00096   MarSystem::myUpdate(sender);
00097 
00098 
00099   inSamples_ = getctrl("mrs_natural/inSamples")->to<mrs_natural>();
00100   mrs_natural nCandidates = getctrl("mrs_natural/nCandidates")->to<mrs_natural>();
00101   factor_ = getctrl("mrs_real/factor")->to<mrs_real>();
00102 
00103 
00104   MarControlAccessor acc_t(ctrl_tempos_);
00105   mrs_realvec& tempos = acc_t.to<mrs_realvec>();
00106   tempos.stretch(nCandidates);
00107 
00108   MarControlAccessor acc_ts(ctrl_temposcores_);
00109   mrs_realvec& temposcores = acc_ts.to<mrs_realvec>();
00110   temposcores.stretch(nCandidates);
00111 
00112   MarControlAccessor acc_tc(ctrl_tempo_candidates_);
00113   mrs_realvec& tempocandidates = acc_tc.to<mrs_realvec>();
00114   tempocandidates.stretch(nCandidates * 2);
00115 
00116 
00117 
00118   if (pinSamples_ != inSamples_)
00119   {
00120     {
00121       MarControlAccessor acc(ctrl_beats_);
00122       mrs_realvec& beats = acc.to<mrs_realvec>();
00123       beats.create(inSamples_);
00124 
00125       // Output all the beats that are detected via a MarControl
00126       MarControlAccessor beatOutputAcc(ctrl_beatOutput_);
00127       mrs_realvec& beatOutput = beatOutputAcc.to<mrs_realvec>();
00128       beatOutput.create(inSamples_);
00129     }
00130   }
00131 
00132   pinSamples_ = inSamples_;
00133 
00134 
00135 
00136 }
00137 
00138 
00139 void
00140 BeatPhase::myProcess(realvec& in, realvec& out)
00141 {
00142   mrs_natural o,t;
00143 
00144 
00145 
00146   // mrs_real ground_truth_tempo = ctrl_ground_truth_tempo_->to<mrs_real>();
00147   // used for evaluation experiments
00148 
00149 
00150   // The tempo candidates and their scores
00151   MarControlAccessor acctc(ctrl_tempo_candidates_);
00152   mrs_realvec& tempo_candidates = acctc.to<mrs_realvec>();
00153 
00154 
00155   MarControlAccessor acct(ctrl_tempos_);
00156   mrs_realvec& tempos = acct.to<mrs_realvec>();
00157   MarControlAccessor accts(ctrl_temposcores_);
00158   mrs_realvec& tempo_scores = accts.to<mrs_realvec>();
00159 
00160 
00161   // Demultiplex candidates and scores
00162   for (int i=0; i < tempo_candidates.getSize()/2; i++)
00163   {
00164     tempos(i) = tempo_candidates(2*i+1) / factor_;
00165     tempo_scores(i) = tempo_candidates(2*i);
00166   }
00167 
00168   // normalize to pdf
00169   tempo_scores /= tempo_scores.sum();
00170 
00171   // holds the tempo scores based on cross-correlation with pulse train
00172   mrs_realvec onset_scores;
00173   onset_scores.create(tempo_scores.getSize());
00174   // holds the best matching phase for each tempo candidate
00175   mrs_realvec tempo_phases;
00176   tempo_phases.create(tempo_scores.getSize());
00177 
00178   /*
00179   // make sure the tempo candidates are reasonable
00180   for (int k=0; k < tempos.getSize(); k++)
00181   {
00182     if (tempos(k) < 50.0)
00183     tempos(k) = 0;
00184     if (tempos(k) > 200)
00185     tempos(k) = 0;
00186 
00187   }
00188   */
00189 
00190   // The winSize and hopSize of the onset strength function
00191   // needed to output correct beat location times
00192   // mrs_natural bwinSize = ctrl_bwinSize_->to<mrs_natural>();
00193   mrs_natural bhopSize = ctrl_bhopSize_->to<mrs_natural>();
00194 
00195   MarControlAccessor acc(ctrl_beats_);
00196   mrs_realvec& beats = acc.to<mrs_realvec>();
00197 
00198 
00199   for (o=0; o < inObservations_; o++)
00200   {
00201     for (t = 0; t < inSamples_; t++)
00202     {
00203       out (o,t) = in(o,t);
00204       beats (o,t) = 0.0;
00205     }
00206   }
00207 
00208   mrs_real tempo;
00209   mrs_real period;
00210   mrs_natural phase;
00211   mrs_real cross_correlation = 0.0;
00212   mrs_real max_crco=0.0;
00213   mrs_natural max_phase = 0;
00214   mrs_realvec phase_correlations;
00215 
00216 
00217   // loop for cross-correlating
00218   // pulse trains with onset strength function
00219   // for each tempo shift the pulse train until the best match is found
00220   // tempo_scores holds the values of the cross-correlation
00221   // and onset_scores holds the variance of cross-correlation among different phases
00222   // for a particular tempo candidate
00223   for (o=0; o < inObservations_; o++)
00224   {
00225     for (int k=0; k < tempos.getSize(); k++)
00226     {
00227       max_crco = 0.0;
00228 
00229       tempo = tempos(k);
00230       period = 2.0 * osrate_ * 60.0 / tempo; // flux hopSize is half the winSize
00231       mrs_natural period_int = (mrs_natural)(period+0.5);
00232 
00233 
00234       if (period_int > 1)
00235       {
00236         //cout<<tempo<<"\t"<<period_int;
00237         phase_correlations.create( period_int );
00238 
00239         for (phase=inSamples_-1; phase > inSamples_-1-period_int; phase--)
00240         {
00241           cross_correlation = 0.0;
00242           // correlate with pulse train with half-beats and double beats
00243           for (int b=0; b < 4; b++)
00244           {
00245             mrs_natural temp_t;
00246             temp_t = phase - b * period_int;
00247 
00248             // 4 beats
00249             if (temp_t >= 0) {
00250               cross_correlation += in(o, temp_t);
00251               //cout<<"\t"<<temp_t;
00252             }
00253 
00254             // slow down by 2.0
00255             temp_t = phase - b * period_int * 2;
00256             if (temp_t >= 0) {
00257               cross_correlation += 0.5 * in(o, temp_t);
00258               //cout<<"\t"<<temp_t;
00259             }
00260 
00261             // slow down by 3
00262             temp_t = phase - b * period_int * 3 / 2;
00263             if (temp_t >= 0) {
00264               cross_correlation += 0.5 * in(o, temp_t);
00265               //cout<<"\t"<<temp_t;
00266             }
00267             //cout<<endl;
00268 
00269 
00270           }
00271 
00272           // quarter beats
00273           /* for (int b = 0; b < 8; b++)
00274           {
00275             mrs_natural temp_t;
00276             temp_t = phase - b * 0.25 * period;
00277             if (temp_t >= 0)
00278             {
00279               cross_correlation += (0.2 * in(o, temp_t));
00280             }
00281           }
00282           */
00283 
00284           phase_correlations(inSamples_-1-phase) = cross_correlation;
00285           if (cross_correlation > max_crco)
00286           {
00287             max_crco = cross_correlation;
00288             max_phase = phase;
00289           }
00290 
00291         }
00292         onset_scores(k) = phase_correlations.var();
00293         tempo_scores(k) = max_crco;
00294         tempo_phases(k) = max_phase;
00295         beats.setval(0.0);
00296 
00297         //printf("\t%f\t%f", max_crco, onset_scores(k));
00298         //cout<<endl;
00299       }
00300     }
00301   }
00302 
00303 
00304   // renormalize scores
00305   onset_scores /= onset_scores.sum();
00306   tempo_scores /= tempo_scores.sum();
00307 
00308   // combine the cross-correlation score with the variance
00309   for (int i=0; i < tempo_scores.getSize(); i++)
00310     tempo_scores(i) = onset_scores(i) + tempo_scores(i);
00311 
00312   // renormalize
00313   tempo_scores /= tempo_scores.sum();
00314 
00315   // pick the maximum scoring tempo candidate
00316   mrs_real max_score = 0.0;
00317   int max_i=0;
00318   for (int i= 0; i < tempos.getSize(); i++)
00319   {
00320     if (tempo_scores(i) > max_score)
00321     {
00322       max_score = tempo_scores(i);
00323       max_i = i;
00324     }
00325   }
00326 
00327   // return the best tempo candidate and the
00328   // corresponding score in both the tempo vector
00329   // as well as the control phase_tempo
00330   mrs_real swap_tempo = tempos(0);
00331   mrs_real swap_score = tempo_scores(0);
00332   tempos(0) = tempos(max_i);
00333   tempo_scores(0) = tempo_scores(max_i);
00334 
00335   if (max_i != 0) {
00336     tempos(max_i) = swap_tempo;
00337     tempo_scores(max_i) = swap_score;
00338   }
00339 
00340   max_score = 0.0;
00341   max_i=0;
00342   for (int i= 1; i < tempos.getSize(); i++)
00343   {
00344     if (tempo_scores(i) > max_score)
00345     {
00346       max_score = tempo_scores(i);
00347       max_i = i;
00348     }
00349   }
00350 
00351   swap_tempo = tempos(1);
00352   swap_score = tempo_scores(1);
00353   tempos(1) = tempos(max_i);
00354   tempo_scores(1) = tempo_scores(max_i);
00355   if (max_i != 1) {
00356     tempos(max_i) = swap_tempo;
00357     tempo_scores(max_i) = swap_score;
00358   }
00359 
00360   ctrl_phase_tempo_->setValue(tempos(max_i));
00361 
00362 
00363   // select a tempo for the beat locations
00364   // with doubling heuristic if tempo < 75 BPM
00365   tempo = tempos(0);
00366   if (tempo < 70.0)
00367     tempo = tempo * 2;
00368   mrs_real beat_length = 60.0 / tempo;
00369 
00370   if (tempo >= 50)
00371     period = 2.0 * osrate_ * 60.0 / tempo; // flux hopSize is half the winSize
00372   else
00373     period = 0;
00374 
00375   period = (mrs_natural)(period+0.5);
00376   // Place the beats in the right location in the onset detection function
00377   for (int b=0; b < 4; b++) {
00378     mrs_natural temp_t = (mrs_natural) (tempo_phases(max_i) - b * period);
00379     if (temp_t >= 0) {
00380       beats(0,temp_t) = -0.5;
00381     }
00382   }
00383 
00384 
00385   //mrs_natural prev_sample_count;
00386   //prev_sample_count = sampleCount_;
00387 
00388   // Output all the detected beats to it's own MarControl
00389   int total_beats = 0;
00390   MarControlAccessor beatOutputAcc(ctrl_beatOutput_);
00391   mrs_realvec& beatOutput = beatOutputAcc.to<mrs_realvec>();
00392   for (t = 0; t < inSamples_; t++)
00393   {
00394     beatOutput(t) = 0.0;
00395   }
00396 
00397   // output the beats
00398   mrs_real beat_location;
00399   for (int t = inSamples_-1-2*bhopSize; t < inSamples_; t++)
00400   {
00401     if (beats(0,t) == -0.5)
00402     {
00403       beat_location = (sampleCount_ + t -(inSamples_-1 -bhopSize)) / (2.0 * osrate_);
00404       if ((beat_location > current_beat_location_)&&((beat_location - current_beat_location_) > beat_length * 0.75))
00405       {
00406         if (ctrl_verbose_->isTrue()) {
00407           MRSMSG(beat_location << "\t"
00408                  << beat_location + 0.02 << " b");
00409         }
00410         beatOutput(total_beats) = beat_location;
00411         current_beat_location_ = beat_location;
00412         total_beats++;
00413       }
00414     }
00415   }
00416   sampleCount_ += bhopSize;
00417 }