SHOGUN
v3.2.0
|
00001 /* 00002 * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights 00003 * embodied in the content of this file are licensed under the BSD 00004 * (revised) open source license. 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * Written (W) 2011 Shashwat Lal Das 00012 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society. 00013 */ 00014 00015 #include <shogun/classifier/vw/VwRegressor.h> 00016 #include <shogun/loss/SquaredLoss.h> 00017 #include <shogun/io/IOBuffer.h> 00018 00019 using namespace shogun; 00020 00021 CVwRegressor::CVwRegressor() 00022 : CSGObject() 00023 { 00024 weight_vectors = NULL; 00025 loss = new CSquaredLoss(); 00026 init(NULL); 00027 } 00028 00029 CVwRegressor::CVwRegressor(CVwEnvironment* env_to_use) 00030 : CSGObject() 00031 { 00032 weight_vectors = NULL; 00033 loss = new CSquaredLoss(); 00034 init(env_to_use); 00035 } 00036 00037 CVwRegressor::~CVwRegressor() 00038 { 00039 // TODO: the number of weight_vectors depends on num_threads 00040 // this should be reimplemented using SGVector (for reference counting) 00041 SG_FREE(weight_vectors); 00042 SG_UNREF(loss); 00043 SG_UNREF(env); 00044 } 00045 00046 void CVwRegressor::init(CVwEnvironment* env_to_use) 00047 { 00048 if (!env_to_use) 00049 env_to_use = new CVwEnvironment(); 00050 00051 env = env_to_use; 00052 SG_REF(env); 00053 00054 // For each feature, there should be 'stride' number of 00055 // elements in the weight vector 00056 vw_size_t length = ((vw_size_t) 1) << env->num_bits; 00057 env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1; 00058 00059 // Only one learning thread for now 00060 vw_size_t num_threads = 1; 00061 weight_vectors = SG_MALLOC(float32_t*, num_threads); 00062 00063 for (vw_size_t i = 0; i < num_threads; i++) 00064 { 00065 weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads); 00066 00067 if (env->random_weights) 00068 { 00069 for (vw_size_t j = 0; j < length/num_threads; j++) 00070 weight_vectors[i][j] = CMath::random(-0.5, 0.5); 00071 } 00072 00073 if (env->initial_weight != 0.) 00074 for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride) 00075 weight_vectors[i][j] = env->initial_weight; 00076 00077 if (env->adaptive) 00078 for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride) 00079 weight_vectors[i][j] = 1; 00080 } 00081 } 00082 00083 // TODO: remove this, as we have serialization FW 00084 void CVwRegressor::dump_regressor(char* reg_name, bool as_text) 00085 { 00086 CIOBuffer io_temp; 00087 int32_t f = io_temp.open_file(reg_name,'w'); 00088 00089 if (f < 0) 00090 SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name) 00091 00092 const char* vw_version = env->vw_version; 00093 vw_size_t v_length = env->v_length; 00094 00095 if (!as_text) 00096 { 00097 // Write version info 00098 io_temp.write_file((char*)&v_length, sizeof(v_length)); 00099 io_temp.write_file(vw_version,v_length); 00100 00101 // Write max and min labels 00102 io_temp.write_file((char*)&env->min_label, sizeof(env->min_label)); 00103 io_temp.write_file((char*)&env->max_label, sizeof(env->max_label)); 00104 00105 // Write weight vector bits information 00106 io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits)); 00107 io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits)); 00108 00109 // For paired namespaces forming quadratic features 00110 int32_t len = env->pairs.get_num_elements(); 00111 io_temp.write_file((char *)&len, sizeof(len)); 00112 00113 for (int32_t k = 0; k < env->pairs.get_num_elements(); k++) 00114 io_temp.write_file(env->pairs.get_element(k), 2); 00115 00116 // ngram and skips information 00117 io_temp.write_file((char*)&env->ngram, sizeof(env->ngram)); 00118 io_temp.write_file((char*)&env->skips, sizeof(env->skips)); 00119 } 00120 else 00121 { 00122 // Write as human readable form 00123 char buff[512]; 00124 int32_t len; 00125 00126 len = sprintf(buff, "Version %s\n", vw_version); 00127 io_temp.write_file(buff, len); 00128 len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label); 00129 io_temp.write_file(buff, len); 00130 len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits); 00131 io_temp.write_file(buff, len); 00132 00133 if (env->pairs.get_num_elements() > 0) 00134 { 00135 len = sprintf(buff, "\n"); 00136 io_temp.write_file(buff, len); 00137 } 00138 00139 len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips); 00140 io_temp.write_file(buff, len); 00141 } 00142 00143 uint32_t length = 1 << env->num_bits; 00144 vw_size_t num_threads = env->num_threads(); 00145 vw_size_t stride = env->stride; 00146 00147 // Write individual weights 00148 for(uint32_t i = 0; i < length; i++) 00149 { 00150 float32_t v; 00151 v = weight_vectors[i%num_threads][stride*(i/num_threads)]; 00152 if (v != 0.) 00153 { 00154 if (!as_text) 00155 { 00156 io_temp.write_file((char *)&i, sizeof (i)); 00157 io_temp.write_file((char *)&v, sizeof (v)); 00158 } 00159 else 00160 { 00161 char buff[512]; 00162 int32_t len = sprintf(buff, "%d:%f\n", i, v); 00163 io_temp.write_file(buff, len); 00164 } 00165 } 00166 } 00167 00168 io_temp.close_file(); 00169 } 00170 00171 // TODO: remove this, as we have serialization FW 00172 void CVwRegressor::load_regressor(char* file) 00173 { 00174 CIOBuffer source; 00175 int32_t fd = source.open_file(file, 'r'); 00176 00177 if (fd < 0) 00178 SG_SERROR("Unable to open file for loading regressor!\n") 00179 00180 // Read version info 00181 vw_size_t v_length; 00182 source.read_file((char*)&v_length, sizeof(v_length)); 00183 char* t = SG_MALLOC(char, v_length); 00184 source.read_file(t,v_length); 00185 if (strcmp(t,env->vw_version) != 0) 00186 { 00187 SG_FREE(t); 00188 SG_SERROR("Regressor source has an incompatible VW version!\n") 00189 } 00190 SG_FREE(t); 00191 00192 // Read min and max label 00193 source.read_file((char*)&env->min_label, sizeof(env->min_label)); 00194 source.read_file((char*)&env->max_label, sizeof(env->max_label)); 00195 00196 // Read num_bits, multiple sources are not supported 00197 vw_size_t local_num_bits; 00198 source.read_file((char *)&local_num_bits, sizeof(local_num_bits)); 00199 00200 if ((vw_size_t) env->num_bits != local_num_bits) 00201 SG_SERROR("Wrong number of bits in regressor source!\n") 00202 00203 env->num_bits = local_num_bits; 00204 00205 vw_size_t local_thread_bits; 00206 source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits)); 00207 00208 env->thread_bits = local_thread_bits; 00209 00210 int32_t len; 00211 source.read_file((char *)&len, sizeof(len)); 00212 00213 // Read paired namespace information 00214 DynArray<char*> local_pairs; 00215 for (; len > 0; len--) 00216 { 00217 char pair[3]; 00218 source.read_file(pair, sizeof(char)*2); 00219 pair[2]='\0'; 00220 local_pairs.push_back(pair); 00221 } 00222 00223 env->pairs = local_pairs; 00224 00225 // Initialize the weight vector 00226 if (weight_vectors) 00227 SG_FREE(weight_vectors); 00228 init(env); 00229 00230 vw_size_t local_ngram; 00231 source.read_file((char*)&local_ngram, sizeof(local_ngram)); 00232 vw_size_t local_skips; 00233 source.read_file((char*)&local_skips, sizeof(local_skips)); 00234 00235 env->ngram = local_ngram; 00236 env->skips = local_skips; 00237 00238 // Read individual weights 00239 vw_size_t stride = env->stride; 00240 while (true) 00241 { 00242 uint32_t hash; 00243 ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash)); 00244 if (hash_bytes <= 0) 00245 break; 00246 00247 float32_t w = 0.; 00248 ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t)); 00249 if (weight_bytes <= 0) 00250 break; 00251 00252 vw_size_t num_threads = env->num_threads(); 00253 00254 weight_vectors[hash % num_threads][(hash*stride)/num_threads] 00255 = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w; 00256 } 00257 source.close_file(); 00258 }