SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011-2012 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/StratifiedCrossValidationSplitting.h> 00012 #include <shogun/labels/Labels.h> 00013 #include <shogun/labels/BinaryLabels.h> 00014 #include <shogun/labels/MulticlassLabels.h> 00015 00016 using namespace shogun; 00017 00018 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting() : 00019 CSplittingStrategy() 00020 { 00021 m_rng = sg_rand; 00022 } 00023 00024 CStratifiedCrossValidationSplitting::CStratifiedCrossValidationSplitting( 00025 CLabels* labels, index_t num_subsets) : 00026 CSplittingStrategy(labels, num_subsets) 00027 { 00028 /* check for "stupid" combinations of label numbers and num_subsets. 00029 * if there are of a class less labels than num_subsets, the class will not 00030 * appear in every subset, leading to subsets of only one class in the 00031 * extreme case of a two class labeling. */ 00032 SGVector<float64_t> classes; 00033 00034 int32_t num_classes=2; 00035 if (labels->get_label_type() == LT_MULTICLASS) 00036 { 00037 num_classes=((CMulticlassLabels*) labels)->get_num_classes(); 00038 classes=((CMulticlassLabels*) labels)->get_unique_labels(); 00039 } 00040 else if (labels->get_label_type() == LT_BINARY) 00041 { 00042 classes=SGVector<float64_t>(2); 00043 classes[0]=-1; 00044 classes[1]=+1; 00045 } 00046 else 00047 { 00048 SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n") 00049 } 00050 00051 SGVector<index_t> labels_per_class(num_classes); 00052 00053 for (index_t i=0; i<num_classes; ++i) 00054 { 00055 labels_per_class.vector[i]=0; 00056 for (index_t j=0; j<labels->get_num_labels(); ++j) 00057 { 00058 if (classes.vector[i]==((CDenseLabels*) labels)->get_label(j)) 00059 labels_per_class.vector[i]++; 00060 } 00061 } 00062 00063 for (index_t i=0; i<num_classes; ++i) 00064 { 00065 if (labels_per_class.vector[i]<num_subsets) 00066 { 00067 SG_WARNING("There are only %d labels of class %.18g, but %d " 00068 "subsets. Labels of that class will not appear in every " 00069 "subset!\n", labels_per_class.vector[i], classes.vector[i], num_subsets); 00070 } 00071 } 00072 00073 m_rng = sg_rand; 00074 } 00075 00076 void CStratifiedCrossValidationSplitting::build_subsets() 00077 { 00078 /* ensure that subsets are empty and set flag to filled */ 00079 reset_subsets(); 00080 m_is_filled=true; 00081 00082 SGVector<float64_t> unique_labels; 00083 00084 if (m_labels->get_label_type() == LT_MULTICLASS) 00085 { 00086 unique_labels=((CMulticlassLabels*) m_labels)->get_unique_labels(); 00087 } 00088 else if (m_labels->get_label_type() == LT_BINARY) 00089 { 00090 unique_labels=SGVector<float64_t>(2); 00091 unique_labels[0]=-1; 00092 unique_labels[1]=+1; 00093 } 00094 else 00095 { 00096 SG_ERROR("Multiclass or binary labels required for stratified crossvalidation\n") 00097 } 00098 00099 /* for every label, build set for indices */ 00100 CDynamicObjectArray label_indices; 00101 for (index_t i=0; i<unique_labels.vlen; ++i) 00102 label_indices.append_element(new CDynamicArray<index_t> ()); 00103 00104 /* fill set with indices, for each label type ... */ 00105 for (index_t i=0; i<unique_labels.vlen; ++i) 00106 { 00107 /* ... iterate over all labels and add indices with same label to set */ 00108 for (index_t j=0; j<m_labels->get_num_labels(); ++j) 00109 { 00110 if (((CDenseLabels*) m_labels)->get_label(j)==unique_labels.vector[i]) 00111 { 00112 CDynamicArray<index_t>* current=(CDynamicArray<index_t>*) 00113 label_indices.get_element(i); 00114 current->append_element(j); 00115 SG_UNREF(current); 00116 } 00117 } 00118 } 00119 00120 /* shuffle created label sets */ 00121 for (index_t i=0; i<label_indices.get_num_elements(); ++i) 00122 { 00123 CDynamicArray<index_t>* current=(CDynamicArray<index_t>*) 00124 label_indices.get_element(i); 00125 00126 // external random state important for threads 00127 current->shuffle(m_rng); 00128 00129 SG_UNREF(current); 00130 } 00131 00132 /* distribute labels to subsets for all label types */ 00133 index_t target_set=0; 00134 for (index_t i=0; i<unique_labels.vlen; ++i) 00135 { 00136 /* current index set for current label */ 00137 CDynamicArray<index_t>* current=(CDynamicArray<index_t>*) 00138 label_indices.get_element(i); 00139 00140 for (index_t j=0; j<current->get_num_elements(); ++j) 00141 { 00142 CDynamicArray<index_t>* next=(CDynamicArray<index_t>*) 00143 m_subset_indices->get_element(target_set++); 00144 next->append_element(current->get_element(j)); 00145 target_set%=m_subset_indices->get_num_elements(); 00146 SG_UNREF(next); 00147 } 00148 00149 SG_UNREF(current); 00150 } 00151 00152 /* finally shuffle to avoid that subsets with low indices have more 00153 * elements, which happens if the number of class labels is not equal to 00154 * the number of subsets (external random state important for threads) */ 00155 m_subset_indices->shuffle(m_rng); 00156 }