SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
VwRegressor.cpp
Go to the documentation of this file.
00001 /*
00002  * Copyright (c) 2009 Yahoo! Inc.  All rights reserved.  The copyrights
00003  * embodied in the content of this file are licensed under the BSD
00004  * (revised) open source license.
00005  *
00006  * This program is free software; you can redistribute it and/or modify
00007  * it under the terms of the GNU General Public License as published by
00008  * the Free Software Foundation; either version 3 of the License, or
00009  * (at your option) any later version.
00010  *
00011  * Written (W) 2011 Shashwat Lal Das
00012  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
00013  */
00014 
00015 #include <shogun/classifier/vw/VwRegressor.h>
00016 #include <shogun/loss/SquaredLoss.h>
00017 #include <shogun/io/IOBuffer.h>
00018 
00019 using namespace shogun;
00020 
00021 CVwRegressor::CVwRegressor()
00022     : CSGObject()
00023 {
00024     weight_vectors = NULL;
00025     loss = new CSquaredLoss();
00026     init(NULL);
00027 }
00028 
00029 CVwRegressor::CVwRegressor(CVwEnvironment* env_to_use)
00030     : CSGObject()
00031 {
00032     weight_vectors = NULL;
00033     loss = new CSquaredLoss();
00034     init(env_to_use);
00035 }
00036 
00037 CVwRegressor::~CVwRegressor()
00038 {
00039     // TODO: the number of weight_vectors depends on num_threads
00040     // this should be reimplemented using SGVector (for reference counting)
00041     SG_FREE(weight_vectors);
00042     SG_UNREF(loss);
00043     SG_UNREF(env);
00044 }
00045 
00046 void CVwRegressor::init(CVwEnvironment* env_to_use)
00047 {
00048     if (!env_to_use)
00049         env_to_use = new CVwEnvironment();
00050 
00051     env = env_to_use;
00052     SG_REF(env);
00053 
00054     // For each feature, there should be 'stride' number of
00055     // elements in the weight vector
00056     vw_size_t length = ((vw_size_t) 1) << env->num_bits;
00057     env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1;
00058 
00059     // Only one learning thread for now
00060     vw_size_t num_threads = 1;
00061     weight_vectors = SG_MALLOC(float32_t*, num_threads);
00062 
00063     for (vw_size_t i = 0; i < num_threads; i++)
00064     {
00065         weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads);
00066 
00067         if (env->random_weights)
00068         {
00069             for (vw_size_t j = 0; j < length/num_threads; j++)
00070                 weight_vectors[i][j] = CMath::random(-0.5, 0.5);
00071         }
00072 
00073         if (env->initial_weight != 0.)
00074             for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride)
00075                 weight_vectors[i][j] = env->initial_weight;
00076 
00077         if (env->adaptive)
00078             for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride)
00079                 weight_vectors[i][j] = 1;
00080     }
00081 }
00082 
00083 // TODO: remove this, as we have serialization FW
00084 void CVwRegressor::dump_regressor(char* reg_name, bool as_text)
00085 {
00086     CIOBuffer io_temp;
00087     int32_t f = io_temp.open_file(reg_name,'w');
00088 
00089     if (f < 0)
00090         SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name)
00091 
00092     const char* vw_version = env->vw_version;
00093     vw_size_t v_length = env->v_length;
00094 
00095     if (!as_text)
00096     {
00097         // Write version info
00098         io_temp.write_file((char*)&v_length, sizeof(v_length));
00099         io_temp.write_file(vw_version,v_length);
00100 
00101         // Write max and min labels
00102         io_temp.write_file((char*)&env->min_label, sizeof(env->min_label));
00103         io_temp.write_file((char*)&env->max_label, sizeof(env->max_label));
00104 
00105         // Write weight vector bits information
00106         io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits));
00107         io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits));
00108 
00109         // For paired namespaces forming quadratic features
00110         int32_t len = env->pairs.get_num_elements();
00111         io_temp.write_file((char *)&len, sizeof(len));
00112 
00113         for (int32_t k = 0; k < env->pairs.get_num_elements(); k++)
00114             io_temp.write_file(env->pairs.get_element(k), 2);
00115 
00116         // ngram and skips information
00117         io_temp.write_file((char*)&env->ngram, sizeof(env->ngram));
00118         io_temp.write_file((char*)&env->skips, sizeof(env->skips));
00119     }
00120     else
00121     {
00122         // Write as human readable form
00123         char buff[512];
00124         int32_t len;
00125 
00126         len = sprintf(buff, "Version %s\n", vw_version);
00127         io_temp.write_file(buff, len);
00128         len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label);
00129         io_temp.write_file(buff, len);
00130         len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits);
00131         io_temp.write_file(buff, len);
00132 
00133         if (env->pairs.get_num_elements() > 0)
00134         {
00135             len = sprintf(buff, "\n");
00136             io_temp.write_file(buff, len);
00137         }
00138 
00139         len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips);
00140         io_temp.write_file(buff, len);
00141     }
00142 
00143     uint32_t length = 1 << env->num_bits;
00144     vw_size_t num_threads = env->num_threads();
00145     vw_size_t stride = env->stride;
00146 
00147     // Write individual weights
00148     for(uint32_t i = 0; i < length; i++)
00149     {
00150         float32_t v;
00151         v = weight_vectors[i%num_threads][stride*(i/num_threads)];
00152         if (v != 0.)
00153         {
00154             if (!as_text)
00155             {
00156                 io_temp.write_file((char *)&i, sizeof (i));
00157                 io_temp.write_file((char *)&v, sizeof (v));
00158             }
00159             else
00160             {
00161                 char buff[512];
00162                 int32_t len = sprintf(buff, "%d:%f\n", i, v);
00163                 io_temp.write_file(buff, len);
00164             }
00165         }
00166     }
00167 
00168     io_temp.close_file();
00169 }
00170 
00171 // TODO: remove this, as we have serialization FW
00172 void CVwRegressor::load_regressor(char* file)
00173 {
00174     CIOBuffer source;
00175     int32_t fd = source.open_file(file, 'r');
00176 
00177     if (fd < 0)
00178         SG_SERROR("Unable to open file for loading regressor!\n")
00179 
00180     // Read version info
00181     vw_size_t v_length;
00182     source.read_file((char*)&v_length, sizeof(v_length));
00183     char* t = SG_MALLOC(char, v_length);
00184     source.read_file(t,v_length);
00185     if (strcmp(t,env->vw_version) != 0)
00186     {
00187         SG_FREE(t);
00188         SG_SERROR("Regressor source has an incompatible VW version!\n")
00189     }
00190     SG_FREE(t);
00191 
00192     // Read min and max label
00193     source.read_file((char*)&env->min_label, sizeof(env->min_label));
00194     source.read_file((char*)&env->max_label, sizeof(env->max_label));
00195 
00196     // Read num_bits, multiple sources are not supported
00197     vw_size_t local_num_bits;
00198     source.read_file((char *)&local_num_bits, sizeof(local_num_bits));
00199 
00200     if ((vw_size_t) env->num_bits != local_num_bits)
00201         SG_SERROR("Wrong number of bits in regressor source!\n")
00202 
00203     env->num_bits = local_num_bits;
00204 
00205     vw_size_t local_thread_bits;
00206     source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits));
00207 
00208     env->thread_bits = local_thread_bits;
00209 
00210     int32_t len;
00211     source.read_file((char *)&len, sizeof(len));
00212 
00213     // Read paired namespace information
00214     DynArray<char*> local_pairs;
00215     for (; len > 0; len--)
00216     {
00217         char pair[3];
00218         source.read_file(pair, sizeof(char)*2);
00219         pair[2]='\0';
00220         local_pairs.push_back(pair);
00221     }
00222 
00223     env->pairs = local_pairs;
00224 
00225     // Initialize the weight vector
00226     if (weight_vectors)
00227         SG_FREE(weight_vectors);
00228     init(env);
00229 
00230     vw_size_t local_ngram;
00231     source.read_file((char*)&local_ngram, sizeof(local_ngram));
00232     vw_size_t local_skips;
00233     source.read_file((char*)&local_skips, sizeof(local_skips));
00234 
00235     env->ngram = local_ngram;
00236     env->skips = local_skips;
00237 
00238     // Read individual weights
00239     vw_size_t stride = env->stride;
00240     while (true)
00241     {
00242         uint32_t hash;
00243         ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash));
00244         if (hash_bytes <= 0)
00245             break;
00246 
00247         float32_t w = 0.;
00248         ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t));
00249         if (weight_bytes <= 0)
00250             break;
00251 
00252         vw_size_t num_threads = env->num_threads();
00253 
00254         weight_vectors[hash % num_threads][(hash*stride)/num_threads]
00255             = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w;
00256     }
00257     source.close_file();
00258 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation