SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskLogisticRegression.h> 00011 #include <shogun/lib/slep/slep_solver.h> 00012 #include <shogun/lib/slep/slep_options.h> 00013 00014 namespace shogun 00015 { 00016 00017 CMultitaskLogisticRegression::CMultitaskLogisticRegression() : 00018 CMultitaskLinearMachine() 00019 { 00020 initialize_parameters(); 00021 register_parameters(); 00022 } 00023 00024 CMultitaskLogisticRegression::CMultitaskLogisticRegression( 00025 float64_t z, CDotFeatures* train_features, 00026 CBinaryLabels* train_labels, CTaskRelation* task_relation) : 00027 CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation) 00028 { 00029 initialize_parameters(); 00030 register_parameters(); 00031 set_z(z); 00032 } 00033 00034 CMultitaskLogisticRegression::~CMultitaskLogisticRegression() 00035 { 00036 } 00037 00038 void CMultitaskLogisticRegression::register_parameters() 00039 { 00040 SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE); 00041 SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE); 00042 SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE); 00043 SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE); 00044 SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE); 00045 SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE); 00046 } 00047 00048 void CMultitaskLogisticRegression::initialize_parameters() 00049 { 00050 set_z(0.0); 00051 set_q(2.0); 00052 set_termination(0); 00053 set_regularization(0); 00054 set_tolerance(1e-3); 00055 set_max_iter(1000); 00056 } 00057 00058 bool CMultitaskLogisticRegression::train_machine(CFeatures* data) 00059 { 00060 if (data && (CDotFeatures*)data) 00061 set_features((CDotFeatures*)data); 00062 00063 ASSERT(features) 00064 ASSERT(m_labels) 00065 00066 SGVector<float64_t> y(m_labels->get_num_labels()); 00067 for (int32_t i=0; i<y.vlen; i++) 00068 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00069 00070 slep_options options = slep_options::default_options(); 00071 options.n_tasks = m_task_relation->get_num_tasks(); 00072 options.tasks_indices = m_task_relation->get_tasks_indices(); 00073 options.q = m_q; 00074 options.regularization = m_regularization; 00075 options.termination = m_termination; 00076 options.tolerance = m_tolerance; 00077 options.max_iter = m_max_iter; 00078 00079 ETaskRelationType relation_type = m_task_relation->get_relation_type(); 00080 switch (relation_type) 00081 { 00082 case TASK_GROUP: 00083 { 00084 //CTaskGroup* task_group = (CTaskGroup*)m_task_relation; 00085 options.mode = MULTITASK_GROUP; 00086 options.loss = LOGISTIC; 00087 slep_result_t result = slep_solver(features, y.vector, m_z, options); 00088 m_tasks_w = result.w; 00089 m_tasks_c = result.c; 00090 } 00091 break; 00092 case TASK_TREE: 00093 { 00094 CTaskTree* task_tree = (CTaskTree*)m_task_relation; 00095 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t(); 00096 options.ind_t = ind_t.vector; 00097 options.n_nodes = ind_t.vlen / 3; 00098 options.mode = MULTITASK_TREE; 00099 options.loss = LOGISTIC; 00100 slep_result_t result = slep_solver(features, y.vector, m_z, options); 00101 m_tasks_w = result.w; 00102 m_tasks_c = result.c; 00103 } 00104 break; 00105 default: 00106 SG_ERROR("Not supported task relation type\n") 00107 } 00108 SG_FREE(options.tasks_indices); 00109 00110 return true; 00111 } 00112 00113 bool CMultitaskLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks) 00114 { 00115 ASSERT(features) 00116 ASSERT(m_labels) 00117 00118 SGVector<float64_t> y(m_labels->get_num_labels()); 00119 for (int32_t i=0; i<y.vlen; i++) 00120 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00121 00122 slep_options options = slep_options::default_options(); 00123 options.n_tasks = m_task_relation->get_num_tasks(); 00124 options.tasks_indices = tasks; 00125 options.q = m_q; 00126 options.regularization = m_regularization; 00127 options.termination = m_termination; 00128 options.tolerance = m_tolerance; 00129 options.max_iter = m_max_iter; 00130 00131 ETaskRelationType relation_type = m_task_relation->get_relation_type(); 00132 switch (relation_type) 00133 { 00134 case TASK_GROUP: 00135 { 00136 //CTaskGroup* task_group = (CTaskGroup*)m_task_relation; 00137 options.mode = MULTITASK_GROUP; 00138 options.loss = LOGISTIC; 00139 slep_result_t result = slep_solver(features, y.vector, m_z, options); 00140 m_tasks_w = result.w; 00141 m_tasks_c = result.c; 00142 } 00143 break; 00144 case TASK_TREE: 00145 { 00146 CTaskTree* task_tree = (CTaskTree*)m_task_relation; 00147 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t(); 00148 options.ind_t = ind_t.vector; 00149 options.n_nodes = ind_t.vlen / 3; 00150 options.mode = MULTITASK_TREE; 00151 options.loss = LOGISTIC; 00152 slep_result_t result = slep_solver(features, y.vector, m_z, options); 00153 m_tasks_w = result.w; 00154 m_tasks_c = result.c; 00155 } 00156 break; 00157 default: 00158 SG_ERROR("Not supported task relation type\n") 00159 } 00160 return true; 00161 } 00162 00163 float64_t CMultitaskLogisticRegression::apply_one(int32_t i) 00164 { 00165 float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows); 00166 //float64_t ep = CMath::exp(-(dot + m_tasks_c[m_current_task])); 00167 //return 2.0/(1.0+ep) - 1.0; 00168 return dot + m_tasks_c[m_current_task]; 00169 } 00170 00171 int32_t CMultitaskLogisticRegression::get_max_iter() const 00172 { 00173 return m_max_iter; 00174 } 00175 int32_t CMultitaskLogisticRegression::get_regularization() const 00176 { 00177 return m_regularization; 00178 } 00179 int32_t CMultitaskLogisticRegression::get_termination() const 00180 { 00181 return m_termination; 00182 } 00183 float64_t CMultitaskLogisticRegression::get_tolerance() const 00184 { 00185 return m_tolerance; 00186 } 00187 float64_t CMultitaskLogisticRegression::get_z() const 00188 { 00189 return m_z; 00190 } 00191 float64_t CMultitaskLogisticRegression::get_q() const 00192 { 00193 return m_q; 00194 } 00195 00196 void CMultitaskLogisticRegression::set_max_iter(int32_t max_iter) 00197 { 00198 ASSERT(max_iter>=0) 00199 m_max_iter = max_iter; 00200 } 00201 void CMultitaskLogisticRegression::set_regularization(int32_t regularization) 00202 { 00203 ASSERT(regularization==0 || regularization==1) 00204 m_regularization = regularization; 00205 } 00206 void CMultitaskLogisticRegression::set_termination(int32_t termination) 00207 { 00208 ASSERT(termination>=0 && termination<=4) 00209 m_termination = termination; 00210 } 00211 void CMultitaskLogisticRegression::set_tolerance(float64_t tolerance) 00212 { 00213 ASSERT(tolerance>0.0) 00214 m_tolerance = tolerance; 00215 } 00216 void CMultitaskLogisticRegression::set_z(float64_t z) 00217 { 00218 m_z = z; 00219 } 00220 void CMultitaskLogisticRegression::set_q(float64_t q) 00221 { 00222 m_q = q; 00223 } 00224 00225 }