SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Written (W) 2012 Heiko Strathmann 00009 */ 00010 00011 #include <shogun/evaluation/CrossValidationPrintOutput.h> 00012 #include <shogun/machine/LinearMachine.h> 00013 #include <shogun/machine/LinearMulticlassMachine.h> 00014 #include <shogun/machine/KernelMachine.h> 00015 #include <shogun/machine/KernelMulticlassMachine.h> 00016 #include <shogun/kernel/CombinedKernel.h> 00017 #include <shogun/classifier/mkl/MKL.h> 00018 #include <shogun/classifier/mkl/MKLMulticlass.h> 00019 00020 using namespace shogun; 00021 00022 void CCrossValidationPrintOutput::init_num_runs(index_t num_runs, 00023 const char* prefix) 00024 { 00025 SG_PRINT("%scross validation number of runs %d\n", prefix, num_runs) 00026 } 00027 00029 void CCrossValidationPrintOutput::init_num_folds(index_t num_folds, 00030 const char* prefix) 00031 { 00032 SG_PRINT("%scross validation number of folds %d\n", prefix, num_folds) 00033 } 00034 00035 void CCrossValidationPrintOutput::update_run_index(index_t run_index, 00036 const char* prefix) 00037 { 00038 SG_PRINT("%scross validation run %d\n", prefix, run_index) 00039 } 00040 00041 void CCrossValidationPrintOutput::update_fold_index(index_t fold_index, 00042 const char* prefix) 00043 { 00044 SG_PRINT("%sfold %d\n", prefix, fold_index) 00045 } 00046 00047 void CCrossValidationPrintOutput::update_train_indices( 00048 SGVector<index_t> indices, const char* prefix) 00049 { 00050 indices.display_vector("train_indices", prefix); 00051 } 00052 00053 void CCrossValidationPrintOutput::update_test_indices( 00054 SGVector<index_t> indices, const char* prefix) 00055 { 00056 indices.display_vector("test_indices", prefix); 00057 } 00058 00059 void CCrossValidationPrintOutput::update_trained_machine( 00060 CMachine* machine, const char* prefix) 00061 { 00062 if (dynamic_cast<CLinearMachine*>(machine)) 00063 { 00064 CLinearMachine* linear_machine=(CLinearMachine*)machine; 00065 linear_machine->get_w().display_vector("learned_w", prefix); 00066 SG_PRINT("%slearned_bias=%f\n", prefix, linear_machine->get_bias()) 00067 } 00068 00069 if (dynamic_cast<CKernelMachine*>(machine)) 00070 { 00071 CKernelMachine* kernel_machine=(CKernelMachine*)machine; 00072 kernel_machine->get_alphas().display_vector("learned_alphas", prefix); 00073 SG_PRINT("%slearned_bias=%f\n", prefix, kernel_machine->get_bias()) 00074 } 00075 00076 if (dynamic_cast<CLinearMulticlassMachine*>(machine) 00077 || dynamic_cast<CKernelMulticlassMachine*>(machine)) 00078 { 00079 /* append one tab to prefix */ 00080 char* new_prefix=append_tab_to_string(prefix); 00081 00082 CMulticlassMachine* mc_machine=(CMulticlassMachine*)machine; 00083 for (int i=0; i<mc_machine->get_num_machines(); i++) 00084 { 00085 CMachine* sub_machine=mc_machine->get_machine(i); 00086 //SG_PRINT("%smulti-class machine %d:\n", i, sub_machine) 00087 this->update_trained_machine(sub_machine, new_prefix); 00088 SG_UNREF(sub_machine); 00089 } 00090 00091 /* clean up */ 00092 SG_FREE(new_prefix); 00093 } 00094 00095 if (dynamic_cast<CMKL*>(machine)) 00096 { 00097 CMKL* mkl=(CMKL*)machine; 00098 CCombinedKernel* kernel=dynamic_cast<CCombinedKernel*>( 00099 mkl->get_kernel()); 00100 kernel->get_subkernel_weights().display_vector("MKL sub-kernel weights", 00101 prefix); 00102 SG_UNREF(kernel); 00103 } 00104 00105 if (dynamic_cast<CMKLMulticlass*>(machine)) 00106 { 00107 CMKLMulticlass* mkl=(CMKLMulticlass*)machine; 00108 CCombinedKernel* kernel=dynamic_cast<CCombinedKernel*>( 00109 mkl->get_kernel()); 00110 kernel->get_subkernel_weights().display_vector("MKL sub-kernel weights", 00111 prefix); 00112 SG_UNREF(kernel); 00113 } 00114 } 00115 00116 void CCrossValidationPrintOutput::update_test_result(CLabels* results, 00117 const char* prefix) 00118 { 00119 results->get_values().display_vector("test_labels", prefix); 00120 } 00121 00122 void CCrossValidationPrintOutput::update_test_true_result(CLabels* results, 00123 const char* prefix) 00124 { 00125 results->get_values().display_vector("true_labels", prefix); 00126 } 00127 00128 void CCrossValidationPrintOutput::update_evaluation_result(float64_t result, 00129 const char* prefix) 00130 { 00131 SG_PRINT("%sevaluation result=%f\n", prefix, result) 00132 } 00133 00134 char* CCrossValidationPrintOutput::append_tab_to_string(const char* string) 00135 { 00136 /* allocate memory, concatenate and add termination character */ 00137 index_t len=strlen(string); 00138 char* new_prefix=SG_MALLOC(char, len+2); 00139 memcpy(new_prefix, string, sizeof(char*)*len); 00140 new_prefix[len]='\t'; 00141 new_prefix[len+1]='\0'; 00142 00143 return new_prefix; 00144 }