SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
OnlineLinearMachine.h
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _ONLINELINEARCLASSIFIER_H__
00012 #define _ONLINELINEARCLASSIFIER_H__
00013 
00014 #include <shogun/lib/common.h>
00015 #include <shogun/labels/Labels.h>
00016 #include <shogun/labels/RegressionLabels.h>
00017 #include <shogun/features/streaming/StreamingDotFeatures.h>
00018 #include <shogun/machine/Machine.h>
00019 
00020 #include <stdio.h>
00021 
00022 namespace shogun
00023 {
00050 class COnlineLinearMachine : public CMachine
00051 {
00052     public:
00054         COnlineLinearMachine();
00055         virtual ~COnlineLinearMachine();
00056 
00062         virtual void get_w(float32_t*& dst_w, int32_t& dst_dims)
00063         {
00064             ASSERT(w && w_dim>0)
00065             dst_w=w;
00066             dst_dims=w_dim;
00067         }
00068 
00075         virtual void get_w(float64_t*& dst_w, int32_t& dst_dims)
00076         {
00077             ASSERT(w && w_dim>0)
00078             dst_w=SG_MALLOC(float64_t, w_dim);
00079             for (int32_t i=0; i<w_dim; i++)
00080                 dst_w[i]=w[i];
00081             dst_dims=w_dim;
00082         }
00083 
00088         virtual SGVector<float32_t> get_w()
00089         {
00090             float32_t * dst_w = SG_MALLOC(float32_t, w_dim);
00091             for (int32_t i=0; i<w_dim; i++)
00092                 dst_w[i]=w[i];
00093             return SGVector<float32_t>(dst_w, w_dim);
00094         }
00095 
00101         virtual void set_w(float32_t* src_w, int32_t src_w_dim)
00102         {
00103             SG_FREE(w);
00104             w=SG_MALLOC(float32_t, src_w_dim);
00105             memcpy(w, src_w, size_t(src_w_dim)*sizeof(float32_t));
00106             w_dim=src_w_dim;
00107         }
00108 
00115         virtual void set_w(float64_t* src_w, int32_t src_w_dim)
00116         {
00117             SG_FREE(w);
00118             w=SG_MALLOC(float32_t, src_w_dim);
00119             for (int32_t i=0; i<src_w_dim; i++)
00120                 w[i] = src_w[i];
00121             w_dim=src_w_dim;
00122         }
00123 
00128         virtual void set_bias(float32_t b)
00129         {
00130             bias=b;
00131         }
00132 
00137         virtual float32_t get_bias()
00138         {
00139             return bias;
00140         }
00141 
00146         virtual void set_features(CStreamingDotFeatures* feat)
00147         {
00148             SG_REF(feat);
00149             SG_UNREF(features);
00150             features=feat;
00151         }
00152 
00159         virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
00160 
00167         virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
00168 
00170         virtual float64_t apply_one(int32_t vec_idx)
00171         {
00172             SG_NOTIMPLEMENTED
00173             return CMath::INFTY;
00174         }
00175 
00184         virtual float32_t apply_one(float32_t* vec, int32_t len);
00185 
00191         virtual float32_t apply_to_current_example();
00192 
00197         virtual CStreamingDotFeatures* get_features() { SG_REF(features); return features; }
00198 
00204         virtual const char* get_name() const { return "OnlineLinearMachine"; }
00205 
00209         virtual void start_train() { }
00210 
00214         virtual void stop_train() { }
00215 
00225         virtual void train_example(CStreamingDotFeatures *feature, float64_t label) { SG_NOTIMPLEMENTED }
00226 
00227     protected:
00236         virtual bool train_machine(CFeatures* data=NULL);
00237 
00243         SGVector<float64_t> apply_get_outputs(CFeatures* data);
00244 
00246         virtual bool train_require_labels() const { return false; }
00247 
00248     protected:
00250         int32_t w_dim;
00252         float32_t* w;
00254         float32_t bias;
00256         CStreamingDotFeatures* features;
00257 };
00258 }
00259 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation