Marsyas
0.6.0-alpha
|
00001 00008 #ifndef MARSYAS_newWEKASOURCE_H 00009 #define MARSYAS_newWEKASOURCE_H 00010 00011 #include <marsyas/system/MarSystem.h> 00012 #include <marsyas/WekaData.h> 00013 #include <list> 00014 #include <vector> 00015 #include <iostream> 00016 #include <cstdlib> 00017 #include <cstring> 00018 //using namespace std; 00019 00020 namespace Marsyas 00021 { 00022 class WekaFoldData : public WekaData 00023 { 00024 public: 00025 ~WekaFoldData() {} 00026 00027 typedef enum 00028 { 00029 None, 00030 Training, 00031 Predict 00032 } nextMode; 00033 00034 private: 00035 mrs_natural foldCount_; 00036 00037 mrs_real rstep_; 00038 // mrs_natural predictSum_; 00039 mrs_natural excludeSectionStart_; 00040 mrs_natural excludeSectionEnd_; 00041 00042 mrs_natural iteration_; 00043 mrs_natural currentIndex_; 00044 // mrs_natural predictIndex_; 00045 00046 public: 00047 void SetupkFoldSections(const WekaData& data, mrs_natural foldCount, mrs_natural classAttr=-1) 00048 { 00049 MRSASSERT(foldCount>0); 00050 foldCount_ = foldCount; 00051 00052 //create the dataset with same number of columns as input data 00053 this->Create(data.getCols()); 00054 if(classAttr<0) 00055 { //if no class specified, copy all data into this instance 00056 this->assign(data.begin(), data.end()); 00057 } 00058 else 00059 { //otherwise only copy rows that match input class into this dataset 00060 for(mrs_natural ii=0; ii<(mrs_natural)data.size(); ++ii) 00061 if(data.GetClass(ii)==classAttr) 00062 this->Append(data[ii]); 00063 }//else 00064 00065 //setup fold sections 00066 this->Reset(); 00067 00068 }//SetupkFoldSections 00069 00070 //setup the fold sections for this dataset. 00071 void Reset() 00072 { 00073 00074 00075 this->Shuffle(); 00076 00077 rstep_ = (mrs_real)this->size() / (mrs_real)foldCount_; 00078 00079 if (foldCount_ > (mrs_natural)this->size()) 00080 { 00081 std::cout << "Folds exceed number of instances" << std::endl; 00082 std::cout << "foldCount_ = " << foldCount_ << std::endl; 00083 std::cout << "size = " << this->size() << std::endl;; 00084 exit(1); 00085 } 00086 00087 00088 iteration_ = 0; 00089 00090 excludeSectionStart_ = 0; 00091 excludeSectionEnd_ = ((mrs_natural)rstep_) - 1; 00092 currentIndex_ = excludeSectionEnd_ + 1; 00093 } 00094 00095 std::vector<mrs_real> *Next(nextMode& next) 00096 { 00097 00098 std::vector<mrs_real> *ret = this->at(currentIndex_); 00099 00100 if(currentIndex_ == excludeSectionEnd_) 00101 { 00102 iteration_++; 00103 if(iteration_ >= foldCount_) 00104 { 00105 next = None; 00106 return ret; 00107 }//if 00108 00109 excludeSectionStart_ = excludeSectionEnd_ + 1; 00110 if(iteration_ == (foldCount_ - 1)) 00111 { 00112 excludeSectionEnd_ = (mrs_natural)this->size() - 1; 00113 currentIndex_ = 0; 00114 } 00115 else 00116 { 00117 excludeSectionEnd_ = ((mrs_natural)((iteration_+1) * rstep_)) - 1; 00118 currentIndex_ = excludeSectionEnd_ + 1; 00119 } 00120 00121 00122 next = Training; 00123 return ret; 00124 }//if 00125 00126 currentIndex_++; 00127 00128 00129 if(currentIndex_ >= (mrs_natural)this->size()) 00130 currentIndex_ = 0; 00131 00132 if(currentIndex_ >= excludeSectionStart_ && currentIndex_ <= excludeSectionEnd_) 00133 next = Predict; 00134 else 00135 next = Training; 00136 00137 00138 return ret; 00139 00140 }//Next 00141 00142 }; 00143 00144 typedef enum 00145 { 00146 None, 00147 kFoldStratified, 00148 kFoldNonStratified, 00149 UseTestSet, 00150 PercentageSplit, 00151 OutputInstancePair 00152 } ValidationModeEnum; 00153 00154 class marsyas_EXPORT WekaSource : public MarSystem 00155 { 00156 public: 00157 WekaSource(std::string name); 00158 WekaSource(const WekaSource& a); 00159 ~WekaSource(); 00160 00161 MarSystem *clone()const; 00162 void myProcess(realvec& in, realvec& out); 00163 00164 private: 00165 void addControls(); 00166 void myUpdate(MarControlPtr sender); 00167 00168 //control values 00169 std::string filename_; //name of arff file to read 00170 std::string attributesToInclude_; //list of attributes to include in dataset 00171 00172 //these are the class names froun in the arff file header 00173 std::vector<std::string>classesFound_; 00174 // if there are no classes, we're doing regression 00175 MarControlPtr ctrl_regression_; 00176 00177 std::string relation_; 00178 00179 //these are the attribute names found in the arff file header 00180 std::vector<std::string>attributesFound_; 00181 00182 //Holds the actual attribute data read from the arff file 00183 WekaData data_; 00184 00185 //an array of bools that specify if an attribute from the arff file should be included 00186 //in the dataset. 00187 std::vector<bool>attributesIncluded_; 00188 00189 //the list of attributes that are to be included in the dataset 00190 std::vector<std::string>attributesIncludedList_; 00191 00192 //the validation mode enum to use 00193 ValidationModeEnum validationModeEnum_; 00194 00195 //Common validation method data members 00196 mrs_natural currentIndex_; 00197 00198 //kFold Stratified validation method data members 00199 mrs_natural foldCount_; 00200 WekaFoldData foldData_; 00201 WekaFoldData::nextMode foldCurrentMode_; 00202 WekaFoldData::nextMode foldNextMode_; 00203 00204 //kFold NonStratified validation method data members 00205 std::vector<WekaFoldData> foldClassData_; 00206 mrs_natural foldClassDataIndex_; 00207 00208 //UseTestSet validation method data members 00209 WekaData useTestSetData_; 00210 00211 //PercentageSplit validation method data members 00212 mrs_natural percentageIndex_; 00213 00214 void handleDefault(bool trainMode, realvec& out); 00215 void handleInstancePair(realvec& out); 00216 void handleFoldingNonStratifiedValidation(bool trainMode, realvec &out); 00217 void handleFoldingStratifiedValidation(bool trainMode, realvec &out); 00218 void handleUseTestSet(bool trainMode, realvec &out); 00219 void handlePercentageSplit(bool trainMode, realvec &out); 00220 00221 private: 00222 mrs_natural findClass(const char *className)const; 00223 mrs_natural findAttribute(const char *attribute)const; 00224 mrs_natural parseAttribute(const char *attribute)const; 00225 00226 void parseAttributesToInclude(const std::string& attributesToInclude); 00227 void loadFile(const std::string& filename, const std::string& attributesToExtract, WekaData& data); 00228 void parseHeader(std::ifstream& mis, const std::string& filename, const std::string& attributesToExtract); 00229 void parseData(std::ifstream& mis, const std::string& filename, WekaData& data); 00230 00231 };//class WekaSource 00232 }//namespace Marsyas 00233 00234 #endif