SHOGUN
v3.2.0
|
00001 #include "AmariIndex.h" 00002 00003 #ifdef HAVE_EIGEN3 00004 00005 #include <shogun/mathematics/Math.h> 00006 #include <shogun/mathematics/eigen3.h> 00007 00008 using namespace shogun; 00009 using namespace Eigen; 00010 00011 float64_t amari_index(SGMatrix<float64_t> SGW, SGMatrix<float64_t> SGA, bool standardize) 00012 { 00013 Map<MatrixXd> W(SGW.matrix,SGW.num_rows,SGW.num_cols); 00014 Map<MatrixXd> A(SGA.matrix,SGA.num_rows,SGA.num_cols); 00015 00016 REQUIRE(W.rows() == W.cols(), "amari_index - W must be square\n") 00017 REQUIRE(A.rows() == A.cols(), "amari_index - A must be square\n") 00018 REQUIRE(W.rows() == A.rows(), "amari_index - A and W must be the same size\n") 00019 REQUIRE(W.rows() >= 2, "amari_index - input must be at least 2x2\n") 00020 00021 // normalizing both mixing matrices 00022 if (standardize) 00023 { 00024 for (int r = 0; r < W.rows(); r++) 00025 { 00026 W.row(r).normalize(); 00027 if (W.row(r).maxCoeff() < -1*W.row(r).minCoeff()) 00028 W.row(r) *= -1; 00029 } 00030 00031 A = A.inverse(); 00032 for (int r = 0; r < A.rows(); r++) 00033 { 00034 A.row(r).normalize(); 00035 if (A.row(r).maxCoeff() < -1*A.row(r).minCoeff()) 00036 A.row(r) *= -1; 00037 } 00038 A = A.inverse(); 00039 00040 bool swap = false; 00041 do 00042 { 00043 swap = false; 00044 for (int j = 1; j < A.cols(); j++) 00045 { 00046 if (A(0,j) < A(0,j-1)) 00047 { 00048 A.col(j).swap(A.col(j-1)); 00049 swap = true; 00050 } 00051 } 00052 00053 } while(swap); 00054 } 00055 00056 // calculating the permutation matrix 00057 MatrixXd P = (W * A).cwiseAbs(); 00058 int k = P.rows(); 00059 00060 // summing the error in the permutation matrix 00061 MatrixXd E1(k,k); 00062 for (int r = 0; r < k; r++) 00063 E1.row(r) = P.row(r) / P.row(r).maxCoeff(); 00064 00065 float64_t row_error = (E1.rowwise().sum().array()-1).sum(); 00066 00067 MatrixXd E2(k,k); 00068 for (int c = 0; c < k; c++) 00069 E2.col(c) = P.col(c) / P.col(c).maxCoeff(); 00070 00071 float64_t col_error = (E2.colwise().sum().array()-1).sum(); 00072 00073 return 1.0 / (float)(2*k*(k-1)) * (row_error + col_error); 00074 00075 } 00076 #endif //HAVE_EIGEN3