SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MultitaskLinearMachine.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/MultitaskLinearMachine.h>
00011 #include <shogun/lib/slep/slep_solver.h>
00012 #include <shogun/lib/slep/slep_options.h>
00013 
00014 #include <map>
00015 #include <vector>
00016 
00017 using namespace std;
00018 
00019 namespace shogun
00020 {
00021 
00022 CMultitaskLinearMachine::CMultitaskLinearMachine() :
00023     CLinearMachine(), m_current_task(0),
00024     m_task_relation(NULL)
00025 {
00026     register_parameters();
00027 }
00028 
00029 CMultitaskLinearMachine::CMultitaskLinearMachine(
00030      CDotFeatures* train_features,
00031      CLabels* train_labels, CTaskRelation* task_relation) :
00032     CLinearMachine(), m_current_task(0), m_task_relation(NULL)
00033 {
00034     set_features(train_features);
00035     set_labels(train_labels);
00036     set_task_relation(task_relation);
00037     register_parameters();
00038 }
00039 
00040 CMultitaskLinearMachine::~CMultitaskLinearMachine()
00041 {
00042     SG_UNREF(m_task_relation);
00043 }
00044 
00045 void CMultitaskLinearMachine::register_parameters()
00046 {
00047     SG_ADD((CSGObject**)&m_task_relation, "task_relation", "task relation", MS_NOT_AVAILABLE);
00048 }
00049 
00050 int32_t CMultitaskLinearMachine::get_current_task() const
00051 {
00052     return m_current_task;
00053 }
00054 
00055 void CMultitaskLinearMachine::set_current_task(int32_t task)
00056 {
00057     ASSERT(task>=0)
00058     ASSERT(task<m_tasks_w.num_cols)
00059     m_current_task = task;
00060 }
00061 
00062 CTaskRelation* CMultitaskLinearMachine::get_task_relation() const
00063 {
00064     SG_REF(m_task_relation);
00065     return m_task_relation;
00066 }
00067 
00068 void CMultitaskLinearMachine::set_task_relation(CTaskRelation* task_relation)
00069 {
00070     SG_REF(task_relation);
00071     SG_UNREF(m_task_relation);
00072     m_task_relation = task_relation;
00073 }
00074 
00075 bool CMultitaskLinearMachine::train_machine(CFeatures* data)
00076 {
00077     SG_NOTIMPLEMENTED
00078     return false;
00079 }
00080 
00081 void CMultitaskLinearMachine::post_lock(CLabels* labels, CFeatures* features_)
00082 {
00083     set_features((CDotFeatures*)features_);
00084     int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
00085     SGVector<index_t>* tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices();
00086 
00087     m_tasks_indices.clear();
00088     for (int32_t i=0; i<n_tasks; i++)
00089     {
00090         set<index_t> indices_set;
00091         SGVector<index_t> task_indices = tasks_indices[i];
00092         for (int32_t j=0; j<task_indices.vlen; j++)
00093             indices_set.insert(task_indices[j]);
00094 
00095         m_tasks_indices.push_back(indices_set);
00096     }
00097 
00098     SG_FREE(tasks_indices);
00099 }
00100 
00101 bool CMultitaskLinearMachine::train_locked(SGVector<index_t> indices)
00102 {
00103     int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
00104     ASSERT((int)m_tasks_indices.size()==n_tasks)
00105     vector< vector<index_t> > cutted_task_indices;
00106     for (int32_t i=0; i<n_tasks; i++)
00107         cutted_task_indices.push_back(vector<index_t>());
00108     for (int32_t i=0; i<indices.vlen; i++)
00109     {
00110         for (int32_t j=0; j<n_tasks; j++)
00111         {
00112             if (m_tasks_indices[j].count(indices[i]))
00113             {
00114                 cutted_task_indices[j].push_back(indices[i]);
00115                 break;
00116             }
00117         }
00118     }
00119     SGVector<index_t>* tasks = SG_MALLOC(SGVector<index_t>, n_tasks);
00120     for (int32_t i=0; i<n_tasks; i++)
00121     {
00122         tasks[i]=SGVector<index_t>(cutted_task_indices[i].size());
00123         for (int32_t j=0; j<(int)cutted_task_indices[i].size(); j++)
00124             tasks[i][j] = cutted_task_indices[i][j];
00125         //tasks[i].display_vector();
00126     }
00127     bool res = train_locked_implementation(tasks);
00128     SG_FREE(tasks);
00129     return res;
00130 }
00131 
00132 bool CMultitaskLinearMachine::train_locked_implementation(SGVector<index_t>* tasks)
00133 {
00134     SG_NOTIMPLEMENTED
00135     return false;
00136 }
00137 
00138 CBinaryLabels* CMultitaskLinearMachine::apply_locked_binary(SGVector<index_t> indices)
00139 {
00140     int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
00141     SGVector<float64_t> result(indices.vlen);
00142     result.zero();
00143     for (int32_t i=0; i<indices.vlen; i++)
00144     {
00145         for (int32_t j=0; j<n_tasks; j++)
00146         {
00147             if (m_tasks_indices[j].count(indices[i]))
00148             {
00149                 set_current_task(j);
00150                 result[i] = apply_one(indices[i]);
00151                 break;
00152             }
00153         }
00154     }
00155     return new CBinaryLabels(result);
00156 }
00157 
00158 float64_t CMultitaskLinearMachine::apply_one(int32_t i)
00159 {
00160     SG_NOTIMPLEMENTED
00161     return 0.0;
00162 }
00163 
00164 SGVector<float64_t> CMultitaskLinearMachine::apply_get_outputs(CFeatures* data)
00165 {
00166     if (data)
00167     {
00168         if (!data->has_property(FP_DOT))
00169             SG_ERROR("Specified features are not of type CDotFeatures\n")
00170 
00171         set_features((CDotFeatures*) data);
00172     }
00173 
00174     if (!features)
00175         return SGVector<float64_t>();
00176 
00177     int32_t num=features->get_num_vectors();
00178     ASSERT(num>0)
00179     float64_t* out=SG_MALLOC(float64_t, num);
00180     for (int32_t i=0; i<num; i++)
00181         out[i] = apply_one(i);
00182 
00183     return SGVector<float64_t>(out,num);
00184 }
00185 
00186 SGVector<float64_t> CMultitaskLinearMachine::get_w() const
00187 {
00188     SGVector<float64_t> w_(m_tasks_w.num_rows);
00189     for (int32_t i=0; i<w_.vlen; i++)
00190         w_[i] = m_tasks_w(i,m_current_task);
00191     return w_;
00192 }
00193 
00194 void CMultitaskLinearMachine::set_w(const SGVector<float64_t> src_w)
00195 {
00196     for (int32_t i=0; i<m_tasks_w.num_rows; i++)
00197         m_tasks_w(i,m_current_task) = src_w[i];
00198 }
00199 
00200 void CMultitaskLinearMachine::set_bias(float64_t b)
00201 {
00202     m_tasks_c[m_current_task] = b;
00203 }
00204 
00205 float64_t CMultitaskLinearMachine::get_bias()
00206 {
00207     return m_tasks_c[m_current_task];
00208 }
00209 
00210 SGVector<index_t>* CMultitaskLinearMachine::get_subset_tasks_indices()
00211 {
00212     int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks();
00213     SGVector<index_t>* tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices();
00214 
00215     CSubsetStack* sstack = features->get_subset_stack();
00216     map<index_t,index_t> subset_inv_map = map<index_t,index_t>();
00217     for (int32_t i=0; i<sstack->get_size(); i++)
00218         subset_inv_map[sstack->subset_idx_conversion(i)] = i;
00219 
00220     SGVector<index_t>* subset_tasks_indices = SG_MALLOC(SGVector<index_t>, n_tasks);
00221     for (int32_t i=0; i<n_tasks; i++)
00222     {
00223         SGVector<index_t> task = tasks_indices[i];
00224         //task.display_vector("task");
00225         vector<index_t> cutted = vector<index_t>();
00226         for (int32_t j=0; j<task.vlen; j++)
00227         {
00228             if (subset_inv_map.count(task[j]))
00229                 cutted.push_back(subset_inv_map[task[j]]);
00230         }
00231         SGVector<index_t> cutted_task(cutted.size());
00232         for (int32_t j=0; j<cutted_task.vlen; j++)
00233             cutted_task[j] = cutted[j];
00234         //cutted_task.display_vector("cutted");
00235         subset_tasks_indices[i] = cutted_task;
00236     }
00237     SG_FREE(tasks_indices);
00238 
00239     return subset_tasks_indices;
00240 }
00241 
00242 
00243 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation