SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MultitaskROCEvaluation.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/transfer/multitask/MultitaskROCEvaluation.h>
00011 #include <shogun/mathematics/Math.h>
00012 
00013 #include <set>
00014 #include <vector>
00015 
00016 using namespace std;
00017 using namespace shogun;
00018 
00019 void CMultitaskROCEvaluation::set_indices(SGVector<index_t> indices)
00020 {
00021     indices.display_vector("indices");
00022     ASSERT(m_task_relation)
00023 
00024     set<index_t> indices_set;
00025     for (int32_t i=0; i<indices.vlen; i++)
00026         indices_set.insert(indices[i]);
00027 
00028     if (m_num_tasks>0)
00029     {
00030         SG_FREE(m_tasks_indices);
00031     }
00032     m_num_tasks = m_task_relation->get_num_tasks();
00033     m_tasks_indices = SG_MALLOC(SGVector<index_t>, m_num_tasks);
00034 
00035     SGVector<index_t>* tasks_indices = m_task_relation->get_tasks_indices();
00036     for (int32_t t=0; t<m_num_tasks; t++)
00037     {
00038         vector<index_t> task_indices_cut;
00039         SGVector<index_t> task_indices = tasks_indices[t];
00040         //task_indices.display_vector("task indices");
00041         for (int32_t i=0; i<task_indices.vlen; i++)
00042         {
00043             if (indices_set.count(task_indices[i]))
00044             {
00045                 //SG_SPRINT("%d is in %d task\n",task_indices[i],t)
00046                 task_indices_cut.push_back(task_indices[i]);
00047             }
00048         }
00049 
00050         SGVector<index_t> cutted(task_indices_cut.size());
00051         for (int32_t i=0; i<cutted.vlen; i++)
00052             cutted[i] = task_indices_cut[i];
00053         //cutted.display_vector("cutted");
00054         m_tasks_indices[t] = cutted;
00055     }
00056     SG_FREE(tasks_indices);
00057 }
00058 
00059 float64_t CMultitaskROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00060 {
00061     //SG_SPRINT("Evaluate\n")
00062     predicted->remove_all_subsets();
00063     ground_truth->remove_all_subsets();
00064     float64_t result = 0.0;
00065     for (int32_t t=0; t<m_num_tasks; t++)
00066     {
00067         //SG_SPRINT("%d task", t)
00068         //m_tasks_indices[t].display_vector();
00069         predicted->add_subset(m_tasks_indices[t]);
00070         ground_truth->add_subset(m_tasks_indices[t]);
00071         result += evaluate_roc(predicted,ground_truth)/m_tasks_indices[t].vlen;
00072         predicted->remove_subset();
00073         ground_truth->remove_subset();
00074     }
00075     return result;
00076 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation