SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MulticlassMachine.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2011 Soeren Sonnenburg
00008  * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn
00009  * Written (W) 2013 Shell Hu and Heiko Strathmann
00010  * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia
00011  */
00012 
00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00014 #include <shogun/machine/LinearMachine.h>
00015 #include <shogun/machine/KernelMachine.h>
00016 #include <shogun/machine/MulticlassMachine.h>
00017 #include <shogun/base/Parameter.h>
00018 #include <shogun/labels/MulticlassLabels.h>
00019 #include <shogun/labels/RegressionLabels.h>
00020 #include <shogun/mathematics/Statistics.h>
00021 
00022 using namespace shogun;
00023 
00024 CMulticlassMachine::CMulticlassMachine()
00025 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()),
00026     m_machine(NULL)
00027 {
00028     SG_REF(m_multiclass_strategy);
00029     register_parameters();
00030 }
00031 
00032 CMulticlassMachine::CMulticlassMachine(
00033         CMulticlassStrategy *strategy,
00034         CMachine* machine, CLabels* labs)
00035 : CBaseMulticlassMachine(), m_multiclass_strategy(strategy)
00036 {
00037     SG_REF(strategy);
00038     set_labels(labs);
00039     SG_REF(machine);
00040     m_machine = machine;
00041     register_parameters();
00042 
00043     if (labs)
00044         init_strategy();
00045 }
00046 
00047 CMulticlassMachine::~CMulticlassMachine()
00048 {
00049     SG_UNREF(m_multiclass_strategy);
00050     SG_UNREF(m_machine);
00051 }
00052 
00053 void CMulticlassMachine::set_labels(CLabels* lab)
00054 {
00055     CMachine::set_labels(lab);
00056     if (lab)
00057         init_strategy();
00058 }
00059 
00060 void CMulticlassMachine::register_parameters()
00061 {
00062     SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
00063     SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
00064 }
00065 
00066 void CMulticlassMachine::init_strategy()
00067 {
00068     int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00069     m_multiclass_strategy->set_num_classes(num_classes);
00070 }
00071 
00072 CBinaryLabels* CMulticlassMachine::get_submachine_outputs(int32_t i)
00073 {
00074     CMachine *machine = (CMachine*)m_machines->get_element(i);
00075     ASSERT(machine)
00076     CBinaryLabels* output = machine->apply_binary();
00077     SG_UNREF(machine);
00078     return output;
00079 }
00080 
00081 float64_t CMulticlassMachine::get_submachine_output(int32_t i, int32_t num)
00082 {
00083     CMachine *machine = get_machine(i);
00084     float64_t output = 0.0;
00085     // dirty hack
00086     if (dynamic_cast<CLinearMachine*>(machine))
00087         output = ((CLinearMachine*)machine)->apply_one(num);
00088     if (dynamic_cast<CKernelMachine*>(machine))
00089         output = ((CKernelMachine*)machine)->apply_one(num);
00090     SG_UNREF(machine);
00091     return output;
00092 }
00093 
00094 CMulticlassLabels* CMulticlassMachine::apply_multiclass(CFeatures* data)
00095 {
00096     SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n",
00097             get_name(), data ? data->get_name() : "NULL", data);
00098 
00099     CMulticlassLabels* return_labels=NULL;
00100 
00101     if (data)
00102         init_machines_for_apply(data);
00103     else
00104         init_machines_for_apply(NULL);
00105 
00106     if (is_ready())
00107     {
00108         /* num vectors depends on whether data is provided */
00109         int32_t num_vectors=data ? data->get_num_vectors() :
00110                 get_num_rhs_vectors();
00111 
00112         int32_t num_machines=m_machines->get_num_elements();
00113         if (num_machines <= 0)
00114             SG_ERROR("num_machines = %d, did you train your machine?", num_machines)
00115 
00116         CMulticlassLabels* result=new CMulticlassLabels(num_vectors);
00117 
00118         // if outputs are prob, only one confidence for each class
00119         int32_t num_classes=m_multiclass_strategy->get_num_classes();
00120         EProbHeuristicType heuris = get_prob_heuris();
00121 
00122         if (heuris!=PROB_HEURIS_NONE)
00123             result->allocate_confidences_for(num_classes);
00124         else
00125             result->allocate_confidences_for(num_machines);
00126 
00127         CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
00128         SGVector<float64_t> As(num_machines);
00129         SGVector<float64_t> Bs(num_machines);
00130 
00131         for (int32_t i=0; i<num_machines; ++i)
00132         {
00133             outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
00134 
00135             if (heuris==OVA_SOFTMAX)
00136             {
00137                 CStatistics::SigmoidParamters params = CStatistics::fit_sigmoid(outputs[i]->get_values());
00138                 As[i] = params.a;
00139                 Bs[i] = params.b;
00140             }
00141 
00142             if (heuris!=PROB_HEURIS_NONE && heuris!=OVA_SOFTMAX)
00143                 outputs[i]->scores_to_probabilities(0,0);
00144         }
00145 
00146         SGVector<float64_t> output_for_i(num_machines);
00147         SGVector<float64_t> r_output_for_i(num_machines);
00148         if (heuris!=PROB_HEURIS_NONE)
00149             r_output_for_i.resize_vector(num_classes);
00150 
00151         for (int32_t i=0; i<num_vectors; i++)
00152         {
00153             for (int32_t j=0; j<num_machines; j++)
00154                 output_for_i[j] = outputs[j]->get_value(i);
00155 
00156             if (heuris==PROB_HEURIS_NONE)
00157             {
00158                 r_output_for_i = output_for_i;
00159             }
00160             else
00161             {
00162                 if (heuris==OVA_SOFTMAX)
00163                     m_multiclass_strategy->rescale_outputs(output_for_i,As,Bs);
00164                 else
00165                     m_multiclass_strategy->rescale_outputs(output_for_i);
00166 
00167                 // only first num_classes are returned
00168                 for (int32_t r=0; r<num_classes; r++)
00169                     r_output_for_i[r] = output_for_i[r];
00170 
00171                 SG_DEBUG("%s::apply_multiclass(): sum(r_output_for_i) = %f\n",
00172                     get_name(), SGVector<float64_t>::sum(r_output_for_i.vector,num_classes));
00173             }
00174 
00175             // use rescaled outputs for label decision
00176             result->set_label(i, m_multiclass_strategy->decide_label(r_output_for_i));
00177             result->set_multiclass_confidences(i, r_output_for_i);
00178         }
00179 
00180         for (int32_t i=0; i < num_machines; ++i)
00181             SG_UNREF(outputs[i]);
00182 
00183         SG_FREE(outputs);
00184 
00185         return_labels=result;
00186     }
00187     else
00188         SG_ERROR("Not ready")
00189 
00190 
00191     SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n",
00192                 get_name(), data ? data->get_name() : "NULL", data);
00193     return return_labels;
00194 }
00195 
00196 CMulticlassMultipleOutputLabels* CMulticlassMachine::apply_multiclass_multiple_output(CFeatures* data, int32_t n_outputs)
00197 {
00198     CMulticlassMultipleOutputLabels* return_labels=NULL;
00199 
00200     if (data)
00201         init_machines_for_apply(data);
00202     else
00203         init_machines_for_apply(NULL);
00204 
00205     if (is_ready())
00206     {
00207         /* num vectors depends on whether data is provided */
00208         int32_t num_vectors=data ? data->get_num_vectors() :
00209                 get_num_rhs_vectors();
00210 
00211         int32_t num_machines=m_machines->get_num_elements();
00212         if (num_machines <= 0)
00213             SG_ERROR("num_machines = %d, did you train your machine?", num_machines)
00214         REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available")
00215 
00216         CMulticlassMultipleOutputLabels* result=new CMulticlassMultipleOutputLabels(num_vectors);
00217         CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
00218 
00219         for (int32_t i=0; i < num_machines; ++i)
00220             outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
00221 
00222         SGVector<float64_t> output_for_i(num_machines);
00223         for (int32_t i=0; i<num_vectors; i++)
00224         {
00225             for (int32_t j=0; j<num_machines; j++)
00226                 output_for_i[j] = outputs[j]->get_value(i);
00227 
00228             result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs));
00229         }
00230 
00231         for (int32_t i=0; i < num_machines; ++i)
00232             SG_UNREF(outputs[i]);
00233 
00234         SG_FREE(outputs);
00235 
00236         return_labels=result;
00237     }
00238     else
00239         SG_ERROR("Not ready")
00240 
00241     return return_labels;
00242 }
00243 
00244 bool CMulticlassMachine::train_machine(CFeatures* data)
00245 {
00246     ASSERT(m_multiclass_strategy)
00247 
00248     if ( !data && !is_ready() )
00249         SG_ERROR("Please provide training data.\n")
00250     else
00251         init_machine_for_train(data);
00252 
00253     m_machines->reset_array();
00254     CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors());
00255     SG_REF(train_labels);
00256     m_machine->set_labels(train_labels);
00257 
00258     m_multiclass_strategy->train_start(CLabelsFactory::to_multiclass(m_labels), train_labels);
00259     while (m_multiclass_strategy->train_has_more())
00260     {
00261         SGVector<index_t> subset=m_multiclass_strategy->train_prepare_next();
00262         if (subset.vlen)
00263         {
00264             train_labels->add_subset(subset);
00265             add_machine_subset(subset);
00266         }
00267 
00268         m_machine->train();
00269         m_machines->push_back(get_machine_from_trained(m_machine));
00270 
00271         if (subset.vlen)
00272         {
00273             train_labels->remove_subset();
00274             remove_machine_subset();
00275         }
00276     }
00277 
00278     m_multiclass_strategy->train_stop();
00279     SG_UNREF(train_labels);
00280 
00281     return true;
00282 }
00283 
00284 float64_t CMulticlassMachine::apply_one(int32_t vec_idx)
00285 {
00286     init_machines_for_apply(NULL);
00287 
00288     ASSERT(m_machines->get_num_elements()>0)
00289     SGVector<float64_t> outputs(m_machines->get_num_elements());
00290 
00291     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00292         outputs[i] = get_submachine_output(i, vec_idx);
00293 
00294     float64_t result = m_multiclass_strategy->decide_label(outputs);
00295 
00296     return result;
00297 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation