SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Viktor Gal 00008 * Copyright (C) 2013 Viktor Gal 00009 */ 00010 00011 #include <shogun/ensemble/WeightedMajorityVote.h> 00012 #include <shogun/base/Parameter.h> 00013 #include <shogun/lib/SGMatrix.h> 00014 #include <map> 00015 00016 using namespace shogun; 00017 00018 CWeightedMajorityVote::CWeightedMajorityVote() 00019 : CCombinationRule() 00020 { 00021 init(); 00022 register_parameters(); 00023 } 00024 00025 CWeightedMajorityVote::CWeightedMajorityVote(SGVector<float64_t>& weights) 00026 : CCombinationRule() 00027 { 00028 init(); 00029 register_parameters(); 00030 m_weights = weights; 00031 } 00032 00033 CWeightedMajorityVote::~CWeightedMajorityVote() 00034 { 00035 00036 } 00037 00038 SGVector<float64_t> CWeightedMajorityVote::combine(const SGMatrix<float64_t>& ensemble_result) const 00039 { 00040 REQUIRE(m_weights.vlen == ensemble_result.num_cols, "The number of results and weights does not match!"); 00041 SGVector<float64_t> mv(ensemble_result.num_rows); 00042 for (index_t i = 0; i < ensemble_result.num_rows; ++i) 00043 { 00044 SGVector<float64_t> rv = ensemble_result.get_row_vector(i); 00045 mv[i] = combine(rv); 00046 } 00047 00048 return mv; 00049 } 00050 00051 float64_t CWeightedMajorityVote::combine(const SGVector<float64_t>& ensemble_result) const 00052 { 00053 return weighted_combine(ensemble_result); 00054 } 00055 00056 float64_t CWeightedMajorityVote::weighted_combine(const SGVector<float64_t>& ensemble_result) const 00057 { 00058 REQUIRE(m_weights.vlen == ensemble_result.vlen, "The number of results and weights does not match!"); 00059 std::map<index_t, float64_t> freq; 00060 std::map<index_t, float64_t>::iterator it; 00061 index_t max_label = -100; 00062 float64_t max = CMath::ALMOST_NEG_INFTY; 00063 00064 for (index_t i = 0; i < ensemble_result.vlen; ++i) 00065 { 00066 if (CMath::is_nan(ensemble_result[i])) 00067 continue; 00068 00069 it = freq.find(ensemble_result[i]); 00070 if (it == freq.end()) 00071 { 00072 freq.insert(std::make_pair(ensemble_result[i], m_weights[i])); 00073 if (max < m_weights[i]) 00074 { 00075 max_label = ensemble_result[i]; 00076 max = m_weights[i]; 00077 } 00078 } 00079 else 00080 { 00081 it->second += m_weights[i]; 00082 if (max < it->second) 00083 { 00084 max_label = it->first; 00085 max = it->second; 00086 } 00087 } 00088 } 00089 00090 return max_label; 00091 } 00092 00093 void CWeightedMajorityVote::set_weights(SGVector<float64_t>& w) 00094 { 00095 m_weights = w; 00096 } 00097 00098 SGVector<float64_t> CWeightedMajorityVote::get_weights() const 00099 { 00100 return m_weights; 00101 } 00102 00103 void CWeightedMajorityVote::init() 00104 { 00105 m_weights = SGVector<float64_t>(); 00106 } 00107 00108 void CWeightedMajorityVote::register_parameters() 00109 { 00110 SG_ADD(&m_weights, "weights", "Weights for the majority vote", MS_AVAILABLE); 00111 }