SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #include <shogun/multiclass/MulticlassLogisticRegression.h> 00012 #ifdef HAVE_EIGEN3 00013 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00014 #include <shogun/io/SGIO.h> 00015 #include <shogun/mathematics/Math.h> 00016 #include <shogun/labels/MulticlassLabels.h> 00017 #include <shogun/lib/slep/slep_mc_plain_lr.h> 00018 00019 using namespace shogun; 00020 00021 CMulticlassLogisticRegression::CMulticlassLogisticRegression() : 00022 CLinearMulticlassMachine() 00023 { 00024 init_defaults(); 00025 } 00026 00027 CMulticlassLogisticRegression::CMulticlassLogisticRegression(float64_t z, CDotFeatures* feats, CLabels* labs) : 00028 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),feats,NULL,labs) 00029 { 00030 init_defaults(); 00031 set_z(z); 00032 } 00033 00034 void CMulticlassLogisticRegression::init_defaults() 00035 { 00036 set_z(0.1); 00037 set_epsilon(1e-2); 00038 set_max_iter(10000); 00039 } 00040 00041 void CMulticlassLogisticRegression::register_parameters() 00042 { 00043 SG_ADD(&m_z, "m_z", "regularization constant",MS_AVAILABLE); 00044 SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE); 00045 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE); 00046 } 00047 00048 CMulticlassLogisticRegression::~CMulticlassLogisticRegression() 00049 { 00050 } 00051 00052 bool CMulticlassLogisticRegression::train_machine(CFeatures* data) 00053 { 00054 if (data) 00055 set_features((CDotFeatures*)data); 00056 00057 REQUIRE(m_features, "%s::train_machine(): No features attached!\n"); 00058 REQUIRE(m_labels, "%s::train_machine(): No labels attached!\n"); 00059 REQUIRE(m_labels->get_label_type()==LT_MULTICLASS, "%s::train_machine(): " 00060 "Attached labels are no multiclass labels\n"); 00061 REQUIRE(m_multiclass_strategy, "%s::train_machine(): No multiclass strategy" 00062 " attached!\n"); 00063 00064 int32_t n_classes = ((CMulticlassLabels*)m_labels)->get_num_classes(); 00065 int32_t n_feats = m_features->get_dim_feature_space(); 00066 00067 slep_options options = slep_options::default_options(); 00068 if (m_machines->get_num_elements()!=0) 00069 { 00070 SGMatrix<float64_t> all_w_old(n_feats, n_classes); 00071 SGVector<float64_t> all_c_old(n_classes); 00072 for (int32_t i=0; i<n_classes; i++) 00073 { 00074 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i); 00075 SGVector<float64_t> w = machine->get_w(); 00076 for (int32_t j=0; j<n_feats; j++) 00077 all_w_old(j,i) = w[j]; 00078 all_c_old[i] = machine->get_bias(); 00079 SG_UNREF(machine); 00080 } 00081 options.last_result = new slep_result_t(all_w_old,all_c_old); 00082 m_machines->reset_array(); 00083 } 00084 options.tolerance = m_epsilon; 00085 options.max_iter = m_max_iter; 00086 slep_result_t result = slep_mc_plain_lr(m_features,(CMulticlassLabels*)m_labels,m_z,options); 00087 00088 SGMatrix<float64_t> all_w = result.w; 00089 SGVector<float64_t> all_c = result.c; 00090 for (int32_t i=0; i<n_classes; i++) 00091 { 00092 SGVector<float64_t> w(n_feats); 00093 for (int32_t j=0; j<n_feats; j++) 00094 w[j] = all_w(j,i); 00095 float64_t c = all_c[i]; 00096 CLinearMachine* machine = new CLinearMachine(); 00097 machine->set_w(w); 00098 machine->set_bias(c); 00099 m_machines->push_back(machine); 00100 } 00101 return true; 00102 } 00103 #endif /* HAVE_EIGEN3 */