SHOGUN
v3.2.0
|
00001 /* 00002 * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights 00003 * embodied in the content of this file are licensed under the BSD 00004 * (revised) open source license. 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * Written (W) 2011 Shashwat Lal Das 00012 * Adaptation of Vowpal Wabbit v5.1. 00013 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society. 00014 */ 00015 00016 #include <shogun/classifier/vw/cache/VwNativeCacheReader.h> 00017 00018 using namespace shogun; 00019 00020 CVwNativeCacheReader::CVwNativeCacheReader() 00021 : CVwCacheReader(), char_size(2) 00022 { 00023 init(); 00024 } 00025 00026 CVwNativeCacheReader::CVwNativeCacheReader(char * fname, CVwEnvironment* env_to_use) 00027 : CVwCacheReader(fname, env_to_use), char_size(2) 00028 { 00029 init(); 00030 buf.use_file(fd); 00031 check_cache_metadata(); 00032 } 00033 00034 CVwNativeCacheReader::CVwNativeCacheReader(int32_t f, CVwEnvironment* env_to_use) 00035 : CVwCacheReader(f, env_to_use), char_size(2) 00036 { 00037 init(); 00038 buf.use_file(fd); 00039 check_cache_metadata(); 00040 } 00041 00042 CVwNativeCacheReader::~CVwNativeCacheReader() 00043 { 00044 buf.close_file(); 00045 } 00046 00047 void CVwNativeCacheReader::set_file(int32_t f) 00048 { 00049 if (fd > 0) 00050 buf.close_file(); 00051 00052 fd = f; 00053 buf.use_file(fd); 00054 check_cache_metadata(); 00055 } 00056 00057 void CVwNativeCacheReader::init() 00058 { 00059 neg_1 = 1; 00060 general = 2; 00061 } 00062 00063 void CVwNativeCacheReader::check_cache_metadata() 00064 { 00065 const char* vw_version=env->vw_version; 00066 vw_size_t numbits = env->num_bits; 00067 00068 vw_size_t v_length; 00069 buf.read_file((char*)&v_length, sizeof(v_length)); 00070 if(v_length > 29) 00071 SG_SERROR("Cache version too long, cache file is probably invalid.\n") 00072 00073 char* t=SG_MALLOC(char, v_length); 00074 buf.read_file(t,v_length); 00075 if (strcmp(t,vw_version) != 0) 00076 { 00077 SG_FREE(t); 00078 SG_SERROR("Cache has possibly incompatible version!\n") 00079 } 00080 SG_FREE(t); 00081 00082 vw_size_t cache_numbits = 0; 00083 if (buf.read_file(&cache_numbits, sizeof(vw_size_t)) < ssize_t(sizeof(vw_size_t))) 00084 return; 00085 00086 if (cache_numbits != numbits) 00087 SG_SERROR("Bug encountered in caching! Bits used for weight in cache: %d.\n", cache_numbits) 00088 } 00089 00090 char* CVwNativeCacheReader::run_len_decode(char *p, vw_size_t& i) 00091 { 00092 // Read an int32_t 7 bits at a time. 00093 vw_size_t count = 0; 00094 while(*p & 128)\ 00095 i = i | ((*(p++) & 127) << 7*count++); 00096 i = i | (*(p++) << 7*count); 00097 return p; 00098 } 00099 00100 char* CVwNativeCacheReader::bufread_label(VwLabel* const ld, char* c) 00101 { 00102 ld->label = *(float32_t*)c; 00103 c += sizeof(ld->label); 00104 set_minmax(ld->label); 00105 00106 ld->weight = *(float32_t*)c; 00107 c += sizeof(ld->weight); 00108 ld->initial = *(float32_t*)c; 00109 c += sizeof(ld->initial); 00110 00111 return c; 00112 } 00113 00114 vw_size_t CVwNativeCacheReader::read_cached_label(VwLabel* const ld) 00115 { 00116 char *c; 00117 vw_size_t total = sizeof(ld->label)+sizeof(ld->weight)+sizeof(ld->initial); 00118 if (buf.buf_read(c, total) < total) 00119 return 0; 00120 c = bufread_label(ld,c); 00121 00122 return total; 00123 } 00124 00125 vw_size_t CVwNativeCacheReader::read_cached_tag(VwExample* const ae) 00126 { 00127 char* c; 00128 vw_size_t tag_size; 00129 if (buf.buf_read(c, sizeof(tag_size)) < sizeof(tag_size)) 00130 return 0; 00131 tag_size = *(vw_size_t*)c; 00132 c += sizeof(tag_size); 00133 00134 buf.set(c); 00135 if (buf.buf_read(c, tag_size) < tag_size) 00136 return 0; 00137 00138 ae->tag.erase(); 00139 ae->tag.push_many(c, tag_size); 00140 return tag_size+sizeof(tag_size); 00141 } 00142 00143 bool CVwNativeCacheReader::read_cached_example(VwExample* const ae) 00144 { 00145 vw_size_t mask = env->mask; 00146 vw_size_t total = read_cached_label(ae->ld); 00147 if (total == 0) 00148 return false; 00149 if (read_cached_tag(ae) == 0) 00150 return false; 00151 00152 char* c; 00153 unsigned char num_indices = 0; 00154 if (buf.buf_read(c, sizeof(num_indices)) < sizeof(num_indices)) 00155 return false; 00156 num_indices = *(unsigned char*)c; 00157 c += sizeof(num_indices); 00158 00159 buf.set(c); 00160 00161 for (; num_indices > 0; num_indices--) 00162 { 00163 vw_size_t temp; 00164 unsigned char index = 0; 00165 temp = buf.buf_read(c, sizeof(index) + sizeof(vw_size_t)); 00166 00167 if (temp < sizeof(index) + sizeof(vw_size_t)) 00168 SG_SERROR("Truncated example! %d < %d bytes expected.\n", 00169 temp, char_size + sizeof(vw_size_t)); 00170 00171 index = *(unsigned char*) c; 00172 c += sizeof(index); 00173 ae->indices.push((vw_size_t) index); 00174 00175 v_array<VwFeature>* ours = ae->atomics+index; 00176 float64_t* our_sum_feat_sq = ae->sum_feat_sq+index; 00177 vw_size_t storage = *(vw_size_t *)c; 00178 c += sizeof(vw_size_t); 00179 00180 buf.set(c); 00181 total += storage; 00182 if (buf.buf_read(c, storage) < storage) 00183 SG_SERROR("Truncated example! Wanted %d bytes!\n", storage) 00184 00185 char *end = c + storage; 00186 00187 vw_size_t last = 0; 00188 00189 for (; c!=end; ) 00190 { 00191 VwFeature f = {1., 0}; 00192 temp = f.weight_index; 00193 c = run_len_decode(c, temp); 00194 f.weight_index = temp; 00195 00196 if (f.weight_index & neg_1) 00197 f.x = -1.; 00198 else if (f.weight_index & general) 00199 { 00200 f.x = ((one_float*)c)->f; 00201 c += sizeof(float32_t); 00202 } 00203 00204 *our_sum_feat_sq += f.x*f.x; 00205 00206 vw_size_t diff = f.weight_index >> 2; 00207 int32_t s_diff = ZigZagDecode(diff); 00208 if (s_diff < 0) 00209 ae->sorted = false; 00210 00211 f.weight_index = last + s_diff; 00212 last = f.weight_index; 00213 f.weight_index = f.weight_index & mask; 00214 00215 ours->push(f); 00216 } 00217 buf.set(c); 00218 } 00219 00220 return true; 00221 }