SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Alexander Binder 00008 * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 * 00010 * Update to patch 0.10.0 - thanks to Eric aka Yoo (thereisnoknife@gmail.com) 00011 * 00012 */ 00013 00014 #ifndef MKLMulticlassGRADIENT_H_ 00015 #define MKLMulticlassGRADIENT_H_ 00016 00017 #include <vector> 00018 #include <cmath> 00019 #include <cassert> 00020 #include <shogun/base/SGObject.h> 00021 #include <shogun/classifier/mkl/MKLMulticlassOptimizationBase.h> 00022 00023 00024 namespace shogun 00025 { 00031 class MKLMulticlassGradient: public MKLMulticlassOptimizationBase 00032 { 00033 public: 00037 MKLMulticlassGradient(); 00041 virtual ~MKLMulticlassGradient(); 00042 00046 MKLMulticlassGradient(MKLMulticlassGradient & gl); 00047 00051 MKLMulticlassGradient operator=(MKLMulticlassGradient & gl); 00052 00059 virtual void setup(const int32_t numkernels2); 00060 00069 virtual void addconstraint(const ::std::vector<float64_t> & normw2, 00070 const float64_t sumofpositivealphas); 00071 00077 virtual void computeweights(std::vector<float64_t> & weights2); 00078 00080 virtual const char* get_name() const { return "MKLMulticlassGradient"; } 00081 00085 virtual void set_mkl_norm(float64_t norm); 00086 00087 protected: 00094 void linesearch2(std::vector<float64_t> & finalbeta,const std::vector<float64_t> & oldweights); 00095 00102 void genbetas( ::std::vector<float64_t> & weights ,const ::std::vector<float64_t> & gammas); 00103 00111 void gengammagradient( ::std::vector<float64_t> & gammagradient ,const ::std::vector<float64_t> & gammas,const int32_t dim); 00112 00119 float64_t objectives(const ::std::vector<float64_t> & weights, const int32_t index); 00120 00127 void linesearch(std::vector<float64_t> & finalbeta,const std::vector<float64_t> & oldweights); 00128 00129 protected: 00131 int32_t numkernels; 00132 00133 00135 ::std::vector< ::std::vector<float64_t> > normsofsubkernels; 00137 ::std::vector< float64_t > sumsofalphas ; 00139 float64_t pnorm; 00140 }; 00141 } 00142 00143 #endif