SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Soumyajit De 00008 */ 00009 00010 #include <shogun/lib/common.h> 00011 00012 #ifdef HAVE_EIGEN3 00013 00014 #include <shogun/lib/SGVector.h> 00015 #include <shogun/lib/Time.h> 00016 #include <shogun/mathematics/eigen3.h> 00017 #include <shogun/mathematics/Math.h> 00018 #include <shogun/mathematics/linalg/linop/LinearOperator.h> 00019 #include <shogun/mathematics/linalg/linsolver/CGMShiftedFamilySolver.h> 00020 #include <shogun/mathematics/linalg/linsolver/IterativeSolverIterator.h> 00021 00022 using namespace Eigen; 00023 00024 namespace shogun 00025 { 00026 00027 CCGMShiftedFamilySolver::CCGMShiftedFamilySolver() 00028 : CIterativeShiftedLinearFamilySolver<float64_t, complex128_t>() 00029 { 00030 } 00031 00032 CCGMShiftedFamilySolver::CCGMShiftedFamilySolver(bool store_residuals) 00033 : CIterativeShiftedLinearFamilySolver<float64_t, complex128_t>(store_residuals) 00034 { 00035 } 00036 00037 CCGMShiftedFamilySolver::~CCGMShiftedFamilySolver() 00038 { 00039 } 00040 00041 SGVector<float64_t> CCGMShiftedFamilySolver::solve( 00042 CLinearOperator<float64_t>* A, SGVector<float64_t> b) 00043 { 00044 SGVector<complex128_t> shifts(1); 00045 shifts[0]=0.0; 00046 SGVector<complex128_t> weights(1); 00047 weights[0]=1.0; 00048 00049 return solve_shifted_weighted(A, b, shifts, weights).get_real(); 00050 } 00051 00052 SGVector<complex128_t> CCGMShiftedFamilySolver::solve_shifted_weighted( 00053 CLinearOperator<float64_t>* A, SGVector<float64_t> b, 00054 SGVector<complex128_t> shifts, SGVector<complex128_t> weights) 00055 { 00056 SG_DEBUG("Entering\n"); 00057 00058 // sanity check 00059 REQUIRE(A, "Operator is NULL!\n"); 00060 REQUIRE(A->get_dimension()==b.vlen, "Dimension mismatch! [%d vs %d]\n", 00061 A->get_dimension(), b.vlen); 00062 REQUIRE(shifts.vector,"Shifts are not initialized!\n"); 00063 REQUIRE(weights.vector,"Weights are not initialized!\n"); 00064 REQUIRE(shifts.vlen==weights.vlen, "Number of shifts and number of " 00065 "weights are not equal! [%d vs %d]\n", shifts.vlen, weights.vlen); 00066 00067 // the solution matrix, one column per shift, initial guess 0 for all 00068 MatrixXcd x_sh=MatrixXcd::Zero(b.vlen, shifts.vlen); 00069 MatrixXcd p_sh=MatrixXcd::Zero(b.vlen, shifts.vlen); 00070 00071 // non-shifted direction 00072 SGVector<float64_t> p_(b.vlen); 00073 00074 // the rest of the part hinges on eigen3 for computing norms 00075 Map<VectorXd> b_map(b.vector, b.vlen); 00076 Map<VectorXd> p(p_.vector, p_.vlen); 00077 00078 // residual r_i=b-Ax_i, here x_0=[0], so r_0=b 00079 VectorXd r=b_map; 00080 00081 // initial direction is same as residual 00082 p=r; 00083 p_sh=r.replicate(1, shifts.vlen).cast<complex128_t>(); 00084 00085 // non shifted initializers 00086 float64_t r_norm2=r.dot(r); 00087 float64_t beta_old=1.0; 00088 float64_t alpha=1.0; 00089 00090 // shifted quantities 00091 SGVector<complex128_t> alpha_sh(shifts.vlen); 00092 SGVector<complex128_t> beta_sh(shifts.vlen); 00093 SGVector<complex128_t> zeta_sh_old(shifts.vlen); 00094 SGVector<complex128_t> zeta_sh_cur(shifts.vlen); 00095 SGVector<complex128_t> zeta_sh_new(shifts.vlen); 00096 00097 // shifted initializers 00098 zeta_sh_old.set_const(1.0); 00099 zeta_sh_cur.set_const(1.0); 00100 00101 // the iterator for this iterative solver 00102 IterativeSolverIterator<float64_t> it(r, m_max_iteration_limit, 00103 m_relative_tolerence, m_absolute_tolerence); 00104 00105 // start the timer 00106 CTime time; 00107 time.start(); 00108 00109 // set the residuals to zero 00110 if (m_store_residuals) 00111 m_residuals.set_const(0.0); 00112 00113 // CG iteration begins 00114 for (it.begin(r); !it.end(r); ++it) 00115 { 00116 00117 SG_DEBUG("CG iteration %d, residual norm %f\n", 00118 it.get_iter_info().iteration_count, 00119 it.get_iter_info().residual_norm); 00120 00121 if (m_store_residuals) 00122 { 00123 m_residuals[it.get_iter_info().iteration_count] 00124 =it.get_iter_info().residual_norm; 00125 } 00126 00127 // apply linear operator to the direction vector 00128 SGVector<float64_t> Ap_=A->apply(p_); 00129 Map<VectorXd> Ap(Ap_.vector, Ap_.vlen); 00130 00131 // compute p^{T}Ap, if zero, failure 00132 float64_t p_dot_Ap=p.dot(Ap); 00133 if (p_dot_Ap==0.0) 00134 break; 00135 00136 // compute the beta parameter of CG_M 00137 float64_t beta=-r_norm2/p_dot_Ap; 00138 00139 // compute the zeta-shifted parameter of CG_M 00140 compute_zeta_sh_new(zeta_sh_old, zeta_sh_cur, shifts, beta_old, beta, 00141 alpha, zeta_sh_new); 00142 00143 // compute beta-shifted parameter of CG_M 00144 compute_beta_sh(zeta_sh_new, zeta_sh_cur, beta, beta_sh); 00145 00146 // update the solution vector and residual 00147 for (index_t i=0; i<shifts.vlen; ++i) 00148 x_sh.col(i)-=beta_sh[i]*p_sh.col(i); 00149 00150 // r_{i}=r_{i-1}+\beta_{i}Ap 00151 r+=beta*Ap; 00152 00153 // compute new ||r||_{2}, if zero, converged 00154 float64_t r_norm2_i=r.dot(r); 00155 if (r_norm2_i==0.0) 00156 break; 00157 00158 // compute the alpha parameter of CG_M 00159 alpha=r_norm2_i/r_norm2; 00160 00161 // update ||r||_{2} 00162 r_norm2=r_norm2_i; 00163 00164 // update direction 00165 p=r+alpha*p; 00166 00167 compute_alpha_sh(zeta_sh_new, zeta_sh_cur, beta_sh, beta, alpha, alpha_sh); 00168 00169 for (index_t i=0; i<shifts.vlen; ++i) 00170 { 00171 p_sh.col(i)*=alpha_sh[i]; 00172 p_sh.col(i)+=zeta_sh_new[i]*r; 00173 } 00174 00175 // update parameters 00176 for (index_t i=0; i<shifts.vlen; ++i) 00177 { 00178 zeta_sh_old[i]=zeta_sh_cur[i]; 00179 zeta_sh_cur[i]=zeta_sh_new[i]; 00180 } 00181 beta_old=beta; 00182 } 00183 00184 float64_t elapsed=time.cur_time_diff(); 00185 00186 if (!it.succeeded(r)) 00187 SG_WARNING("Did not converge!\n"); 00188 00189 SG_INFO("Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n", 00190 it.get_iter_info().iteration_count, it.get_iter_info().residual_norm, elapsed); 00191 00192 // compute the final result vector multiplied by weights 00193 SGVector<complex128_t> result(b.vlen); 00194 result.set_const(0.0); 00195 Map<VectorXcd> x(result.vector, result.vlen); 00196 00197 for (index_t i=0; i<x_sh.cols(); ++i) 00198 x+=x_sh.col(i)*weights[i]; 00199 00200 SG_DEBUG("Leaving\n"); 00201 return result; 00202 } 00203 00204 } 00205 #endif // HAVE_EIGEN3