![]() |
Eigen-unsupported
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de> 00005 // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de> 00006 // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net> 00007 // 00008 // This Source Code Form is subject to the terms of the Mozilla 00009 // Public License v. 2.0. If a copy of the MPL was not distributed 00010 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00011 00012 #ifndef KRONECKER_TENSOR_PRODUCT_H 00013 #define KRONECKER_TENSOR_PRODUCT_H 00014 00015 namespace Eigen { 00016 00024 template<typename Derived> 00025 class KroneckerProductBase : public ReturnByValue<Derived> 00026 { 00027 private: 00028 typedef typename internal::traits<Derived> Traits; 00029 typedef typename Traits::Scalar Scalar; 00030 00031 protected: 00032 typedef typename Traits::Lhs Lhs; 00033 typedef typename Traits::Rhs Rhs; 00034 00035 public: 00037 KroneckerProductBase(const Lhs& A, const Rhs& B) 00038 : m_A(A), m_B(B) 00039 {} 00040 00041 inline Index rows() const { return m_A.rows() * m_B.rows(); } 00042 inline Index cols() const { return m_A.cols() * m_B.cols(); } 00043 00048 Scalar coeff(Index row, Index col) const 00049 { 00050 return m_A.coeff(row / m_B.rows(), col / m_B.cols()) * 00051 m_B.coeff(row % m_B.rows(), col % m_B.cols()); 00052 } 00053 00058 Scalar coeff(Index i) const 00059 { 00060 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived); 00061 return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size()); 00062 } 00063 00064 protected: 00065 typename Lhs::Nested m_A; 00066 typename Rhs::Nested m_B; 00067 }; 00068 00081 template<typename Lhs, typename Rhs> 00082 class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs,Rhs> > 00083 { 00084 private: 00085 typedef KroneckerProductBase<KroneckerProduct> Base; 00086 using Base::m_A; 00087 using Base::m_B; 00088 00089 public: 00091 KroneckerProduct(const Lhs& A, const Rhs& B) 00092 : Base(A, B) 00093 {} 00094 00096 template<typename Dest> void evalTo(Dest& dst) const; 00097 }; 00098 00114 template<typename Lhs, typename Rhs> 00115 class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs,Rhs> > 00116 { 00117 private: 00118 typedef KroneckerProductBase<KroneckerProductSparse> Base; 00119 using Base::m_A; 00120 using Base::m_B; 00121 00122 public: 00124 KroneckerProductSparse(const Lhs& A, const Rhs& B) 00125 : Base(A, B) 00126 {} 00127 00129 template<typename Dest> void evalTo(Dest& dst) const; 00130 }; 00131 00132 template<typename Lhs, typename Rhs> 00133 template<typename Dest> 00134 void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const 00135 { 00136 const int BlockRows = Rhs::RowsAtCompileTime, 00137 BlockCols = Rhs::ColsAtCompileTime; 00138 const Index Br = m_B.rows(), 00139 Bc = m_B.cols(); 00140 for (Index i=0; i < m_A.rows(); ++i) 00141 for (Index j=0; j < m_A.cols(); ++j) 00142 Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B; 00143 } 00144 00145 template<typename Lhs, typename Rhs> 00146 template<typename Dest> 00147 void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const 00148 { 00149 Index Br = m_B.rows(), Bc = m_B.cols(); 00150 dst.resize(this->rows(), this->cols()); 00151 dst.resizeNonZeros(0); 00152 00153 // 1 - evaluate the operands if needed: 00154 typedef typename internal::nested_eval<Lhs,Dynamic>::type Lhs1; 00155 typedef typename internal::remove_all<Lhs1>::type Lhs1Cleaned; 00156 const Lhs1 lhs1(m_A); 00157 typedef typename internal::nested_eval<Rhs,Dynamic>::type Rhs1; 00158 typedef typename internal::remove_all<Rhs1>::type Rhs1Cleaned; 00159 const Rhs1 rhs1(m_B); 00160 00161 // 2 - construct respective iterators 00162 typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator; 00163 typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator; 00164 00165 // compute number of non-zeros per innervectors of dst 00166 { 00167 // TODO VectorXi is not necessarily big enough! 00168 VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols()); 00169 for (Index kA=0; kA < m_A.outerSize(); ++kA) 00170 for (LhsInnerIterator itA(lhs1,kA); itA; ++itA) 00171 nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++; 00172 00173 VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols()); 00174 for (Index kB=0; kB < m_B.outerSize(); ++kB) 00175 for (RhsInnerIterator itB(rhs1,kB); itB; ++itB) 00176 nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++; 00177 00178 Matrix<int,Dynamic,Dynamic,ColMajor> nnzAB = nnzB * nnzA.transpose(); 00179 dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size())); 00180 } 00181 00182 for (Index kA=0; kA < m_A.outerSize(); ++kA) 00183 { 00184 for (Index kB=0; kB < m_B.outerSize(); ++kB) 00185 { 00186 for (LhsInnerIterator itA(lhs1,kA); itA; ++itA) 00187 { 00188 for (RhsInnerIterator itB(rhs1,kB); itB; ++itB) 00189 { 00190 Index i = itA.row() * Br + itB.row(), 00191 j = itA.col() * Bc + itB.col(); 00192 dst.insert(i,j) = itA.value() * itB.value(); 00193 } 00194 } 00195 } 00196 } 00197 } 00198 00199 namespace internal { 00200 00201 template<typename _Lhs, typename _Rhs> 00202 struct traits<KroneckerProduct<_Lhs,_Rhs> > 00203 { 00204 typedef typename remove_all<_Lhs>::type Lhs; 00205 typedef typename remove_all<_Rhs>::type Rhs; 00206 typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; 00207 typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex; 00208 00209 enum { 00210 Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret, 00211 Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret, 00212 MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret, 00213 MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret 00214 }; 00215 00216 typedef Matrix<Scalar,Rows,Cols> ReturnType; 00217 }; 00218 00219 template<typename _Lhs, typename _Rhs> 00220 struct traits<KroneckerProductSparse<_Lhs,_Rhs> > 00221 { 00222 typedef MatrixXpr XprKind; 00223 typedef typename remove_all<_Lhs>::type Lhs; 00224 typedef typename remove_all<_Rhs>::type Rhs; 00225 typedef typename ScalarBinaryOpTraits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; 00226 typedef typename cwise_promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind, scalar_product_op<typename Lhs::Scalar, typename Rhs::Scalar> >::ret StorageKind; 00227 typedef typename promote_index_type<typename Lhs::StorageIndex, typename Rhs::StorageIndex>::type StorageIndex; 00228 00229 enum { 00230 LhsFlags = Lhs::Flags, 00231 RhsFlags = Rhs::Flags, 00232 00233 RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret, 00234 ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret, 00235 MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret, 00236 MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret, 00237 00238 EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit), 00239 RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), 00240 00241 Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) 00242 | EvalBeforeNestingBit, 00243 CoeffReadCost = HugeCost 00244 }; 00245 00246 typedef SparseMatrix<Scalar, 0, StorageIndex> ReturnType; 00247 }; 00248 00249 } // end namespace internal 00250 00270 template<typename A, typename B> 00271 KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<B>& b) 00272 { 00273 return KroneckerProduct<A, B>(a.derived(), b.derived()); 00274 } 00275 00297 template<typename A, typename B> 00298 KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenBase<B>& b) 00299 { 00300 return KroneckerProductSparse<A,B>(a.derived(), b.derived()); 00301 } 00302 00303 } // end namespace Eigen 00304 00305 #endif // KRONECKER_TENSOR_PRODUCT_H