SHOGUN
v3.2.0
|
00001 #ifdef HAVE_EIGEN3 00002 00003 #include <shogun/mathematics/ajd/JADiag.h> 00004 00005 #include <shogun/base/init.h> 00006 00007 #include <shogun/mathematics/Math.h> 00008 #include <shogun/mathematics/eigen3.h> 00009 00010 using namespace shogun; 00011 using namespace Eigen; 00012 00013 void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[], 00014 float64_t *logdet, float64_t *decr, float64_t *result); 00015 00016 SGMatrix<float64_t> CJADiag::diagonalize(SGNDArray<float64_t> C, SGMatrix<float64_t> V0, 00017 double eps, int itermax) 00018 { 00019 int d = C.dims[0]; 00020 int L = C.dims[2]; 00021 00022 // check that the input matrices are pos def 00023 for (int i = 0; i < L; i++) 00024 { 00025 Map<MatrixXd> Ci(C.get_matrix(i),d,d); 00026 00027 EigenSolver<MatrixXd> eig; 00028 eig.compute(Ci); 00029 00030 MatrixXd D = eig.pseudoEigenvalueMatrix(); 00031 00032 for (int j = 0; j < d; j++) 00033 { 00034 if (D(j,j) < 0) 00035 { 00036 SG_SERROR("Input Matrix %d is not Positive-definite\n", i) 00037 } 00038 } 00039 } 00040 00041 SGMatrix<float64_t> V; 00042 if (V0.num_rows == d && V0.num_cols == d) 00043 V = V0.clone(); 00044 else 00045 V = SGMatrix<float64_t>::create_identity_matrix(d,1); 00046 00047 VectorXd w(L); 00048 w.setOnes(); 00049 00050 MatrixXd ctot(d, d*L); 00051 for (int i = 0; i < L; i++) 00052 { 00053 Map<MatrixXd> Ci(C.get_matrix(i),d,d); 00054 ctot.block(0,i*d,d,d) = Ci; 00055 } 00056 00057 int iter = 0; 00058 float64_t decr = 1; 00059 float64_t logdet = log(5.184e17); 00060 float64_t result = 0; 00061 std::vector<float64_t> crit; 00062 while (decr > eps && iter < itermax) 00063 { 00064 if(logdet == 0)// is NA 00065 { 00066 SG_SERROR("log det does not exist\n") 00067 break; 00068 } 00069 00070 jadiagw(ctot.data(), 00071 w.data(), 00072 &d, &L, 00073 V.matrix, 00074 &logdet, 00075 &decr, 00076 &result); 00077 00078 crit.push_back(result); 00079 iter = iter + 1; 00080 } 00081 00082 if (iter == itermax) 00083 SG_SERROR("Convergence not reached\n") 00084 00085 return V; 00086 00087 } 00088 00089 void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[], 00090 float64_t *logdet, float64_t *decr, float64_t *result) 00091 { 00092 int n = *ptn; 00093 int m = *ptm; 00094 //int i1,j1; 00095 int n2 = n*n, mn2 = m*n2, 00096 i, ic, ii, ij, j, jc, jj, k, k0; 00097 float64_t sumweigh, p2, q1, p, q, 00098 alpha, beta, gamma, a12, a21, /*tiny,*/ det; 00099 register float64_t tmp1, tmp2, tmp, weigh; 00100 00101 for (sumweigh = 0, i = 0; i < m; i++) 00102 sumweigh += w[i]; 00103 00104 det = 1; 00105 *decr = 0; 00106 00107 for (i = 1, ic = n; i < n ; i++, ic += n) 00108 { 00109 for (j = jc = 0; j < i; j++, jc += n) 00110 { 00111 ii = i + ic; 00112 jj = j + jc; 00113 ij = i + jc; 00114 00115 for (q1 = p2 = p = q = 0, k0 = k = 0; k0 < m; k0++, k += n2) 00116 { 00117 weigh = w[k0]; 00118 tmp1 = c[ii+k]; 00119 tmp2 = c[jj+k]; 00120 tmp = c[ij+k]; 00121 p += weigh*tmp/tmp1; 00122 q += weigh*tmp/tmp2; 00123 q1 += weigh*tmp1/tmp2; 00124 p2 += weigh*tmp2/tmp1; 00125 } 00126 00127 q1 /= sumweigh; 00128 p2 /= sumweigh; 00129 p /= sumweigh; 00130 q /= sumweigh; 00131 beta = 1 - p2*q1;// p1 = q2 = 1 00132 00133 if (q1 <= p2)// the same as q1*q2 <= p1*p2 00134 { 00135 alpha = p2*q - p;// q2 = 1 00136 00137 if (fabs(alpha) - beta < 10e-20)// beta <= 0 always 00138 { 00139 beta = -1; 00140 gamma = p/p2; 00141 } 00142 else 00143 { 00144 gamma = - (p*beta + alpha)/p2;// p1 = 1 00145 } 00146 00147 *decr += sumweigh*(p*p - alpha*alpha/beta)/p2; 00148 } 00149 else 00150 { 00151 gamma = p*q1 - q;// p1 = 1 00152 00153 if (fabs(gamma) - beta < 10e-20)// beta <= 0 always 00154 { 00155 beta = -1; 00156 alpha = q/q1; 00157 } 00158 else 00159 { 00160 alpha = - (q*beta + gamma)/q1;// q2 = 1 00161 } 00162 00163 *decr += sumweigh*(q*q - gamma*gamma/beta)/q1; 00164 } 00165 00166 tmp = (beta - sqrt(beta*beta - 4*alpha*gamma))/2; 00167 a12 = gamma/tmp; 00168 a21 = alpha/tmp; 00169 00170 for (k = 0; k < mn2; k += n2) 00171 { 00172 for (ii = i, jj = j; ii < ij; ii += n, jj += n) 00173 { 00174 tmp = c[ii+k]; 00175 c[ii+k] += a12*c[jj+k]; 00176 c[jj+k] += a21*tmp; 00177 }// at exit ii = ij = i + jc 00178 00179 tmp = c[i+ic+k]; 00180 c[i+ic+k] += a12*(2*c[ij+k] + a12*c[jj+k]); 00181 c[jj+k] += a21*c[ij+k]; 00182 c[ij+k] += a21*tmp;// = element of index j,i 00183 00184 for (; ii < ic; ii += n, jj++) 00185 { 00186 tmp = c[ii+k]; 00187 c[ii+k] += a12*c[jj+k]; 00188 c[jj+k] += a21*tmp; 00189 } 00190 00191 for (; ++ii, ++jj < jc+n; ) 00192 { 00193 tmp = c[ii+k]; 00194 c[ii+k] += a12*c[jj+k]; 00195 c[jj+k] += a21*tmp; 00196 } 00197 00198 } 00199 00200 for (k = 0; k < n2; k += n) 00201 { 00202 tmp = a[i+k]; 00203 a[i+k] += a12*a[j+k]; 00204 a[j+k] += a21*tmp; 00205 } 00206 00207 det *= 1 - a12*a21;// compute determinant 00208 } 00209 } 00210 00211 *logdet += 2*sumweigh*log(det); 00212 00213 for (tmp = 0, k0 = k = 0; k0 < m; k0++, k += n2) 00214 { 00215 for (det = 1, ii = 0; ii < n2; ii += n+1) 00216 { 00217 det *= c[ii+k]; 00218 tmp += w[k0]*log(det); 00219 } 00220 } 00221 00222 *result = tmp - *logdet; 00223 00224 return; 00225 } 00226 #endif //HAVE_EIGEN3