SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Written (W) 2013 Shell Hu and Heiko Strathmann 00009 * Copyright (C) 2012 Chiyuan Zhang 00010 */ 00011 00012 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h> 00013 #include <shogun/labels/BinaryLabels.h> 00014 #include <shogun/labels/MulticlassLabels.h> 00015 00016 using namespace shogun; 00017 00018 CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy() 00019 :CMulticlassStrategy(), m_num_machines(0), m_num_samples(SGVector<int32_t>()) 00020 { 00021 register_parameters(); 00022 } 00023 00024 CMulticlassOneVsOneStrategy::CMulticlassOneVsOneStrategy(EProbHeuristicType prob_heuris) 00025 :CMulticlassStrategy(prob_heuris), m_num_machines(0), m_num_samples(SGVector<int32_t>()) 00026 { 00027 register_parameters(); 00028 } 00029 00030 void CMulticlassOneVsOneStrategy::register_parameters() 00031 { 00032 //SG_ADD(&m_num_samples, "num_samples", "Number of samples in each training machine", MS_NOT_AVAILABLE); 00033 SG_WARNING("%s::CMulticlassOneVsOneStrategy(): register parameters!\n", get_name()); 00034 } 00035 00036 void CMulticlassOneVsOneStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels) 00037 { 00038 CMulticlassStrategy::train_start(orig_labels, train_labels); 00039 m_num_machines=m_num_classes*(m_num_classes-1)/2; 00040 00041 m_train_pair_idx_1 = 0; 00042 m_train_pair_idx_2 = 1; 00043 00044 m_num_samples.resize_vector(m_num_machines); 00045 } 00046 00047 bool CMulticlassOneVsOneStrategy::train_has_more() 00048 { 00049 return m_train_iter < m_num_machines; 00050 } 00051 00052 SGVector<int32_t> CMulticlassOneVsOneStrategy::train_prepare_next() 00053 { 00054 CMulticlassStrategy::train_prepare_next(); 00055 00056 SGVector<int32_t> subset(m_orig_labels->get_num_labels()); 00057 int32_t tot=0; 00058 for (int32_t k=0; k < m_orig_labels->get_num_labels(); ++k) 00059 { 00060 if (((CMulticlassLabels*) m_orig_labels)->get_int_label(k)==m_train_pair_idx_1) 00061 { 00062 ((CBinaryLabels*) m_train_labels)->set_label(k, +1.0); 00063 subset[tot]=k; 00064 tot++; 00065 } 00066 else if (((CMulticlassLabels*) m_orig_labels)->get_int_label(k)==m_train_pair_idx_2) 00067 { 00068 ((CBinaryLabels*) m_train_labels)->set_label(k, -1.0); 00069 subset[tot]=k; 00070 tot++; 00071 } 00072 } 00073 00074 m_train_pair_idx_2++; 00075 if (m_train_pair_idx_2 >= m_num_classes) 00076 { 00077 m_train_pair_idx_1++; 00078 m_train_pair_idx_2=m_train_pair_idx_1+1; 00079 } 00080 00081 // collect num samples each machine 00082 m_num_samples[m_train_iter-1] = tot; 00083 00084 subset.resize_vector(tot); 00085 return subset; 00086 } 00087 00088 int32_t CMulticlassOneVsOneStrategy::decide_label(SGVector<float64_t> outputs) 00089 { 00090 // if OVO with prob outputs, find max posterior 00091 if (outputs.vlen==m_num_classes) 00092 return SGVector<float64_t>::arg_max(outputs.vector, 1, outputs.vlen); 00093 00094 int32_t s=0; 00095 SGVector<int32_t> votes(m_num_classes); 00096 SGVector<int32_t> dec_vals(m_num_classes); 00097 votes.zero(); 00098 dec_vals.zero(); 00099 00100 for (int32_t i=0; i<m_num_classes; i++) 00101 { 00102 for (int32_t j=i+1; j<m_num_classes; j++) 00103 { 00104 if (outputs[s]>0) 00105 { 00106 votes[i]++; 00107 dec_vals[i] += CMath::abs(outputs[s]); 00108 } 00109 else 00110 { 00111 votes[j]++; 00112 dec_vals[j] += CMath::abs(outputs[s]); 00113 } 00114 s++; 00115 } 00116 } 00117 00118 int32_t i_max=0; 00119 int32_t vote_max=-1; 00120 float64_t dec_val_max=-1; 00121 00122 for (int32_t i=0; i < m_num_classes; ++i) 00123 { 00124 if (votes[i] > vote_max) 00125 { 00126 i_max = i; 00127 vote_max = votes[i]; 00128 dec_val_max = dec_vals[i]; 00129 } 00130 else if (votes[i] == vote_max) 00131 { 00132 if (dec_vals[i] > dec_val_max) 00133 { 00134 i_max = i; 00135 dec_val_max = dec_vals[i]; 00136 } 00137 } 00138 } 00139 00140 return i_max; 00141 } 00142 00143 void CMulticlassOneVsOneStrategy::rescale_outputs(SGVector<float64_t> outputs) 00144 { 00145 if (m_num_machines < 1) 00146 return; 00147 00148 SGVector<int32_t> indx1(m_num_machines); 00149 SGVector<int32_t> indx2(m_num_machines); 00150 00151 int32_t tot = 0; 00152 for (int32_t j=0; j<m_num_classes; j++) 00153 { 00154 for (int32_t k=j+1; k<m_num_classes; k++) 00155 { 00156 indx1[tot] = j; 00157 indx2[tot] = k; 00158 tot++; 00159 } 00160 } 00161 00162 if(tot!=m_num_machines) 00163 SG_ERROR("%s::rescale_output(): size(outputs) is not num_machines.\n", get_name()); 00164 00165 switch(get_prob_heuris_type()) 00166 { 00167 case OVO_PRICE: 00168 rescale_heuris_price(outputs,indx1,indx2); 00169 break; 00170 case OVO_HASTIE: 00171 rescale_heuris_hastie(outputs,indx1,indx2); 00172 break; 00173 case OVO_HAMAMURA: 00174 rescale_heuris_hamamura(outputs,indx1,indx2); 00175 break; 00176 case PROB_HEURIS_NONE: 00177 break; 00178 default: 00179 SG_ERROR("%s::rescale_outputs(): Unknown OVO probability heuristic type!\n", get_name()); 00180 break; 00181 } 00182 } 00183 00184 void CMulticlassOneVsOneStrategy::rescale_heuris_price(SGVector<float64_t> outputs, 00185 const SGVector<int32_t> indx1, const SGVector<int32_t> indx2) 00186 { 00187 if (m_num_machines != outputs.vlen) 00188 { 00189 SG_ERROR("%s::rescale_heuris_price(): size(outputs) = %d != m_num_machines = %d\n", 00190 get_name(), outputs.vlen, m_num_machines); 00191 } 00192 00193 SGVector<float64_t> new_outputs(m_num_classes); 00194 new_outputs.zero(); 00195 00196 for (int32_t j=0; j<m_num_classes; j++) 00197 { 00198 for (int32_t m=0; m<m_num_machines; m++) 00199 { 00200 if (indx1[m]==j) 00201 new_outputs[j] += 1.0 / (outputs[m]+1E-12); 00202 if (indx2[m]==j) 00203 new_outputs[j] += 1.0 / (1.0-outputs[m]+1E-12); 00204 } 00205 00206 new_outputs[j] = 1.0 / (new_outputs[j] - m_num_classes + 2); 00207 } 00208 00209 //outputs.resize_vector(m_num_classes); 00210 00211 float64_t norm = SGVector<float64_t>::sum(new_outputs); 00212 for (int32_t i=0; i<new_outputs.vlen; i++) 00213 outputs[i] = new_outputs[i] / norm; 00214 } 00215 00216 void CMulticlassOneVsOneStrategy::rescale_heuris_hastie(SGVector<float64_t> outputs, 00217 const SGVector<int32_t> indx1, const SGVector<int32_t> indx2) 00218 { 00219 if (m_num_machines != outputs.vlen) 00220 { 00221 SG_ERROR("%s::rescale_heuris_hastie(): size(outputs) = %d != m_num_machines = %d\n", 00222 get_name(), outputs.vlen, m_num_machines); 00223 } 00224 00225 SGVector<float64_t> new_outputs(m_num_classes); 00226 new_outputs.zero(); 00227 00228 for (int32_t j=0; j<m_num_classes; j++) 00229 { 00230 for (int32_t m=0; m<m_num_machines; m++) 00231 { 00232 if (indx1[m]==j) 00233 new_outputs[j] += outputs[m]; 00234 if (indx2[m]==j) 00235 new_outputs[j] += 1.0-outputs[m]; 00236 } 00237 00238 new_outputs[j] *= 2.0 / (m_num_classes * (m_num_classes - 1)); 00239 new_outputs[j] += 1E-10; 00240 } 00241 00242 SGVector<float64_t> mu(m_num_machines); 00243 SGVector<float64_t> prev_outputs(m_num_classes); 00244 float64_t gap = 1.0; 00245 00246 while (gap > 1E-12) 00247 { 00248 prev_outputs = new_outputs.clone(); 00249 00250 for (int32_t m=0; m<m_num_machines; m++) 00251 mu[m] = new_outputs[indx1[m]] / (new_outputs[indx1[m]] + new_outputs[indx2[m]]); 00252 00253 for (int32_t j=0; j<m_num_classes; j++) 00254 { 00255 float64_t numerator = 0.0; 00256 float64_t denominator = 0.0; 00257 for (int32_t m=0; m<m_num_machines; m++) 00258 { 00259 if (indx1[m]==j) 00260 { 00261 numerator += m_num_samples[m] * outputs[m]; 00262 denominator += m_num_samples[m] * mu[m]; 00263 } 00264 00265 if (indx2[m]==j) 00266 { 00267 numerator += m_num_samples[m] * (1.0-outputs[m]); 00268 denominator += m_num_samples[m] * (1.0-mu[m]); 00269 } 00270 } 00271 00272 // update posterior 00273 new_outputs[j] *= numerator / denominator; 00274 } 00275 00276 float64_t norm = SGVector<float64_t>::sum(new_outputs); 00277 for (int32_t i=0; i<new_outputs.vlen; i++) 00278 new_outputs[i] /= norm; 00279 00280 // gap is Euclidean distance 00281 for (int32_t i=0; i<new_outputs.vlen; i++) 00282 prev_outputs[i] -= new_outputs[i]; 00283 00284 gap = SGVector<float64_t>::qsq(prev_outputs.vector, prev_outputs.vlen, 2); 00285 SG_DEBUG("[Hastie's heuristic] gap = %.12f\n", gap); 00286 } 00287 00288 for (int32_t i=0; i<new_outputs.vlen; i++) 00289 outputs[i] = new_outputs[i]; 00290 } 00291 00292 void CMulticlassOneVsOneStrategy::rescale_heuris_hamamura(SGVector<float64_t> outputs, 00293 const SGVector<int32_t> indx1, const SGVector<int32_t> indx2) 00294 { 00295 if (m_num_machines != outputs.vlen) 00296 { 00297 SG_ERROR("%s::rescale_heuris_hamamura(): size(outputs) = %d != m_num_machines = %d\n", 00298 get_name(), outputs.vlen, m_num_machines); 00299 } 00300 00301 SGVector<float64_t> new_outputs(m_num_classes); 00302 SGVector<float64_t>::fill_vector(new_outputs.vector, new_outputs.vlen, 1.0); 00303 00304 for (int32_t j=0; j<m_num_classes; j++) 00305 { 00306 for (int32_t m=0; m<m_num_machines; m++) 00307 { 00308 if (indx1[m]==j) 00309 new_outputs[j] *= outputs[m]; 00310 if (indx2[m]==j) 00311 new_outputs[j] *= 1-outputs[m]; 00312 } 00313 00314 new_outputs[j] += 1E-10; 00315 } 00316 00317 float64_t norm = SGVector<float64_t>::sum(new_outputs); 00318 00319 for (int32_t i=0; i<new_outputs.vlen; i++) 00320 outputs[i] = new_outputs[i] / norm; 00321 } 00322