SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2006-2009 Soeren Sonnenburg 00008 * Copyright (C) 2006-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/classifier/svm/SVMLin.h> 00012 #include <shogun/labels/Labels.h> 00013 #include <shogun/mathematics/Math.h> 00014 #include <shogun/lib/external/ssl.h> 00015 #include <shogun/machine/LinearMachine.h> 00016 #include <shogun/features/DotFeatures.h> 00017 #include <shogun/labels/Labels.h> 00018 #include <shogun/labels/BinaryLabels.h> 00019 00020 using namespace shogun; 00021 00022 CSVMLin::CSVMLin() 00023 : CLinearMachine(), C1(1), C2(1), epsilon(1e-5), use_bias(true) 00024 { 00025 } 00026 00027 CSVMLin::CSVMLin( 00028 float64_t C, CDotFeatures* traindat, CLabels* trainlab) 00029 : CLinearMachine(), C1(C), C2(C), epsilon(1e-5), use_bias(true) 00030 { 00031 set_features(traindat); 00032 set_labels(trainlab); 00033 } 00034 00035 00036 CSVMLin::~CSVMLin() 00037 { 00038 } 00039 00040 bool CSVMLin::train_machine(CFeatures* data) 00041 { 00042 ASSERT(m_labels) 00043 00044 if (data) 00045 { 00046 if (!data->has_property(FP_DOT)) 00047 SG_ERROR("Specified features are not of type CDotFeatures\n") 00048 set_features((CDotFeatures*) data); 00049 } 00050 00051 ASSERT(features) 00052 00053 SGVector<float64_t> train_labels=((CBinaryLabels*) m_labels)->get_labels(); 00054 int32_t num_feat=features->get_dim_feature_space(); 00055 int32_t num_vec=features->get_num_vectors(); 00056 00057 ASSERT(num_vec==train_labels.vlen) 00058 00059 struct options Options; 00060 struct data Data; 00061 struct vector_double Weights; 00062 struct vector_double Outputs; 00063 00064 Data.l=num_vec; 00065 Data.m=num_vec; 00066 Data.u=0; 00067 Data.n=num_feat+1; 00068 Data.nz=num_feat+1; 00069 Data.Y=train_labels.vector; 00070 Data.features=features; 00071 Data.C = SG_MALLOC(float64_t, Data.l); 00072 00073 Options.algo = SVM; 00074 Options.lambda=1/(2*get_C1()); 00075 Options.lambda_u=1/(2*get_C1()); 00076 Options.S=10000; 00077 Options.R=0.5; 00078 Options.epsilon = get_epsilon(); 00079 Options.cgitermax=10000; 00080 Options.mfnitermax=50; 00081 Options.Cp = get_C2()/get_C1(); 00082 Options.Cn = 1; 00083 00084 if (use_bias) 00085 Options.bias=1.0; 00086 else 00087 Options.bias=0.0; 00088 00089 for (int32_t i=0;i<num_vec;i++) 00090 { 00091 if(train_labels.vector[i]>0) 00092 Data.C[i]=Options.Cp; 00093 else 00094 Data.C[i]=Options.Cn; 00095 } 00096 ssl_train(&Data, &Options, &Weights, &Outputs); 00097 ASSERT(Weights.vec && Weights.d==num_feat+1) 00098 00099 float64_t sgn=train_labels.vector[0]; 00100 for (int32_t i=0; i<num_feat+1; i++) 00101 Weights.vec[i]*=sgn; 00102 00103 set_w(SGVector<float64_t>(Weights.vec, num_feat)); 00104 set_bias(Weights.vec[num_feat]); 00105 00106 SG_FREE(Data.C); 00107 SG_FREE(Outputs.vec); 00108 return true; 00109 }