SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskClusteredLogisticRegression.h> 00011 #include <shogun/lib/malsar/malsar_clustered.h> 00012 #include <shogun/lib/malsar/malsar_options.h> 00013 #include <shogun/lib/SGVector.h> 00014 00015 namespace shogun 00016 { 00017 00018 CMultitaskClusteredLogisticRegression::CMultitaskClusteredLogisticRegression() : 00019 CMultitaskLogisticRegression(), m_rho1(0.0), m_rho2(0.0) 00020 { 00021 } 00022 00023 CMultitaskClusteredLogisticRegression::CMultitaskClusteredLogisticRegression( 00024 float64_t rho1, float64_t rho2, CDotFeatures* train_features, 00025 CBinaryLabels* train_labels, CTaskGroup* task_group, int32_t n_clusters) : 00026 CMultitaskLogisticRegression(0.0,train_features,train_labels,(CTaskRelation*)task_group) 00027 { 00028 set_rho1(rho1); 00029 set_rho2(rho2); 00030 set_num_clusters(n_clusters); 00031 } 00032 00033 int32_t CMultitaskClusteredLogisticRegression::get_rho1() const 00034 { 00035 return m_rho1; 00036 } 00037 00038 int32_t CMultitaskClusteredLogisticRegression::get_rho2() const 00039 { 00040 return m_rho2; 00041 } 00042 00043 void CMultitaskClusteredLogisticRegression::set_rho1(float64_t rho1) 00044 { 00045 m_rho1 = rho1; 00046 } 00047 00048 void CMultitaskClusteredLogisticRegression::set_rho2(float64_t rho2) 00049 { 00050 m_rho2 = rho2; 00051 } 00052 00053 int32_t CMultitaskClusteredLogisticRegression::get_num_clusters() const 00054 { 00055 return m_num_clusters; 00056 } 00057 00058 void CMultitaskClusteredLogisticRegression::set_num_clusters(int32_t num_clusters) 00059 { 00060 m_num_clusters = num_clusters; 00061 } 00062 00063 CMultitaskClusteredLogisticRegression::~CMultitaskClusteredLogisticRegression() 00064 { 00065 } 00066 00067 bool CMultitaskClusteredLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks) 00068 { 00069 SGVector<float64_t> y(m_labels->get_num_labels()); 00070 for (int32_t i=0; i<y.vlen; i++) 00071 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00072 00073 malsar_options options = malsar_options::default_options(); 00074 options.termination = m_termination; 00075 options.tolerance = m_tolerance; 00076 options.max_iter = m_max_iter; 00077 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00078 options.tasks_indices = tasks; 00079 options.n_clusters = m_num_clusters; 00080 00081 #ifdef HAVE_EIGEN3 00082 #ifndef HAVE_CXX11 00083 malsar_result_t model = malsar_clustered( 00084 features, y.vector, m_rho1, m_rho2, options); 00085 00086 m_tasks_w = model.w; 00087 m_tasks_c = model.c; 00088 #else 00089 SG_WARNING("Clustered LR is unstable with C++11\n") 00090 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 00091 m_tasks_w.set_const(0); 00092 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00093 m_tasks_c.set_const(0); 00094 #endif 00095 #else 00096 SG_WARNING("Please install Eigen3 to use MultitaskClusteredLogisticRegression\n") 00097 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 00098 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00099 #endif 00100 return true; 00101 } 00102 00103 bool CMultitaskClusteredLogisticRegression::train_machine(CFeatures* data) 00104 { 00105 if (data && (CDotFeatures*)data) 00106 set_features((CDotFeatures*)data); 00107 00108 ASSERT(features) 00109 ASSERT(m_labels) 00110 ASSERT(m_task_relation) 00111 00112 SGVector<float64_t> y(m_labels->get_num_labels()); 00113 for (int32_t i=0; i<y.vlen; i++) 00114 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00115 00116 malsar_options options = malsar_options::default_options(); 00117 options.termination = m_termination; 00118 options.tolerance = m_tolerance; 00119 options.max_iter = m_max_iter; 00120 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00121 options.tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices(); 00122 options.n_clusters = m_num_clusters; 00123 00124 #ifdef HAVE_EIGEN3 00125 #ifndef HAVE_CXX11 00126 malsar_result_t model = malsar_clustered( 00127 features, y.vector, m_rho1, m_rho2, options); 00128 00129 m_tasks_w = model.w; 00130 m_tasks_c = model.c; 00131 #else 00132 SG_WARNING("Clustered LR is unstable with C++11\n") 00133 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 00134 m_tasks_w.set_const(0); 00135 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00136 m_tasks_c.set_const(0); 00137 #endif 00138 #else 00139 SG_WARNING("Please install Eigen3 to use MultitaskClusteredLogisticRegression\n") 00140 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 00141 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00142 #endif 00143 00144 SG_FREE(options.tasks_indices); 00145 00146 return true; 00147 } 00148 00149 }