KroneckerTensorProduct.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de>
00005 // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de>
00006 // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
00007 //
00008 // This Source Code Form is subject to the terms of the Mozilla
00009 // Public License v. 2.0. If a copy of the MPL was not distributed
00010 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00011 
00012 #ifndef KRONECKER_TENSOR_PRODUCT_H
00013 #define KRONECKER_TENSOR_PRODUCT_H
00014 
00015 namespace Eigen {
00016 
00024 template<typename Derived>
00025 class KroneckerProductBase : public ReturnByValue<Derived>
00026 {
00027   private:
00028     typedef typename internal::traits<Derived> Traits;
00029     typedef typename Traits::Scalar Scalar;
00030 
00031   protected:
00032     typedef typename Traits::Lhs Lhs;
00033     typedef typename Traits::Rhs Rhs;
00034 
00035   public:
00037     KroneckerProductBase(const Lhs& A, const Rhs& B)
00038       : m_A(A), m_B(B)
00039     {}
00040 
00041     inline Index rows() const { return m_A.rows() * m_B.rows(); }
00042     inline Index cols() const { return m_A.cols() * m_B.cols(); }
00043 
00048     Scalar coeff(Index row, Index col) const
00049     {
00050       return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
00051              m_B.coeff(row % m_B.rows(), col % m_B.cols());
00052     }
00053 
00058     Scalar coeff(Index i) const
00059     {
00060       EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
00061       return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
00062     }
00063 
00064   protected:
00065     typename Lhs::Nested m_A;
00066     typename Rhs::Nested m_B;
00067 };
00068 
00081 template<typename Lhs, typename Rhs>
00082 class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> >
00083 {
00084   private:
00085     typedef KroneckerProductBase<KroneckerProduct> Base;
00086     using Base::m_A;
00087     using Base::m_B;
00088 
00089   public:
00091     KroneckerProduct(const Lhs& A, const Rhs& B)
00092       : Base(A, B)
00093     {}
00094 
00096     template<typename Dest> void evalTo(Dest& dst) const;
00097 };
00098 
00114 template<typename Lhs, typename Rhs>
00115 class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> >
00116 {
00117   private:
00118     typedef KroneckerProductBase<KroneckerProductSparse> Base;
00119     using Base::m_A;
00120     using Base::m_B;
00121 
00122   public:
00124     KroneckerProductSparse(const Lhs& A, const Rhs& B)
00125       : Base(A, B)
00126     {}
00127 
00129     template<typename Dest> void evalTo(Dest& dst) const;
00130 };
00131 
00132 template<typename Lhs, typename Rhs>
00133 template<typename Dest>
00134 void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
00135 {
00136   const int BlockRows = Rhs::RowsAtCompileTime,
00137             BlockCols = Rhs::ColsAtCompileTime;
00138   const Index Br = m_B.rows(),
00139               Bc = m_B.cols();
00140   for (Index i=0; i < m_A.rows(); ++i)
00141     for (Index j=0; j < m_A.cols(); ++j)
00142       Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B;
00143 }
00144 
00145 template<typename Lhs, typename Rhs>
00146 template<typename Dest>
00147 void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
00148 {
00149   Index Br = m_B.rows(), Bc = m_B.cols();
00150   dst.resize(this->rows(), this->cols());
00151   dst.resizeNonZeros(0);
00152   
00153   // 1 - evaluate the operands if needed:
00154   typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1;
00155   typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned;
00156   const Lhs1 lhs1(m_A);
00157   typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1;
00158   typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned;
00159   const Rhs1 rhs1(m_B);
00160     
00161   // 2 - construct respective iterators
00162   typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator;
00163   typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator;
00164   
00165   // compute number of non-zeros per innervectors of dst
00166   {
00167     // TODO VectorXi is not necessarily big enough!
00168     VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
00169     for (Index kA=0; kA < m_A.outerSize(); ++kA)
00170       for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
00171         nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
00172       
00173     VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
00174     for (Index kB=0; kB < m_B.outerSize(); ++kB)
00175       for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
00176         nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
00177     
00178     Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose();
00179     dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
00180   }
00181 
00182   for (Index kA=0; kA < m_A.outerSize(); ++kA)
00183   {
00184     for (Index kB=0; kB < m_B.outerSize(); ++kB)
00185     {
00186       for (LhsInnerIterator itA(lhs1,kA); itA; ++itA)
00187       {
00188         for (RhsInnerIterator itB(rhs1,kB); itB; ++itB)
00189         {
00190           Index i = itA.row() * Br + itB.row(),
00191                 j = itA.col() * Bc + itB.col();
00192           dst.insert(i,j) = itA.value() * itB.value();
00193         }
00194       }
00195     }
00196   }
00197 }
00198 
00199 namespace internal {
00200 
00201 template<typename _Lhs, typename _Rhs>
00202 struct traits<KroneckerProduct<_Lhs,_Rhs> >
00203 {
00204   typedef typename remove_all<_Lhs>::type Lhs;
00205   typedef typename remove_all<_Rhs>::type Rhs;
00206   typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
00207   typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
00208 
00209   enum {
00210     Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
00211     Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
00212     MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
00213     MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret
00214   };
00215 
00216   typedef Matrix<Scalar,Rows,Cols> ReturnType;
00217 };
00218 
00219 template<typename _Lhs, typename _Rhs>
00220 struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
00221 {
00222   typedef MatrixXpr XprKind;
00223   typedef typename remove_all<_Lhs>::type Lhs;
00224   typedef typename remove_all<_Rhs>::type Rhs;
00225   typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
00226   typedef typename cwise_promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind, scalar_product_op<typename Lhs::Scalar, typename Rhs::Scalar> >::ret StorageKind;
00227   typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex;
00228 
00229   enum {
00230     LhsFlags = Lhs::Flags,
00231     RhsFlags = Rhs::Flags,
00232 
00233     RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
00234     ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
00235     MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
00236     MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
00237 
00238     EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
00239     RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
00240 
00241     Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
00242           | EvalBeforeNestingBit,
00243     CoeffReadCost = HugeCost
00244   };
00245 
00246   typedef SparseMatrix<Scalar, 0, StorageIndex> ReturnType;
00247 };
00248 
00249 } // end namespace internal
00250 
00270 template<typename A, typename B>
00271 KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<B>& b)
00272 {
00273   return KroneckerProduct<A, B>(a.derived(), b.derived());
00274 }
00275 
00297 template<typename A, typename B>
00298 KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenBase<B>& b)
00299 {
00300   return KroneckerProductSparse<A,B>(a.derived(), b.derived());
00301 }
00302 
00303 } // end namespace Eigen
00304 
00305 #endif // KRONECKER_TENSOR_PRODUCT_H
 All Classes Functions Variables Typedefs Enumerator