SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskLinearMachine.h> 00011 #include <shogun/lib/slep/slep_solver.h> 00012 #include <shogun/lib/slep/slep_options.h> 00013 00014 #include <map> 00015 #include <vector> 00016 00017 using namespace std; 00018 00019 namespace shogun 00020 { 00021 00022 CMultitaskLinearMachine::CMultitaskLinearMachine() : 00023 CLinearMachine(), m_current_task(0), 00024 m_task_relation(NULL) 00025 { 00026 register_parameters(); 00027 } 00028 00029 CMultitaskLinearMachine::CMultitaskLinearMachine( 00030 CDotFeatures* train_features, 00031 CLabels* train_labels, CTaskRelation* task_relation) : 00032 CLinearMachine(), m_current_task(0), m_task_relation(NULL) 00033 { 00034 set_features(train_features); 00035 set_labels(train_labels); 00036 set_task_relation(task_relation); 00037 register_parameters(); 00038 } 00039 00040 CMultitaskLinearMachine::~CMultitaskLinearMachine() 00041 { 00042 SG_UNREF(m_task_relation); 00043 } 00044 00045 void CMultitaskLinearMachine::register_parameters() 00046 { 00047 SG_ADD((CSGObject**)&m_task_relation, "task_relation", "task relation", MS_NOT_AVAILABLE); 00048 } 00049 00050 int32_t CMultitaskLinearMachine::get_current_task() const 00051 { 00052 return m_current_task; 00053 } 00054 00055 void CMultitaskLinearMachine::set_current_task(int32_t task) 00056 { 00057 ASSERT(task>=0) 00058 ASSERT(task<m_tasks_w.num_cols) 00059 m_current_task = task; 00060 } 00061 00062 CTaskRelation* CMultitaskLinearMachine::get_task_relation() const 00063 { 00064 SG_REF(m_task_relation); 00065 return m_task_relation; 00066 } 00067 00068 void CMultitaskLinearMachine::set_task_relation(CTaskRelation* task_relation) 00069 { 00070 SG_REF(task_relation); 00071 SG_UNREF(m_task_relation); 00072 m_task_relation = task_relation; 00073 } 00074 00075 bool CMultitaskLinearMachine::train_machine(CFeatures* data) 00076 { 00077 SG_NOTIMPLEMENTED 00078 return false; 00079 } 00080 00081 void CMultitaskLinearMachine::post_lock(CLabels* labels, CFeatures* features_) 00082 { 00083 set_features((CDotFeatures*)features_); 00084 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00085 SGVector<index_t>* tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices(); 00086 00087 m_tasks_indices.clear(); 00088 for (int32_t i=0; i<n_tasks; i++) 00089 { 00090 set<index_t> indices_set; 00091 SGVector<index_t> task_indices = tasks_indices[i]; 00092 for (int32_t j=0; j<task_indices.vlen; j++) 00093 indices_set.insert(task_indices[j]); 00094 00095 m_tasks_indices.push_back(indices_set); 00096 } 00097 00098 SG_FREE(tasks_indices); 00099 } 00100 00101 bool CMultitaskLinearMachine::train_locked(SGVector<index_t> indices) 00102 { 00103 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00104 ASSERT((int)m_tasks_indices.size()==n_tasks) 00105 vector< vector<index_t> > cutted_task_indices; 00106 for (int32_t i=0; i<n_tasks; i++) 00107 cutted_task_indices.push_back(vector<index_t>()); 00108 for (int32_t i=0; i<indices.vlen; i++) 00109 { 00110 for (int32_t j=0; j<n_tasks; j++) 00111 { 00112 if (m_tasks_indices[j].count(indices[i])) 00113 { 00114 cutted_task_indices[j].push_back(indices[i]); 00115 break; 00116 } 00117 } 00118 } 00119 SGVector<index_t>* tasks = SG_MALLOC(SGVector<index_t>, n_tasks); 00120 for (int32_t i=0; i<n_tasks; i++) 00121 { 00122 tasks[i]=SGVector<index_t>(cutted_task_indices[i].size()); 00123 for (int32_t j=0; j<(int)cutted_task_indices[i].size(); j++) 00124 tasks[i][j] = cutted_task_indices[i][j]; 00125 //tasks[i].display_vector(); 00126 } 00127 bool res = train_locked_implementation(tasks); 00128 SG_FREE(tasks); 00129 return res; 00130 } 00131 00132 bool CMultitaskLinearMachine::train_locked_implementation(SGVector<index_t>* tasks) 00133 { 00134 SG_NOTIMPLEMENTED 00135 return false; 00136 } 00137 00138 CBinaryLabels* CMultitaskLinearMachine::apply_locked_binary(SGVector<index_t> indices) 00139 { 00140 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00141 SGVector<float64_t> result(indices.vlen); 00142 result.zero(); 00143 for (int32_t i=0; i<indices.vlen; i++) 00144 { 00145 for (int32_t j=0; j<n_tasks; j++) 00146 { 00147 if (m_tasks_indices[j].count(indices[i])) 00148 { 00149 set_current_task(j); 00150 result[i] = apply_one(indices[i]); 00151 break; 00152 } 00153 } 00154 } 00155 return new CBinaryLabels(result); 00156 } 00157 00158 float64_t CMultitaskLinearMachine::apply_one(int32_t i) 00159 { 00160 SG_NOTIMPLEMENTED 00161 return 0.0; 00162 } 00163 00164 SGVector<float64_t> CMultitaskLinearMachine::apply_get_outputs(CFeatures* data) 00165 { 00166 if (data) 00167 { 00168 if (!data->has_property(FP_DOT)) 00169 SG_ERROR("Specified features are not of type CDotFeatures\n") 00170 00171 set_features((CDotFeatures*) data); 00172 } 00173 00174 if (!features) 00175 return SGVector<float64_t>(); 00176 00177 int32_t num=features->get_num_vectors(); 00178 ASSERT(num>0) 00179 float64_t* out=SG_MALLOC(float64_t, num); 00180 for (int32_t i=0; i<num; i++) 00181 out[i] = apply_one(i); 00182 00183 return SGVector<float64_t>(out,num); 00184 } 00185 00186 SGVector<float64_t> CMultitaskLinearMachine::get_w() const 00187 { 00188 SGVector<float64_t> w_(m_tasks_w.num_rows); 00189 for (int32_t i=0; i<w_.vlen; i++) 00190 w_[i] = m_tasks_w(i,m_current_task); 00191 return w_; 00192 } 00193 00194 void CMultitaskLinearMachine::set_w(const SGVector<float64_t> src_w) 00195 { 00196 for (int32_t i=0; i<m_tasks_w.num_rows; i++) 00197 m_tasks_w(i,m_current_task) = src_w[i]; 00198 } 00199 00200 void CMultitaskLinearMachine::set_bias(float64_t b) 00201 { 00202 m_tasks_c[m_current_task] = b; 00203 } 00204 00205 float64_t CMultitaskLinearMachine::get_bias() 00206 { 00207 return m_tasks_c[m_current_task]; 00208 } 00209 00210 SGVector<index_t>* CMultitaskLinearMachine::get_subset_tasks_indices() 00211 { 00212 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00213 SGVector<index_t>* tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices(); 00214 00215 CSubsetStack* sstack = features->get_subset_stack(); 00216 map<index_t,index_t> subset_inv_map = map<index_t,index_t>(); 00217 for (int32_t i=0; i<sstack->get_size(); i++) 00218 subset_inv_map[sstack->subset_idx_conversion(i)] = i; 00219 00220 SGVector<index_t>* subset_tasks_indices = SG_MALLOC(SGVector<index_t>, n_tasks); 00221 for (int32_t i=0; i<n_tasks; i++) 00222 { 00223 SGVector<index_t> task = tasks_indices[i]; 00224 //task.display_vector("task"); 00225 vector<index_t> cutted = vector<index_t>(); 00226 for (int32_t j=0; j<task.vlen; j++) 00227 { 00228 if (subset_inv_map.count(task[j])) 00229 cutted.push_back(subset_inv_map[task[j]]); 00230 } 00231 SGVector<index_t> cutted_task(cutted.size()); 00232 for (int32_t j=0; j<cutted_task.vlen; j++) 00233 cutted_task[j] = cutted[j]; 00234 //cutted_task.display_vector("cutted"); 00235 subset_tasks_indices[i] = cutted_task; 00236 } 00237 SG_FREE(tasks_indices); 00238 00239 return subset_tasks_indices; 00240 } 00241 00242 00243 }