![]() |
Eigen
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SOLVETRIANGULAR_H 00011 #define EIGEN_SOLVETRIANGULAR_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 // Forward declarations: 00018 // The following two routines are implemented in the products/TriangularSolver*.h files 00019 template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder> 00020 struct triangular_solve_vector; 00021 00022 template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder> 00023 struct triangular_solve_matrix; 00024 00025 // small helper struct extracting some traits on the underlying solver operation 00026 template<typename Lhs, typename Rhs, int Side> 00027 class trsolve_traits 00028 { 00029 private: 00030 enum { 00031 RhsIsVectorAtCompileTime = (Side==OnTheLeft ? Rhs::ColsAtCompileTime : Rhs::RowsAtCompileTime)==1 00032 }; 00033 public: 00034 enum { 00035 Unrolling = (RhsIsVectorAtCompileTime && Rhs::SizeAtCompileTime != Dynamic && Rhs::SizeAtCompileTime <= 8) 00036 ? CompleteUnrolling : NoUnrolling, 00037 RhsVectors = RhsIsVectorAtCompileTime ? 1 : Dynamic 00038 }; 00039 }; 00040 00041 template<typename Lhs, typename Rhs, 00042 int Side, // can be OnTheLeft/OnTheRight 00043 int Mode, // can be Upper/Lower | UnitDiag 00044 int Unrolling = trsolve_traits<Lhs,Rhs,Side>::Unrolling, 00045 int RhsVectors = trsolve_traits<Lhs,Rhs,Side>::RhsVectors 00046 > 00047 struct triangular_solver_selector; 00048 00049 template<typename Lhs, typename Rhs, int Side, int Mode> 00050 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,1> 00051 { 00052 typedef typename Lhs::Scalar LhsScalar; 00053 typedef typename Rhs::Scalar RhsScalar; 00054 typedef blas_traits<Lhs> LhsProductTraits; 00055 typedef typename LhsProductTraits::ExtractType ActualLhsType; 00056 typedef Map<Matrix<RhsScalar,Dynamic,1>, Aligned> MappedRhs; 00057 static void run(const Lhs& lhs, Rhs& rhs) 00058 { 00059 ActualLhsType actualLhs = LhsProductTraits::extract(lhs); 00060 00061 // FIXME find a way to allow an inner stride if packet_traits<Scalar>::size==1 00062 00063 bool useRhsDirectly = Rhs::InnerStrideAtCompileTime==1 || rhs.innerStride()==1; 00064 00065 ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhs,rhs.size(), 00066 (useRhsDirectly ? rhs.data() : 0)); 00067 00068 if(!useRhsDirectly) 00069 MappedRhs(actualRhs,rhs.size()) = rhs; 00070 00071 triangular_solve_vector<LhsScalar, RhsScalar, Index, Side, Mode, LhsProductTraits::NeedToConjugate, 00072 (int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor> 00073 ::run(actualLhs.cols(), actualLhs.data(), actualLhs.outerStride(), actualRhs); 00074 00075 if(!useRhsDirectly) 00076 rhs = MappedRhs(actualRhs, rhs.size()); 00077 } 00078 }; 00079 00080 // the rhs is a matrix 00081 template<typename Lhs, typename Rhs, int Side, int Mode> 00082 struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic> 00083 { 00084 typedef typename Rhs::Scalar Scalar; 00085 typedef blas_traits<Lhs> LhsProductTraits; 00086 typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; 00087 00088 static void run(const Lhs& lhs, Rhs& rhs) 00089 { 00090 typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsProductTraits::extract(lhs); 00091 00092 const Index size = lhs.rows(); 00093 const Index othersize = Side==OnTheLeft? rhs.cols() : rhs.rows(); 00094 00095 typedef internal::gemm_blocking_space<(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar, 00096 Rhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxRowsAtCompileTime,4> BlockingType; 00097 00098 BlockingType blocking(rhs.rows(), rhs.cols(), size, 1, false); 00099 00100 triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor, 00101 (Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor> 00102 ::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking); 00103 } 00104 }; 00105 00106 /*************************************************************************** 00107 * meta-unrolling implementation 00108 ***************************************************************************/ 00109 00110 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size, 00111 bool Stop = LoopIndex==Size> 00112 struct triangular_solver_unroller; 00113 00114 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size> 00115 struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,false> { 00116 enum { 00117 IsLower = ((Mode&Lower)==Lower), 00118 DiagIndex = IsLower ? LoopIndex : Size - LoopIndex - 1, 00119 StartIndex = IsLower ? 0 : DiagIndex+1 00120 }; 00121 static void run(const Lhs& lhs, Rhs& rhs) 00122 { 00123 if (LoopIndex>0) 00124 rhs.coeffRef(DiagIndex) -= lhs.row(DiagIndex).template segment<LoopIndex>(StartIndex).transpose() 00125 .cwiseProduct(rhs.template segment<LoopIndex>(StartIndex)).sum(); 00126 00127 if(!(Mode & UnitDiag)) 00128 rhs.coeffRef(DiagIndex) /= lhs.coeff(DiagIndex,DiagIndex); 00129 00130 triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex+1,Size>::run(lhs,rhs); 00131 } 00132 }; 00133 00134 template<typename Lhs, typename Rhs, int Mode, int LoopIndex, int Size> 00135 struct triangular_solver_unroller<Lhs,Rhs,Mode,LoopIndex,Size,true> { 00136 static void run(const Lhs&, Rhs&) {} 00137 }; 00138 00139 template<typename Lhs, typename Rhs, int Mode> 00140 struct triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,CompleteUnrolling,1> { 00141 static void run(const Lhs& lhs, Rhs& rhs) 00142 { triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } 00143 }; 00144 00145 template<typename Lhs, typename Rhs, int Mode> 00146 struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> { 00147 static void run(const Lhs& lhs, Rhs& rhs) 00148 { 00149 Transpose<const Lhs> trLhs(lhs); 00150 Transpose<Rhs> trRhs(rhs); 00151 00152 triangular_solver_unroller<Transpose<const Lhs>,Transpose<Rhs>, 00153 ((Mode&Upper)==Upper ? Lower : Upper) | (Mode&UnitDiag), 00154 0,Rhs::SizeAtCompileTime>::run(trLhs,trRhs); 00155 } 00156 }; 00157 00158 } // end namespace internal 00159 00160 /*************************************************************************** 00161 * TriangularView methods 00162 ***************************************************************************/ 00163 00164 #ifndef EIGEN_PARSED_BY_DOXYGEN 00165 template<typename MatrixType, unsigned int Mode> 00166 template<int Side, typename OtherDerived> 00167 void TriangularViewImpl<MatrixType,Mode,Dense>::solveInPlace(const MatrixBase<OtherDerived>& _other) const 00168 { 00169 OtherDerived& other = _other.const_cast_derived(); 00170 eigen_assert( derived().cols() == derived().rows() && ((Side==OnTheLeft && derived().cols() == other.rows()) || (Side==OnTheRight && derived().cols() == other.cols())) ); 00171 eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper|Lower))); 00172 00173 enum { copy = (internal::traits<OtherDerived>::Flags & RowMajorBit) && OtherDerived::IsVectorAtCompileTime && OtherDerived::SizeAtCompileTime!=1}; 00174 typedef typename internal::conditional<copy, 00175 typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>::type OtherCopy; 00176 OtherCopy otherCopy(other); 00177 00178 internal::triangular_solver_selector<MatrixType, typename internal::remove_reference<OtherCopy>::type, 00179 Side, Mode>::run(derived().nestedExpression(), otherCopy); 00180 00181 if (copy) 00182 other = otherCopy; 00183 } 00184 00185 template<typename Derived, unsigned int Mode> 00186 template<int Side, typename Other> 00187 const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other> 00188 TriangularViewImpl<Derived,Mode,Dense>::solve(const MatrixBase<Other>& other) const 00189 { 00190 return internal::triangular_solve_retval<Side,TriangularViewType,Other>(derived(), other.derived()); 00191 } 00192 #endif 00193 00194 namespace internal { 00195 00196 00197 template<int Side, typename TriangularType, typename Rhs> 00198 struct traits<triangular_solve_retval<Side, TriangularType, Rhs> > 00199 { 00200 typedef typename internal::plain_matrix_type_column_major<Rhs>::type ReturnType; 00201 }; 00202 00203 template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval 00204 : public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> > 00205 { 00206 typedef typename remove_all<typename Rhs::Nested>::type RhsNestedCleaned; 00207 typedef ReturnByValue<triangular_solve_retval> Base; 00208 00209 triangular_solve_retval(const TriangularType& tri, const Rhs& rhs) 00210 : m_triangularMatrix(tri), m_rhs(rhs) 00211 {} 00212 00213 inline Index rows() const { return m_rhs.rows(); } 00214 inline Index cols() const { return m_rhs.cols(); } 00215 00216 template<typename Dest> inline void evalTo(Dest& dst) const 00217 { 00218 if(!is_same_dense(dst,m_rhs)) 00219 dst = m_rhs; 00220 m_triangularMatrix.template solveInPlace<Side>(dst); 00221 } 00222 00223 protected: 00224 const TriangularType& m_triangularMatrix; 00225 typename Rhs::Nested m_rhs; 00226 }; 00227 00228 } // namespace internal 00229 00230 } // end namespace Eigen 00231 00232 #endif // EIGEN_SOLVETRIANGULAR_H