SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MulticlassOneVsOneStrategy.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Written (W) 2013 Shell Hu and Heiko Strathmann
00009  * Copyright (C) 2012 Chiyuan Zhang
00010  */
00011 
00012 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h>
00013 #include <shogun/labels/BinaryLabels.h>
00014 #include <shogun/labels/MulticlassLabels.h>
00015 
00016 using namespace shogun;
00017 
00018 CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy()
00019     :CMulticlassStrategy(), m_num_machines(0), m_num_samples(SGVector<int32_t>())
00020 {
00021     register_parameters();
00022 }
00023 
00024 CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy(EProbHeuristicType prob_heuris)
00025     :CMulticlassStrategy(prob_heuris), m_num_machines(0), m_num_samples(SGVector<int32_t>())
00026 {
00027     register_parameters();
00028 }
00029 
00030 void CMulticlassOneVsOneStrategy::register_parameters()
00031 {
00032     //SG_ADD(&m_num_samples, "num_samples", "Number of samples in each training machine", MS_NOT_AVAILABLE);
00033     SG_WARNING("%s::CMulticlassOneVsOneStrategy(): register parameters!\n", get_name());
00034 }
00035 
00036 void CMulticlassOneVsOneStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels)
00037 {
00038     CMulticlassStrategy::train_start(orig_labels, train_labels);
00039     m_num_machines=m_num_classes*(m_num_classes-1)/2;
00040 
00041     m_train_pair_idx_1 = 0;
00042     m_train_pair_idx_2 = 1;
00043 
00044     m_num_samples.resize_vector(m_num_machines);
00045 }
00046 
00047 bool CMulticlassOneVsOneStrategy::train_has_more()
00048 {
00049     return m_train_iter < m_num_machines;
00050 }
00051 
00052 SGVector<int32_t> CMulticlassOneVsOneStrategy::train_prepare_next()
00053 {
00054     CMulticlassStrategy::train_prepare_next();
00055 
00056     SGVector<int32_t> subset(m_orig_labels->get_num_labels());
00057     int32_t tot=0;
00058     for (int32_t k=0; k < m_orig_labels->get_num_labels(); ++k)
00059     {
00060         if (((CMulticlassLabels*) m_orig_labels)->get_int_label(k)==m_train_pair_idx_1)
00061         {
00062             ((CBinaryLabels*) m_train_labels)->set_label(k, +1.0);
00063             subset[tot]=k;
00064             tot++;
00065         }
00066         else if (((CMulticlassLabels*) m_orig_labels)->get_int_label(k)==m_train_pair_idx_2)
00067         {
00068             ((CBinaryLabels*) m_train_labels)->set_label(k, -1.0);
00069             subset[tot]=k;
00070             tot++;
00071         }
00072     }
00073 
00074     m_train_pair_idx_2++;
00075     if (m_train_pair_idx_2 >= m_num_classes)
00076     {
00077         m_train_pair_idx_1++;
00078         m_train_pair_idx_2=m_train_pair_idx_1+1;
00079     }
00080 
00081     // collect num samples each machine
00082     m_num_samples[m_train_iter-1] = tot;
00083 
00084     subset.resize_vector(tot);
00085     return subset;
00086 }
00087 
00088 int32_t CMulticlassOneVsOneStrategy::decide_label(SGVector<float64_t> outputs)
00089 {
00090     // if OVO with prob outputs, find max posterior
00091     if (outputs.vlen==m_num_classes)
00092         return SGVector<float64_t>::arg_max(outputs.vector, 1, outputs.vlen);
00093 
00094     int32_t s=0;
00095     SGVector<int32_t> votes(m_num_classes);
00096     SGVector<int32_t> dec_vals(m_num_classes);
00097     votes.zero();
00098     dec_vals.zero();
00099 
00100     for (int32_t i=0; i<m_num_classes; i++)
00101     {
00102         for (int32_t j=i+1; j<m_num_classes; j++)
00103         {
00104             if (outputs[s]>0)
00105             {
00106                 votes[i]++;
00107                 dec_vals[i] += CMath::abs(outputs[s]);
00108             }
00109             else
00110             {
00111                 votes[j]++;
00112                 dec_vals[j] += CMath::abs(outputs[s]);
00113             }
00114             s++;
00115         }
00116     }
00117 
00118     int32_t i_max=0;
00119     int32_t vote_max=-1;
00120     float64_t dec_val_max=-1;
00121 
00122     for (int32_t i=0; i < m_num_classes; ++i)
00123     {
00124         if (votes[i] > vote_max)
00125         {
00126             i_max = i;
00127             vote_max = votes[i];
00128             dec_val_max = dec_vals[i];
00129         }
00130         else if (votes[i] == vote_max)
00131         {
00132             if (dec_vals[i] > dec_val_max)
00133             {
00134                 i_max = i;
00135                 dec_val_max = dec_vals[i];
00136             }
00137         }
00138     }
00139 
00140     return i_max;
00141 }
00142 
00143 void CMulticlassOneVsOneStrategy::rescale_outputs(SGVector<float64_t> outputs)
00144 {
00145     if (m_num_machines < 1)
00146         return;
00147 
00148     SGVector<int32_t> indx1(m_num_machines);
00149     SGVector<int32_t> indx2(m_num_machines);
00150 
00151     int32_t tot = 0;
00152     for (int32_t j=0; j<m_num_classes; j++)
00153     {
00154         for (int32_t k=j+1; k<m_num_classes; k++)
00155         {
00156             indx1[tot] = j;
00157             indx2[tot] = k;
00158             tot++;
00159         }
00160     }
00161 
00162     if(tot!=m_num_machines)
00163         SG_ERROR("%s::rescale_output(): size(outputs) is not num_machines.\n", get_name());
00164 
00165     switch(get_prob_heuris_type())
00166     {
00167         case OVO_PRICE:
00168             rescale_heuris_price(outputs,indx1,indx2);
00169             break;
00170         case OVO_HASTIE:
00171             rescale_heuris_hastie(outputs,indx1,indx2);
00172             break;
00173         case OVO_HAMAMURA:
00174             rescale_heuris_hamamura(outputs,indx1,indx2);
00175             break;
00176         case PROB_HEURIS_NONE:
00177             break;
00178         default:
00179             SG_ERROR("%s::rescale_outputs(): Unknown OVO probability heuristic type!\n", get_name());
00180             break;
00181     }
00182 }
00183 
00184 void CMulticlassOneVsOneStrategy::rescale_heuris_price(SGVector<float64_t> outputs,
00185         const SGVector<int32_t> indx1, const SGVector<int32_t> indx2)
00186 {
00187     if (m_num_machines != outputs.vlen)
00188     {
00189         SG_ERROR("%s::rescale_heuris_price(): size(outputs) = %d != m_num_machines = %d\n",
00190                 get_name(), outputs.vlen, m_num_machines);
00191     }
00192 
00193     SGVector<float64_t> new_outputs(m_num_classes);
00194     new_outputs.zero();
00195 
00196     for (int32_t j=0; j<m_num_classes; j++)
00197     {
00198         for (int32_t m=0; m<m_num_machines; m++)
00199         {
00200             if (indx1[m]==j)
00201                 new_outputs[j] += 1.0 / (outputs[m]+1E-12);
00202             if (indx2[m]==j)
00203                 new_outputs[j] += 1.0 / (1.0-outputs[m]+1E-12);
00204         }
00205 
00206         new_outputs[j] = 1.0 / (new_outputs[j] - m_num_classes + 2);
00207     }
00208 
00209     //outputs.resize_vector(m_num_classes);
00210 
00211     float64_t norm = SGVector<float64_t>::sum(new_outputs);
00212     for (int32_t i=0; i<new_outputs.vlen; i++)
00213         outputs[i] = new_outputs[i] / norm;
00214 }
00215 
00216 void CMulticlassOneVsOneStrategy::rescale_heuris_hastie(SGVector<float64_t> outputs,
00217         const SGVector<int32_t> indx1, const SGVector<int32_t> indx2)
00218 {
00219     if (m_num_machines != outputs.vlen)
00220     {
00221         SG_ERROR("%s::rescale_heuris_hastie(): size(outputs) = %d != m_num_machines = %d\n",
00222                 get_name(), outputs.vlen, m_num_machines);
00223     }
00224 
00225     SGVector<float64_t> new_outputs(m_num_classes);
00226     new_outputs.zero();
00227 
00228     for (int32_t j=0; j<m_num_classes; j++)
00229     {
00230         for (int32_t m=0; m<m_num_machines; m++)
00231         {
00232             if (indx1[m]==j)
00233                 new_outputs[j] += outputs[m];
00234             if (indx2[m]==j)
00235                 new_outputs[j] += 1.0-outputs[m];
00236         }
00237 
00238         new_outputs[j] *= 2.0 / (m_num_classes * (m_num_classes - 1));
00239         new_outputs[j] += 1E-10;
00240     }
00241 
00242     SGVector<float64_t> mu(m_num_machines);
00243     SGVector<float64_t> prev_outputs(m_num_classes);
00244     float64_t gap = 1.0;
00245 
00246     while (gap > 1E-12)
00247     {
00248         prev_outputs = new_outputs.clone();
00249 
00250         for (int32_t m=0; m<m_num_machines; m++)
00251             mu[m] = new_outputs[indx1[m]] / (new_outputs[indx1[m]] + new_outputs[indx2[m]]);
00252 
00253         for (int32_t j=0; j<m_num_classes; j++)
00254         {
00255             float64_t numerator = 0.0;
00256             float64_t denominator = 0.0;
00257             for (int32_t m=0; m<m_num_machines; m++)
00258             {
00259                 if (indx1[m]==j)
00260                 {
00261                     numerator += m_num_samples[m] * outputs[m];
00262                     denominator += m_num_samples[m] * mu[m];
00263                 }
00264 
00265                 if (indx2[m]==j)
00266                 {
00267                     numerator += m_num_samples[m] * (1.0-outputs[m]);
00268                     denominator += m_num_samples[m] * (1.0-mu[m]);
00269                 }
00270             }
00271 
00272             // update posterior
00273             new_outputs[j] *= numerator / denominator;
00274         }
00275 
00276         float64_t norm = SGVector<float64_t>::sum(new_outputs);
00277         for (int32_t i=0; i<new_outputs.vlen; i++)
00278             new_outputs[i] /= norm;
00279 
00280         // gap is Euclidean distance
00281         for (int32_t i=0; i<new_outputs.vlen; i++)
00282             prev_outputs[i] -= new_outputs[i];
00283 
00284         gap = SGVector<float64_t>::qsq(prev_outputs.vector, prev_outputs.vlen, 2);
00285         SG_DEBUG("[Hastie's heuristic] gap = %.12f\n", gap);
00286     }
00287 
00288     for (int32_t i=0; i<new_outputs.vlen; i++)
00289         outputs[i] = new_outputs[i];
00290 }
00291 
00292 void CMulticlassOneVsOneStrategy::rescale_heuris_hamamura(SGVector<float64_t> outputs,
00293         const SGVector<int32_t> indx1, const SGVector<int32_t> indx2)
00294 {
00295     if (m_num_machines != outputs.vlen)
00296     {
00297         SG_ERROR("%s::rescale_heuris_hamamura(): size(outputs) = %d != m_num_machines = %d\n",
00298                 get_name(), outputs.vlen, m_num_machines);
00299     }
00300 
00301     SGVector<float64_t> new_outputs(m_num_classes);
00302     SGVector<float64_t>::fill_vector(new_outputs.vector, new_outputs.vlen, 1.0);
00303 
00304     for (int32_t j=0; j<m_num_classes; j++)
00305     {
00306         for (int32_t m=0; m<m_num_machines; m++)
00307         {
00308             if (indx1[m]==j)
00309                 new_outputs[j] *= outputs[m];
00310             if (indx2[m]==j)
00311                 new_outputs[j] *= 1-outputs[m];
00312         }
00313 
00314         new_outputs[j] += 1E-10;
00315     }
00316 
00317     float64_t norm = SGVector<float64_t>::sum(new_outputs);
00318 
00319     for (int32_t i=0; i<new_outputs.vlen; i++)
00320         outputs[i] = new_outputs[i] / norm;
00321 }
00322 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation