Eigen  3.3.3
SparseSparseProductWithPruning.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00011 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
00012 
00013 namespace Eigen { 
00014 
00015 namespace internal {
00016 
00017 
00018 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
00019 template<typename Lhs, typename Rhs, typename ResultType>
00020 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
00021 {
00022   // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
00023 
00024   typedef typename remove_all<Lhs>::type::Scalar Scalar;
00025   typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex;
00026 
00027   // make sure to call innerSize/outerSize since we fake the storage order.
00028   Index rows = lhs.innerSize();
00029   Index cols = rhs.outerSize();
00030   //Index size = lhs.outerSize();
00031   eigen_assert(lhs.outerSize() == rhs.innerSize());
00032 
00033   // allocate a temporary buffer
00034   AmbiVector<Scalar,StorageIndex> tempVector(rows);
00035 
00036   // mimics a resizeByInnerOuter:
00037   if(ResultType::IsRowMajor)
00038     res.resize(cols, rows);
00039   else
00040     res.resize(rows, cols);
00041   
00042   evaluator<Lhs> lhsEval(lhs);
00043   evaluator<Rhs> rhsEval(rhs);
00044   
00045   // estimate the number of non zero entries
00046   // given a rhs column containing Y non zeros, we assume that the respective Y columns
00047   // of the lhs differs in average of one non zeros, thus the number of non zeros for
00048   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
00049   // per column of the lhs.
00050   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
00051   Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
00052 
00053   res.reserve(estimated_nnz_prod);
00054   double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
00055   for (Index j=0; j<cols; ++j)
00056   {
00057     // FIXME:
00058     //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
00059     // let's do a more accurate determination of the nnz ratio for the current column j of res
00060     tempVector.init(ratioColRes);
00061     tempVector.setZero();
00062     for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
00063     {
00064       // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
00065       tempVector.restart();
00066       Scalar x = rhsIt.value();
00067       for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
00068       {
00069         tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
00070       }
00071     }
00072     res.startVec(j);
00073     for (typename AmbiVector<Scalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
00074       res.insertBackByOuterInner(j,it.index()) = it.value();
00075   }
00076   res.finalize();
00077 }
00078 
00079 template<typename Lhs, typename Rhs, typename ResultType,
00080   int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
00081   int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
00082   int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
00083 struct sparse_sparse_product_with_pruning_selector;
00084 
00085 template<typename Lhs, typename Rhs, typename ResultType>
00086 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
00087 {
00088   typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
00089   typedef typename ResultType::RealScalar RealScalar;
00090 
00091   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
00092   {
00093     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00094     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
00095     res.swap(_res);
00096   }
00097 };
00098 
00099 template<typename Lhs, typename Rhs, typename ResultType>
00100 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
00101 {
00102   typedef typename ResultType::RealScalar RealScalar;
00103   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
00104   {
00105     // we need a col-major matrix to hold the result
00106     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType;
00107     SparseTemporaryType _res(res.rows(), res.cols());
00108     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
00109     res = _res;
00110   }
00111 };
00112 
00113 template<typename Lhs, typename Rhs, typename ResultType>
00114 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
00115 {
00116   typedef typename ResultType::RealScalar RealScalar;
00117   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
00118   {
00119     // let's transpose the product to get a column x column product
00120     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
00121     internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
00122     res.swap(_res);
00123   }
00124 };
00125 
00126 template<typename Lhs, typename Rhs, typename ResultType>
00127 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
00128 {
00129   typedef typename ResultType::RealScalar RealScalar;
00130   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
00131   {
00132     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
00133     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
00134     ColMajorMatrixLhs colLhs(lhs);
00135     ColMajorMatrixRhs colRhs(rhs);
00136     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
00137 
00138     // let's transpose the product to get a column x column product
00139 //     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
00140 //     SparseTemporaryType _res(res.cols(), res.rows());
00141 //     sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
00142 //     res = _res.transpose();
00143   }
00144 };
00145 
00146 template<typename Lhs, typename Rhs, typename ResultType>
00147 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
00148 {
00149   typedef typename ResultType::RealScalar RealScalar;
00150   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
00151   {
00152     typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
00153     RowMajorMatrixLhs rowLhs(lhs);
00154     sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
00155   }
00156 };
00157 
00158 template<typename Lhs, typename Rhs, typename ResultType>
00159 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
00160 {
00161   typedef typename ResultType::RealScalar RealScalar;
00162   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
00163   {
00164     typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
00165     RowMajorMatrixRhs rowRhs(rhs);
00166     sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
00167   }
00168 };
00169 
00170 template<typename Lhs, typename Rhs, typename ResultType>
00171 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
00172 {
00173   typedef typename ResultType::RealScalar RealScalar;
00174   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
00175   {
00176     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
00177     ColMajorMatrixRhs colRhs(rhs);
00178     internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
00179   }
00180 };
00181 
00182 template<typename Lhs, typename Rhs, typename ResultType>
00183 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
00184 {
00185   typedef typename ResultType::RealScalar RealScalar;
00186   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
00187   {
00188     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
00189     ColMajorMatrixLhs colLhs(lhs);
00190     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
00191   }
00192 };
00193 
00194 } // end namespace internal
00195 
00196 } // end namespace Eigen
00197 
00198 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends