SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Gunnar Raetsch 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/common.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/kernel/string/SalzbergWordStringKernel.h> 00014 #include <shogun/features/Features.h> 00015 #include <shogun/features/StringFeatures.h> 00016 #include <shogun/labels/Labels.h> 00017 #include <shogun/labels/BinaryLabels.h> 00018 #include <shogun/classifier/PluginEstimate.h> 00019 00020 using namespace shogun; 00021 00022 CSalzbergWordStringKernel::CSalzbergWordStringKernel() 00023 : CStringKernel<uint16_t>(0) 00024 { 00025 init(); 00026 } 00027 00028 CSalzbergWordStringKernel::CSalzbergWordStringKernel(int32_t size, CPluginEstimate* pie, CLabels* labels) 00029 : CStringKernel<uint16_t>(size) 00030 { 00031 init(); 00032 estimate=pie; 00033 00034 if (labels) 00035 set_prior_probs_from_labels(labels); 00036 } 00037 00038 CSalzbergWordStringKernel::CSalzbergWordStringKernel( 00039 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, 00040 CPluginEstimate* pie, CLabels* labels) 00041 : CStringKernel<uint16_t>(10),estimate(pie) 00042 { 00043 init(); 00044 estimate=pie; 00045 00046 if (labels) 00047 set_prior_probs_from_labels(labels); 00048 00049 init(l, r); 00050 } 00051 00052 CSalzbergWordStringKernel::~CSalzbergWordStringKernel() 00053 { 00054 cleanup(); 00055 } 00056 00057 bool CSalzbergWordStringKernel::init(CFeatures* p_l, CFeatures* p_r) 00058 { 00059 CStringKernel<uint16_t>::init(p_l,p_r); 00060 CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l; 00061 ASSERT(l) 00062 CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r; 00063 ASSERT(r) 00064 00065 int32_t i; 00066 initialized=false; 00067 00068 if (sqrtdiag_lhs!=sqrtdiag_rhs) 00069 SG_FREE(sqrtdiag_rhs); 00070 sqrtdiag_rhs=NULL; 00071 SG_FREE(sqrtdiag_lhs); 00072 sqrtdiag_lhs=NULL; 00073 if (ld_mean_lhs!=ld_mean_rhs) 00074 SG_FREE(ld_mean_rhs); 00075 ld_mean_rhs=NULL; 00076 SG_FREE(ld_mean_lhs); 00077 ld_mean_lhs=NULL; 00078 00079 sqrtdiag_lhs=SG_MALLOC(float64_t, l->get_num_vectors()); 00080 ld_mean_lhs=SG_MALLOC(float64_t, l->get_num_vectors()); 00081 00082 for (i=0; i<l->get_num_vectors(); i++) 00083 sqrtdiag_lhs[i]=1; 00084 00085 if (l==r) 00086 { 00087 sqrtdiag_rhs=sqrtdiag_lhs; 00088 ld_mean_rhs=ld_mean_lhs; 00089 } 00090 else 00091 { 00092 sqrtdiag_rhs=SG_MALLOC(float64_t, r->get_num_vectors()); 00093 for (i=0; i<r->get_num_vectors(); i++) 00094 sqrtdiag_rhs[i]=1; 00095 00096 ld_mean_rhs=SG_MALLOC(float64_t, r->get_num_vectors()); 00097 } 00098 00099 float64_t* l_ld_mean_lhs=ld_mean_lhs; 00100 float64_t* l_ld_mean_rhs=ld_mean_rhs; 00101 00102 //from our knowledge first normalize variance to 1 and then norm=1 does the job 00103 if (!initialized) 00104 { 00105 int32_t num_vectors=l->get_num_vectors(); 00106 num_symbols=(int32_t) l->get_num_symbols(); 00107 int32_t llen=l->get_vector_length(0); 00108 int32_t rlen=r->get_vector_length(0); 00109 num_params=(int32_t) llen*l->get_num_symbols(); 00110 int32_t num_params2=(int32_t) llen*l->get_num_symbols()+rlen*r->get_num_symbols(); 00111 if ((!estimate) || (!estimate->check_models())) 00112 { 00113 SG_ERROR("no estimate available\n") 00114 return false ; 00115 } ; 00116 if (num_params2!=estimate->get_num_params()) 00117 { 00118 SG_ERROR("number of parameters of estimate and feature representation do not match\n") 00119 return false ; 00120 } ; 00121 00122 SG_FREE(variance); 00123 SG_FREE(mean); 00124 mean=SG_MALLOC(float64_t, num_params); 00125 ASSERT(mean) 00126 variance=SG_MALLOC(float64_t, num_params); 00127 ASSERT(variance) 00128 00129 for (i=0; i<num_params; i++) 00130 { 00131 mean[i]=0; 00132 variance[i]=0; 00133 } 00134 00135 00136 // compute mean 00137 for (i=0; i<num_vectors; i++) 00138 { 00139 int32_t len; 00140 bool free_vec; 00141 uint16_t* vec=l->get_feature_vector(i, len, free_vec); 00142 00143 for (int32_t j=0; j<len; j++) 00144 { 00145 int32_t idx=compute_index(j, vec[j]); 00146 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(vec[j], j) ; 00147 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(vec[j], j) ; 00148 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ; 00149 00150 mean[idx] += value/num_vectors ; 00151 } 00152 l->free_feature_vector(vec, i, free_vec); 00153 } 00154 00155 // compute variance 00156 for (i=0; i<num_vectors; i++) 00157 { 00158 int32_t len; 00159 bool free_vec; 00160 uint16_t* vec=l->get_feature_vector(i, len, free_vec); 00161 00162 for (int32_t j=0; j<len; j++) 00163 { 00164 for (int32_t k=0; k<4; k++) 00165 { 00166 int32_t idx=compute_index(j, k); 00167 if (k!=vec[j]) 00168 variance[idx]+=mean[idx]*mean[idx]/num_vectors; 00169 else 00170 { 00171 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(vec[j], j) ; 00172 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(vec[j], j) ; 00173 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ; 00174 00175 variance[idx] += CMath::sq(value-mean[idx])/num_vectors; 00176 } 00177 } 00178 } 00179 l->free_feature_vector(vec, i, free_vec); 00180 } 00181 00182 00183 // compute sum_i m_i^2/s_i^2 00184 sum_m2_s2=0 ; 00185 for (i=0; i<num_params; i++) 00186 { 00187 if (variance[i]<1e-14) // then it is likely to be numerical inaccuracy 00188 variance[i]=1 ; 00189 00190 //fprintf(stderr, "%i: mean=%1.2e std=%1.2e\n", i, mean[i], std[i]) ; 00191 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ; 00192 } ; 00193 } 00194 00195 // compute sum of 00196 //result -= feature*mean[a_idx]/variance[a_idx] ; 00197 00198 for (i=0; i<l->get_num_vectors(); i++) 00199 { 00200 int32_t alen ; 00201 bool free_avec; 00202 uint16_t* avec=l->get_feature_vector(i, alen, free_avec); 00203 float64_t result=0 ; 00204 for (int32_t j=0; j<alen; j++) 00205 { 00206 int32_t a_idx = compute_index(j, avec[j]) ; 00207 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(avec[j], j) ; 00208 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(avec[j], j) ; 00209 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ; 00210 00211 if (variance[a_idx]!=0) 00212 result-=value*mean[a_idx]/variance[a_idx]; 00213 } 00214 ld_mean_lhs[i]=result ; 00215 00216 l->free_feature_vector(avec, i, free_avec); 00217 } 00218 00219 if (ld_mean_lhs!=ld_mean_rhs) 00220 { 00221 // compute sum of 00222 //result -= feature*mean[b_idx]/variance[b_idx] ; 00223 for (i=0; i<r->get_num_vectors(); i++) 00224 { 00225 int32_t alen; 00226 bool free_avec; 00227 uint16_t* avec=r->get_feature_vector(i, alen, free_avec); 00228 float64_t result=0; 00229 00230 for (int32_t j=0; j<alen; j++) 00231 { 00232 int32_t a_idx = compute_index(j, avec[j]) ; 00233 float64_t theta_p=1/estimate->log_derivative_pos_obsolete( 00234 avec[j], j) ; 00235 float64_t theta_n=1/estimate->log_derivative_neg_obsolete( 00236 avec[j], j) ; 00237 float64_t value=(theta_p/(pos_prior*theta_p+neg_prior*theta_n)); 00238 00239 result -= value*mean[a_idx]/variance[a_idx] ; 00240 } 00241 00242 ld_mean_rhs[i]=result; 00243 r->free_feature_vector(avec, i, free_avec); 00244 } 00245 } 00246 00247 //warning hacky 00248 // 00249 this->lhs=l; 00250 this->rhs=l; 00251 ld_mean_lhs = l_ld_mean_lhs ; 00252 ld_mean_rhs = l_ld_mean_lhs ; 00253 00254 //compute normalize to 1 values 00255 for (i=0; i<lhs->get_num_vectors(); i++) 00256 { 00257 sqrtdiag_lhs[i]=sqrt(compute(i,i)); 00258 00259 //trap divide by zero exception 00260 if (sqrtdiag_lhs[i]==0) 00261 sqrtdiag_lhs[i]=1e-16; 00262 } 00263 00264 // if lhs is different from rhs (train/test data) 00265 // compute also the normalization for rhs 00266 if (sqrtdiag_lhs!=sqrtdiag_rhs) 00267 { 00268 this->lhs=r; 00269 this->rhs=r; 00270 ld_mean_lhs = l_ld_mean_rhs ; 00271 ld_mean_rhs = l_ld_mean_rhs ; 00272 00273 //compute normalize to 1 values 00274 for (i=0; i<rhs->get_num_vectors(); i++) 00275 { 00276 sqrtdiag_rhs[i]=sqrt(compute(i,i)); 00277 00278 //trap divide by zero exception 00279 if (sqrtdiag_rhs[i]==0) 00280 sqrtdiag_rhs[i]=1e-16; 00281 } 00282 } 00283 00284 this->lhs=l; 00285 this->rhs=r; 00286 ld_mean_lhs = l_ld_mean_lhs ; 00287 ld_mean_rhs = l_ld_mean_rhs ; 00288 00289 initialized = true ; 00290 return init_normalizer(); 00291 } 00292 00293 void CSalzbergWordStringKernel::cleanup() 00294 { 00295 SG_FREE(variance); 00296 variance=NULL; 00297 00298 SG_FREE(mean); 00299 mean=NULL; 00300 00301 if (sqrtdiag_lhs != sqrtdiag_rhs) 00302 SG_FREE(sqrtdiag_rhs); 00303 sqrtdiag_rhs=NULL; 00304 00305 SG_FREE(sqrtdiag_lhs); 00306 sqrtdiag_lhs=NULL; 00307 00308 if (ld_mean_lhs!=ld_mean_rhs) 00309 SG_FREE(ld_mean_rhs); 00310 ld_mean_rhs=NULL; 00311 00312 SG_FREE(ld_mean_lhs); 00313 ld_mean_lhs=NULL; 00314 00315 CKernel::cleanup(); 00316 } 00317 00318 float64_t CSalzbergWordStringKernel::compute(int32_t idx_a, int32_t idx_b) 00319 { 00320 int32_t alen, blen; 00321 bool free_avec, free_bvec; 00322 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec); 00323 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec); 00324 // can only deal with strings of same length 00325 ASSERT(alen==blen) 00326 00327 float64_t result = sum_m2_s2 ; // does not contain 0-th element 00328 00329 for (int32_t i=0; i<alen; i++) 00330 { 00331 if (avec[i]==bvec[i]) 00332 { 00333 int32_t a_idx = compute_index(i, avec[i]) ; 00334 00335 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(avec[i], i) ; 00336 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(avec[i], i) ; 00337 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ; 00338 00339 result += value*value/variance[a_idx] ; 00340 } 00341 } 00342 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ; 00343 00344 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec); 00345 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec); 00346 00347 if (initialized) 00348 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ; 00349 00350 return result; 00351 } 00352 00353 void CSalzbergWordStringKernel::set_prior_probs_from_labels(CLabels* labels) 00354 { 00355 ASSERT(labels) 00356 ASSERT(labels->get_label_type() == LT_BINARY) 00357 labels->ensure_valid(); 00358 00359 int32_t num_pos=0, num_neg=0; 00360 for (int32_t i=0; i<labels->get_num_labels(); i++) 00361 { 00362 if (((CBinaryLabels*) labels)->get_int_label(i)==1) 00363 num_pos++; 00364 if (((CBinaryLabels*) labels)->get_int_label(i)==-1) 00365 num_neg++; 00366 } 00367 00368 SG_INFO("priors: pos=%1.3f (%i) neg=%1.3f (%i)\n", 00369 (float64_t) num_pos/(num_pos+num_neg), num_pos, 00370 (float64_t) num_neg/(num_pos+num_neg), num_neg); 00371 00372 set_prior_probs( 00373 (float64_t)num_pos/(num_pos+num_neg), 00374 (float64_t)num_neg/(num_pos+num_neg)); 00375 } 00376 00377 void CSalzbergWordStringKernel::init() 00378 { 00379 estimate=NULL; 00380 mean=NULL; 00381 variance=NULL; 00382 00383 sqrtdiag_lhs=NULL; 00384 sqrtdiag_rhs=NULL; 00385 00386 ld_mean_lhs=NULL; 00387 ld_mean_rhs=NULL; 00388 00389 num_params=0; 00390 num_symbols=0; 00391 sum_m2_s2=0; 00392 pos_prior=0.5; 00393 00394 neg_prior=0.5; 00395 initialized=false; 00396 }