SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
KernelMulticlassMachine.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Written (W) 2012 Heiko Strathmann
00009  * Copyright (C) 2012 Chiyuan Zhang
00010  */
00011 
00012 #include <shogun/lib/Set.h>
00013 #include <shogun/machine/KernelMulticlassMachine.h>
00014 
00015 using namespace shogun;
00016 
00017 void CKernelMulticlassMachine::store_model_features()
00018 {
00019     CKernel *kernel= m_kernel;
00020     if (!kernel)
00021         SG_ERROR("%s::store_model_features(): kernel is needed to store SV "
00022                 "features.\n", get_name());
00023 
00024     CFeatures* lhs=kernel->get_lhs();
00025     CFeatures* rhs=kernel->get_rhs();
00026     if (!lhs)
00027     {
00028         SG_ERROR("%s::store_model_features(): kernel lhs is needed to store "
00029         "SV features.\n", get_name());
00030     }
00031 
00032     /* this map will be abused as a map */
00033     CSet<index_t> all_sv;
00034     for (index_t i=0; i<m_machines->get_num_elements(); ++i)
00035     {
00036         CKernelMachine *machine=(CKernelMachine *)get_machine(i);
00037         for (index_t j=0; j<machine->get_num_support_vectors(); ++j)
00038             all_sv.add(machine->get_support_vector(j));
00039 
00040         SG_UNREF(machine);
00041     }
00042 
00043     /* convert map to vector of SV */
00044     SGVector<index_t> sv_idx(all_sv.get_num_elements());
00045     for (index_t i=0; i<sv_idx.vlen; ++i)
00046         sv_idx[i]=*all_sv.get_element_ptr(i);
00047 
00048     CFeatures* sv_features=lhs->copy_subset(sv_idx);
00049 
00050     /* now, features are replaced by concatenated SV features */
00051     kernel->init(sv_features, rhs);
00052 
00053     /* was SG_REF'ed by copy_subset */
00054     SG_UNREF(sv_features);
00055 
00056     /* now the old SV indices have to be mapped to the new features */
00057 
00058     /* update SV of all machines */
00059     for (int32_t i=0; i<m_machines->get_num_elements(); ++i)
00060     {
00061         CKernelMachine *machine=(CKernelMachine *)get_machine(i);
00062 
00063         /* for each machine, replace SV by index in sv_idx array */
00064         for (int32_t j=0; j<machine->get_num_support_vectors(); ++j)
00065         {
00066             /* get index of SV in old features */
00067             index_t current_sv_idx=machine->get_support_vector(j);
00068 
00069             /* the position of this old index in the map is the position of
00070              * the SV in the new features */
00071             index_t new_sv_idx=all_sv.index_of(current_sv_idx);
00072 
00073             machine->set_support_vector(j, new_sv_idx);
00074         }
00075 
00076         SG_UNREF(machine);
00077     }
00078 
00079     SG_UNREF(lhs);
00080     SG_UNREF(rhs);
00081 }
00082 
00083 CKernelMulticlassMachine::CKernelMulticlassMachine() : CMulticlassMachine(), m_kernel(NULL)
00084 {
00085     SG_ADD((CSGObject**)&m_kernel,"kernel", "The kernel to be used", MS_AVAILABLE);
00086 }
00087 
00094 CKernelMulticlassMachine::CKernelMulticlassMachine(CMulticlassStrategy *strategy, CKernel* kernel, CKernelMachine* machine, CLabels* labs) :
00095     CMulticlassMachine(strategy,(CMachine*)machine,labs), m_kernel(NULL)
00096 {
00097     set_kernel(kernel);
00098     SG_ADD((CSGObject**)&m_kernel,"kernel", "The kernel to be used", MS_AVAILABLE);
00099 }
00100 
00102 CKernelMulticlassMachine::~CKernelMulticlassMachine()
00103 {
00104     SG_UNREF(m_kernel);
00105 }
00106 
00111 void CKernelMulticlassMachine::set_kernel(CKernel* k)
00112 {
00113     ((CKernelMachine*)m_machine)->set_kernel(k);
00114     SG_REF(k);
00115     SG_UNREF(m_kernel);
00116     m_kernel=k;
00117 }
00118 
00119 CKernel* CKernelMulticlassMachine::get_kernel()
00120 {
00121     SG_REF(m_kernel);
00122     return m_kernel;
00123 }
00124 
00125 bool CKernelMulticlassMachine::init_machine_for_train(CFeatures* data)
00126 {
00127     if (data)
00128         m_kernel->init(data,data);
00129 
00130     ((CKernelMachine*)m_machine)->set_kernel(m_kernel);
00131 
00132     return true;
00133 }
00134 
00135 bool CKernelMulticlassMachine::init_machines_for_apply(CFeatures* data)
00136 {
00137     if (data)
00138     {
00139         /* set data to rhs for this kernel */
00140         CFeatures* lhs=m_kernel->get_lhs();
00141         m_kernel->init(lhs, data);
00142         SG_UNREF(lhs);
00143     }
00144 
00145     /* set kernel to all sub-machines */
00146     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00147     {
00148         CKernelMachine *machine=
00149                 (CKernelMachine*)m_machines->get_element(i);
00150         machine->set_kernel(m_kernel);
00151         SG_UNREF(machine);
00152     }
00153 
00154     return true;
00155 }
00156 
00157 bool CKernelMulticlassMachine::is_ready()
00158 {
00159     if (m_kernel && m_kernel->get_num_vec_lhs() && m_kernel->get_num_vec_rhs())
00160             return true;
00161 
00162     return false;
00163 }
00164 
00165 CMachine* CKernelMulticlassMachine::get_machine_from_trained(CMachine* machine)
00166 {
00167     return new CKernelMachine((CKernelMachine*)machine);
00168 }
00169 
00170 int32_t CKernelMulticlassMachine::get_num_rhs_vectors()
00171 {
00172     return m_kernel->get_num_vec_rhs();
00173 }
00174 
00175 void CKernelMulticlassMachine::add_machine_subset(SGVector<index_t> subset)
00176 {
00177     SG_NOTIMPLEMENTED
00178 }
00179 
00180 void CKernelMulticlassMachine::remove_machine_subset()
00181 {
00182     SG_NOTIMPLEMENTED
00183 }
00184 
00185 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation