![]() |
Eigen
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 00011 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H 00012 00013 namespace Eigen { 00014 00015 namespace internal { 00016 00017 00018 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major 00019 template<typename Lhs, typename Rhs, typename ResultType> 00020 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance) 00021 { 00022 // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); 00023 00024 typedef typename remove_all<Lhs>::type::Scalar Scalar; 00025 typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex; 00026 00027 // make sure to call innerSize/outerSize since we fake the storage order. 00028 Index rows = lhs.innerSize(); 00029 Index cols = rhs.outerSize(); 00030 //Index size = lhs.outerSize(); 00031 eigen_assert(lhs.outerSize() == rhs.innerSize()); 00032 00033 // allocate a temporary buffer 00034 AmbiVector<Scalar,StorageIndex> tempVector(rows); 00035 00036 // mimics a resizeByInnerOuter: 00037 if(ResultType::IsRowMajor) 00038 res.resize(cols, rows); 00039 else 00040 res.resize(rows, cols); 00041 00042 evaluator<Lhs> lhsEval(lhs); 00043 evaluator<Rhs> rhsEval(rhs); 00044 00045 // estimate the number of non zero entries 00046 // given a rhs column containing Y non zeros, we assume that the respective Y columns 00047 // of the lhs differs in average of one non zeros, thus the number of non zeros for 00048 // the product of a rhs column with the lhs is X+Y where X is the average number of non zero 00049 // per column of the lhs. 00050 // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) 00051 Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate(); 00052 00053 res.reserve(estimated_nnz_prod); 00054 double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols())); 00055 for (Index j=0; j<cols; ++j) 00056 { 00057 // FIXME: 00058 //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); 00059 // let's do a more accurate determination of the nnz ratio for the current column j of res 00060 tempVector.init(ratioColRes); 00061 tempVector.setZero(); 00062 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) 00063 { 00064 // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) 00065 tempVector.restart(); 00066 Scalar x = rhsIt.value(); 00067 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) 00068 { 00069 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; 00070 } 00071 } 00072 res.startVec(j); 00073 for (typename AmbiVector<Scalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it) 00074 res.insertBackByOuterInner(j,it.index()) = it.value(); 00075 } 00076 res.finalize(); 00077 } 00078 00079 template<typename Lhs, typename Rhs, typename ResultType, 00080 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, 00081 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, 00082 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> 00083 struct sparse_sparse_product_with_pruning_selector; 00084 00085 template<typename Lhs, typename Rhs, typename ResultType> 00086 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> 00087 { 00088 typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar; 00089 typedef typename ResultType::RealScalar RealScalar; 00090 00091 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00092 { 00093 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 00094 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); 00095 res.swap(_res); 00096 } 00097 }; 00098 00099 template<typename Lhs, typename Rhs, typename ResultType> 00100 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> 00101 { 00102 typedef typename ResultType::RealScalar RealScalar; 00103 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00104 { 00105 // we need a col-major matrix to hold the result 00106 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType; 00107 SparseTemporaryType _res(res.rows(), res.cols()); 00108 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); 00109 res = _res; 00110 } 00111 }; 00112 00113 template<typename Lhs, typename Rhs, typename ResultType> 00114 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> 00115 { 00116 typedef typename ResultType::RealScalar RealScalar; 00117 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00118 { 00119 // let's transpose the product to get a column x column product 00120 typename remove_all<ResultType>::type _res(res.rows(), res.cols()); 00121 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); 00122 res.swap(_res); 00123 } 00124 }; 00125 00126 template<typename Lhs, typename Rhs, typename ResultType> 00127 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> 00128 { 00129 typedef typename ResultType::RealScalar RealScalar; 00130 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00131 { 00132 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; 00133 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; 00134 ColMajorMatrixLhs colLhs(lhs); 00135 ColMajorMatrixRhs colRhs(rhs); 00136 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance); 00137 00138 // let's transpose the product to get a column x column product 00139 // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; 00140 // SparseTemporaryType _res(res.cols(), res.rows()); 00141 // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); 00142 // res = _res.transpose(); 00143 } 00144 }; 00145 00146 template<typename Lhs, typename Rhs, typename ResultType> 00147 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor> 00148 { 00149 typedef typename ResultType::RealScalar RealScalar; 00150 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00151 { 00152 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs; 00153 RowMajorMatrixLhs rowLhs(lhs); 00154 sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance); 00155 } 00156 }; 00157 00158 template<typename Lhs, typename Rhs, typename ResultType> 00159 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor> 00160 { 00161 typedef typename ResultType::RealScalar RealScalar; 00162 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00163 { 00164 typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs; 00165 RowMajorMatrixRhs rowRhs(rhs); 00166 sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance); 00167 } 00168 }; 00169 00170 template<typename Lhs, typename Rhs, typename ResultType> 00171 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor> 00172 { 00173 typedef typename ResultType::RealScalar RealScalar; 00174 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00175 { 00176 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs; 00177 ColMajorMatrixRhs colRhs(rhs); 00178 internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance); 00179 } 00180 }; 00181 00182 template<typename Lhs, typename Rhs, typename ResultType> 00183 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor> 00184 { 00185 typedef typename ResultType::RealScalar RealScalar; 00186 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) 00187 { 00188 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs; 00189 ColMajorMatrixLhs colLhs(lhs); 00190 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance); 00191 } 00192 }; 00193 00194 } // end namespace internal 00195 00196 } // end namespace Eigen 00197 00198 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H