SHOGUN
v3.2.0
|
00001 /* 00002 * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights 00003 * embodied in the content of this file are licensed under the BSD 00004 * (revised) open source license. 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * Written (W) 2011 Shashwat Lal Das 00012 * Adaptation of Vowpal Wabbit v5.1. 00013 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society. 00014 */ 00015 00016 #ifndef _VOWPALWABBIT_H__ 00017 #define _VOWPALWABBIT_H__ 00018 00019 #include <shogun/classifier/vw/vw_common.h> 00020 #include <shogun/classifier/vw/learners/VwAdaptiveLearner.h> 00021 #include <shogun/classifier/vw/learners/VwNonAdaptiveLearner.h> 00022 #include <shogun/classifier/vw/VwRegressor.h> 00023 00024 #include <shogun/features/streaming/StreamingVwFeatures.h> 00025 #include <shogun/machine/OnlineLinearMachine.h> 00026 00027 namespace shogun 00028 { 00038 class CVowpalWabbit: public COnlineLinearMachine 00039 { 00040 public: 00041 00043 MACHINE_PROBLEM_TYPE(PT_BINARY); 00044 00048 CVowpalWabbit(); 00049 00056 CVowpalWabbit(CStreamingVwFeatures* feat); 00057 00061 CVowpalWabbit(CVowpalWabbit *vw); 00062 00066 ~CVowpalWabbit(); 00067 00072 void reinitialize_weights(); 00073 00082 void set_no_training(bool dont_train) { no_training = dont_train; } 00083 00089 void set_adaptive(bool adaptive_learning); 00090 00097 void set_exact_adaptive_norm(bool exact_adaptive); 00098 00104 void set_num_passes(int32_t passes) 00105 { 00106 env->num_passes = passes; 00107 } 00108 00114 void load_regressor(char* file_name); 00115 00122 void set_regressor_out(char* file_name, bool is_text = true); 00123 00129 void set_prediction_out(char* file_name); 00130 00137 void add_quadratic_pair(char* pair); 00138 00144 virtual bool train_machine(CFeatures* feat = NULL); 00145 00153 virtual float32_t predict_and_finalize(VwExample* ex); 00154 00163 float32_t compute_exact_norm(VwExample* &ex, float32_t& sum_abs_x); 00164 00177 float32_t compute_exact_norm_quad(float32_t* weights, VwFeature& page_feature, v_array<VwFeature> &offer_features, 00178 vw_size_t mask, float32_t g, float32_t& sum_abs_x); 00179 00185 virtual CVwEnvironment* get_env() 00186 { 00187 SG_REF(env); 00188 return env; 00189 } 00190 00196 virtual const char* get_name() const { return "VowpalWabbit"; } 00197 00202 virtual void set_learner(); 00203 00207 CVwLearner* get_learner() { return learner; } 00208 00209 private: 00215 virtual void init(CStreamingVwFeatures* feat = NULL); 00216 00224 virtual float32_t inline_l1_predict(VwExample* &ex); 00225 00233 virtual float32_t inline_predict(VwExample* &ex); 00234 00242 virtual float32_t finalize_prediction(float32_t ret); 00243 00249 virtual void output_example(VwExample* &ex); 00250 00256 virtual void print_update(VwExample* &ex); 00257 00266 virtual void output_prediction(int32_t f, float32_t res, float32_t weight, v_array<char> tag); 00267 00273 void set_verbose(bool verbose); 00274 00275 protected: 00277 CStreamingVwFeatures* features; 00278 00280 CVwEnvironment* env; 00281 00283 CVwLearner* learner; 00284 00286 CVwRegressor* reg; 00287 00288 private: 00290 bool quiet; 00291 00293 bool no_training; 00294 00296 float32_t dump_interval; 00298 float32_t sum_loss_since_last_dump; 00300 float64_t old_weighted_examples; 00301 00303 char* reg_name; 00305 bool reg_dump_text; 00306 00308 bool save_predictions; 00310 int32_t prediction_fd; 00311 }; 00312 00313 } 00314 #endif // _VOWPALWABBIT_H__