![]() |
Eigen
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_CONJUGATE_GRADIENT_H 00011 #define EIGEN_CONJUGATE_GRADIENT_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00026 template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner> 00027 EIGEN_DONT_INLINE 00028 void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, 00029 const Preconditioner& precond, Index& iters, 00030 typename Dest::RealScalar& tol_error) 00031 { 00032 using std::sqrt; 00033 using std::abs; 00034 typedef typename Dest::RealScalar RealScalar; 00035 typedef typename Dest::Scalar Scalar; 00036 typedef Matrix<Scalar,Dynamic,1> VectorType; 00037 00038 RealScalar tol = tol_error; 00039 Index maxIters = iters; 00040 00041 Index n = mat.cols(); 00042 00043 VectorType residual = rhs - mat * x; //initial residual 00044 00045 RealScalar rhsNorm2 = rhs.squaredNorm(); 00046 if(rhsNorm2 == 0) 00047 { 00048 x.setZero(); 00049 iters = 0; 00050 tol_error = 0; 00051 return; 00052 } 00053 RealScalar threshold = tol*tol*rhsNorm2; 00054 RealScalar residualNorm2 = residual.squaredNorm(); 00055 if (residualNorm2 < threshold) 00056 { 00057 iters = 0; 00058 tol_error = sqrt(residualNorm2 / rhsNorm2); 00059 return; 00060 } 00061 00062 VectorType p(n); 00063 p = precond.solve(residual); // initial search direction 00064 00065 VectorType z(n), tmp(n); 00066 RealScalar absNew = numext::real(residual.dot(p)); // the square of the absolute value of r scaled by invM 00067 Index i = 0; 00068 while(i < maxIters) 00069 { 00070 tmp.noalias() = mat * p; // the bottleneck of the algorithm 00071 00072 Scalar alpha = absNew / p.dot(tmp); // the amount we travel on dir 00073 x += alpha * p; // update solution 00074 residual -= alpha * tmp; // update residual 00075 00076 residualNorm2 = residual.squaredNorm(); 00077 if(residualNorm2 < threshold) 00078 break; 00079 00080 z = precond.solve(residual); // approximately solve for "A z = residual" 00081 00082 RealScalar absOld = absNew; 00083 absNew = numext::real(residual.dot(z)); // update the absolute value of r 00084 RealScalar beta = absNew / absOld; // calculate the Gram-Schmidt value used to create the new search direction 00085 p = z + beta * p; // update search direction 00086 i++; 00087 } 00088 tol_error = sqrt(residualNorm2 / rhsNorm2); 00089 iters = i; 00090 } 00091 00092 } 00093 00094 template< typename _MatrixType, int _UpLo=Lower, 00095 typename _Preconditioner = DiagonalPreconditioner<typename _MatrixType::Scalar> > 00096 class ConjugateGradient; 00097 00098 namespace internal { 00099 00100 template< typename _MatrixType, int _UpLo, typename _Preconditioner> 00101 struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> > 00102 { 00103 typedef _MatrixType MatrixType; 00104 typedef _Preconditioner Preconditioner; 00105 }; 00106 00107 } 00108 00156 template< typename _MatrixType, int _UpLo, typename _Preconditioner> 00157 class ConjugateGradient : public IterativeSolverBase<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> > 00158 { 00159 typedef IterativeSolverBase<ConjugateGradient> Base; 00160 using Base::matrix; 00161 using Base::m_error; 00162 using Base::m_iterations; 00163 using Base::m_info; 00164 using Base::m_isInitialized; 00165 public: 00166 typedef _MatrixType MatrixType; 00167 typedef typename MatrixType::Scalar Scalar; 00168 typedef typename MatrixType::RealScalar RealScalar; 00169 typedef _Preconditioner Preconditioner; 00170 00171 enum { 00172 UpLo = _UpLo 00173 }; 00174 00175 public: 00176 00178 ConjugateGradient() : Base() {} 00179 00190 template<typename MatrixDerived> 00191 explicit ConjugateGradient(const EigenBase<MatrixDerived>& A) : Base(A.derived()) {} 00192 00193 ~ConjugateGradient() {} 00194 00196 template<typename Rhs,typename Dest> 00197 void _solve_with_guess_impl(const Rhs& b, Dest& x) const 00198 { 00199 typedef typename Base::MatrixWrapper MatrixWrapper; 00200 typedef typename Base::ActualMatrixType ActualMatrixType; 00201 enum { 00202 TransposeInput = (!MatrixWrapper::MatrixFree) 00203 && (UpLo==(Lower|Upper)) 00204 && (!MatrixType::IsRowMajor) 00205 && (!NumTraits<Scalar>::IsComplex) 00206 }; 00207 typedef typename internal::conditional<TransposeInput,Transpose<const ActualMatrixType>, ActualMatrixType const&>::type RowMajorWrapper; 00208 EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(MatrixWrapper::MatrixFree,UpLo==(Lower|Upper)),MATRIX_FREE_CONJUGATE_GRADIENT_IS_COMPATIBLE_WITH_UPPER_UNION_LOWER_MODE_ONLY); 00209 typedef typename internal::conditional<UpLo==(Lower|Upper), 00210 RowMajorWrapper, 00211 typename MatrixWrapper::template ConstSelfAdjointViewReturnType<UpLo>::Type 00212 >::type SelfAdjointWrapper; 00213 m_iterations = Base::maxIterations(); 00214 m_error = Base::m_tolerance; 00215 00216 for(Index j=0; j<b.cols(); ++j) 00217 { 00218 m_iterations = Base::maxIterations(); 00219 m_error = Base::m_tolerance; 00220 00221 typename Dest::ColXpr xj(x,j); 00222 RowMajorWrapper row_mat(matrix()); 00223 internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error); 00224 } 00225 00226 m_isInitialized = true; 00227 m_info = m_error <= Base::m_tolerance ? Success : NoConvergence; 00228 } 00229 00231 using Base::_solve_impl; 00232 template<typename Rhs,typename Dest> 00233 void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const 00234 { 00235 x.setZero(); 00236 _solve_with_guess_impl(b.derived(),x); 00237 } 00238 00239 protected: 00240 00241 }; 00242 00243 } // end namespace Eigen 00244 00245 #endif // EIGEN_CONJUGATE_GRADIENT_H