SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___ 00013 #define _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___ 00014 00015 #include <shogun/lib/common.h> 00016 #include <shogun/kernel/string/StringKernel.h> 00017 #include <shogun/kernel/string/WeightedDegreeStringKernel.h> 00018 #include <shogun/lib/Trie.h> 00019 00020 namespace shogun 00021 { 00022 00023 class CSVM; 00024 00048 class CWeightedDegreePositionStringKernel: public CStringKernel<char> 00049 { 00050 public: 00052 CWeightedDegreePositionStringKernel(); 00053 00061 CWeightedDegreePositionStringKernel( 00062 int32_t size, int32_t degree, 00063 int32_t max_mismatch=0, int32_t mkl_stepsize=1); 00064 00074 CWeightedDegreePositionStringKernel( 00075 int32_t size, SGVector<float64_t> weights, int32_t degree, 00076 int32_t max_mismatch, SGVector<int32_t> shifts, 00077 int32_t mkl_stepsize=1); 00078 00085 CWeightedDegreePositionStringKernel( 00086 CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t degree); 00087 00088 virtual ~CWeightedDegreePositionStringKernel(); 00089 00096 virtual bool init(CFeatures* l, CFeatures* r); 00097 00099 virtual void cleanup(); 00100 00105 virtual EKernelType get_kernel_type() { return K_WEIGHTEDDEGREEPOS; } 00106 00111 virtual const char* get_name() const { return "WeightedDegreePositionStringKernel"; } 00112 00120 virtual bool init_optimization( 00121 int32_t p_count, int32_t *IDX, float64_t * alphas) 00122 { 00123 return init_optimization(p_count, IDX, alphas, -1); 00124 } 00125 00137 virtual bool init_optimization( 00138 int32_t count, int32_t *IDX, float64_t * alphas, int32_t tree_num, 00139 int32_t upto_tree=-1); 00140 00145 virtual bool delete_optimization(); 00146 00152 virtual float64_t compute_optimized(int32_t idx) 00153 { 00154 ASSERT(get_is_initialized()) 00155 ASSERT(alphabet) 00156 ASSERT(alphabet->get_alphabet()==DNA || alphabet->get_alphabet()==RNA) 00157 return compute_by_tree(idx); 00158 } 00159 00164 static void* compute_batch_helper(void* p); 00165 00176 virtual void compute_batch( 00177 int32_t num_vec, int32_t* vec_idx, float64_t* target, 00178 int32_t num_suppvec, int32_t* IDX, float64_t* alphas, 00179 float64_t factor=1.0); 00180 00184 virtual void clear_normal() 00185 { 00186 if ((opt_type==FASTBUTMEMHUNGRY) && (tries.get_use_compact_terminal_nodes())) 00187 { 00188 tries.set_use_compact_terminal_nodes(false) ; 00189 SG_DEBUG("disabling compact trie nodes with FASTBUTMEMHUNGRY\n") 00190 } 00191 00192 if (get_is_initialized()) 00193 { 00194 if (opt_type==SLOWBUTMEMEFFICIENT) 00195 tries.delete_trees(true); 00196 else if (opt_type==FASTBUTMEMHUNGRY) 00197 tries.delete_trees(false); // still buggy 00198 else 00199 SG_ERROR("unknown optimization type\n") 00200 00201 set_is_initialized(false); 00202 } 00203 } 00204 00210 virtual void add_to_normal(int32_t idx, float64_t weight) 00211 { 00212 add_example_to_tree(idx, weight); 00213 set_is_initialized(true); 00214 } 00215 00220 virtual int32_t get_num_subkernels() 00221 { 00222 if (position_weights!=NULL) 00223 return (int32_t) ceil(1.0*seq_length/mkl_stepsize) ; 00224 if (length==0) 00225 return (int32_t) ceil(1.0*get_degree()/mkl_stepsize); 00226 return (int32_t) ceil(1.0*get_degree()*length/mkl_stepsize) ; 00227 } 00228 00234 inline void compute_by_subkernel( 00235 int32_t idx, float64_t * subkernel_contrib) 00236 { 00237 if (get_is_initialized()) 00238 { 00239 compute_by_tree(idx, subkernel_contrib); 00240 return ; 00241 } 00242 00243 SG_ERROR("CWeightedDegreePositionStringKernel optimization not initialized\n") 00244 } 00245 00251 inline const float64_t* get_subkernel_weights(int32_t& num_weights) 00252 { 00253 num_weights = get_num_subkernels() ; 00254 00255 SG_FREE(weights_buffer); 00256 weights_buffer = SG_MALLOC(float64_t, num_weights); 00257 00258 if (position_weights!=NULL) 00259 for (int32_t i=0; i<num_weights; i++) 00260 weights_buffer[i] = position_weights[i*mkl_stepsize] ; 00261 else 00262 for (int32_t i=0; i<num_weights; i++) 00263 weights_buffer[i] = weights[i*mkl_stepsize] ; 00264 00265 return weights_buffer ; 00266 } 00267 00272 virtual void set_subkernel_weights(SGVector<float64_t> w) 00273 { 00274 float64_t* weights2=w.vector; 00275 int32_t num_weights2=w.vlen; 00276 00277 int32_t num_weights = get_num_subkernels() ; 00278 if (num_weights!=num_weights2) 00279 SG_ERROR("number of weights do not match\n") 00280 00281 if (position_weights!=NULL) 00282 for (int32_t i=0; i<num_weights; i++) 00283 for (int32_t j=0; j<mkl_stepsize; j++) 00284 { 00285 if (i*mkl_stepsize+j<seq_length) 00286 position_weights[i*mkl_stepsize+j] = weights2[i] ; 00287 } 00288 else if (length==0) 00289 { 00290 for (int32_t i=0; i<num_weights; i++) 00291 for (int32_t j=0; j<mkl_stepsize; j++) 00292 if (i*mkl_stepsize+j<get_degree()) 00293 weights[i*mkl_stepsize+j] = weights2[i] ; 00294 } 00295 else 00296 { 00297 for (int32_t i=0; i<num_weights; i++) 00298 for (int32_t j=0; j<mkl_stepsize; j++) 00299 if (i*mkl_stepsize+j<get_degree()*length) 00300 weights[i*mkl_stepsize+j] = weights2[i] ; 00301 } 00302 } 00303 00304 // other kernel tree operations 00310 float64_t* compute_abs_weights(int32_t & len); 00311 00316 bool is_tree_initialized() { return tree_initialized; } 00317 00322 inline int32_t get_max_mismatch() { return max_mismatch; } 00323 00328 inline int32_t get_degree() { return degree; } 00329 00335 inline float64_t *get_degree_weights(int32_t& d, int32_t& len) 00336 { 00337 d=degree; 00338 len=length; 00339 return weights; 00340 } 00341 00347 inline float64_t *get_weights(int32_t& num_weights) 00348 { 00349 if (position_weights!=NULL) 00350 { 00351 num_weights = seq_length ; 00352 return position_weights ; 00353 } 00354 if (length==0) 00355 num_weights = degree ; 00356 else 00357 num_weights = degree*length ; 00358 return weights; 00359 } 00360 00366 inline float64_t *get_position_weights(int32_t& len) 00367 { 00368 len=seq_length; 00369 return position_weights; 00370 } 00371 00376 void set_shifts(SGVector<int32_t> shifts); 00377 00382 bool set_weights(SGMatrix<float64_t> new_weights); 00383 00388 virtual bool set_wd_weights(); 00389 00395 virtual void set_position_weights(SGVector<float64_t> pws); 00396 00404 bool set_position_weights_lhs(float64_t* pws, int32_t len, int32_t num); 00405 00413 bool set_position_weights_rhs(float64_t* pws, int32_t len, int32_t num); 00414 00419 bool init_block_weights(); 00420 00425 bool init_block_weights_from_wd(); 00426 00431 bool init_block_weights_from_wd_external(); 00432 00437 bool init_block_weights_const(); 00438 00443 bool init_block_weights_linear(); 00444 00449 bool init_block_weights_sqpoly(); 00450 00455 bool init_block_weights_cubicpoly(); 00456 00461 bool init_block_weights_exp(); 00462 00467 bool init_block_weights_log(); 00468 00473 bool delete_position_weights() 00474 { 00475 SG_FREE(position_weights); 00476 position_weights=NULL; 00477 return true; 00478 } 00479 00484 bool delete_position_weights_lhs() 00485 { 00486 SG_FREE(position_weights_lhs); 00487 position_weights_lhs=NULL; 00488 return true; 00489 } 00490 00495 bool delete_position_weights_rhs() 00496 { 00497 SG_FREE(position_weights_rhs); 00498 position_weights_rhs=NULL; 00499 return true; 00500 } 00501 00507 virtual float64_t compute_by_tree(int32_t idx); 00508 00514 virtual void compute_by_tree(int32_t idx, float64_t* LevelContrib); 00515 00528 float64_t* compute_scoring( 00529 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00530 float64_t* target, int32_t num_suppvec, int32_t* IDX, 00531 float64_t* weights); 00532 00541 char* compute_consensus( 00542 int32_t &num_feat, int32_t num_suppvec, int32_t* IDX, 00543 float64_t* alphas); 00544 00556 float64_t* extract_w( 00557 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00558 float64_t* w_result, int32_t num_suppvec, int32_t* IDX, 00559 float64_t* alphas); 00560 00573 float64_t* compute_POIM( 00574 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00575 float64_t* poim_result, int32_t num_suppvec, int32_t* IDX, 00576 float64_t* alphas, float64_t* distrib); 00577 00582 void prepare_POIM2(SGMatrix<float64_t> distrib); 00583 00590 void compute_POIM2(int32_t max_degree, CSVM* svm); 00591 00596 SGVector<float64_t> get_POIM2(); 00597 00599 void cleanup_POIM2(); 00600 00601 protected: 00603 void create_empty_tries(); 00604 00610 virtual void add_example_to_tree( 00611 int32_t idx, float64_t weight); 00612 00619 void add_example_to_single_tree( 00620 int32_t idx, float64_t weight, int32_t tree_num); 00621 00630 virtual float64_t compute(int32_t idx_a, int32_t idx_b); 00631 00640 float64_t compute_with_mismatch( 00641 char* avec, int32_t alen, char* bvec, int32_t blen); 00642 00651 float64_t compute_without_mismatch( 00652 char* avec, int32_t alen, char* bvec, int32_t blen); 00653 00662 float64_t compute_without_mismatch_matrix( 00663 char* avec, int32_t alen, char* bvec, int32_t blen); 00664 00675 float64_t compute_without_mismatch_position_weights( 00676 char* avec, float64_t *posweights_lhs, int32_t alen, 00677 char* bvec, float64_t *posweights_rhs, int32_t blen); 00678 00680 virtual void remove_lhs(); 00681 00690 virtual void load_serializable_post() throw (ShogunException); 00691 00692 private: 00695 void init(); 00696 00697 protected: 00699 float64_t* weights; 00701 int32_t weights_degree; 00703 int32_t weights_length; 00704 00706 float64_t* position_weights; 00708 int32_t position_weights_len; 00709 00711 float64_t* position_weights_lhs; 00713 int32_t position_weights_lhs_len; 00715 float64_t* position_weights_rhs; 00717 int32_t position_weights_rhs_len; 00719 bool* position_mask; 00720 00722 float64_t* weights_buffer; 00724 int32_t mkl_stepsize; 00725 00727 int32_t degree; 00729 int32_t length; 00730 00732 int32_t max_mismatch; 00734 int32_t seq_length; 00735 00737 int32_t *shift; 00739 int32_t shift_len; 00741 int32_t max_shift; 00742 00744 bool block_computation; 00745 00747 float64_t* block_weights; 00749 EWDKernType type; 00751 int32_t which_degree; 00752 00754 CTrie<DNATrie> tries; 00756 CTrie<POIMTrie> poim_tries; 00757 00759 bool tree_initialized; 00761 bool use_poim_tries; 00762 00764 float64_t* m_poim_distrib; 00766 float64_t* m_poim; 00767 00769 int32_t m_poim_num_sym; 00771 int32_t m_poim_num_feat; 00773 int32_t m_poim_result_len; 00774 00776 CAlphabet* alphabet; 00777 }; 00778 } 00779 #endif /* _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H__ */