Eigen  3.3.3
SparseTriangularView.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 // Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr>
00006 //
00007 // This Source Code Form is subject to the terms of the Mozilla
00008 // Public License v. 2.0. If a copy of the MPL was not distributed
00009 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00010 
00011 #ifndef EIGEN_SPARSE_TRIANGULARVIEW_H
00012 #define EIGEN_SPARSE_TRIANGULARVIEW_H
00013 
00014 namespace Eigen {
00015 
00025 template<typename MatrixType, unsigned int Mode> class TriangularViewImpl<MatrixType,Mode,Sparse>
00026   : public SparseMatrixBase<TriangularView<MatrixType,Mode> >
00027 {
00028     enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit))
00029                     || ((Mode&Upper) &&  (MatrixType::Flags&RowMajorBit)),
00030            SkipLast = !SkipFirst,
00031            SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
00032            HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
00033     };
00034     
00035     typedef TriangularView<MatrixType,Mode> TriangularViewType;
00036     
00037   protected:
00038     // dummy solve function to make TriangularView happy.
00039     void solve() const;
00040 
00041     typedef SparseMatrixBase<TriangularViewType> Base;
00042   public:
00043     
00044     EIGEN_SPARSE_PUBLIC_INTERFACE(TriangularViewType)
00045     
00046     typedef typename MatrixType::Nested MatrixTypeNested;
00047     typedef typename internal::remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
00048     typedef typename internal::remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
00049 
00050     template<typename RhsType, typename DstType>
00051     EIGEN_DEVICE_FUNC
00052     EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const {
00053       if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs)))
00054         dst = rhs;
00055       this->solveInPlace(dst);
00056     }
00057 
00059     template<typename OtherDerived> void solveInPlace(MatrixBase<OtherDerived>& other) const;
00060 
00062     template<typename OtherDerived> void solveInPlace(SparseMatrixBase<OtherDerived>& other) const;
00063   
00064 };
00065 
00066 namespace internal {
00067 
00068 template<typename ArgType, unsigned int Mode>
00069 struct unary_evaluator<TriangularView<ArgType,Mode>, IteratorBased>
00070  : evaluator_base<TriangularView<ArgType,Mode> >
00071 {
00072   typedef TriangularView<ArgType,Mode> XprType;
00073   
00074 protected:
00075   
00076   typedef typename XprType::Scalar Scalar;
00077   typedef typename XprType::StorageIndex StorageIndex;
00078   typedef typename evaluator<ArgType>::InnerIterator EvalIterator;
00079   
00080   enum { SkipFirst = ((Mode&Lower) && !(ArgType::Flags&RowMajorBit))
00081                     || ((Mode&Upper) &&  (ArgType::Flags&RowMajorBit)),
00082          SkipLast = !SkipFirst,
00083          SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
00084          HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
00085   };
00086   
00087 public:
00088   
00089   enum {
00090     CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
00091     Flags = XprType::Flags
00092   };
00093     
00094   explicit unary_evaluator(const XprType &xpr) : m_argImpl(xpr.nestedExpression()), m_arg(xpr.nestedExpression()) {}
00095   
00096   inline Index nonZerosEstimate() const {
00097     return m_argImpl.nonZerosEstimate();
00098   }
00099   
00100   class InnerIterator : public EvalIterator
00101   {
00102       typedef EvalIterator Base;
00103     public:
00104 
00105       EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& xprEval, Index outer)
00106         : Base(xprEval.m_argImpl,outer), m_returnOne(false), m_containsDiag(Base::outer()<xprEval.m_arg.innerSize())
00107       {
00108         if(SkipFirst)
00109         {
00110           while((*this) && ((HasUnitDiag||SkipDiag)  ? this->index()<=outer : this->index()<outer))
00111             Base::operator++();
00112           if(HasUnitDiag)
00113             m_returnOne = m_containsDiag;
00114         }
00115         else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer()))
00116         {
00117           if((!SkipFirst) && Base::operator bool())
00118             Base::operator++();
00119           m_returnOne = m_containsDiag;
00120         }
00121       }
00122 
00123       EIGEN_STRONG_INLINE InnerIterator& operator++()
00124       {
00125         if(HasUnitDiag && m_returnOne)
00126           m_returnOne = false;
00127         else
00128         {
00129           Base::operator++();
00130           if(HasUnitDiag && (!SkipFirst) && ((!Base::operator bool()) || Base::index()>=Base::outer()))
00131           {
00132             if((!SkipFirst) && Base::operator bool())
00133               Base::operator++();
00134             m_returnOne = m_containsDiag;
00135           }
00136         }
00137         return *this;
00138       }
00139       
00140       EIGEN_STRONG_INLINE operator bool() const
00141       {
00142         if(HasUnitDiag && m_returnOne)
00143           return true;
00144         if(SkipFirst) return  Base::operator bool();
00145         else
00146         {
00147           if (SkipDiag) return (Base::operator bool() && this->index() < this->outer());
00148           else return (Base::operator bool() && this->index() <= this->outer());
00149         }
00150       }
00151 
00152 //       inline Index row() const { return (ArgType::Flags&RowMajorBit ? Base::outer() : this->index()); }
00153 //       inline Index col() const { return (ArgType::Flags&RowMajorBit ? this->index() : Base::outer()); }
00154       inline StorageIndex index() const
00155       {
00156         if(HasUnitDiag && m_returnOne)  return internal::convert_index<StorageIndex>(Base::outer());
00157         else                            return Base::index();
00158       }
00159       inline Scalar value() const
00160       {
00161         if(HasUnitDiag && m_returnOne)  return Scalar(1);
00162         else                            return Base::value();
00163       }
00164 
00165     protected:
00166       bool m_returnOne;
00167       bool m_containsDiag;
00168     private:
00169       Scalar& valueRef();
00170   };
00171   
00172 protected:
00173   evaluator<ArgType> m_argImpl;
00174   const ArgType& m_arg;
00175 };
00176 
00177 } // end namespace internal
00178 
00179 template<typename Derived>
00180 template<int Mode>
00181 inline const TriangularView<const Derived, Mode>
00182 SparseMatrixBase<Derived>::triangularView() const
00183 {
00184   return TriangularView<const Derived, Mode>(derived());
00185 }
00186 
00187 } // end namespace Eigen
00188 
00189 #endif // EIGEN_SPARSE_TRIANGULARVIEW_H
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends