SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
WeightedMajorityVote.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2013 Viktor Gal
00008  * Copyright (C) 2013 Viktor Gal
00009  */
00010 
00011 #include <shogun/ensemble/WeightedMajorityVote.h>
00012 #include <shogun/base/Parameter.h>
00013 #include <shogun/lib/SGMatrix.h>
00014 #include <map>
00015 
00016 using namespace shogun;
00017 
00018 CWeightedMajorityVote::CWeightedMajorityVote()
00019     : CCombinationRule()
00020 {
00021     init();
00022     register_parameters();
00023 }
00024 
00025 CWeightedMajorityVote::CWeightedMajorityVote(SGVector<float64_t>& weights)
00026     : CCombinationRule()
00027 {
00028     init();
00029     register_parameters();
00030     m_weights = weights;
00031 }
00032 
00033 CWeightedMajorityVote::~CWeightedMajorityVote()
00034 {
00035 
00036 }
00037 
00038 SGVector<float64_t> CWeightedMajorityVote::combine(const SGMatrix<float64_t>& ensemble_result) const
00039 {
00040     REQUIRE(m_weights.vlen == ensemble_result.num_cols, "The number of results and weights does not match!");
00041     SGVector<float64_t> mv(ensemble_result.num_rows);
00042     for (index_t i = 0; i < ensemble_result.num_rows; ++i)
00043     {
00044         SGVector<float64_t> rv = ensemble_result.get_row_vector(i);
00045         mv[i] = combine(rv);
00046     }
00047 
00048     return mv;
00049 }
00050 
00051 float64_t CWeightedMajorityVote::combine(const SGVector<float64_t>& ensemble_result) const
00052 {
00053     return weighted_combine(ensemble_result);
00054 }
00055 
00056 float64_t CWeightedMajorityVote::weighted_combine(const SGVector<float64_t>& ensemble_result) const
00057 {
00058     REQUIRE(m_weights.vlen == ensemble_result.vlen, "The number of results and weights does not match!");
00059     std::map<index_t, float64_t> freq;
00060     std::map<index_t, float64_t>::iterator it;
00061     index_t max_label = -100;
00062     float64_t max = CMath::ALMOST_NEG_INFTY;
00063 
00064     for (index_t i = 0; i < ensemble_result.vlen; ++i)
00065     {
00066         if (CMath::is_nan(ensemble_result[i]))
00067             continue;
00068 
00069         it = freq.find(ensemble_result[i]);
00070         if (it == freq.end())
00071         {
00072             freq.insert(std::make_pair(ensemble_result[i], m_weights[i]));
00073             if (max < m_weights[i])
00074             {
00075                 max_label = ensemble_result[i];
00076                 max = m_weights[i];
00077             }
00078         }
00079         else
00080         {
00081             it->second += m_weights[i];
00082             if (max < it->second)
00083             {
00084                 max_label = it->first;
00085                 max = it->second;
00086             }
00087         }
00088     }
00089 
00090     return max_label;
00091 }
00092 
00093 void CWeightedMajorityVote::set_weights(SGVector<float64_t>& w)
00094 {
00095     m_weights = w;
00096 }
00097 
00098 SGVector<float64_t> CWeightedMajorityVote::get_weights() const
00099 {
00100     return m_weights;
00101 }
00102 
00103 void CWeightedMajorityVote::init()
00104 {
00105     m_weights = SGVector<float64_t>();
00106 }
00107 
00108 void CWeightedMajorityVote::register_parameters()
00109 {
00110     SG_ADD(&m_weights, "weights", "Weights for the majority vote", MS_AVAILABLE);
00111 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation