SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _DICEKERNELNORMALIZER_H___ 00012 #define _DICEKERNELNORMALIZER_H___ 00013 00014 #include <shogun/kernel/normalizer/KernelNormalizer.h> 00015 #include <shogun/kernel/string/CommWordStringKernel.h> 00016 00017 namespace shogun 00018 { 00026 class CDiceKernelNormalizer : public CKernelNormalizer 00027 { 00028 public: 00033 CDiceKernelNormalizer(bool use_opt_diag=false) : CKernelNormalizer(), 00034 diag_lhs(NULL), num_diag_lhs(0), diag_rhs(NULL), num_diag_rhs(0), 00035 use_optimized_diagonal_computation(use_opt_diag) 00036 { 00037 m_parameters->add_vector(&diag_lhs, &num_diag_lhs, "diag_lhs", 00038 "K(x,x) for left hand side examples."); 00039 m_parameters->add_vector(&diag_rhs, &num_diag_rhs, "diag_rhs", 00040 "K(x,x) for right hand side examples."); 00041 SG_ADD(&use_optimized_diagonal_computation, 00042 "use_optimized_diagonal_computation", 00043 "flat if optimized diagonal computation is used", MS_NOT_AVAILABLE); 00044 } 00045 00047 virtual ~CDiceKernelNormalizer() 00048 { 00049 SG_FREE(diag_lhs); 00050 SG_FREE(diag_rhs); 00051 } 00052 00055 virtual bool init(CKernel* k) 00056 { 00057 ASSERT(k) 00058 num_diag_lhs=k->get_num_vec_lhs(); 00059 num_diag_rhs=k->get_num_vec_rhs(); 00060 ASSERT(num_diag_lhs>0) 00061 ASSERT(num_diag_rhs>0) 00062 00063 CFeatures* old_lhs=k->lhs; 00064 CFeatures* old_rhs=k->rhs; 00065 00066 k->lhs=old_lhs; 00067 k->rhs=old_lhs; 00068 bool r1=alloc_and_compute_diag(k, diag_lhs, num_diag_lhs); 00069 00070 k->lhs=old_rhs; 00071 k->rhs=old_rhs; 00072 bool r2=alloc_and_compute_diag(k, diag_rhs, num_diag_rhs); 00073 00074 k->lhs=old_lhs; 00075 k->rhs=old_rhs; 00076 00077 return r1 && r2; 00078 } 00079 00085 virtual float64_t normalize( 00086 float64_t value, int32_t idx_lhs, int32_t idx_rhs) 00087 { 00088 float64_t diag_sum=diag_lhs[idx_lhs]*diag_rhs[idx_rhs]; 00089 return 2*value/diag_sum; 00090 } 00091 00096 virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs) 00097 { 00098 SG_ERROR("linadd not supported with Dice normalization.\n") 00099 return 0; 00100 } 00101 00106 virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs) 00107 { 00108 SG_ERROR("linadd not supported with Dice normalization.\n") 00109 return 0; 00110 } 00111 00117 virtual const char* get_name() const { 00118 return "DiceKernelNormalizer"; } 00119 00120 public: 00125 bool alloc_and_compute_diag(CKernel* k, float64_t* &v, int32_t num) 00126 { 00127 SG_FREE(v); 00128 v=SG_MALLOC(float64_t, num); 00129 00130 for (int32_t i=0; i<num; i++) 00131 { 00132 if (k->get_kernel_type() == K_COMMWORDSTRING) 00133 { 00134 if (use_optimized_diagonal_computation) 00135 v[i]=((CCommWordStringKernel*) k)->compute_diag(i); 00136 else 00137 v[i]=((CCommWordStringKernel*) k)->compute_helper(i,i, true); 00138 } 00139 else 00140 v[i]=k->compute(i,i); 00141 00142 if (v[i]==0.0) 00143 v[i]=1e-16; /* avoid divide by zero exception */ 00144 } 00145 00146 return (v!=NULL); 00147 } 00148 00149 protected: 00151 float64_t* diag_lhs; 00153 int32_t num_diag_lhs; 00154 00156 float64_t* diag_rhs; 00158 int32_t num_diag_rhs; 00159 00161 bool use_optimized_diagonal_computation; 00162 }; 00163 } 00164 #endif