![]() |
Eigen
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SPARSEDENSEPRODUCT_H 00011 #define EIGEN_SPARSEDENSEPRODUCT_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; }; 00018 template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; }; 00019 00020 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, 00021 typename AlphaType, 00022 int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor, 00023 bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1> 00024 struct sparse_time_dense_product_impl; 00025 00026 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00027 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true> 00028 { 00029 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00030 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00031 typedef typename internal::remove_all<DenseResType>::type Res; 00032 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 00033 typedef evaluator<Lhs> LhsEval; 00034 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00035 { 00036 LhsEval lhsEval(lhs); 00037 00038 Index n = lhs.outerSize(); 00039 #ifdef EIGEN_HAS_OPENMP 00040 Eigen::initParallel(); 00041 Index threads = Eigen::nbThreads(); 00042 #endif 00043 00044 for(Index c=0; c<rhs.cols(); ++c) 00045 { 00046 #ifdef EIGEN_HAS_OPENMP 00047 // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems. 00048 // It basically represents the minimal amount of work to be done to be worth it. 00049 if(threads>1 && lhsEval.nonZerosEstimate() > 20000) 00050 { 00051 #pragma omp parallel for schedule(dynamic,(n+threads*4-1)/(threads*4)) num_threads(threads) 00052 for(Index i=0; i<n; ++i) 00053 processRow(lhsEval,rhs,res,alpha,i,c); 00054 } 00055 else 00056 #endif 00057 { 00058 for(Index i=0; i<n; ++i) 00059 processRow(lhsEval,rhs,res,alpha,i,c); 00060 } 00061 } 00062 } 00063 00064 static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha, Index i, Index col) 00065 { 00066 typename Res::Scalar tmp(0); 00067 for(LhsInnerIterator it(lhsEval,i); it ;++it) 00068 tmp += it.value() * rhs.coeff(it.index(),col); 00069 res.coeffRef(i,col) += alpha * tmp; 00070 } 00071 00072 }; 00073 00074 // FIXME: what is the purpose of the following specialization? Is it for the BlockedSparse format? 00075 // -> let's disable it for now as it is conflicting with generic scalar*matrix and matrix*scalar operators 00076 // template<typename T1, typename T2/*, int _Options, typename _StrideType*/> 00077 // struct ScalarBinaryOpTraits<T1, Ref<T2/*, _Options, _StrideType*/> > 00078 // { 00079 // enum { 00080 // Defined = 1 00081 // }; 00082 // typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType; 00083 // }; 00084 00085 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType> 00086 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true> 00087 { 00088 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00089 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00090 typedef typename internal::remove_all<DenseResType>::type Res; 00091 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 00092 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 00093 { 00094 evaluator<Lhs> lhsEval(lhs); 00095 for(Index c=0; c<rhs.cols(); ++c) 00096 { 00097 for(Index j=0; j<lhs.outerSize(); ++j) 00098 { 00099 // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c); 00100 typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c)); 00101 for(LhsInnerIterator it(lhsEval,j); it ;++it) 00102 res.coeffRef(it.index(),c) += it.value() * rhs_j; 00103 } 00104 } 00105 } 00106 }; 00107 00108 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00109 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false> 00110 { 00111 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00112 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00113 typedef typename internal::remove_all<DenseResType>::type Res; 00114 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 00115 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00116 { 00117 evaluator<Lhs> lhsEval(lhs); 00118 for(Index j=0; j<lhs.outerSize(); ++j) 00119 { 00120 typename Res::RowXpr res_j(res.row(j)); 00121 for(LhsInnerIterator it(lhsEval,j); it ;++it) 00122 res_j += (alpha*it.value()) * rhs.row(it.index()); 00123 } 00124 } 00125 }; 00126 00127 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType> 00128 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false> 00129 { 00130 typedef typename internal::remove_all<SparseLhsType>::type Lhs; 00131 typedef typename internal::remove_all<DenseRhsType>::type Rhs; 00132 typedef typename internal::remove_all<DenseResType>::type Res; 00133 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator; 00134 static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha) 00135 { 00136 evaluator<Lhs> lhsEval(lhs); 00137 for(Index j=0; j<lhs.outerSize(); ++j) 00138 { 00139 typename Rhs::ConstRowXpr rhs_j(rhs.row(j)); 00140 for(LhsInnerIterator it(lhsEval,j); it ;++it) 00141 res.row(it.index()) += (alpha*it.value()) * rhs_j; 00142 } 00143 } 00144 }; 00145 00146 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType> 00147 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha) 00148 { 00149 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha); 00150 } 00151 00152 } // end namespace internal 00153 00154 namespace internal { 00155 00156 template<typename Lhs, typename Rhs, int ProductType> 00157 struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> 00158 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> > 00159 { 00160 typedef typename Product<Lhs,Rhs>::Scalar Scalar; 00161 00162 template<typename Dest> 00163 static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) 00164 { 00165 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested; 00166 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==0) ? 1 : Dynamic>::type RhsNested; 00167 LhsNested lhsNested(lhs); 00168 RhsNested rhsNested(rhs); 00169 internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha); 00170 } 00171 }; 00172 00173 template<typename Lhs, typename Rhs, int ProductType> 00174 struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType> 00175 : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType> 00176 {}; 00177 00178 template<typename Lhs, typename Rhs, int ProductType> 00179 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> 00180 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> > 00181 { 00182 typedef typename Product<Lhs,Rhs>::Scalar Scalar; 00183 00184 template<typename Dst> 00185 static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) 00186 { 00187 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? Dynamic : 1>::type LhsNested; 00188 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested; 00189 LhsNested lhsNested(lhs); 00190 RhsNested rhsNested(rhs); 00191 00192 // transpose everything 00193 Transpose<Dst> dstT(dst); 00194 internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha); 00195 } 00196 }; 00197 00198 template<typename Lhs, typename Rhs, int ProductType> 00199 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType> 00200 : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType> 00201 {}; 00202 00203 template<typename LhsT, typename RhsT, bool NeedToTranspose> 00204 struct sparse_dense_outer_product_evaluator 00205 { 00206 protected: 00207 typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1; 00208 typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs; 00209 typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType; 00210 00211 // if the actual left-hand side is a dense vector, 00212 // then build a sparse-view so that we can seamlessly iterate over it. 00213 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, 00214 Lhs1, SparseView<Lhs1> >::type ActualLhs; 00215 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value, 00216 Lhs1 const&, SparseView<Lhs1> >::type LhsArg; 00217 00218 typedef evaluator<ActualLhs> LhsEval; 00219 typedef evaluator<ActualRhs> RhsEval; 00220 typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator; 00221 typedef typename ProdXprType::Scalar Scalar; 00222 00223 public: 00224 enum { 00225 Flags = NeedToTranspose ? RowMajorBit : 0, 00226 CoeffReadCost = HugeCost 00227 }; 00228 00229 class InnerIterator : public LhsIterator 00230 { 00231 public: 00232 InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer) 00233 : LhsIterator(xprEval.m_lhsXprImpl, 0), 00234 m_outer(outer), 00235 m_empty(false), 00236 m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind() )) 00237 {} 00238 00239 EIGEN_STRONG_INLINE Index outer() const { return m_outer; } 00240 EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); } 00241 EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; } 00242 00243 EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; } 00244 EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); } 00245 00246 protected: 00247 Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const 00248 { 00249 return rhs.coeff(outer); 00250 } 00251 00252 Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse()) 00253 { 00254 typename RhsEval::InnerIterator it(rhs, outer); 00255 if (it && it.index()==0 && it.value()!=Scalar(0)) 00256 return it.value(); 00257 m_empty = true; 00258 return Scalar(0); 00259 } 00260 00261 Index m_outer; 00262 bool m_empty; 00263 Scalar m_factor; 00264 }; 00265 00266 sparse_dense_outer_product_evaluator(const Lhs1 &lhs, const ActualRhs &rhs) 00267 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) 00268 { 00269 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 00270 } 00271 00272 // transpose case 00273 sparse_dense_outer_product_evaluator(const ActualRhs &rhs, const Lhs1 &lhs) 00274 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs) 00275 { 00276 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 00277 } 00278 00279 protected: 00280 const LhsArg m_lhs; 00281 evaluator<ActualLhs> m_lhsXprImpl; 00282 evaluator<ActualRhs> m_rhsXprImpl; 00283 }; 00284 00285 // sparse * dense outer product 00286 template<typename Lhs, typename Rhs> 00287 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape> 00288 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> 00289 { 00290 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base; 00291 00292 typedef Product<Lhs, Rhs> XprType; 00293 typedef typename XprType::PlainObject PlainObject; 00294 00295 explicit product_evaluator(const XprType& xpr) 00296 : Base(xpr.lhs(), xpr.rhs()) 00297 {} 00298 00299 }; 00300 00301 template<typename Lhs, typename Rhs> 00302 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape> 00303 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> 00304 { 00305 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base; 00306 00307 typedef Product<Lhs, Rhs> XprType; 00308 typedef typename XprType::PlainObject PlainObject; 00309 00310 explicit product_evaluator(const XprType& xpr) 00311 : Base(xpr.lhs(), xpr.rhs()) 00312 {} 00313 00314 }; 00315 00316 } // end namespace internal 00317 00318 } // end namespace Eigen 00319 00320 #endif // EIGEN_SPARSEDENSEPRODUCT_H