Eigen  3.3.3
BlasUtil.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_BLASUTIL_H
00011 #define EIGEN_BLASUTIL_H
00012 
00013 // This file contains many lightweight helper classes used to
00014 // implement and control fast level 2 and level 3 BLAS-like routines.
00015 
00016 namespace Eigen {
00017 
00018 namespace internal {
00019 
00020 // forward declarations
00021 template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
00022 struct gebp_kernel;
00023 
00024 template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
00025 struct gemm_pack_rhs;
00026 
00027 template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
00028 struct gemm_pack_lhs;
00029 
00030 template<
00031   typename Index,
00032   typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
00033   typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
00034   int ResStorageOrder>
00035 struct general_matrix_matrix_product;
00036 
00037 template<typename Index,
00038          typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
00039          typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized>
00040 struct general_matrix_vector_product;
00041 
00042 
00043 template<bool Conjugate> struct conj_if;
00044 
00045 template<> struct conj_if<true> {
00046   template<typename T>
00047   inline T operator()(const T& x) const { return numext::conj(x); }
00048   template<typename T>
00049   inline T pconj(const T& x) const { return internal::pconj(x); }
00050 };
00051 
00052 template<> struct conj_if<false> {
00053   template<typename T>
00054   inline const T& operator()(const T& x) const { return x; }
00055   template<typename T>
00056   inline const T& pconj(const T& x) const { return x; }
00057 };
00058 
00059 // Generic implementation for custom complex types.
00060 template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs>
00061 struct conj_helper
00062 {
00063   typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar;
00064 
00065   EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const
00066   { return padd(c, pmul(x,y)); }
00067 
00068   EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const
00069   { return conj_if<ConjLhs>()(x) *  conj_if<ConjRhs>()(y); }
00070 };
00071 
00072 template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
00073 {
00074   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
00075   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
00076 };
00077 
00078 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
00079 {
00080   typedef std::complex<RealScalar> Scalar;
00081   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00082   { return c + pmul(x,y); }
00083 
00084   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00085   { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
00086 };
00087 
00088 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
00089 {
00090   typedef std::complex<RealScalar> Scalar;
00091   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00092   { return c + pmul(x,y); }
00093 
00094   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00095   { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
00096 };
00097 
00098 template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
00099 {
00100   typedef std::complex<RealScalar> Scalar;
00101   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
00102   { return c + pmul(x,y); }
00103 
00104   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
00105   { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
00106 };
00107 
00108 template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
00109 {
00110   typedef std::complex<RealScalar> Scalar;
00111   EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
00112   { return padd(c, pmul(x,y)); }
00113   EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
00114   { return conj_if<Conj>()(x)*y; }
00115 };
00116 
00117 template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
00118 {
00119   typedef std::complex<RealScalar> Scalar;
00120   EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
00121   { return padd(c, pmul(x,y)); }
00122   EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
00123   { return x*conj_if<Conj>()(y); }
00124 };
00125 
00126 template<typename From,typename To> struct get_factor {
00127   EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); }
00128 };
00129 
00130 template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
00131   EIGEN_DEVICE_FUNC
00132   static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
00133 };
00134 
00135 
00136 template<typename Scalar, typename Index>
00137 class BlasVectorMapper {
00138   public:
00139   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}
00140 
00141   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
00142     return m_data[i];
00143   }
00144   template <typename Packet, int AlignmentType>
00145   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const {
00146     return ploadt<Packet, AlignmentType>(m_data + i);
00147   }
00148 
00149   template <typename Packet>
00150   EIGEN_DEVICE_FUNC bool aligned(Index i) const {
00151     return (UIntPtr(m_data+i)%sizeof(Packet))==0;
00152   }
00153 
00154   protected:
00155   Scalar* m_data;
00156 };
00157 
00158 template<typename Scalar, typename Index, int AlignmentType>
00159 class BlasLinearMapper {
00160   public:
00161   typedef typename packet_traits<Scalar>::type Packet;
00162   typedef typename packet_traits<Scalar>::half HalfPacket;
00163 
00164   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
00165 
00166   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
00167     internal::prefetch(&operator()(i));
00168   }
00169 
00170   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
00171     return m_data[i];
00172   }
00173 
00174   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
00175     return ploadt<Packet, AlignmentType>(m_data + i);
00176   }
00177 
00178   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
00179     return ploadt<HalfPacket, AlignmentType>(m_data + i);
00180   }
00181 
00182   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const {
00183     pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
00184   }
00185 
00186   protected:
00187   Scalar *m_data;
00188 };
00189 
00190 // Lightweight helper class to access matrix coefficients.
00191 template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
00192 class blas_data_mapper {
00193   public:
00194   typedef typename packet_traits<Scalar>::type Packet;
00195   typedef typename packet_traits<Scalar>::half HalfPacket;
00196 
00197   typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
00198   typedef BlasVectorMapper<Scalar, Index> VectorMapper;
00199 
00200   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
00201 
00202   EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
00203   getSubMapper(Index i, Index j) const {
00204     return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
00205   }
00206 
00207   EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
00208     return LinearMapper(&operator()(i, j));
00209   }
00210 
00211   EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
00212     return VectorMapper(&operator()(i, j));
00213   }
00214 
00215 
00216   EIGEN_DEVICE_FUNC
00217   EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
00218     return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride];
00219   }
00220 
00221   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
00222     return ploadt<Packet, AlignmentType>(&operator()(i, j));
00223   }
00224 
00225   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
00226     return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
00227   }
00228 
00229   template<typename SubPacket>
00230   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
00231     pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
00232   }
00233 
00234   template<typename SubPacket>
00235   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
00236     return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
00237   }
00238 
00239   EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
00240   EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }
00241 
00242   EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
00243     if (UIntPtr(m_data)%sizeof(Scalar)) {
00244       return -1;
00245     }
00246     return internal::first_default_aligned(m_data, size);
00247   }
00248 
00249   protected:
00250   Scalar* EIGEN_RESTRICT m_data;
00251   const Index m_stride;
00252 };
00253 
00254 // lightweight helper class to access matrix coefficients (const version)
00255 template<typename Scalar, typename Index, int StorageOrder>
00256 class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
00257   public:
00258   EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}
00259 
00260   EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
00261     return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
00262   }
00263 };
00264 
00265 
00266 /* Helper class to analyze the factors of a Product expression.
00267  * In particular it allows to pop out operator-, scalar multiples,
00268  * and conjugate */
00269 template<typename XprType> struct blas_traits
00270 {
00271   typedef typename traits<XprType>::Scalar Scalar;
00272   typedef const XprType& ExtractType;
00273   typedef XprType _ExtractType;
00274   enum {
00275     IsComplex = NumTraits<Scalar>::IsComplex,
00276     IsTransposed = false,
00277     NeedToConjugate = false,
00278     HasUsableDirectAccess = (    (int(XprType::Flags)&DirectAccessBit)
00279                               && (   bool(XprType::IsVectorAtCompileTime)
00280                                   || int(inner_stride_at_compile_time<XprType>::ret) == 1)
00281                              ) ?  1 : 0
00282   };
00283   typedef typename conditional<bool(HasUsableDirectAccess),
00284     ExtractType,
00285     typename _ExtractType::PlainObject
00286     >::type DirectLinearAccessType;
00287   static inline ExtractType extract(const XprType& x) { return x; }
00288   static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
00289 };
00290 
00291 // pop conjugate
00292 template<typename Scalar, typename NestedXpr>
00293 struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
00294  : blas_traits<NestedXpr>
00295 {
00296   typedef blas_traits<NestedXpr> Base;
00297   typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
00298   typedef typename Base::ExtractType ExtractType;
00299 
00300   enum {
00301     IsComplex = NumTraits<Scalar>::IsComplex,
00302     NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
00303   };
00304   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00305   static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
00306 };
00307 
00308 // pop scalar multiple
00309 template<typename Scalar, typename NestedXpr, typename Plain>
00310 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
00311  : blas_traits<NestedXpr>
00312 {
00313   typedef blas_traits<NestedXpr> Base;
00314   typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
00315   typedef typename Base::ExtractType ExtractType;
00316   static inline ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); }
00317   static inline Scalar extractScalarFactor(const XprType& x)
00318   { return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
00319 };
00320 template<typename Scalar, typename NestedXpr, typename Plain>
00321 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
00322  : blas_traits<NestedXpr>
00323 {
00324   typedef blas_traits<NestedXpr> Base;
00325   typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
00326   typedef typename Base::ExtractType ExtractType;
00327   static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); }
00328   static inline Scalar extractScalarFactor(const XprType& x)
00329   { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; }
00330 };
00331 template<typename Scalar, typename Plain1, typename Plain2>
00332 struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>,
00333                                                             const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > >
00334  : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> >
00335 {};
00336 
00337 // pop opposite
00338 template<typename Scalar, typename NestedXpr>
00339 struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
00340  : blas_traits<NestedXpr>
00341 {
00342   typedef blas_traits<NestedXpr> Base;
00343   typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
00344   typedef typename Base::ExtractType ExtractType;
00345   static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
00346   static inline Scalar extractScalarFactor(const XprType& x)
00347   { return - Base::extractScalarFactor(x.nestedExpression()); }
00348 };
00349 
00350 // pop/push transpose
00351 template<typename NestedXpr>
00352 struct blas_traits<Transpose<NestedXpr> >
00353  : blas_traits<NestedXpr>
00354 {
00355   typedef typename NestedXpr::Scalar Scalar;
00356   typedef blas_traits<NestedXpr> Base;
00357   typedef Transpose<NestedXpr> XprType;
00358   typedef Transpose<const typename Base::_ExtractType>  ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
00359   typedef Transpose<const typename Base::_ExtractType> _ExtractType;
00360   typedef typename conditional<bool(Base::HasUsableDirectAccess),
00361     ExtractType,
00362     typename ExtractType::PlainObject
00363     >::type DirectLinearAccessType;
00364   enum {
00365     IsTransposed = Base::IsTransposed ? 0 : 1
00366   };
00367   static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
00368   static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
00369 };
00370 
00371 template<typename T>
00372 struct blas_traits<const T>
00373      : blas_traits<T>
00374 {};
00375 
00376 template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
00377 struct extract_data_selector {
00378   static const typename T::Scalar* run(const T& m)
00379   {
00380     return blas_traits<T>::extract(m).data();
00381   }
00382 };
00383 
00384 template<typename T>
00385 struct extract_data_selector<T,false> {
00386   static typename T::Scalar* run(const T&) { return 0; }
00387 };
00388 
00389 template<typename T> const typename T::Scalar* extract_data(const T& m)
00390 {
00391   return extract_data_selector<T>::run(m);
00392 }
00393 
00394 } // end namespace internal
00395 
00396 } // end namespace Eigen
00397 
00398 #endif // EIGEN_BLASUTIL_H
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends