SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MulticlassOneVsRestStrategy.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00012 #include <shogun/labels/BinaryLabels.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 #include <shogun/mathematics/Math.h>
00015 
00016 using namespace shogun;
00017 
00018 CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy()
00019     : CMulticlassStrategy()
00020 {
00021 }
00022 
00023 CMulticlassOneVsRestStrategy::CMulticlassOneVsRestStrategy(EProbHeuristicType prob_heuris)
00024     : CMulticlassStrategy(prob_heuris)
00025 {
00026 }
00027 
00028 SGVector<int32_t> CMulticlassOneVsRestStrategy::train_prepare_next()
00029 {
00030     for (int32_t i=0; i < m_orig_labels->get_num_labels(); ++i)
00031     {
00032         if (((CMulticlassLabels*) m_orig_labels)->get_int_label(i)==m_train_iter)
00033             ((CBinaryLabels*) m_train_labels)->set_label(i, +1.0);
00034         else
00035             ((CBinaryLabels*) m_train_labels)->set_label(i, -1.0);
00036     }
00037 
00038     // increase m_train_iter *after* setting labels
00039     CMulticlassStrategy::train_prepare_next();
00040 
00041     return SGVector<int32_t>();
00042 }
00043 
00044 int32_t CMulticlassOneVsRestStrategy::decide_label(SGVector<float64_t> outputs)
00045 {
00046     if (m_rejection_strategy && m_rejection_strategy->reject(outputs))
00047         return CDenseLabels::REJECTION_LABEL;
00048 
00049     return SGVector<float64_t>::arg_max(outputs.vector, 1, outputs.vlen);
00050 }
00051 
00052 SGVector<index_t> CMulticlassOneVsRestStrategy::decide_label_multiple_output(SGVector<float64_t> outputs, int32_t n_outputs)
00053 {
00054     float64_t* outputs_ = SG_MALLOC(float64_t, outputs.vlen);
00055     int32_t* indices_ = SG_MALLOC(int32_t, outputs.vlen);
00056     for (int32_t i=0; i<outputs.vlen; i++)
00057     {
00058         outputs_[i] = outputs[i];
00059         indices_[i] = i;
00060     }
00061     CMath::qsort_backward_index(outputs_,indices_,outputs.vlen);
00062     SGVector<index_t> result(n_outputs);
00063     for (int32_t i=0; i<n_outputs; i++)
00064         result[i] = indices_[i];
00065     SG_FREE(outputs_);
00066     SG_FREE(indices_);
00067     return result;
00068 }
00069 
00070 void CMulticlassOneVsRestStrategy::rescale_outputs(SGVector<float64_t> outputs)
00071 {
00072     switch(get_prob_heuris_type())
00073     {
00074         case OVA_NORM:
00075             rescale_heuris_norm(outputs);
00076             break;
00077         case OVA_SOFTMAX:
00078             SG_ERROR("%s::rescale_outputs(): Need to specify sigmoid parameters!\n", get_name());
00079             break;
00080         case PROB_HEURIS_NONE:
00081             break;
00082         default:
00083             SG_ERROR("%s::rescale_outputs(): Unknown OVA probability heuristic type!\n", get_name());
00084             break;
00085     }
00086 }
00087 
00088 void CMulticlassOneVsRestStrategy::rescale_outputs(SGVector<float64_t> outputs,
00089         const SGVector<float64_t> As, const SGVector<float64_t> Bs)
00090 {
00091     if (get_prob_heuris_type()==OVA_SOFTMAX)
00092         rescale_heuris_softmax(outputs,As,Bs);
00093     else
00094         rescale_outputs(outputs);
00095 }
00096 
00097 void CMulticlassOneVsRestStrategy::rescale_heuris_norm(SGVector<float64_t> outputs)
00098 {
00099     if (m_num_classes != outputs.vlen)
00100     {
00101         SG_ERROR("%s::rescale_heuris_norm(): size(outputs) = %d != m_num_classes = %d\n",
00102                 get_name(), outputs.vlen, m_num_classes);
00103     }
00104 
00105     float64_t norm = SGVector<float64_t>::sum(outputs);
00106     norm += 1E-10;
00107     for (int32_t i=0; i<outputs.vlen; i++)
00108         outputs[i] /= norm;
00109 }
00110 
00111 void CMulticlassOneVsRestStrategy::rescale_heuris_softmax(SGVector<float64_t> outputs,
00112         const SGVector<float64_t> As, const SGVector<float64_t> Bs)
00113 {
00114     if (m_num_classes != outputs.vlen)
00115     {
00116         SG_ERROR("%s::rescale_heuris_softmax(): size(outputs) = %d != m_num_classes = %d\n",
00117                 get_name(), outputs.vlen, m_num_classes);
00118     }
00119 
00120     for (int32_t i=0; i<outputs.vlen; i++)
00121         outputs[i] = CMath::exp(-As[i]*outputs[i]-Bs[i]);
00122 
00123     float64_t norm = SGVector<float64_t>::sum(outputs);
00124     norm += 1E-10;
00125     for (int32_t i=0; i<outputs.vlen; i++)
00126         outputs[i] /= norm;
00127 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation