SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/classifier/svm/GPBTSVM.h> 00012 #include <shogun/lib/external/gpdt.h> 00013 #include <shogun/lib/external/gpdtsolve.h> 00014 #include <shogun/io/SGIO.h> 00015 #include <shogun/labels/BinaryLabels.h> 00016 00017 using namespace shogun; 00018 00019 CGPBTSVM::CGPBTSVM() 00020 : CSVM(), model(NULL) 00021 { 00022 } 00023 00024 CGPBTSVM::CGPBTSVM(float64_t C, CKernel* k, CLabels* lab) 00025 : CSVM(C, k, lab), model(NULL) 00026 { 00027 } 00028 00029 CGPBTSVM::~CGPBTSVM() 00030 { 00031 SG_FREE(model); 00032 } 00033 00034 bool CGPBTSVM::train_machine(CFeatures* data) 00035 { 00036 float64_t* solution; /* store the solution found */ 00037 QPproblem prob; /* object containing the solvers */ 00038 00039 ASSERT(kernel) 00040 ASSERT(m_labels && m_labels->get_num_labels()) 00041 ASSERT(m_labels->get_label_type() == LT_BINARY) 00042 if (data) 00043 { 00044 if (m_labels->get_num_labels() != data->get_num_vectors()) 00045 SG_ERROR("Number of training vectors does not match number of labels\n") 00046 kernel->init(data, data); 00047 } 00048 00049 SGVector<int32_t> lab=((CBinaryLabels*) m_labels)->get_int_labels(); 00050 prob.KER=new sKernel(kernel, lab.vlen); 00051 prob.y=lab.vector; 00052 prob.ell=lab.vlen; 00053 SG_INFO("%d trainlabels\n", prob.ell) 00054 00055 // /*** set options defaults ***/ 00056 prob.delta = epsilon; 00057 prob.maxmw = kernel->get_cache_size(); 00058 prob.verbosity = 0; 00059 prob.preprocess_size = -1; 00060 prob.projection_projector = -1; 00061 prob.c_const = get_C1(); 00062 prob.chunk_size = get_qpsize(); 00063 prob.linadd = get_linadd_enabled(); 00064 00065 if (prob.chunk_size < 2) prob.chunk_size = 2; 00066 if (prob.q <= 0) prob.q = prob.chunk_size / 3; 00067 if (prob.q < 2) prob.q = 2; 00068 if (prob.q > prob.chunk_size) prob.q = prob.chunk_size; 00069 prob.q = prob.q & (~1); 00070 if (prob.maxmw < 5) 00071 prob.maxmw = 5; 00072 00073 /*** set the problem description for final report ***/ 00074 SG_INFO("\nTRAINING PARAMETERS:\n") 00075 SG_INFO("\tNumber of training documents: %d\n", prob.ell) 00076 SG_INFO("\tq: %d\n", prob.chunk_size) 00077 SG_INFO("\tn: %d\n", prob.q) 00078 SG_INFO("\tC: %lf\n", prob.c_const) 00079 SG_INFO("\tkernel type: %d\n", prob.ker_type) 00080 SG_INFO("\tcache size: %dMb\n", prob.maxmw) 00081 SG_INFO("\tStopping tolerance: %lf\n", prob.delta) 00082 00083 // /*** compute the number of cache rows up to maxmw Mb. ***/ 00084 if (prob.preprocess_size == -1) 00085 prob.preprocess_size = (int32_t) ( (float64_t)prob.chunk_size * 1.5 ); 00086 00087 if (prob.projection_projector == -1) 00088 { 00089 if (prob.chunk_size <= 20) prob.projection_projector = 0; 00090 else prob.projection_projector = 1; 00091 } 00092 00093 /*** compute the problem solution *******************************************/ 00094 solution = SG_MALLOC(float64_t, prob.ell); 00095 prob.gpdtsolve(solution); 00096 /****************************************************************************/ 00097 00098 CSVM::set_objective(prob.objective_value); 00099 00100 int32_t num_sv=0; 00101 int32_t bsv=0; 00102 int32_t i=0; 00103 int32_t k=0; 00104 00105 for (i = 0; i < prob.ell; i++) 00106 { 00107 if (solution[i] > prob.DELTAsv) 00108 { 00109 num_sv++; 00110 if (solution[i] > (prob.c_const - prob.DELTAsv)) bsv++; 00111 } 00112 } 00113 00114 create_new_model(num_sv); 00115 set_bias(prob.bee); 00116 00117 SG_INFO("SV: %d BSV = %d\n", num_sv, bsv) 00118 00119 for (i = 0; i < prob.ell; i++) 00120 { 00121 if (solution[i] > prob.DELTAsv) 00122 { 00123 set_support_vector(k, i); 00124 set_alpha(k++, solution[i]*((CBinaryLabels*) m_labels)->get_label(i)); 00125 } 00126 } 00127 00128 delete prob.KER; 00129 SG_FREE(solution); 00130 00131 return true; 00132 }