SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Jacob Walker 00008 * 00009 * Adapted from WeightedDegreeRBFKernel.cpp 00010 */ 00011 00012 #include <shogun/lib/common.h> 00013 #include <shogun/kernel/LinearARDKernel.h> 00014 #include <shogun/features/Features.h> 00015 #include <shogun/io/SGIO.h> 00016 00017 using namespace shogun; 00018 00019 CLinearARDKernel::CLinearARDKernel() : CDotKernel() 00020 { 00021 init(); 00022 } 00023 00024 00025 CLinearARDKernel::CLinearARDKernel(int32_t size) : CDotKernel(size) 00026 { 00027 init(); 00028 } 00029 00030 CLinearARDKernel::CLinearARDKernel(CDenseFeatures<float64_t>* l, 00031 CDenseFeatures<float64_t>* r, int32_t size) : CDotKernel(size) 00032 { 00033 init(); 00034 init(l,r); 00035 } 00036 00037 void CLinearARDKernel::init() 00038 { 00039 SG_ADD(&m_weights, "weights", "Feature weights", MS_AVAILABLE, 00040 GRADIENT_AVAILABLE); 00041 } 00042 00043 CLinearARDKernel::~CLinearARDKernel() 00044 { 00045 CKernel::cleanup(); 00046 } 00047 00048 bool CLinearARDKernel::init(CFeatures* l, CFeatures* r) 00049 { 00050 CDotKernel::init(l, r); 00051 00052 init_ft_weights(); 00053 00054 return init_normalizer(); 00055 } 00056 00057 void CLinearARDKernel::init_ft_weights() 00058 { 00059 if (!lhs || !rhs) 00060 return; 00061 00062 int32_t alen, blen; 00063 00064 alen=((CDenseFeatures<float64_t>*) lhs)->get_num_features(); 00065 blen=((CDenseFeatures<float64_t>*) rhs)->get_num_features(); 00066 00067 REQUIRE(alen==blen, "Number of right and left hand features must be the " 00068 "same\n") 00069 00070 if (m_weights.vlen != alen) 00071 { 00072 m_weights=SGVector<float64_t>(alen); 00073 m_weights.set_const(1.0); 00074 } 00075 00076 SG_DEBUG("Initialized weights for LinearARDKernel (%p).\n", this) 00077 } 00078 00079 void CLinearARDKernel::set_weight(float64_t w, index_t i) 00080 { 00081 if (i >= m_weights.vlen) 00082 { 00083 SG_ERROR("Index %i out of range for LinearARDKernel."\ 00084 "Number of features is %i.\n", i, m_weights.vlen); 00085 } 00086 00087 m_weights[i]=w; 00088 } 00089 00090 float64_t CLinearARDKernel::get_weight(index_t i) 00091 { 00092 if (i >= m_weights.vlen) 00093 { 00094 SG_ERROR("Index %i out of range for LinearARDKernel."\ 00095 "Number of features is %i.\n", i, m_weights.vlen); 00096 } 00097 00098 return m_weights[i]; 00099 } 00100 00101 float64_t CLinearARDKernel::compute(int32_t idx_a, int32_t idx_b) 00102 { 00103 REQUIRE(lhs && rhs, "Features not set!\n") 00104 00105 SGVector<float64_t> avec= 00106 ((CDenseFeatures<float64_t>*) lhs)->get_feature_vector(idx_a); 00107 SGVector<float64_t> bvec= 00108 ((CDenseFeatures<float64_t>*) rhs)->get_feature_vector(idx_b); 00109 00110 REQUIRE(avec.vlen==bvec.vlen, "Number of right and left hand " 00111 "features must be the same\n") 00112 00113 float64_t result=0; 00114 00115 for (index_t i = 0; i < avec.vlen; i++) 00116 result += avec[i]*bvec[i]*m_weights[i]*m_weights[i]; 00117 00118 return result; 00119 } 00120 00121 SGMatrix<float64_t> CLinearARDKernel::get_parameter_gradient( 00122 const TParameter* param, index_t index) 00123 { 00124 REQUIRE(lhs && rhs, "Features not set!\n") 00125 00126 if (!strcmp(param->m_name, "weights")) 00127 { 00128 SGMatrix<float64_t> derivative(num_lhs, num_rhs); 00129 00130 for (index_t j=0; j<num_lhs; j++) 00131 { 00132 for (index_t k=0; k<num_rhs; k++) 00133 { 00134 SGVector<float64_t> avec= 00135 ((CDenseFeatures<float64_t>*) lhs)->get_feature_vector(j); 00136 SGVector<float64_t> bvec= 00137 ((CDenseFeatures<float64_t>*) rhs)->get_feature_vector(k); 00138 00139 REQUIRE(avec.vlen==bvec.vlen, "Number of right and left hand " 00140 "features must be the same\n"); 00141 00142 derivative(j,k)=avec[index]*bvec[index]*m_weights[index]; 00143 } 00144 } 00145 return derivative; 00146 } 00147 else 00148 { 00149 SG_ERROR("Can't compute derivative wrt %s parameter\n", param->m_name); 00150 return SGMatrix<float64_t>(); 00151 } 00152 }