SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 00012 #include <stdio.h> 00013 00014 #include <shogun/lib/config.h> 00015 #include <shogun/io/SGIO.h> 00016 #include <shogun/structure/Plif.h> 00017 #include <shogun/lib/memory.h> 00018 00019 //#define PLIF_DEBUG 00020 00021 using namespace shogun; 00022 00023 CPlif::CPlif(int32_t l) 00024 : CPlifBase() 00025 { 00026 limits=SGVector<float64_t>(); 00027 penalties=SGVector<float64_t>(); 00028 cum_derivatives=SGVector<float64_t>(); 00029 id=-1; 00030 transform=T_LINEAR; 00031 name=NULL; 00032 max_value=0; 00033 min_value=0; 00034 cache=NULL; 00035 use_svm=0; 00036 use_cache=false; 00037 len=0; 00038 do_calc = true; 00039 if (l>0) 00040 set_plif_length(l); 00041 } 00042 00043 CPlif::~CPlif() 00044 { 00045 SG_FREE(name); 00046 SG_FREE(cache); 00047 } 00048 00049 bool CPlif::set_transform_type(const char *type_str) 00050 { 00051 invalidate_cache(); 00052 00053 if (strcmp(type_str, "linear")==0) 00054 transform = T_LINEAR ; 00055 else if (strcmp(type_str, "")==0) 00056 transform = T_LINEAR ; 00057 else if (strcmp(type_str, "log")==0) 00058 transform = T_LOG ; 00059 else if (strcmp(type_str, "log(+1)")==0) 00060 transform = T_LOG_PLUS1 ; 00061 else if (strcmp(type_str, "log(+3)")==0) 00062 transform = T_LOG_PLUS3 ; 00063 else if (strcmp(type_str, "(+3)")==0) 00064 transform = T_LINEAR_PLUS3 ; 00065 else 00066 { 00067 SG_ERROR("unknown transform type (%s)\n", type_str) 00068 return false ; 00069 } 00070 return true ; 00071 } 00072 00073 void CPlif::init_penalty_struct_cache() 00074 { 00075 if (!use_cache) 00076 return ; 00077 if (cache || use_svm) 00078 return ; 00079 if (max_value<=0) 00080 return ; 00081 00082 float64_t* local_cache=SG_MALLOC(float64_t, ((int32_t) max_value) + 2); 00083 00084 if (local_cache) 00085 { 00086 for (int32_t i=0; i<=max_value; i++) 00087 { 00088 if (i<min_value) 00089 local_cache[i] = -CMath::INFTY ; 00090 else 00091 local_cache[i] = lookup_penalty(i, NULL) ; 00092 } 00093 } 00094 this->cache=local_cache ; 00095 } 00096 00097 void CPlif::set_plif_name(char *p_name) 00098 { 00099 SG_FREE(name); 00100 name=get_strdup(p_name); 00101 } 00102 00103 char* CPlif::get_plif_name() const 00104 { 00105 if (name) 00106 return name; 00107 else 00108 { 00109 char buf[20]; 00110 sprintf(buf, "plif%i", id); 00111 return get_strdup(buf); 00112 } 00113 } 00114 00115 void CPlif::delete_penalty_struct(CPlif** PEN, int32_t P) 00116 { 00117 for (int32_t i=0; i<P; i++) 00118 delete PEN[i] ; 00119 SG_FREE(PEN); 00120 } 00121 00122 float64_t CPlif::lookup_penalty_svm( 00123 float64_t p_value, float64_t *d_values) const 00124 { 00125 ASSERT(use_svm>0) 00126 float64_t d_value=d_values[use_svm-1] ; 00127 #ifdef PLIF_DEBUG 00128 SG_PRINT("%s.lookup_penalty_svm(%f)\n", get_name(), d_value) 00129 #endif 00130 00131 if (!do_calc) 00132 return d_value; 00133 switch (transform) 00134 { 00135 case T_LINEAR: 00136 break ; 00137 case T_LOG: 00138 d_value = log(d_value) ; 00139 break ; 00140 case T_LOG_PLUS1: 00141 d_value = log(d_value+1) ; 00142 break ; 00143 case T_LOG_PLUS3: 00144 d_value = log(d_value+3) ; 00145 break ; 00146 case T_LINEAR_PLUS3: 00147 d_value = d_value+3 ; 00148 break ; 00149 default: 00150 SG_ERROR("unknown transform\n") 00151 break ; 00152 } 00153 00154 int32_t idx = 0 ; 00155 float64_t ret ; 00156 for (int32_t i=0; i<len; i++) 00157 if (limits[i]<=d_value) 00158 idx++ ; 00159 else 00160 break ; // assume it is monotonically increasing 00161 00162 #ifdef PLIF_DEBUG 00163 SG_PRINT(" -> idx = %i ", idx) 00164 #endif 00165 00166 if (idx==0) 00167 ret=penalties[0] ; 00168 else if (idx==len) 00169 ret=penalties[len-1] ; 00170 else 00171 { 00172 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]* 00173 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ; 00174 #ifdef PLIF_DEBUG 00175 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f)", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) 00176 #endif 00177 } 00178 #ifdef PLIF_DEBUG 00179 SG_PRINT(" -> ret=%1.3f\n", ret) 00180 #endif 00181 00182 return ret ; 00183 } 00184 00185 float64_t CPlif::lookup_penalty(int32_t p_value, float64_t* svm_values) const 00186 { 00187 if (use_svm) 00188 return lookup_penalty_svm(p_value, svm_values) ; 00189 00190 if ((p_value<min_value) || (p_value>max_value)) 00191 { 00192 //SG_PRINT("Feature:%s, %s.lookup_penalty(%i): return -inf min_value: %f, max_value: %f\n", name, get_name(), p_value, min_value, max_value) 00193 return -CMath::INFTY ; 00194 } 00195 if (!do_calc) 00196 return p_value; 00197 if (cache!=NULL && (p_value>=0) && (p_value<=max_value)) 00198 { 00199 float64_t ret=cache[p_value] ; 00200 return ret ; 00201 } 00202 return lookup_penalty((float64_t) p_value, svm_values) ; 00203 } 00204 00205 float64_t CPlif::lookup_penalty(float64_t p_value, float64_t* svm_values) const 00206 { 00207 if (use_svm) 00208 return lookup_penalty_svm(p_value, svm_values) ; 00209 00210 #ifdef PLIF_DEBUG 00211 SG_PRINT("%s.lookup_penalty(%f)\n", get_name(), p_value) 00212 #endif 00213 00214 00215 if ((p_value<min_value) || (p_value>max_value)) 00216 { 00217 //SG_PRINT("Feature:%s, %s.lookup_penalty(%f): return -inf min_value: %f, max_value: %f\n", name, get_name(), p_value, min_value, max_value) 00218 return -CMath::INFTY ; 00219 } 00220 00221 if (!do_calc) 00222 return p_value; 00223 00224 float64_t d_value = (float64_t) p_value ; 00225 switch (transform) 00226 { 00227 case T_LINEAR: 00228 break ; 00229 case T_LOG: 00230 d_value = log(d_value) ; 00231 break ; 00232 case T_LOG_PLUS1: 00233 d_value = log(d_value+1) ; 00234 break ; 00235 case T_LOG_PLUS3: 00236 d_value = log(d_value+3) ; 00237 break ; 00238 case T_LINEAR_PLUS3: 00239 d_value = d_value+3 ; 00240 break ; 00241 default: 00242 SG_ERROR("unknown transform\n") 00243 break ; 00244 } 00245 00246 #ifdef PLIF_DEBUG 00247 SG_PRINT(" -> value = %1.4f ", d_value) 00248 #endif 00249 00250 int32_t idx = 0 ; 00251 float64_t ret ; 00252 for (int32_t i=0; i<len; i++) 00253 if (limits[i]<=d_value) 00254 idx++ ; 00255 else 00256 break ; // assume it is monotonically increasing 00257 00258 #ifdef PLIF_DEBUG 00259 SG_PRINT(" -> idx = %i ", idx) 00260 #endif 00261 00262 if (idx==0) 00263 ret=penalties[0] ; 00264 else if (idx==len) 00265 ret=penalties[len-1] ; 00266 else 00267 { 00268 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]* 00269 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ; 00270 #ifdef PLIF_DEBUG 00271 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f) ", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) 00272 #endif 00273 } 00274 //if (p_value>=30 && p_value<150) 00275 //SG_PRINT("%s %i(%i) -> %1.2f\n", PEN->name, p_value, idx, ret) 00276 #ifdef PLIF_DEBUG 00277 SG_PRINT(" -> ret=%1.3f\n", ret) 00278 #endif 00279 00280 return ret ; 00281 } 00282 00283 void CPlif::penalty_clear_derivative() 00284 { 00285 for (int32_t i=0; i<len; i++) 00286 cum_derivatives[i]=0.0 ; 00287 } 00288 00289 void CPlif::penalty_add_derivative(float64_t p_value, float64_t* svm_values, float64_t factor) 00290 { 00291 if (use_svm) 00292 { 00293 penalty_add_derivative_svm(p_value, svm_values, factor) ; 00294 return ; 00295 } 00296 00297 if ((p_value<min_value) || (p_value>max_value)) 00298 { 00299 return ; 00300 } 00301 float64_t d_value = (float64_t) p_value ; 00302 switch (transform) 00303 { 00304 case T_LINEAR: 00305 break ; 00306 case T_LOG: 00307 d_value = log(d_value) ; 00308 break ; 00309 case T_LOG_PLUS1: 00310 d_value = log(d_value+1) ; 00311 break ; 00312 case T_LOG_PLUS3: 00313 d_value = log(d_value+3) ; 00314 break ; 00315 case T_LINEAR_PLUS3: 00316 d_value = d_value+3 ; 00317 break ; 00318 default: 00319 SG_ERROR("unknown transform\n") 00320 break ; 00321 } 00322 00323 int32_t idx = 0 ; 00324 for (int32_t i=0; i<len; i++) 00325 if (limits[i]<=d_value) 00326 idx++ ; 00327 else 00328 break ; // assume it is monotonically increasing 00329 00330 if (idx==0) 00331 cum_derivatives[0]+= factor ; 00332 else if (idx==len) 00333 cum_derivatives[len-1]+= factor ; 00334 else 00335 { 00336 cum_derivatives[idx] += factor * (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ; 00337 cum_derivatives[idx-1]+= factor*(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ; 00338 } 00339 } 00340 00341 void CPlif::penalty_add_derivative_svm(float64_t p_value, float64_t *d_values, float64_t factor) 00342 { 00343 ASSERT(use_svm>0) 00344 float64_t d_value=d_values[use_svm-1] ; 00345 00346 if (d_value<-1e+20) 00347 return; 00348 00349 switch (transform) 00350 { 00351 case T_LINEAR: 00352 break ; 00353 case T_LOG: 00354 d_value = log(d_value) ; 00355 break ; 00356 case T_LOG_PLUS1: 00357 d_value = log(d_value+1) ; 00358 break ; 00359 case T_LOG_PLUS3: 00360 d_value = log(d_value+3) ; 00361 break ; 00362 case T_LINEAR_PLUS3: 00363 d_value = d_value+3 ; 00364 break ; 00365 default: 00366 SG_ERROR("unknown transform\n") 00367 break ; 00368 } 00369 00370 int32_t idx = 0 ; 00371 for (int32_t i=0; i<len; i++) 00372 if (limits[i]<=d_value) 00373 idx++ ; 00374 else 00375 break ; // assume it is monotonically increasing 00376 00377 if (idx==0) 00378 cum_derivatives[0]+=factor ; 00379 else if (idx==len) 00380 cum_derivatives[len-1]+=factor ; 00381 else 00382 { 00383 cum_derivatives[idx] += factor*(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ; 00384 cum_derivatives[idx-1] += factor*(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ; 00385 } 00386 } 00387 00388 void CPlif::get_used_svms(int32_t* num_svms, int32_t* svm_ids) 00389 { 00390 if (use_svm) 00391 { 00392 svm_ids[(*num_svms)] = use_svm; 00393 (*num_svms)++; 00394 } 00395 SG_PRINT("->use_svm:%i plif_id:%i name:%s trans_type:%s ",use_svm, get_id(), get_name(), get_transform_type()) 00396 } 00397 00398 bool CPlif::get_do_calc() 00399 { 00400 return do_calc; 00401 } 00402 00403 void CPlif::set_do_calc(bool b) 00404 { 00405 do_calc = b;; 00406 }