SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MultitaskLogisticRegression.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/MultitaskLogisticRegression.h>
00011 #include <shogun/lib/slep/slep_solver.h>
00012 #include <shogun/lib/slep/slep_options.h>
00013 
00014 namespace shogun
00015 {
00016 
00017 CMultitaskLogisticRegression::CMultitaskLogisticRegression() :
00018     CMultitaskLinearMachine()
00019 {
00020     initialize_parameters();
00021     register_parameters();
00022 }
00023 
00024 CMultitaskLogisticRegression::CMultitaskLogisticRegression(
00025      float64_t z, CDotFeatures* train_features,
00026      CBinaryLabels* train_labels, CTaskRelation* task_relation) :
00027     CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation)
00028 {
00029     initialize_parameters();
00030     register_parameters();
00031     set_z(z);
00032 }
00033 
00034 CMultitaskLogisticRegression::~CMultitaskLogisticRegression()
00035 {
00036 }
00037 
00038 void CMultitaskLogisticRegression::register_parameters()
00039 {
00040     SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE);
00041     SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE);
00042     SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE);
00043     SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE);
00044     SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE);
00045     SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE);
00046 }
00047 
00048 void CMultitaskLogisticRegression::initialize_parameters()
00049 {
00050     set_z(0.0);
00051     set_q(2.0);
00052     set_termination(0);
00053     set_regularization(0);
00054     set_tolerance(1e-3);
00055     set_max_iter(1000);
00056 }
00057 
00058 bool CMultitaskLogisticRegression::train_machine(CFeatures* data)
00059 {
00060     if (data && (CDotFeatures*)data)
00061         set_features((CDotFeatures*)data);
00062 
00063     ASSERT(features)
00064     ASSERT(m_labels)
00065 
00066     SGVector<float64_t> y(m_labels->get_num_labels());
00067     for (int32_t i=0; i<y.vlen; i++)
00068         y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00069 
00070     slep_options options = slep_options::default_options();
00071     options.n_tasks = m_task_relation->get_num_tasks();
00072     options.tasks_indices = m_task_relation->get_tasks_indices();
00073     options.q = m_q;
00074     options.regularization = m_regularization;
00075     options.termination = m_termination;
00076     options.tolerance = m_tolerance;
00077     options.max_iter = m_max_iter;
00078 
00079     ETaskRelationType relation_type = m_task_relation->get_relation_type();
00080     switch (relation_type)
00081     {
00082         case TASK_GROUP:
00083         {
00084             //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
00085             options.mode = MULTITASK_GROUP;
00086             options.loss = LOGISTIC;
00087             slep_result_t result = slep_solver(features, y.vector, m_z, options);
00088             m_tasks_w = result.w;
00089             m_tasks_c = result.c;
00090         }
00091         break;
00092         case TASK_TREE:
00093         {
00094             CTaskTree* task_tree = (CTaskTree*)m_task_relation;
00095             SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
00096             options.ind_t = ind_t.vector;
00097             options.n_nodes = ind_t.vlen / 3;
00098             options.mode = MULTITASK_TREE;
00099             options.loss = LOGISTIC;
00100             slep_result_t result = slep_solver(features, y.vector, m_z, options);
00101             m_tasks_w = result.w;
00102             m_tasks_c = result.c;
00103         }
00104         break;
00105         default:
00106             SG_ERROR("Not supported task relation type\n")
00107     }
00108     SG_FREE(options.tasks_indices);
00109 
00110     return true;
00111 }
00112 
00113 bool CMultitaskLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks)
00114 {
00115     ASSERT(features)
00116     ASSERT(m_labels)
00117 
00118     SGVector<float64_t> y(m_labels->get_num_labels());
00119     for (int32_t i=0; i<y.vlen; i++)
00120         y[i] = ((CBinaryLabels*)m_labels)->get_label(i);
00121 
00122     slep_options options = slep_options::default_options();
00123     options.n_tasks = m_task_relation->get_num_tasks();
00124     options.tasks_indices = tasks;
00125     options.q = m_q;
00126     options.regularization = m_regularization;
00127     options.termination = m_termination;
00128     options.tolerance = m_tolerance;
00129     options.max_iter = m_max_iter;
00130 
00131     ETaskRelationType relation_type = m_task_relation->get_relation_type();
00132     switch (relation_type)
00133     {
00134         case TASK_GROUP:
00135         {
00136             //CTaskGroup* task_group = (CTaskGroup*)m_task_relation;
00137             options.mode = MULTITASK_GROUP;
00138             options.loss = LOGISTIC;
00139             slep_result_t result = slep_solver(features, y.vector, m_z, options);
00140             m_tasks_w = result.w;
00141             m_tasks_c = result.c;
00142         }
00143         break;
00144         case TASK_TREE:
00145         {
00146             CTaskTree* task_tree = (CTaskTree*)m_task_relation;
00147             SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t();
00148             options.ind_t = ind_t.vector;
00149             options.n_nodes = ind_t.vlen / 3;
00150             options.mode = MULTITASK_TREE;
00151             options.loss = LOGISTIC;
00152             slep_result_t result = slep_solver(features, y.vector, m_z, options);
00153             m_tasks_w = result.w;
00154             m_tasks_c = result.c;
00155         }
00156         break;
00157         default:
00158             SG_ERROR("Not supported task relation type\n")
00159     }
00160     return true;
00161 }
00162 
00163 float64_t CMultitaskLogisticRegression::apply_one(int32_t i)
00164 {
00165     float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows);
00166     //float64_t ep = CMath::exp(-(dot + m_tasks_c[m_current_task]));
00167     //return 2.0/(1.0+ep) - 1.0;
00168     return dot + m_tasks_c[m_current_task];
00169 }
00170 
00171 int32_t CMultitaskLogisticRegression::get_max_iter() const
00172 {
00173     return m_max_iter;
00174 }
00175 int32_t CMultitaskLogisticRegression::get_regularization() const
00176 {
00177     return m_regularization;
00178 }
00179 int32_t CMultitaskLogisticRegression::get_termination() const
00180 {
00181     return m_termination;
00182 }
00183 float64_t CMultitaskLogisticRegression::get_tolerance() const
00184 {
00185     return m_tolerance;
00186 }
00187 float64_t CMultitaskLogisticRegression::get_z() const
00188 {
00189     return m_z;
00190 }
00191 float64_t CMultitaskLogisticRegression::get_q() const
00192 {
00193     return m_q;
00194 }
00195 
00196 void CMultitaskLogisticRegression::set_max_iter(int32_t max_iter)
00197 {
00198     ASSERT(max_iter>=0)
00199     m_max_iter = max_iter;
00200 }
00201 void CMultitaskLogisticRegression::set_regularization(int32_t regularization)
00202 {
00203     ASSERT(regularization==0 || regularization==1)
00204     m_regularization = regularization;
00205 }
00206 void CMultitaskLogisticRegression::set_termination(int32_t termination)
00207 {
00208     ASSERT(termination>=0 && termination<=4)
00209     m_termination = termination;
00210 }
00211 void CMultitaskLogisticRegression::set_tolerance(float64_t tolerance)
00212 {
00213     ASSERT(tolerance>0.0)
00214     m_tolerance = tolerance;
00215 }
00216 void CMultitaskLogisticRegression::set_z(float64_t z)
00217 {
00218     m_z = z;
00219 }
00220 void CMultitaskLogisticRegression::set_q(float64_t q)
00221 {
00222     m_q = q;
00223 }
00224 
00225 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation