SHOGUN
v3.2.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2013 Soumyajit De 00008 */ 00009 00010 #include <shogun/lib/common.h> 00011 00012 #ifdef HAVE_EIGEN3 00013 00014 #include <shogun/lib/SGVector.h> 00015 #include <shogun/lib/Time.h> 00016 #include <shogun/mathematics/eigen3.h> 00017 #include <shogun/mathematics/Math.h> 00018 #include <shogun/mathematics/linalg/linop/LinearOperator.h> 00019 #include <shogun/mathematics/linalg/linsolver/ConjugateOrthogonalCGSolver.h> 00020 #include <shogun/mathematics/linalg/linsolver/IterativeSolverIterator.h> 00021 using namespace Eigen; 00022 00023 namespace shogun 00024 { 00025 00026 CConjugateOrthogonalCGSolver::CConjugateOrthogonalCGSolver() 00027 : CIterativeLinearSolver<complex128_t, float64_t>() 00028 { 00029 SG_GCDEBUG("%s created (%p)\n", this->get_name(), this); 00030 } 00031 00032 CConjugateOrthogonalCGSolver::CConjugateOrthogonalCGSolver(bool store_residuals) 00033 : CIterativeLinearSolver<complex128_t, float64_t>(store_residuals) 00034 { 00035 SG_GCDEBUG("%s created (%p)\n", this->get_name(), this); 00036 } 00037 00038 CConjugateOrthogonalCGSolver::~CConjugateOrthogonalCGSolver() 00039 { 00040 SG_GCDEBUG("%s destroyed (%p)\n", this->get_name(), this); 00041 } 00042 00043 SGVector<complex128_t> CConjugateOrthogonalCGSolver::solve( 00044 CLinearOperator<complex128_t>* A, SGVector<float64_t> b) 00045 { 00046 SG_DEBUG("CConjugateOrthogonalCGSolver::solve(): Entering..\n"); 00047 00048 // sanity check 00049 REQUIRE(A, "Operator is NULL!\n"); 00050 REQUIRE(A->get_dimension()==b.vlen, "Dimension mismatch!\n, %d vs %d", 00051 A->get_dimension(), b.vlen); 00052 00053 // the final solution vector, initial guess is 0 00054 SGVector<complex128_t> result(b.vlen); 00055 result.set_const(0.0); 00056 00057 // the rest of the part hinges on eigen3 for computing norms 00058 Map<VectorXcd> x(result.vector, result.vlen); 00059 Map<VectorXd> b_map(b.vector, b.vlen); 00060 00061 // direction vector 00062 SGVector<complex128_t> p_(result.vlen); 00063 Map<VectorXcd> p(p_.vector, p_.vlen); 00064 00065 // residual r_i=b-Ax_i, here x_0=[0], so r_0=b 00066 VectorXcd r=b_map.cast<complex128_t>(); 00067 00068 // initial direction is same as residual 00069 p=r; 00070 00071 // the iterator for this iterative solver 00072 IterativeSolverIterator<complex128_t> it(r, m_max_iteration_limit, 00073 m_relative_tolerence, m_absolute_tolerence); 00074 00075 // start the timer 00076 CTime time; 00077 time.start(); 00078 00079 // set the residuals to zero 00080 if (m_store_residuals) 00081 m_residuals.set_const(0.0); 00082 00083 // CG iteration begins 00084 complex128_t r_norm2=r.transpose()*r; 00085 00086 for (it.begin(r); !it.end(r); ++it) 00087 { 00088 SG_DEBUG("CG iteration %d, residual norm %f\n", 00089 it.get_iter_info().iteration_count, 00090 it.get_iter_info().residual_norm); 00091 00092 if (m_store_residuals) 00093 { 00094 m_residuals[it.get_iter_info().iteration_count] 00095 =it.get_iter_info().residual_norm; 00096 } 00097 00098 // apply linear operator to the direction vector 00099 SGVector<complex128_t> Ap_=A->apply(p_); 00100 Map<VectorXcd> Ap(Ap_.vector, Ap_.vlen); 00101 00102 // compute p^{T}Ap, if zero, failure 00103 complex128_t p_T_times_Ap=p.transpose()*Ap; 00104 if (p_T_times_Ap==0.0) 00105 break; 00106 00107 // compute the alpha parameter of CG 00108 complex128_t alpha=r_norm2/p_T_times_Ap; 00109 00110 // update the solution vector and residual 00111 // x_{i}=x_{i-1}+\alpha_{i}p 00112 x+=alpha*p; 00113 00114 // r_{i}=r_{i-1}-\alpha_{i}p 00115 r-=alpha*Ap; 00116 00117 // compute new ||r||_{2}, if zero, converged 00118 complex128_t r_norm2_i=r.transpose()*r; 00119 if (r_norm2_i==0.0) 00120 break; 00121 00122 // compute the beta parameter of CG 00123 complex128_t beta=r_norm2_i/r_norm2; 00124 00125 // update direction, and ||r||_{2} 00126 r_norm2=r_norm2_i; 00127 p=r+beta*p; 00128 } 00129 00130 float64_t elapsed=time.cur_time_diff(); 00131 00132 if (!it.succeeded(r)) 00133 SG_WARNING("Did not converge!\n"); 00134 00135 SG_INFO("Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n", 00136 it.get_iter_info().iteration_count, it.get_iter_info().residual_norm, elapsed); 00137 00138 SG_DEBUG("CConjugateOrthogonalCGSolver::solve(): Leaving..\n"); 00139 return result; 00140 } 00141 00142 } 00143 #endif // HAVE_EIGEN3