SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Written (W) 2012 Heiko Strathmann 00009 * Copyright (C) 2012 Chiyuan Zhang 00010 */ 00011 00012 #include <shogun/lib/Set.h> 00013 #include <shogun/machine/KernelMulticlassMachine.h> 00014 00015 using namespace shogun; 00016 00017 void CKernelMulticlassMachine::store_model_features() 00018 { 00019 CKernel *kernel= m_kernel; 00020 if (!kernel) 00021 SG_ERROR("%s::store_model_features(): kernel is needed to store SV " 00022 "features.\n", get_name()); 00023 00024 CFeatures* lhs=kernel->get_lhs(); 00025 CFeatures* rhs=kernel->get_rhs(); 00026 if (!lhs) 00027 { 00028 SG_ERROR("%s::store_model_features(): kernel lhs is needed to store " 00029 "SV features.\n", get_name()); 00030 } 00031 00032 /* this map will be abused as a map */ 00033 CSet<index_t> all_sv; 00034 for (index_t i=0; i<m_machines->get_num_elements(); ++i) 00035 { 00036 CKernelMachine *machine=(CKernelMachine *)get_machine(i); 00037 for (index_t j=0; j<machine->get_num_support_vectors(); ++j) 00038 all_sv.add(machine->get_support_vector(j)); 00039 00040 SG_UNREF(machine); 00041 } 00042 00043 /* convert map to vector of SV */ 00044 SGVector<index_t> sv_idx(all_sv.get_num_elements()); 00045 for (index_t i=0; i<sv_idx.vlen; ++i) 00046 sv_idx[i]=*all_sv.get_element_ptr(i); 00047 00048 CFeatures* sv_features=lhs->copy_subset(sv_idx); 00049 00050 /* now, features are replaced by concatenated SV features */ 00051 kernel->init(sv_features, rhs); 00052 00053 /* was SG_REF'ed by copy_subset */ 00054 SG_UNREF(sv_features); 00055 00056 /* now the old SV indices have to be mapped to the new features */ 00057 00058 /* update SV of all machines */ 00059 for (int32_t i=0; i<m_machines->get_num_elements(); ++i) 00060 { 00061 CKernelMachine *machine=(CKernelMachine *)get_machine(i); 00062 00063 /* for each machine, replace SV by index in sv_idx array */ 00064 for (int32_t j=0; j<machine->get_num_support_vectors(); ++j) 00065 { 00066 /* get index of SV in old features */ 00067 index_t current_sv_idx=machine->get_support_vector(j); 00068 00069 /* the position of this old index in the map is the position of 00070 * the SV in the new features */ 00071 index_t new_sv_idx=all_sv.index_of(current_sv_idx); 00072 00073 machine->set_support_vector(j, new_sv_idx); 00074 } 00075 00076 SG_UNREF(machine); 00077 } 00078 00079 SG_UNREF(lhs); 00080 SG_UNREF(rhs); 00081 } 00082 00083 CKernelMulticlassMachine::CKernelMulticlassMachine() : CMulticlassMachine(), m_kernel(NULL) 00084 { 00085 SG_ADD((CSGObject**)&m_kernel,"kernel", "The kernel to be used", MS_AVAILABLE); 00086 } 00087 00094 CKernelMulticlassMachine::CKernelMulticlassMachine(CMulticlassStrategy *strategy, CKernel* kernel, CKernelMachine* machine, CLabels* labs) : 00095 CMulticlassMachine(strategy,(CMachine*)machine,labs), m_kernel(NULL) 00096 { 00097 set_kernel(kernel); 00098 SG_ADD((CSGObject**)&m_kernel,"kernel", "The kernel to be used", MS_AVAILABLE); 00099 } 00100 00102 CKernelMulticlassMachine::~CKernelMulticlassMachine() 00103 { 00104 SG_UNREF(m_kernel); 00105 } 00106 00111 void CKernelMulticlassMachine::set_kernel(CKernel* k) 00112 { 00113 ((CKernelMachine*)m_machine)->set_kernel(k); 00114 SG_REF(k); 00115 SG_UNREF(m_kernel); 00116 m_kernel=k; 00117 } 00118 00119 CKernel* CKernelMulticlassMachine::get_kernel() 00120 { 00121 SG_REF(m_kernel); 00122 return m_kernel; 00123 } 00124 00125 bool CKernelMulticlassMachine::init_machine_for_train(CFeatures* data) 00126 { 00127 if (data) 00128 m_kernel->init(data,data); 00129 00130 ((CKernelMachine*)m_machine)->set_kernel(m_kernel); 00131 00132 return true; 00133 } 00134 00135 bool CKernelMulticlassMachine::init_machines_for_apply(CFeatures* data) 00136 { 00137 if (data) 00138 { 00139 /* set data to rhs for this kernel */ 00140 CFeatures* lhs=m_kernel->get_lhs(); 00141 m_kernel->init(lhs, data); 00142 SG_UNREF(lhs); 00143 } 00144 00145 /* set kernel to all sub-machines */ 00146 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00147 { 00148 CKernelMachine *machine= 00149 (CKernelMachine*)m_machines->get_element(i); 00150 machine->set_kernel(m_kernel); 00151 SG_UNREF(machine); 00152 } 00153 00154 return true; 00155 } 00156 00157 bool CKernelMulticlassMachine::is_ready() 00158 { 00159 if (m_kernel && m_kernel->get_num_vec_lhs() && m_kernel->get_num_vec_rhs()) 00160 return true; 00161 00162 return false; 00163 } 00164 00165 CMachine* CKernelMulticlassMachine::get_machine_from_trained(CMachine* machine) 00166 { 00167 return new CKernelMachine((CKernelMachine*)machine); 00168 } 00169 00170 int32_t CKernelMulticlassMachine::get_num_rhs_vectors() 00171 { 00172 return m_kernel->get_num_vec_rhs(); 00173 } 00174 00175 void CKernelMulticlassMachine::add_machine_subset(SGVector<index_t> subset) 00176 { 00177 SG_NOTIMPLEMENTED 00178 } 00179 00180 void CKernelMulticlassMachine::remove_machine_subset() 00181 { 00182 SG_NOTIMPLEMENTED 00183 } 00184 00185