SHOGUN
v3.2.0
|
00001 /* 00002 * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights 00003 * embodied in the content of this file are licensed under the BSD 00004 * (revised) open source license. 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * Written (W) 2011 Shashwat Lal Das 00012 * Adaptation of Vowpal Wabbit v5.1. 00013 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society. 00014 */ 00015 00016 #include <shogun/classifier/vw/VwParser.h> 00017 #include <shogun/classifier/vw/cache/VwNativeCacheWriter.h> 00018 00019 using namespace shogun; 00020 00021 CVwParser::CVwParser() 00022 : CSGObject() 00023 { 00024 env = new CVwEnvironment(); 00025 hasher = CHash::MurmurHashString; 00026 write_cache = false; 00027 cache_writer = NULL; 00028 } 00029 00030 CVwParser::CVwParser(CVwEnvironment* env_to_use) 00031 : CSGObject() 00032 { 00033 ASSERT(env_to_use) 00034 00035 env = env_to_use; 00036 hasher = CHash::MurmurHashString; 00037 write_cache = false; 00038 cache_writer = NULL; 00039 SG_REF(env); 00040 } 00041 00042 CVwParser::~CVwParser() 00043 { 00044 SG_UNREF(env); 00045 SG_UNREF(cache_writer); 00046 } 00047 00048 int32_t CVwParser::read_features(CIOBuffer* buf, VwExample*& ae) 00049 { 00050 char *line=NULL; 00051 int32_t num_chars = buf->read_line(line); 00052 if (num_chars == 0) 00053 return num_chars; 00054 00055 /* Mark begin and end of example in the buffer */ 00056 substring example_string = {line, line + num_chars}; 00057 00058 /* Channels containing separate namespaces/label information*/ 00059 channels.erase(); 00060 00061 /* Split at '|' character */ 00062 tokenize('|', example_string, channels); 00063 00064 /* If first char is not '|', then the first channel contains label data */ 00065 substring* feature_start = &channels[1]; 00066 00067 if (*line == '|') 00068 feature_start = &channels[0]; /* Unlabelled data */ 00069 else 00070 { 00071 /* First channel has label info */ 00072 substring label_space = channels[0]; 00073 char* tab_location = safe_index(label_space.start, '\t', label_space.end); 00074 if (tab_location != label_space.end) 00075 label_space.start = tab_location+1; 00076 00077 /* Split the label space on spaces */ 00078 tokenize(' ',label_space,words); 00079 if (words.index() > 0 && words.last().end == label_space.end) //The last field is a tag, so record and strip it off 00080 { 00081 substring tag = words.pop(); 00082 ae->tag.push_many(tag.start, tag.end - tag.start); 00083 } 00084 00085 ae->ld->label_from_substring(words); 00086 set_minmax(ae->ld->label); 00087 } 00088 00089 vw_size_t mask = env->mask; 00090 00091 /* Now parse the individual channels, i.e., namespaces */ 00092 for (substring* i = feature_start; i != channels.end; i++) 00093 { 00094 substring channel = *i; 00095 00096 tokenize(' ',channel, words); 00097 if (words.begin == words.end) 00098 continue; 00099 00100 /* Set default scale value for channel */ 00101 float32_t channel_v = 1.; 00102 vw_size_t channel_hash; 00103 00104 /* Index by which to refer to the namespace */ 00105 vw_size_t index = 0; 00106 bool new_index = false; 00107 vw_size_t feature_offset = 0; 00108 00109 if (channel.start[0] != ' ') 00110 { 00111 /* Nonanonymous namespace specified */ 00112 feature_offset++; 00113 feature_value(words[0], name, channel_v); 00114 00115 if (name.index() > 0) 00116 { 00117 index = (unsigned char)(*name[0].start); 00118 if (ae->atomics[index].begin == ae->atomics[index].end) 00119 { 00120 ae->sum_feat_sq[index] = 0; 00121 new_index = true; 00122 } 00123 } 00124 channel_hash = hasher(name[0], hash_base); 00125 } 00126 else 00127 { 00128 /* Use default namespace with index below */ 00129 index = (unsigned char)' '; 00130 if (ae->atomics[index].begin == ae->atomics[index].end) 00131 { 00132 ae->sum_feat_sq[index] = 0; 00133 new_index = true; 00134 } 00135 channel_hash = 0; 00136 } 00137 00138 for (substring* j = words.begin+feature_offset; j != words.end; j++) 00139 { 00140 /* Get individual features and multiply by scale value */ 00141 float32_t v = 0.0; 00142 feature_value(*j, name, v); 00143 v *= channel_v; 00144 00145 /* Hash feature */ 00146 vw_size_t word_hash = (hasher(name[0], channel_hash)) & mask; 00147 VwFeature f = {v,word_hash}; 00148 ae->sum_feat_sq[index] += v*v; 00149 ae->atomics[index].push(f); 00150 } 00151 00152 /* Add index to list of indices if required */ 00153 if (new_index && ae->atomics[index].begin != ae->atomics[index].end) 00154 ae->indices.push(index); 00155 00156 } 00157 00158 if (write_cache) 00159 cache_writer->cache_example(ae); 00160 00161 return num_chars; 00162 } 00163 00164 int32_t CVwParser::read_svmlight_features(CIOBuffer* buf, VwExample*& ae) 00165 { 00166 char *line=NULL; 00167 int32_t num_chars = buf->read_line(line); 00168 if (num_chars == 0) 00169 return num_chars; 00170 00171 /* Mark begin and end of example in the buffer */ 00172 substring example_string = {line, line + num_chars}; 00173 00174 vw_size_t mask = env->mask; 00175 tokenize(' ', example_string, words); 00176 00177 ae->ld->label = SGIO::float_of_substring(words[0]); 00178 ae->ld->weight = 1.; 00179 ae->ld->initial = 0.; 00180 set_minmax(ae->ld->label); 00181 00182 substring* feature_start = &words[1]; 00183 00184 vw_size_t index = (unsigned char)' '; // Any default namespace is ok 00185 vw_size_t channel_hash = 0; 00186 ae->sum_feat_sq[index] = 0; 00187 ae->indices.push(index); 00188 /* Now parse the individual features */ 00189 for (substring* i = feature_start; i != words.end; i++) 00190 { 00191 float32_t v; 00192 feature_value(*i, name, v); 00193 00194 vw_size_t word_hash = (hasher(name[0], channel_hash)) & mask; 00195 VwFeature f = {v,word_hash}; 00196 ae->sum_feat_sq[index] += v*v; 00197 ae->atomics[index].push(f); 00198 } 00199 00200 if (write_cache) 00201 cache_writer->cache_example(ae); 00202 00203 return num_chars; 00204 } 00205 00206 int32_t CVwParser::read_dense_features(CIOBuffer* buf, VwExample*& ae) 00207 { 00208 char *line=NULL; 00209 int32_t num_chars = buf->read_line(line); 00210 if (num_chars == 0) 00211 return num_chars; 00212 00213 // Mark begin and end of example in the buffer 00214 substring example_string = {line, line + num_chars}; 00215 00216 vw_size_t mask = env->mask; 00217 tokenize(' ', example_string, words); 00218 00219 ae->ld->label = SGIO::float_of_substring(words[0]); 00220 ae->ld->weight = 1.; 00221 ae->ld->initial = 0.; 00222 set_minmax(ae->ld->label); 00223 00224 substring* feature_start = &words[1]; 00225 00226 vw_size_t index = (unsigned char)' '; 00227 00228 ae->sum_feat_sq[index] = 0; 00229 ae->indices.push(index); 00230 // Now parse individual features 00231 int32_t j=0; 00232 for (substring* i = feature_start; i != words.end; i++) 00233 { 00234 float32_t v = SGIO::float_of_substring(*i); 00235 vw_size_t word_hash = j & mask; 00236 VwFeature f = {v,word_hash}; 00237 ae->sum_feat_sq[index] += v*v; 00238 ae->atomics[index].push(f); 00239 j++; 00240 } 00241 00242 if (write_cache) 00243 cache_writer->cache_example(ae); 00244 00245 return num_chars; 00246 } 00247 00248 void CVwParser::init_cache(char * fname, EVwCacheType type) 00249 { 00250 char* file_name = fname; 00251 char default_cache_name[] = "vw_cache.dat.cache"; 00252 00253 if (!fname) 00254 file_name = default_cache_name; 00255 00256 write_cache = true; 00257 cache_type = type; 00258 00259 switch (type) 00260 { 00261 case C_NATIVE: 00262 cache_writer = new CVwNativeCacheWriter(file_name, env); 00263 return; 00264 case C_PROTOBUF: 00265 SG_ERROR("Protocol buffers cache support is not implemented yet.\n") 00266 } 00267 00268 SG_ERROR("Unexpected cache type specified!\n") 00269 } 00270 00271 void CVwParser::feature_value(substring &s, v_array<substring>& feat_name, float32_t &v) 00272 { 00273 // Get the value of the feature in the substring 00274 tokenize(':', s, feat_name); 00275 00276 switch (feat_name.index()) 00277 { 00278 // If feature value is not specified, assume 1.0 00279 case 0: 00280 case 1: 00281 v = 1.; 00282 break; 00283 case 2: 00284 v = SGIO::float_of_substring(feat_name[1]); 00285 if (CMath::is_nan(v)) 00286 SG_SERROR("error NaN value for feature %s! Terminating!\n", 00287 SGIO::c_string_of_substring(feat_name[0])); 00288 break; 00289 default: 00290 SG_SERROR("Examples with a weird name, i.e., '%s'\n", 00291 SGIO::c_string_of_substring(s)); 00292 } 00293 } 00294 00295 void CVwParser::tokenize(char delim, substring s, v_array<substring>& ret) 00296 { 00297 ret.erase(); 00298 char *last = s.start; 00299 for (; s.start != s.end; s.start++) 00300 { 00301 if (*s.start == delim) 00302 { 00303 if (s.start != last) 00304 { 00305 substring temp = {last,s.start}; 00306 ret.push(temp); 00307 } 00308 last = s.start+1; 00309 } 00310 } 00311 if (s.start != last) 00312 { 00313 substring final = {last, s.start}; 00314 ret.push(final); 00315 } 00316 }