SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
StratifiedCrossValidationSplitting.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011-2012 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
00012 #include <shogun/labels/Labels.h>
00013 #include <shogun/labels/BinaryLabels.h>
00014 #include <shogun/labels/MulticlassLabels.h>
00015 
00016 using namespace shogun;
00017 
00018 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting() :
00019     CSplittingStrategy()
00020 {
00021     m_rng = sg_rand;
00022 }
00023 
00024 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting(
00025         CLabels* labels, index_t num_subsets) :
00026     CSplittingStrategy(labels, num_subsets)
00027 {
00028     /* check for "stupid" combinations of label numbers and num_subsets.
00029      * if there are of a class less labels than num_subsets, the class will not
00030      * appear in every subset, leading to subsets of only one class in the
00031      * extreme case of a two class labeling. */
00032     SGVector<float64_t> classes;
00033 
00034     int32_t num_classes=2;
00035     if (labels->get_label_type() == LT_MULTICLASS)
00036     {
00037         num_classes=((CMulticlassLabels*) labels)->get_num_classes();
00038         classes=((CMulticlassLabels*) labels)->get_unique_labels();
00039     }
00040     else if (labels->get_label_type() == LT_BINARY)
00041     {
00042         classes=SGVector<float64_t>(2);
00043         classes[0]=-1;
00044         classes[1]=+1;
00045     }
00046     else
00047     {
00048         SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n")
00049     }
00050 
00051     SGVector<index_t> labels_per_class(num_classes);
00052 
00053     for (index_t i=0; i<num_classes; ++i)
00054     {
00055         labels_per_class.vector[i]=0;
00056         for (index_t j=0; j<labels->get_num_labels(); ++j)
00057         {
00058             if (classes.vector[i]==((CDenseLabels*) labels)->get_label(j))
00059                 labels_per_class.vector[i]++;
00060         }
00061     }
00062 
00063     for (index_t i=0; i<num_classes; ++i)
00064     {
00065         if (labels_per_class.vector[i]<num_subsets)
00066         {
00067             SG_WARNING("There are only %d labels of class %.18g, but %d "
00068                     "subsets. Labels of that class will not appear in every "
00069                     "subset!\n", labels_per_class.vector[i], classes.vector[i], num_subsets);
00070         }
00071     }
00072 
00073     m_rng = sg_rand;
00074 }
00075 
00076 void CStratifiedCrossValidationSplitting::build_subsets()
00077 {
00078     /* ensure that subsets are empty and set flag to filled */
00079     reset_subsets();
00080     m_is_filled=true;
00081 
00082     SGVector<float64_t> unique_labels;
00083 
00084     if (m_labels->get_label_type() == LT_MULTICLASS)
00085     {
00086         unique_labels=((CMulticlassLabels*) m_labels)->get_unique_labels();
00087     }
00088     else if (m_labels->get_label_type() == LT_BINARY)
00089     {
00090         unique_labels=SGVector<float64_t>(2);
00091         unique_labels[0]=-1;
00092         unique_labels[1]=+1;
00093     }
00094     else
00095     {
00096         SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n")
00097     }
00098 
00099     /* for every label, build set for indices */
00100     CDynamicObjectArray label_indices;
00101     for (index_t i=0; i<unique_labels.vlen; ++i)
00102         label_indices.append_element(new CDynamicArray<index_t> ());
00103 
00104     /* fill set with indices, for each label type ... */
00105     for (index_t i=0; i<unique_labels.vlen; ++i)
00106     {
00107         /* ... iterate over all labels and add indices with same label to set */
00108         for (index_t j=0; j<m_labels->get_num_labels(); ++j)
00109         {
00110             if (((CDenseLabels*) m_labels)->get_label(j)==unique_labels.vector[i])
00111             {
00112                 CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
00113                         label_indices.get_element(i);
00114                 current->append_element(j);
00115                 SG_UNREF(current);
00116             }
00117         }
00118     }
00119 
00120     /* shuffle created label sets */
00121     for (index_t i=0; i<label_indices.get_num_elements(); ++i)
00122     {
00123         CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
00124                 label_indices.get_element(i);
00125 
00126         // external random state important for threads
00127         current->shuffle(m_rng);
00128 
00129         SG_UNREF(current);
00130     }
00131 
00132     /* distribute labels to subsets for all label types */
00133     index_t target_set=0;
00134     for (index_t i=0; i<unique_labels.vlen; ++i)
00135     {
00136         /* current index set for current label */
00137         CDynamicArray<index_t>* current=(CDynamicArray<index_t>*)
00138                 label_indices.get_element(i);
00139 
00140         for (index_t j=0; j<current->get_num_elements(); ++j)
00141         {
00142             CDynamicArray<index_t>* next=(CDynamicArray<index_t>*)
00143                     m_subset_indices->get_element(target_set++);
00144             next->append_element(current->get_element(j));
00145             target_set%=m_subset_indices->get_num_elements();
00146             SG_UNREF(next);
00147         }
00148 
00149         SG_UNREF(current);
00150     }
00151 
00152     /* finally shuffle to avoid that subsets with low indices have more
00153      * elements, which happens if the number of class labels is not equal to
00154      * the number of subsets (external random state important for threads) */
00155     m_subset_indices->shuffle(m_rng);
00156 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation