SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2009 Vojtech Franc 00008 * Written (W) 2007-2009 Soeren Sonnenburg 00009 * Copyright (C) 2007-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _SVMOCAS_H___ 00013 #define _SVMOCAS_H___ 00014 00015 #include <shogun/lib/common.h> 00016 #include <shogun/machine/LinearMachine.h> 00017 #include <shogun/lib/external/libocas.h> 00018 #include <shogun/features/DotFeatures.h> 00019 #include <shogun/labels/Labels.h> 00020 00021 namespace shogun 00022 { 00023 #ifndef DOXYGEN_SHOULD_SKIP_THIS 00024 enum E_SVM_TYPE 00025 { 00026 SVM_OCAS = 0, 00027 SVM_BMRM = 1 00028 }; 00029 #endif 00030 00032 class CSVMOcas : public CLinearMachine 00033 { 00034 public: 00035 00037 MACHINE_PROBLEM_TYPE(PT_BINARY); 00038 00040 CSVMOcas(); 00041 00046 CSVMOcas(E_SVM_TYPE type); 00047 00054 CSVMOcas( 00055 float64_t C, CDotFeatures* traindat, 00056 CLabels* trainlab); 00057 virtual ~CSVMOcas(); 00058 00063 virtual EMachineType get_classifier_type() { return CT_SVMOCAS; } 00064 00071 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00072 00077 inline float64_t get_C1() { return C1; } 00078 00083 inline float64_t get_C2() { return C2; } 00084 00089 inline void set_epsilon(float64_t eps) { epsilon=eps; } 00090 00095 inline float64_t get_epsilon() { return epsilon; } 00096 00101 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00102 00107 inline bool get_bias_enabled() { return use_bias; } 00108 00113 inline void set_bufsize(int32_t sz) { bufsize=sz; } 00114 00119 inline int32_t get_bufsize() { return bufsize; } 00120 00125 virtual float64_t compute_primal_objective() const; 00126 00127 protected: 00136 static void compute_W( 00137 float64_t *sq_norm_W, float64_t *dp_WoldW, float64_t *alpha, 00138 uint32_t nSel, void* ptr); 00139 00146 static float64_t update_W(float64_t t, void* ptr ); 00147 00156 static int add_new_cut( 00157 float64_t *new_col_H, uint32_t *new_cut, uint32_t cut_length, 00158 uint32_t nSel, void* ptr ); 00159 00165 static int compute_output( float64_t *output, void* ptr ); 00166 00173 static int sort( float64_t* vals, float64_t* data, uint32_t size); 00174 00176 static inline void print(ocas_return_value_T value) 00177 { 00178 return; 00179 } 00180 00181 protected: 00190 virtual bool train_machine(CFeatures* data=NULL); 00191 00193 inline const char* get_name() const { return "SVMOcas"; } 00194 private: 00195 void init(); 00196 00197 protected: 00199 bool use_bias; 00201 int32_t bufsize; 00203 float64_t C1; 00205 float64_t C2; 00207 float64_t epsilon; 00209 E_SVM_TYPE method; 00210 00212 float64_t* old_w; 00214 float64_t old_bias; 00216 float64_t* tmp_a_buf; 00218 SGVector<float64_t> lab; 00219 00222 float64_t** cp_value; 00224 uint32_t** cp_index; 00226 uint32_t* cp_nz_dims; 00228 float64_t* cp_bias; 00229 00231 float64_t primal_objective; 00232 }; 00233 } 00234 #endif