SHOGUN  v3.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
JADiag.cpp
Go to the documentation of this file.
00001 #ifdef HAVE_EIGEN3
00002 
00003 #include <shogun/mathematics/ajd/JADiag.h>
00004 
00005 #include <shogun/base/init.h>
00006 
00007 #include <shogun/mathematics/Math.h>
00008 #include <shogun/mathematics/eigen3.h>
00009 
00010 using namespace shogun;
00011 using namespace Eigen;
00012 
00013 void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[],
00014              float64_t *logdet, float64_t *decr, float64_t *result);
00015 
00016 SGMatrix<float64_t> CJADiag::diagonalize(SGNDArray<float64_t> C, SGMatrix<float64_t> V0,
00017                         double eps, int itermax)
00018 {
00019     int d = C.dims[0];
00020     int L = C.dims[2];
00021 
00022     // check that the input matrices are pos def
00023     for (int i = 0; i < L; i++)
00024     {
00025         Map<MatrixXd> Ci(C.get_matrix(i),d,d);
00026 
00027         EigenSolver<MatrixXd> eig;
00028         eig.compute(Ci);
00029 
00030         MatrixXd D = eig.pseudoEigenvalueMatrix();
00031 
00032         for (int j = 0; j < d; j++)
00033         {
00034             if (D(j,j) < 0)
00035             {
00036                 SG_SERROR("Input Matrix %d is not Positive-definite\n", i)
00037             }
00038         }
00039     }
00040 
00041     SGMatrix<float64_t> V;
00042     if (V0.num_rows == d && V0.num_cols == d)
00043         V = V0.clone();
00044     else
00045         V = SGMatrix<float64_t>::create_identity_matrix(d,1);
00046 
00047     VectorXd w(L);
00048     w.setOnes();
00049 
00050     MatrixXd ctot(d, d*L);
00051     for (int i = 0; i < L; i++)
00052     {
00053         Map<MatrixXd> Ci(C.get_matrix(i),d,d);
00054         ctot.block(0,i*d,d,d) = Ci;
00055     }
00056 
00057     int iter = 0;
00058     float64_t decr = 1;
00059     float64_t logdet = log(5.184e17);
00060     float64_t result = 0;
00061     std::vector<float64_t> crit;
00062     while (decr > eps && iter < itermax)
00063     {
00064         if(logdet == 0)// is NA
00065         {
00066             SG_SERROR("log det does not exist\n")
00067             break;
00068         }
00069 
00070         jadiagw(ctot.data(),
00071                 w.data(),
00072                 &d, &L,
00073                 V.matrix,
00074                 &logdet,
00075                 &decr,
00076                 &result);
00077 
00078         crit.push_back(result);
00079         iter = iter + 1;
00080     }
00081 
00082     if (iter == itermax)
00083         SG_SERROR("Convergence not reached\n")
00084 
00085     return V;
00086 
00087 }
00088 
00089 void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[],
00090         float64_t *logdet, float64_t *decr, float64_t *result)
00091 {
00092     int n = *ptn;
00093     int m = *ptm;
00094     //int   i1,j1;
00095     int n2 = n*n, mn2 = m*n2,
00096     i, ic, ii, ij, j, jc, jj, k, k0;
00097     float64_t  sumweigh, p2, q1, p, q,
00098     alpha, beta, gamma, a12, a21, /*tiny,*/ det;
00099     register float64_t tmp1, tmp2, tmp, weigh;
00100 
00101     for (sumweigh = 0, i = 0; i < m; i++)
00102         sumweigh += w[i];
00103 
00104     det = 1;
00105     *decr = 0;
00106 
00107     for (i = 1, ic = n; i < n ; i++, ic += n)
00108     {
00109         for (j = jc = 0; j < i; j++, jc += n)
00110         {
00111             ii = i + ic;
00112             jj = j + jc;
00113             ij = i + jc;
00114 
00115             for (q1 = p2 = p = q = 0, k0 = k = 0; k0 < m; k0++, k += n2)
00116             {
00117                 weigh = w[k0];
00118                 tmp1 = c[ii+k];
00119                 tmp2 = c[jj+k];
00120                 tmp = c[ij+k];
00121                 p += weigh*tmp/tmp1;
00122                 q += weigh*tmp/tmp2;
00123                 q1 += weigh*tmp1/tmp2;
00124                 p2 += weigh*tmp2/tmp1;
00125             }
00126 
00127             q1 /= sumweigh;
00128             p2 /= sumweigh;
00129             p /= sumweigh;
00130             q /= sumweigh;
00131             beta = 1 - p2*q1;// p1 = q2 = 1
00132 
00133             if (q1 <= p2)// the same as q1*q2 <= p1*p2
00134             {
00135                 alpha = p2*q - p;// q2 = 1
00136 
00137                 if (fabs(alpha) - beta < 10e-20)// beta <= 0 always
00138                 {
00139                     beta = -1;
00140                     gamma = p/p2;
00141                 }
00142                 else
00143                 {
00144                     gamma = - (p*beta + alpha)/p2;// p1 = 1
00145                 }
00146 
00147                 *decr += sumweigh*(p*p - alpha*alpha/beta)/p2;
00148             }
00149             else
00150             {
00151                 gamma = p*q1 - q;// p1 = 1
00152 
00153                 if (fabs(gamma) - beta < 10e-20)// beta <= 0 always
00154                 {
00155                     beta = -1;
00156                     alpha = q/q1;
00157                 }
00158                 else
00159                 {
00160                     alpha = - (q*beta + gamma)/q1;// q2 = 1
00161                 }
00162 
00163                 *decr += sumweigh*(q*q - gamma*gamma/beta)/q1;
00164             }
00165 
00166             tmp = (beta - sqrt(beta*beta - 4*alpha*gamma))/2;
00167             a12 = gamma/tmp;
00168             a21 = alpha/tmp;
00169 
00170             for (k = 0; k < mn2; k += n2)
00171             {
00172                 for (ii = i, jj = j; ii < ij; ii += n, jj += n)
00173                 {
00174                     tmp = c[ii+k];
00175                     c[ii+k] += a12*c[jj+k];
00176                     c[jj+k] += a21*tmp;
00177                 }// at exit ii = ij = i + jc
00178 
00179                 tmp = c[i+ic+k];
00180                 c[i+ic+k] += a12*(2*c[ij+k] + a12*c[jj+k]);
00181                 c[jj+k] += a21*c[ij+k];
00182                 c[ij+k] += a21*tmp;// = element of index j,i
00183 
00184                 for (; ii < ic; ii += n, jj++)
00185                 {
00186                     tmp = c[ii+k];
00187                     c[ii+k] += a12*c[jj+k];
00188                     c[jj+k] += a21*tmp;
00189                 }
00190 
00191                 for (; ++ii, ++jj < jc+n; )
00192                 {
00193                     tmp = c[ii+k];
00194                     c[ii+k] += a12*c[jj+k];
00195                     c[jj+k] += a21*tmp;
00196                 }
00197 
00198             }
00199 
00200             for (k = 0; k < n2; k += n)
00201             {
00202                 tmp = a[i+k];
00203                 a[i+k] += a12*a[j+k];
00204                 a[j+k] += a21*tmp;
00205             }
00206 
00207             det *= 1 - a12*a21;// compute determinant
00208         }
00209     }
00210 
00211     *logdet += 2*sumweigh*log(det);
00212 
00213     for (tmp = 0, k0 = k = 0; k0 < m; k0++, k += n2)
00214     {
00215         for (det = 1, ii = 0; ii < n2; ii += n+1)
00216         {
00217             det *= c[ii+k];
00218             tmp += w[k0]*log(det);
00219         }
00220     }
00221 
00222     *result = tmp - *logdet;
00223 
00224     return;
00225 }
00226 #endif //HAVE_EIGEN3
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation