SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Soumyajit De 00008 */ 00009 00010 #include <shogun/lib/common.h> 00011 00012 #ifdef HAVE_EIGEN3 00013 00014 #include <shogun/lib/SGVector.h> 00015 #include <shogun/lib/Time.h> 00016 #include <shogun/mathematics/eigen3.h> 00017 #include <shogun/mathematics/linalg/linop/LinearOperator.h> 00018 #include <shogun/mathematics/linalg/linsolver/ConjugateGradientSolver.h> 00019 #include <shogun/mathematics/linalg/linsolver/IterativeSolverIterator.h> 00020 00021 using namespace Eigen; 00022 00023 namespace shogun 00024 { 00025 00026 CConjugateGradientSolver::CConjugateGradientSolver() 00027 : CIterativeLinearSolver<float64_t>() 00028 { 00029 SG_GCDEBUG("%s created (%p)\n", this->get_name(), this); 00030 } 00031 00032 CConjugateGradientSolver::CConjugateGradientSolver(bool store_residuals) 00033 : CIterativeLinearSolver<float64_t>(store_residuals) 00034 { 00035 SG_GCDEBUG("%s created (%p)\n", this->get_name(), this); 00036 } 00037 00038 CConjugateGradientSolver::~CConjugateGradientSolver() 00039 { 00040 SG_GCDEBUG("%s destroyed (%p)\n", this->get_name(), this); 00041 } 00042 00043 SGVector<float64_t> CConjugateGradientSolver::solve( 00044 CLinearOperator<float64_t>* A, SGVector<float64_t> b) 00045 { 00046 SG_DEBUG("CConjugateGradientSolve::solve(): Entering..\n"); 00047 00048 // sanity check 00049 REQUIRE(A, "Operator is NULL!\n"); 00050 REQUIRE(A->get_dimension()==b.vlen, "Dimension mismatch!\n"); 00051 00052 // the final solution vector, initial guess is 0 00053 SGVector<float64_t> result(b.vlen); 00054 result.set_const(0.0); 00055 00056 // the rest of the part hinges on eigen3 for computing norms 00057 Map<VectorXd> x(result.vector, result.vlen); 00058 Map<VectorXd> b_map(b.vector, b.vlen); 00059 00060 // direction vector 00061 SGVector<float64_t> p_(result.vlen); 00062 Map<VectorXd> p(p_.vector, p_.vlen); 00063 00064 // residual r_i=b-Ax_i, here x_0=[0], so r_0=b 00065 VectorXd r=b_map; 00066 00067 // initial direction is same as residual 00068 p=r; 00069 00070 // the iterator for this iterative solver 00071 IterativeSolverIterator<float64_t> it(b_map, m_max_iteration_limit, 00072 m_relative_tolerence, m_absolute_tolerence); 00073 00074 // CG iteration begins 00075 float64_t r_norm2=r.dot(r); 00076 00077 // start the timer 00078 CTime time; 00079 time.start(); 00080 00081 // set the residuals to zero 00082 if (m_store_residuals) 00083 m_residuals.set_const(0.0); 00084 00085 for (it.begin(r); !it.end(r); ++it) 00086 { 00087 SG_DEBUG("CG iteration %d, residual norm %f\n", 00088 it.get_iter_info().iteration_count, 00089 it.get_iter_info().residual_norm); 00090 00091 if (m_store_residuals) 00092 { 00093 m_residuals[it.get_iter_info().iteration_count] 00094 =it.get_iter_info().residual_norm; 00095 } 00096 00097 // apply linear operator to the direction vector 00098 SGVector<float64_t> Ap_=A->apply(p_); 00099 Map<VectorXd> Ap(Ap_.vector, Ap_.vlen); 00100 00101 // compute p^{T}Ap, if zero, failure 00102 float64_t p_dot_Ap=p.dot(Ap); 00103 if (p_dot_Ap==0.0) 00104 break; 00105 00106 // compute the alpha parameter of CG 00107 float64_t alpha=r_norm2/p_dot_Ap; 00108 00109 // update the solution vector and residual 00110 // x_{i}=x_{i-1}+\alpha_{i}p 00111 x+=alpha*p; 00112 00113 // r_{i}=r_{i-1}-\alpha_{i}p 00114 r-=alpha*Ap; 00115 00116 // compute new ||r||_{2}, if zero, converged 00117 float64_t r_norm2_i=r.dot(r); 00118 if (r_norm2_i==0.0) 00119 break; 00120 00121 // compute the beta parameter of CG 00122 float64_t beta=r_norm2_i/r_norm2; 00123 00124 // update direction, and ||r||_{2} 00125 r_norm2=r_norm2_i; 00126 p=r+beta*p; 00127 } 00128 00129 float64_t elapsed=time.cur_time_diff(); 00130 00131 if (!it.succeeded(r)) 00132 SG_WARNING("Did not converge!\n"); 00133 00134 SG_INFO("Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n", 00135 it.get_iter_info().iteration_count, it.get_iter_info().residual_norm, elapsed); 00136 00137 SG_DEBUG("CConjugateGradientSolve::solve(): Leaving..\n"); 00138 return result; 00139 } 00140 00141 } 00142 #endif // HAVE_EIGEN3