SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <shogun/multiclass/ecoc/ECOCStrategy.h> 00012 #include <shogun/labels/BinaryLabels.h> 00013 #include <shogun/labels/MulticlassLabels.h> 00014 00015 using namespace shogun; 00016 00017 CECOCStrategy::CECOCStrategy() : CMulticlassStrategy() 00018 { 00019 init(); 00020 } 00021 00022 CECOCStrategy::CECOCStrategy(CECOCEncoder *encoder, CECOCDecoder *decoder) 00023 : CMulticlassStrategy() 00024 { 00025 init(); 00026 m_encoder=encoder; 00027 m_decoder=decoder; 00028 SG_REF(m_encoder); 00029 SG_REF(decoder); 00030 } 00031 00032 void CECOCStrategy::init() 00033 { 00034 m_encoder=NULL; 00035 m_decoder=NULL; 00036 00037 SG_ADD((CSGObject **)&m_encoder, "encoder", "ECOC Encoder", MS_NOT_AVAILABLE); 00038 SG_ADD((CSGObject **)&m_decoder, "decoder", "ECOC Decoder", MS_NOT_AVAILABLE); 00039 } 00040 00041 CECOCStrategy::~CECOCStrategy() 00042 { 00043 SG_UNREF(m_encoder); 00044 SG_UNREF(m_decoder); 00045 } 00046 00047 void CECOCStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels) 00048 { 00049 CMulticlassStrategy::train_start(orig_labels, train_labels); 00050 00051 m_codebook = m_encoder->create_codebook(m_num_classes); 00052 } 00053 00054 bool CECOCStrategy::train_has_more() 00055 { 00056 return m_train_iter < m_codebook.num_rows; 00057 } 00058 00059 SGVector<int32_t> CECOCStrategy::train_prepare_next() 00060 { 00061 SGVector<int32_t> subset(m_orig_labels->get_num_labels(), false); 00062 int32_t tot=0; 00063 for (int32_t i=0; i < m_orig_labels->get_num_labels(); ++i) 00064 { 00065 int32_t label = ((CMulticlassLabels*) m_orig_labels)->get_int_label(i); 00066 switch (m_codebook(m_train_iter, label)) 00067 { 00068 case -1: 00069 ((CBinaryLabels*) m_train_labels)->set_label(i, -1); 00070 subset[tot++]=i; 00071 break; 00072 case 1: 00073 ((CBinaryLabels*) m_train_labels)->set_label(i, 1); 00074 subset[tot++]=i; 00075 break; 00076 default: 00077 // 0 means ignore 00078 break; 00079 } 00080 } 00081 00082 CMulticlassStrategy::train_prepare_next(); 00083 return SGVector<int32_t>(subset.vector, tot, true); 00084 } 00085 00086 int32_t CECOCStrategy::decide_label(SGVector<float64_t> outputs) 00087 { 00088 return m_decoder->decide_label(outputs, m_codebook); 00089 } 00090 00091 int32_t CECOCStrategy::get_num_machines() 00092 { 00093 return m_codebook.num_cols; 00094 }