SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #include <shogun/multiclass/MulticlassTreeGuidedLogisticRegression.h> 00012 #ifdef HAVE_EIGEN3 00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00014 #include <shogun/mathematics/Math.h> 00015 #include <shogun/labels/MulticlassLabels.h> 00016 #include <shogun/lib/slep/slep_mc_tree_lr.h> 00017 00018 using namespace shogun; 00019 00020 CMulticlassTreeGuidedLogisticRegression::CMulticlassTreeGuidedLogisticRegression() : 00021 CLinearMulticlassMachine() 00022 { 00023 init_defaults(); 00024 } 00025 00026 CMulticlassTreeGuidedLogisticRegression::CMulticlassTreeGuidedLogisticRegression(float64_t z, CDotFeatures* feats, CLabels* labs, CIndexBlockTree* tree) : 00027 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),feats,NULL,labs) 00028 { 00029 init_defaults(); 00030 set_z(z); 00031 set_index_tree(tree); 00032 } 00033 00034 void CMulticlassTreeGuidedLogisticRegression::init_defaults() 00035 { 00036 m_index_tree = NULL; 00037 set_z(0.1); 00038 set_epsilon(1e-2); 00039 set_max_iter(10000); 00040 } 00041 00042 void CMulticlassTreeGuidedLogisticRegression::register_parameters() 00043 { 00044 SG_ADD(&m_z, "m_z", "regularization constant",MS_AVAILABLE); 00045 SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE); 00046 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE); 00047 } 00048 00049 CMulticlassTreeGuidedLogisticRegression::~CMulticlassTreeGuidedLogisticRegression() 00050 { 00051 SG_UNREF(m_index_tree); 00052 } 00053 00054 bool CMulticlassTreeGuidedLogisticRegression::train_machine(CFeatures* data) 00055 { 00056 if (data) 00057 set_features((CDotFeatures*)data); 00058 00059 ASSERT(m_features) 00060 ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS) 00061 ASSERT(m_multiclass_strategy) 00062 ASSERT(m_index_tree) 00063 00064 int32_t n_classes = ((CMulticlassLabels*)m_labels)->get_num_classes(); 00065 int32_t n_feats = m_features->get_dim_feature_space(); 00066 00067 slep_options options = slep_options::default_options(); 00068 if (m_machines->get_num_elements()!=0) 00069 { 00070 SGMatrix<float64_t> all_w_old(n_feats, n_classes); 00071 SGVector<float64_t> all_c_old(n_classes); 00072 for (int32_t i=0; i<n_classes; i++) 00073 { 00074 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i); 00075 SGVector<float64_t> w = machine->get_w(); 00076 for (int32_t j=0; j<n_feats; j++) 00077 all_w_old(j,i) = w[j]; 00078 all_c_old[i] = machine->get_bias(); 00079 SG_UNREF(machine); 00080 } 00081 options.last_result = new slep_result_t(all_w_old,all_c_old); 00082 m_machines->reset_array(); 00083 } 00084 if (m_index_tree->is_general()) 00085 { 00086 SGVector<float64_t> G = m_index_tree->get_SLEP_G(); 00087 options.G = G.vector; 00088 } 00089 SGVector<float64_t> ind_t = m_index_tree->get_SLEP_ind_t(); 00090 options.ind_t = ind_t.vector; 00091 options.n_nodes = ind_t.size()/3; 00092 options.tolerance = m_epsilon; 00093 options.max_iter = m_max_iter; 00094 slep_result_t result = slep_mc_tree_lr(m_features,(CMulticlassLabels*)m_labels,m_z,options); 00095 00096 SGMatrix<float64_t> all_w = result.w; 00097 SGVector<float64_t> all_c = result.c; 00098 for (int32_t i=0; i<n_classes; i++) 00099 { 00100 SGVector<float64_t> w(n_feats); 00101 for (int32_t j=0; j<n_feats; j++) 00102 w[j] = all_w(j,i); 00103 float64_t c = all_c[i]; 00104 CLinearMachine* machine = new CLinearMachine(); 00105 machine->set_w(w); 00106 machine->set_bias(c); 00107 m_machines->push_back(machine); 00108 } 00109 return true; 00110 } 00111 #endif /* HAVE_EIGEN3 */