SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SpectrumMismatchRBFKernel.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include <vector>
00013 
00014 #include <shogun/lib/common.h>
00015 #include <shogun/io/SGIO.h>
00016 #include <shogun/lib/Signal.h>
00017 #include <shogun/lib/Trie.h>
00018 #include <shogun/base/Parallel.h>
00019 
00020 #include <shogun/kernel/string/SpectrumMismatchRBFKernel.h>
00021 #include <shogun/features/Features.h>
00022 #include <shogun/features/StringFeatures.h>
00023 
00024 #include <vector>
00025 #include <string>
00026 
00027 #include <assert.h>
00028 
00029 #ifndef WIN32
00030 #include <pthread.h>
00031 #endif
00032 
00033 using namespace shogun;
00034 
00035 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel() :
00036         CStringKernel<char>(0)
00037 {
00038     init();
00039     register_params();
00040 }
00041 
00042 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(int32_t size,
00043         float64_t* AA_matrix_, int32_t nr, int32_t nc, int32_t degree_,
00044         int32_t max_mismatch_, float64_t width_) :
00045         CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(
00046                 max_mismatch_), width(width_)
00047 {
00048     init();
00049     target_letter_0=-1;
00050     set_AA_matrix(AA_matrix_, nr, nc);
00051     register_params();
00052 }
00053 
00054 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(CStringFeatures<char>* l,
00055         CStringFeatures<char>* r, int32_t size, float64_t* AA_matrix_,
00056         int32_t nr, int32_t nc, int32_t degree_, int32_t max_mismatch_,
00057         float64_t width_) :
00058         CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(
00059                 max_mismatch_), width(width_)
00060 {
00061     target_letter_0=-1;
00062 
00063     set_AA_matrix(AA_matrix_, nr, nc);
00064     init(l, r);
00065     register_params();
00066 }
00067 
00068 CSpectrumMismatchRBFKernel::~CSpectrumMismatchRBFKernel()
00069 {
00070     cleanup();
00071     SG_UNREF(kernel_matrix);
00072 }
00073 
00074 bool CSpectrumMismatchRBFKernel::init(CFeatures* l, CFeatures* r)
00075 {
00076     int32_t lhs_changed=(lhs!=l);
00077     int32_t rhs_changed=(rhs!=r);
00078 
00079     CStringKernel<char>::init(l, r);
00080 
00081     SG_DEBUG("lhs_changed: %i\n", lhs_changed)
00082     SG_DEBUG("rhs_changed: %i\n", rhs_changed)
00083 
00084     CStringFeatures<char>* sf_l=(CStringFeatures<char>*)l;
00085     CStringFeatures<char>* sf_r=(CStringFeatures<char>*)r;
00086 
00087     SG_UNREF(alphabet);
00088     alphabet=sf_l->get_alphabet();
00089     CAlphabet* ralphabet=sf_r->get_alphabet();
00090 
00091     if (!((alphabet->get_alphabet()==DNA) || (alphabet->get_alphabet()==RNA)))
00092         properties&=((uint64_t)(-1))^(KP_LINADD|KP_BATCHEVALUATION);
00093 
00094     ASSERT(ralphabet->get_alphabet()==alphabet->get_alphabet())
00095     SG_UNREF(ralphabet);
00096 
00097     compute_all();
00098 
00099     return init_normalizer();
00100 }
00101 
00102 void CSpectrumMismatchRBFKernel::cleanup()
00103 {
00104 
00105     SG_UNREF(alphabet);
00106     alphabet=NULL;
00107 
00108     CKernel::cleanup();
00109 }
00110 
00111 float64_t CSpectrumMismatchRBFKernel::AA_helper(std::string &path,
00112         const char* joint_seq, unsigned int index)
00113 {
00114     float64_t diff=0.0;
00115 
00116     for (unsigned int i=0; i<path.size(); i++)
00117     {
00118         if (path[i]!=joint_seq[index+i])
00119         {
00120             diff+=AA_matrix.matrix[(path[i]-1)*128+path[i]-1];
00121             diff-=2*AA_matrix.matrix[(path[i]-1)*128+joint_seq[index+i]-1];
00122             diff+=AA_matrix.matrix[(joint_seq[index+i]-1)*128+joint_seq[index+i]
00123                     -1];
00124         }
00125     }
00126 
00127     return exp(-diff/width);
00128 }
00129 
00130 void CSpectrumMismatchRBFKernel::compute_helper_all(const char *joint_seq,
00131         std::vector<struct joint_list_struct> &joint_list, std::string path,
00132         unsigned int d)
00133 {
00134     const char* AA="ACDEFGHIKLMNPQRSTVWY";
00135     const unsigned int num_AA=strlen(AA);
00136 
00137     assert(path.size()==d);
00138 
00139     for (unsigned int i=0; i<num_AA; i++)
00140     {
00141         std::vector<struct joint_list_struct> joint_list_;
00142 
00143         if (d==0)
00144             SG_PRINT("i=%i: ", i);
00145         if (d==0&&target_letter_0!=-1&&(int)i!=target_letter_0)
00146             continue;
00147 
00148         if (d==1)
00149         {
00150             SG_PRINT("*");
00151         }
00152         if (d==2)
00153         {
00154             SG_PRINT("+");
00155         }
00156 
00157         for (unsigned int j=0; j<joint_list.size(); j++)
00158         {
00159             if (joint_seq[joint_list[j].index+d]!=AA[i])
00160             {
00161                 if (joint_list[j].mismatch+1<=(unsigned int)max_mismatch)
00162                 {
00163                     struct joint_list_struct list_item;
00164                     list_item=joint_list[j];
00165                     list_item.mismatch=joint_list[j].mismatch+1;
00166                     joint_list_.push_back(list_item);
00167                 }
00168             }
00169             else
00170                 joint_list_.push_back(joint_list[j]);
00171         }
00172 
00173         if (joint_list_.size()>0)
00174         {
00175             std::string path_=path+AA[i];
00176 
00177             if (d+1<(unsigned int)degree)
00178             {
00179                 compute_helper_all(joint_seq, joint_list_, path_, d+1);
00180             }
00181             else
00182             {
00183                 CDynamicArray<float64_t> feats;
00184                 feats.resize_array(kernel_matrix->get_dim1());
00185                 feats.set_const(0);
00186 
00187                 for (unsigned int j=0; j<joint_list_.size(); j++)
00188                 {
00189                     if (width==0.0)
00190                     {
00191                         feats[joint_list_[j].ex_index]++;
00192                         //if (joint_mismatch_[j]==0)
00193                         //  feats[joint_ex_index_[j]]+=3 ;
00194                     }
00195                     else
00196                     {
00197                         if (joint_list_[j].mismatch!=0)
00198                             feats[joint_list_[j].ex_index]+=AA_helper(path_,
00199                                     joint_seq, joint_list_[j].index);
00200                         else
00201                             feats[joint_list_[j].ex_index]++;
00202                     }
00203                 }
00204 
00205                 std::vector<int> idx;
00206                 for (int r=0; r<feats.get_array_size(); r++)
00207                     if (feats[r]!=0.0)
00208                         idx.push_back(r);
00209 
00210                 for (unsigned int r=0; r<idx.size(); r++)
00211                     for (unsigned int s=r; s<idx.size(); s++)
00212                         if (s==r)
00213                             kernel_matrix->set_element(
00214                                     feats[idx[r]]*feats[idx[s]]
00215                                             +kernel_matrix->get_element(idx[r],
00216                                                     idx[s]), idx[r], idx[s]);
00217                         else
00218                         {
00219                             kernel_matrix->set_element(
00220                                     feats[idx[r]]*feats[idx[s]]
00221                                             +kernel_matrix->get_element(idx[r],
00222                                                     idx[s]), idx[r], idx[s]);
00223                             kernel_matrix->set_element(
00224                                     feats[idx[r]]*feats[idx[s]]
00225                                             +kernel_matrix->get_element(idx[s],
00226                                                     idx[r]), idx[s], idx[r]);
00227                         }
00228             }
00229         }
00230         if (d==0)
00231             SG_PRINT("\n");
00232     }
00233 }
00234 
00235 void CSpectrumMismatchRBFKernel::compute_all()
00236 {
00237     std::string joint_seq;
00238     std::vector<struct joint_list_struct> joint_list;
00239 
00240     assert(lhs->get_num_vectors()==rhs->get_num_vectors());
00241     kernel_matrix->resize_array(lhs->get_num_vectors(), lhs->get_num_vectors());
00242     kernel_matrix_length=lhs->get_num_vectors()*rhs->get_num_vectors();
00243     for (int i=0; i<lhs->get_num_vectors(); i++)
00244         for (int j=0; j<lhs->get_num_vectors(); j++)
00245             kernel_matrix->set_element(0, i, j);
00246 
00247     for (int i=0; i<lhs->get_num_vectors(); i++)
00248     {
00249         int32_t alen;
00250         bool free_avec;
00251         char* avec=((CStringFeatures<char>*)lhs)->get_feature_vector(i, alen,
00252                 free_avec);
00253 
00254         for (int apos=0; apos+degree-1<alen; apos++)
00255         {
00256             struct joint_list_struct list_item;
00257             list_item.ex_index=i;
00258             list_item.index=apos+joint_seq.size();
00259             list_item.mismatch=0;
00260 
00261             joint_list.push_back(list_item);
00262         }
00263         joint_seq+=std::string(avec, alen);
00264 
00265         ((CStringFeatures<char>*)lhs)->free_feature_vector(avec, i, free_avec);
00266     }
00267 
00268     compute_helper_all(joint_seq.c_str(), joint_list, "", 0);
00269 }
00270 
00271 float64_t CSpectrumMismatchRBFKernel::compute(int32_t idx_a, int32_t idx_b)
00272 {
00273     return kernel_matrix->element(idx_a, idx_b);
00274 }
00275 
00276 bool CSpectrumMismatchRBFKernel::set_AA_matrix(float64_t* AA_matrix_,
00277         int32_t nr, int32_t nc)
00278 {
00279     if (AA_matrix_)
00280     {
00281         if (nr!=128 || nc!=128)
00282             SG_ERROR("AA_matrix should be of shape 128x128\n")
00283 
00284         AA_matrix=SGMatrix<float64_t>(nc, nr);
00285         SG_DEBUG("Setting AA_matrix\n")
00286         memcpy(AA_matrix.matrix, AA_matrix_, 128*128*sizeof(float64_t));
00287         return true;
00288     }
00289 
00290     return false;
00291 }
00292 
00293 bool CSpectrumMismatchRBFKernel::set_max_mismatch(int32_t max)
00294 {
00295     max_mismatch=max;
00296 
00297     if (lhs!=NULL&&rhs!=NULL)
00298         return init(lhs, rhs);
00299     else
00300         return true;
00301 }
00302 
00303 void CSpectrumMismatchRBFKernel::register_params()
00304 {
00305     SG_ADD(&degree, "degree", "degree of the kernel", MS_AVAILABLE);
00306     SG_ADD(&AA_matrix, "AA_matrix", "128*128 scalar product matrix",
00307             MS_NOT_AVAILABLE);
00308     SG_ADD(&width, "width", "width of Gaussian", MS_AVAILABLE);
00309     SG_ADD(&target_letter_0, "target_letter_0", "target letter 0",
00310             MS_NOT_AVAILABLE);
00311     SG_ADD(&initialized, "initialized", "the mark of initialization status",
00312             MS_NOT_AVAILABLE);
00313     SG_ADD((CSGObject** )&kernel_matrix, "kernel_matrix",
00314             "the kernel matrix with its length "
00315                     "defined by the number of vectors of the string features",
00316             MS_NOT_AVAILABLE);
00317 }
00318 
00319 void CSpectrumMismatchRBFKernel::register_alphabet()
00320 {
00321     SG_ADD((CSGObject** )&alphabet, "alphabet", "the alphabet used by kernel",
00322             MS_NOT_AVAILABLE);
00323 }
00324 
00325 void CSpectrumMismatchRBFKernel::init()
00326 {
00327     alphabet=NULL;
00328     degree=0;
00329     max_mismatch=0;
00330     width=0.0;
00331     kernel_matrix=new CDynamicArray<float64_t>();
00332     initialized=false;
00333     target_letter_0=0;
00334 }
00335 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation