SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <algorithm> 00012 00013 #include <shogun/mathematics/Math.h> 00014 #include <shogun/labels/BinaryLabels.h> 00015 #include <shogun/labels/MulticlassLabels.h> 00016 #include <shogun/multiclass/ecoc/ECOCDiscriminantEncoder.h> 00017 00018 using namespace std; 00019 using namespace shogun; 00020 00021 CECOCDiscriminantEncoder::CECOCDiscriminantEncoder() 00022 { 00023 init(); 00024 } 00025 00026 CECOCDiscriminantEncoder::~CECOCDiscriminantEncoder() 00027 { 00028 SG_UNREF(m_features); 00029 SG_UNREF(m_labels); 00030 } 00031 00032 void CECOCDiscriminantEncoder::init() 00033 { 00034 // default parameters 00035 m_iterations = 25; 00036 m_num_trees = 1; 00037 00038 // init values 00039 m_features = NULL; 00040 m_labels = NULL; 00041 00042 // parameters 00043 00044 SG_ADD(&m_iterations, "iterations", "number of iterations in SFFS", MS_NOT_AVAILABLE); 00045 } 00046 00047 void CECOCDiscriminantEncoder::set_features(CDenseFeatures<float64_t> *features) 00048 { 00049 SG_REF(features); 00050 SG_UNREF(m_features); 00051 m_features = features; 00052 } 00053 00054 void CECOCDiscriminantEncoder::set_labels(CLabels *labels) 00055 { 00056 SG_REF(labels); 00057 SG_UNREF(m_labels); 00058 m_labels = labels; 00059 } 00060 00061 SGMatrix<int32_t> CECOCDiscriminantEncoder::create_codebook(int32_t num_classes) 00062 { 00063 if (!m_features || !m_labels) 00064 SG_ERROR("Need features and labels to learn the codebook") 00065 00066 m_feats = m_features->get_feature_matrix(); 00067 m_codebook = SGMatrix<int32_t>(m_num_trees * (num_classes-1), num_classes); 00068 m_codebook.zero(); 00069 m_code_idx = 0; 00070 00071 for (int32_t itree = 0; itree < m_num_trees; ++itree) 00072 { 00073 vector<int32_t> classes(num_classes); 00074 for (int32_t i=0; i < num_classes; ++i) 00075 classes[i] = i; 00076 00077 binary_partition(classes); 00078 } 00079 00080 m_feats = SGMatrix<float64_t>(); // release memory 00081 return m_codebook; 00082 } 00083 00084 void CECOCDiscriminantEncoder::binary_partition(const vector<int32_t>& classes) 00085 { 00086 if (classes.size() > 2) 00087 { 00088 int32_t isplit = classes.size()/2; 00089 vector<int32_t> part1(classes.begin(), classes.begin()+isplit); 00090 vector<int32_t> part2(classes.begin()+isplit, classes.end()); 00091 run_sffs(part1, part2); 00092 for (size_t i=0; i < part1.size(); ++i) 00093 m_codebook(m_code_idx, part1[i]) = +1; 00094 for (size_t i=0; i < part2.size(); ++i) 00095 m_codebook(m_code_idx, part2[i]) = -1; 00096 m_code_idx++; 00097 00098 if (part1.size() > 1) 00099 binary_partition(part1); 00100 if (part2.size() > 1) 00101 binary_partition(part2); 00102 } 00103 else // only two classes 00104 { 00105 m_codebook(m_code_idx, classes[0]) = +1; 00106 m_codebook(m_code_idx, classes[1]) = -1; 00107 m_code_idx++; 00108 } 00109 } 00110 00111 void CECOCDiscriminantEncoder::run_sffs(vector<int32_t>& part1, vector<int32_t>& part2) 00112 { 00113 set<int32_t> idata1; 00114 set<int32_t> idata2; 00115 00116 for (int32_t i=0; i < m_labels->get_num_labels(); ++i) 00117 { 00118 if (find(part1.begin(), part1.end(), ((CMulticlassLabels*) m_labels)->get_int_label(i)) != part1.end()) 00119 idata1.insert(i); 00120 else if (find(part2.begin(), part2.end(), ((CMulticlassLabels*) m_labels)->get_int_label(i)) != part2.end()) 00121 idata2.insert(i); 00122 } 00123 00124 float64_t MI = compute_MI(idata1, idata2); 00125 for (int32_t i=0; i < m_iterations; ++i) 00126 { 00127 if (i % 2 == 0) 00128 MI = sffs_iteration(MI, part1, idata1, part2, idata2); 00129 else 00130 MI = sffs_iteration(MI, part2, idata2, part1, idata1); 00131 } 00132 } 00133 00134 float64_t CECOCDiscriminantEncoder::sffs_iteration(float64_t MI, vector<int32_t>& part1, set<int32_t>& idata1, 00135 vector<int32_t>& part2, set<int32_t>& idata2) 00136 { 00137 if (part1.size() <= 1) 00138 return MI; 00139 00140 int32_t iclas = CMath::random(0, int32_t(part1.size()-1)); 00141 int32_t clas = part1[iclas]; 00142 00143 // move clas from part1 to part2 00144 for (int32_t i=0; i < m_labels->get_num_labels(); ++i) 00145 { 00146 if (((CMulticlassLabels*) m_labels)->get_int_label(i) == clas) 00147 { 00148 idata1.erase(i); 00149 idata2.insert(i); 00150 } 00151 } 00152 00153 float64_t new_MI = compute_MI(idata1, idata2); 00154 if (new_MI < MI) 00155 { 00156 part2.push_back(clas); 00157 part1.erase(part1.begin() + iclas); 00158 return new_MI; 00159 } 00160 else 00161 { 00162 // revert changes 00163 for (int32_t i=0; i < m_labels->get_num_labels(); ++i) 00164 { 00165 if (((CMulticlassLabels*) m_labels)->get_int_label(i) == clas) 00166 { 00167 idata2.erase(i); 00168 idata1.insert(i); 00169 } 00170 } 00171 return MI; 00172 } 00173 00174 } 00175 00176 float64_t CECOCDiscriminantEncoder::compute_MI(const set<int32_t>& idata1, const set<int32_t>& idata2) 00177 { 00178 float64_t MI = 0; 00179 00180 int32_t hist1[10]; 00181 int32_t hist2[10]; 00182 00183 for (int32_t i=0; i < m_feats.num_rows; ++i) 00184 { 00185 float64_t max_val = m_feats(i, 0); 00186 float64_t min_val = m_feats(i, 0); 00187 for (int32_t j=1; j < m_feats.num_cols; ++j) 00188 { 00189 max_val = max(max_val, m_feats(i, j)); 00190 min_val = min(min_val, m_feats(i, j)); 00191 } 00192 00193 if (max_val - min_val < 1e-10) 00194 max_val = min_val + 1; // avoid divide by zero error 00195 00196 compute_hist(i, max_val, min_val, idata1, hist1); 00197 compute_hist(i, max_val, min_val, idata2, hist2); 00198 00199 float64_t MI_i = 0; 00200 for (int j=0; j < 10; ++j) 00201 MI_i += (hist1[j]-hist2[j])*(hist1[j]-hist2[j]); 00202 MI += CMath::sqrt(MI_i); 00203 } 00204 00205 return MI; 00206 } 00207 00208 void CECOCDiscriminantEncoder::compute_hist(int32_t i, float64_t max_val, float64_t min_val, 00209 const set<int32_t>& idata, int32_t *hist) 00210 { 00211 // hist of 0:0.1:1 00212 fill(hist, hist+10, 0); 00213 00214 for (set<int32_t>::const_iterator it = idata.begin(); it != idata.end(); ++it) 00215 { 00216 float64_t val = (m_feats(i, *it) - min_val) / (max_val - min_val); 00217 int32_t pos = min(9, static_cast<int32_t>(val*10)); 00218 hist[pos]++; 00219 } 00220 }