SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef __PLIF_H__ 00012 #define __PLIF_H__ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/lib/SGVector.h> 00016 #include <shogun/mathematics/Math.h> 00017 #include <shogun/structure/PlifBase.h> 00018 00019 namespace shogun 00020 { 00021 00023 enum ETransformType 00024 { 00026 T_LINEAR, 00028 T_LOG, 00030 T_LOG_PLUS1, 00032 T_LOG_PLUS3, 00034 T_LINEAR_PLUS3 00035 }; 00036 00038 class CPlif: public CPlifBase 00039 { 00040 public: 00045 CPlif(int32_t len=0); 00046 virtual ~CPlif(); 00047 00049 void init_penalty_struct_cache(); 00050 00057 float64_t lookup_penalty_svm( 00058 float64_t p_value, float64_t *d_values) const; 00059 00066 float64_t lookup_penalty( 00067 float64_t p_value, float64_t* svm_values) const; 00068 00075 float64_t lookup_penalty(int32_t p_value, float64_t* svm_values) const; 00076 00082 inline float64_t lookup(float64_t p_value) 00083 { 00084 ASSERT(use_svm == 0) 00085 return lookup_penalty(p_value, NULL); 00086 } 00087 00089 void penalty_clear_derivative(); 00090 00097 void penalty_add_derivative_svm( 00098 float64_t p_value, float64_t* svm_values, float64_t factor) ; 00099 00106 void penalty_add_derivative(float64_t p_value, float64_t* svm_values, float64_t factor); 00107 00113 const float64_t * get_cum_derivative(int32_t & p_len) const 00114 { 00115 p_len = len; 00116 return cum_derivatives.vector; 00117 } 00118 00124 bool set_transform_type(const char *type_str); 00125 00130 const char* get_transform_type() 00131 { 00132 if (transform== T_LINEAR) 00133 return "linear"; 00134 else if (transform== T_LOG) 00135 return "log"; 00136 else if (transform== T_LOG_PLUS1) 00137 return "log(+1)"; 00138 else if (transform== T_LOG_PLUS3) 00139 return "log(+3)"; 00140 else if (transform== T_LINEAR_PLUS3) 00141 return "(+3)"; 00142 else 00143 SG_ERROR("wrong type") 00144 return ""; 00145 } 00146 00147 00152 void set_id(int32_t p_id) 00153 { 00154 id=p_id; 00155 } 00156 00161 int32_t get_id() const 00162 { 00163 return id; 00164 } 00165 00170 int32_t get_max_id() const 00171 { 00172 return get_id(); 00173 } 00174 00179 void set_use_svm(int32_t p_use_svm) 00180 { 00181 invalidate_cache(); 00182 use_svm=p_use_svm; 00183 } 00184 00189 int32_t get_use_svm() const 00190 { 00191 return use_svm; 00192 } 00193 00198 virtual bool uses_svm_values() const 00199 { 00200 return (get_use_svm()!=0); 00201 } 00202 00207 void set_use_cache(int32_t p_use_cache) 00208 { 00209 invalidate_cache(); 00210 use_cache=p_use_cache; 00211 } 00212 00215 void invalidate_cache() 00216 { 00217 SG_FREE(cache); 00218 cache=NULL; 00219 } 00220 00225 int32_t get_use_cache() 00226 { 00227 return use_cache; 00228 } 00229 00236 void set_plif( 00237 int32_t p_len, float64_t *p_limits, float64_t* p_penalties) 00238 { 00239 ASSERT(len==p_len) 00240 00241 for (int32_t i=0; i<len; i++) 00242 { 00243 limits[i]=p_limits[i]; 00244 penalties[i]=p_penalties[i]; 00245 } 00246 00247 invalidate_cache(); 00248 penalty_clear_derivative(); 00249 } 00250 00255 void set_plif_limits(SGVector<float64_t> p_limits) 00256 { 00257 ASSERT(len==p_limits.vlen) 00258 00259 limits = p_limits; 00260 00261 invalidate_cache(); 00262 penalty_clear_derivative(); 00263 } 00264 00265 00270 void set_plif_penalty(SGVector<float64_t> p_penalties) 00271 { 00272 ASSERT(len==p_penalties.vlen) 00273 00274 penalties = p_penalties; 00275 00276 invalidate_cache(); 00277 penalty_clear_derivative(); 00278 } 00279 00284 void set_plif_length(int32_t p_len) 00285 { 00286 if (len!=p_len) 00287 { 00288 len=p_len; 00289 00290 SG_DEBUG("set_plif len=%i\n", p_len) 00291 limits = SGVector<float64_t>(len); 00292 penalties = SGVector<float64_t>(len); 00293 cum_derivatives = SGVector<float64_t>(len); 00294 } 00295 00296 for (int32_t i=0; i<len; i++) 00297 { 00298 limits[i]=0.0; 00299 penalties[i]=0.0; 00300 cum_derivatives[i]=0.0; 00301 } 00302 00303 invalidate_cache(); 00304 penalty_clear_derivative(); 00305 } 00306 00311 SGVector<float64_t> get_plif_limits() 00312 { 00313 return limits; 00314 } 00315 00320 SGVector<float64_t> get_plif_penalties() 00321 { 00322 return penalties; 00323 } 00324 00329 inline void set_max_value(float64_t p_max_value) 00330 { 00331 max_value=p_max_value; 00332 invalidate_cache(); 00333 } 00334 00339 virtual float64_t get_max_value() const 00340 { 00341 return max_value; 00342 } 00343 00348 inline void set_min_value(float64_t p_min_value) 00349 { 00350 min_value=p_min_value; 00351 invalidate_cache(); 00352 } 00353 00358 virtual float64_t get_min_value() const 00359 { 00360 return min_value; 00361 } 00362 00367 void set_plif_name(char *p_name); 00368 00373 char* get_plif_name() const; 00374 00379 bool get_do_calc(); 00380 00385 void set_do_calc(bool b); 00386 00390 void get_used_svms(int32_t* num_svms, int32_t* svm_ids); 00391 00396 inline int32_t get_plif_len() 00397 { 00398 return len; 00399 } 00400 00405 virtual void list_plif() const 00406 { 00407 SG_PRINT("CPlif(min_value=%1.2f, max_value=%1.2f, use_svm=%i)\n", min_value, max_value, use_svm) 00408 } 00409 00415 static void delete_penalty_struct(CPlif** PEN, int32_t P); 00416 00418 virtual const char* get_name() const { return "Plif"; } 00419 00420 protected: 00422 int32_t len; 00424 SGVector<float64_t> limits; 00426 SGVector<float64_t> penalties; 00428 SGVector<float64_t> cum_derivatives; 00430 float64_t max_value; 00432 float64_t min_value; 00434 float64_t *cache; 00436 enum ETransformType transform; 00438 int32_t id; 00440 char * name; 00442 int32_t use_svm; 00444 bool use_cache; 00448 bool do_calc; 00449 }; 00450 } 00451 #endif