SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Written (W) 2012 Heiko Strathmann 00009 */ 00010 00011 #include <shogun/evaluation/CrossValidationMKLStorage.h> 00012 #include <shogun/kernel/CombinedKernel.h> 00013 #include <shogun/classifier/mkl/MKL.h> 00014 #include <shogun/classifier/mkl/MKLMulticlass.h> 00015 00016 using namespace shogun; 00017 00018 void CCrossValidationMKLStorage::update_trained_machine( 00019 CMachine* machine, const char* prefix) 00020 { 00021 REQUIRE(machine, "%s::update_trained_machine(): Provided Machine is NULL!\n", 00022 get_name()); 00023 00024 CMKL* mkl=dynamic_cast<CMKL*>(machine); 00025 CMKLMulticlass* mkl_multiclass=dynamic_cast<CMKLMulticlass*>(machine); 00026 REQUIRE(mkl || mkl_multiclass, "%s::update_trained_machine(): This method is only usable " 00027 "with CMKL derived machines. This one is \"%s\"\n", get_name(), 00028 machine->get_name()); 00029 00030 CKernel* kernel = NULL; 00031 if (mkl) 00032 kernel = mkl->get_kernel(); 00033 else 00034 kernel = mkl_multiclass->get_kernel(); 00035 00036 REQUIRE(kernel, "%s::update_trained_machine(): No kernel assigned to " 00037 "machine of type \"%s\"\n", get_name(), machine->get_name()); 00038 00039 CCombinedKernel* combined_kernel=dynamic_cast<CCombinedKernel*>(kernel); 00040 REQUIRE(combined_kernel, "%s::update_trained_machine(): This method is only" 00041 " usable with CCombinedKernel on machines. This one is \"s\"\n", 00042 get_name(), kernel->get_name()); 00043 00044 SGVector<float64_t> w=combined_kernel->get_subkernel_weights(); 00045 00046 /* evtl re-allocate memory (different number of runs from evaluation before) */ 00047 if (m_mkl_weights.num_rows!=w.vlen || 00048 m_mkl_weights.num_cols!=m_num_folds*m_num_runs) 00049 { 00050 if (m_mkl_weights.matrix) 00051 { 00052 SG_DEBUG("deleting memory for mkl weight matrix\n") 00053 m_mkl_weights=SGMatrix<float64_t>(); 00054 } 00055 } 00056 00057 /* evtl allocate memory (first call) */ 00058 if (!m_mkl_weights.matrix) 00059 { 00060 SG_DEBUG("allocating memory for mkl weight matrix\n") 00061 m_mkl_weights=SGMatrix<float64_t>(w.vlen,m_num_folds*m_num_runs); 00062 } 00063 00064 /* put current mkl weights into matrix, copy memory vector wise to make 00065 * things fast. Compute index of address to where vector goes */ 00066 00067 /* number of runs is w.vlen*m_num_folds shift */ 00068 index_t run_shift=m_current_run_index*w.vlen*m_num_folds; 00069 00070 /* fold shift is m_current_fold_index*w-vlen */ 00071 index_t fold_shift=m_current_fold_index*w.vlen; 00072 00073 /* add both index shifts */ 00074 index_t first_idx=run_shift+fold_shift; 00075 SG_DEBUG("run %d, fold %d, matrix index %d\n",m_current_run_index, 00076 m_current_fold_index, first_idx); 00077 00078 /* copy memory */ 00079 memcpy(&m_mkl_weights.matrix[first_idx], w.vector, 00080 w.vlen*sizeof(float64_t)); 00081 00082 SG_UNREF(kernel); 00083 }