SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskLeastSquaresRegression.h> 00011 #include <shogun/transfer/multitask/TaskGroup.h> 00012 #include <shogun/transfer/multitask/TaskTree.h> 00013 #include <shogun/lib/slep/slep_solver.h> 00014 #include <shogun/lib/slep/slep_options.h> 00015 00016 namespace shogun 00017 { 00018 00019 CMultitaskLeastSquaresRegression::CMultitaskLeastSquaresRegression() : 00020 CMultitaskLinearMachine() 00021 { 00022 initialize_parameters(); 00023 register_parameters(); 00024 } 00025 00026 CMultitaskLeastSquaresRegression::CMultitaskLeastSquaresRegression( 00027 float64_t z, CDotFeatures* train_features, 00028 CRegressionLabels* train_labels, CTaskRelation* task_relation) : 00029 CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation) 00030 { 00031 set_z(z); 00032 initialize_parameters(); 00033 register_parameters(); 00034 } 00035 00036 CMultitaskLeastSquaresRegression::~CMultitaskLeastSquaresRegression() 00037 { 00038 } 00039 00040 void CMultitaskLeastSquaresRegression::register_parameters() 00041 { 00042 SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE); 00043 SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE); 00044 SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE); 00045 SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE); 00046 SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE); 00047 SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE); 00048 } 00049 00050 void CMultitaskLeastSquaresRegression::initialize_parameters() 00051 { 00052 set_z(0.0); 00053 set_q(2.0); 00054 set_termination(0); 00055 set_regularization(0); 00056 set_tolerance(1e-3); 00057 set_max_iter(1000); 00058 } 00059 00060 bool CMultitaskLeastSquaresRegression::train_locked_implementation(SGVector<index_t>* tasks) 00061 { 00062 SG_NOTIMPLEMENTED 00063 return false; 00064 } 00065 00066 float64_t CMultitaskLeastSquaresRegression::apply_one(int32_t i) 00067 { 00068 float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows); 00069 return dot + m_tasks_c[m_current_task]; 00070 } 00071 00072 int32_t CMultitaskLeastSquaresRegression::get_max_iter() const 00073 { 00074 return m_max_iter; 00075 } 00076 int32_t CMultitaskLeastSquaresRegression::get_regularization() const 00077 { 00078 return m_regularization; 00079 } 00080 int32_t CMultitaskLeastSquaresRegression::get_termination() const 00081 { 00082 return m_termination; 00083 } 00084 float64_t CMultitaskLeastSquaresRegression::get_tolerance() const 00085 { 00086 return m_tolerance; 00087 } 00088 float64_t CMultitaskLeastSquaresRegression::get_z() const 00089 { 00090 return m_z; 00091 } 00092 float64_t CMultitaskLeastSquaresRegression::get_q() const 00093 { 00094 return m_q; 00095 } 00096 00097 void CMultitaskLeastSquaresRegression::set_max_iter(int32_t max_iter) 00098 { 00099 ASSERT(max_iter>=0) 00100 m_max_iter = max_iter; 00101 } 00102 void CMultitaskLeastSquaresRegression::set_regularization(int32_t regularization) 00103 { 00104 ASSERT(regularization==0 || regularization==1) 00105 m_regularization = regularization; 00106 } 00107 void CMultitaskLeastSquaresRegression::set_termination(int32_t termination) 00108 { 00109 ASSERT(termination>=0 && termination<=4) 00110 m_termination = termination; 00111 } 00112 void CMultitaskLeastSquaresRegression::set_tolerance(float64_t tolerance) 00113 { 00114 ASSERT(tolerance>0.0) 00115 m_tolerance = tolerance; 00116 } 00117 void CMultitaskLeastSquaresRegression::set_z(float64_t z) 00118 { 00119 m_z = z; 00120 } 00121 void CMultitaskLeastSquaresRegression::set_q(float64_t q) 00122 { 00123 m_q = q; 00124 } 00125 00126 bool CMultitaskLeastSquaresRegression::train_machine(CFeatures* data) 00127 { 00128 if (data && (CDotFeatures*)data) 00129 set_features((CDotFeatures*)data); 00130 00131 ASSERT(features) 00132 ASSERT(m_labels) 00133 00134 SGVector<float64_t> y = ((CRegressionLabels*)m_labels)->get_labels(); 00135 00136 slep_options options = slep_options::default_options(); 00137 options.n_tasks = m_task_relation->get_num_tasks(); 00138 options.tasks_indices = m_task_relation->get_tasks_indices(); 00139 options.q = m_q; 00140 options.regularization = m_regularization; 00141 options.termination = m_termination; 00142 options.tolerance = m_tolerance; 00143 options.max_iter = m_max_iter; 00144 00145 ETaskRelationType relation_type = m_task_relation->get_relation_type(); 00146 switch (relation_type) 00147 { 00148 case TASK_GROUP: 00149 { 00150 //CTaskGroup* task_group = (CTaskGroup*)m_task_relation; 00151 options.mode = MULTITASK_GROUP; 00152 options.loss = LEAST_SQUARES; 00153 m_tasks_w = slep_solver(features, y.vector, m_z, options).w; 00154 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00155 m_tasks_c.zero(); 00156 } 00157 break; 00158 case TASK_TREE: 00159 { 00160 CTaskTree* task_tree = (CTaskTree*)m_task_relation; 00161 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t(); 00162 options.ind_t = ind_t.vector; 00163 options.n_nodes = ind_t.vlen/3; 00164 options.mode = MULTITASK_TREE; 00165 options.loss = LEAST_SQUARES; 00166 m_tasks_w = slep_solver(features, y.vector, m_z, options).w; 00167 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00168 m_tasks_c.zero(); 00169 } 00170 break; 00171 default: 00172 SG_ERROR("Not supported task relation type\n") 00173 } 00174 00175 SG_FREE(options.tasks_indices); 00176 00177 return true; 00178 } 00179 00180 }