SHOGUN
v3.2.0
|
00001 #ifndef _SGDQN_H___ 00002 #define _SGDQN_H___ 00003 00004 /* 00005 SVM with Quasi-Newton stochastic gradient 00006 Copyright (C) 2009- Antoine Bordes 00007 00008 This program is free software; you can redistribute it and/or 00009 modify it under the terms of the GNU Lesser General Public 00010 License as published by the Free Software Foundation; either 00011 version 2.1 of the License, or (at your option) any later version. 00012 00013 This program is distributed in the hope that it will be useful, 00014 but WITHOUT ANY WARRANTY; without even the implied warranty of 00015 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00016 GNU General Public License for more details. 00017 00018 You should have received a copy of the GNU Lesser General Public 00019 License along with this library; if not, write to the Free Software 00020 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 00021 00022 Shogun adjustments (w) 2011 Siddharth Kherada 00023 */ 00024 00025 #include <shogun/lib/common.h> 00026 #include <shogun/machine/LinearMachine.h> 00027 #include <shogun/features/DotFeatures.h> 00028 #include <shogun/labels/Labels.h> 00029 #include <shogun/loss/LossFunction.h> 00030 00031 namespace shogun 00032 { 00034 class CSGDQN : public CLinearMachine 00035 { 00036 public: 00037 00039 MACHINE_PROBLEM_TYPE(PT_BINARY); 00040 00042 CSGDQN(); 00043 00048 CSGDQN(float64_t C); 00049 00056 CSGDQN( 00057 float64_t C, CDotFeatures* traindat, 00058 CLabels* trainlab); 00059 00060 virtual ~CSGDQN(); 00061 00066 virtual EMachineType get_classifier_type() { return CT_SGDQN; } 00067 00076 virtual bool train(CFeatures* data=NULL); 00077 00084 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00085 00090 inline float64_t get_C1() { return C1; } 00091 00096 inline float64_t get_C2() { return C2; } 00097 00102 inline void set_epochs(int32_t e) { epochs=e; } 00103 00108 inline int32_t get_epochs() { return epochs; } 00109 00111 void compute_ratio(float64_t* W,float64_t* W_1,float64_t* B,float64_t* dst,int32_t dim,float64_t regularizer_lambda,float64_t loss); 00112 00114 void combine_and_clip(float64_t* Bc,float64_t* B,int32_t dim,float64_t c1,float64_t c2,float64_t v1,float64_t v2); 00115 00120 void set_loss_function(CLossFunction* loss_func); 00121 00126 inline CLossFunction* get_loss_function() { SG_REF(loss); return loss; } 00127 00129 virtual const char* get_name() const { return "SGDQN"; } 00130 00131 protected: 00133 void calibrate(); 00134 00135 private: 00136 void init(); 00137 00138 private: 00139 float64_t t; 00140 float64_t C1; 00141 float64_t C2; 00142 int32_t epochs; 00143 int32_t skip; 00144 int32_t count; 00145 00146 CLossFunction* loss; 00147 }; 00148 } 00149 #endif