![]() |
Eigen-unsupported
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 00011 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 00012 00013 namespace Eigen { 00014 00022 namespace internal { 00023 00024 template<typename Dimensions, typename LhsXprType, typename RhsXprType> 00025 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > 00026 { 00027 // Type promotion to handle the case where the types of the lhs and the rhs are different. 00028 typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type, 00029 typename remove_const<typename RhsXprType::Scalar>::type>::ResScalar Scalar; 00030 00031 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 00032 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 00033 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 00034 typename traits<RhsXprType>::Index>::type Index; 00035 typedef typename LhsXprType::Nested LhsNested; 00036 typedef typename RhsXprType::Nested RhsNested; 00037 typedef typename remove_reference<LhsNested>::type _LhsNested; 00038 typedef typename remove_reference<RhsNested>::type _RhsNested; 00039 00040 // From NumDims below. 00041 static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value; 00042 static const int Layout = traits<LhsXprType>::Layout; 00043 00044 enum { 00045 Flags = 0 00046 }; 00047 }; 00048 00049 template<typename Dimensions, typename LhsXprType, typename RhsXprType> 00050 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense> 00051 { 00052 typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type; 00053 }; 00054 00055 template<typename Dimensions, typename LhsXprType, typename RhsXprType> 00056 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type> 00057 { 00058 typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type; 00059 }; 00060 00061 template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_> 00062 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > { 00063 typedef Indices_ Indices; 00064 typedef LeftArgType_ LeftArgType; 00065 typedef RightArgType_ RightArgType; 00066 typedef Device_ Device; 00067 00068 // From NumDims below. 00069 static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value; 00070 }; 00071 00072 } // end namespace internal 00073 00074 template<typename Indices, typename LhsXprType, typename RhsXprType> 00075 class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors> 00076 { 00077 public: 00078 typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; 00079 typedef typename internal::gebp_traits<typename LhsXprType::CoeffReturnType, 00080 typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType; 00081 typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested; 00082 typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind; 00083 typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; 00084 00085 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( 00086 const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) 00087 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} 00088 00089 EIGEN_DEVICE_FUNC 00090 const Indices& indices() const { return m_indices; } 00091 00093 EIGEN_DEVICE_FUNC 00094 const typename internal::remove_all<typename LhsXprType::Nested>::type& 00095 lhsExpression() const { return m_lhs_xpr; } 00096 00097 EIGEN_DEVICE_FUNC 00098 const typename internal::remove_all<typename RhsXprType::Nested>::type& 00099 rhsExpression() const { return m_rhs_xpr; } 00100 00101 protected: 00102 typename LhsXprType::Nested m_lhs_xpr; 00103 typename RhsXprType::Nested m_rhs_xpr; 00104 const Indices m_indices; 00105 }; 00106 00107 00108 template<typename Derived> 00109 struct TensorContractionEvaluatorBase 00110 { 00111 typedef typename internal::traits<Derived>::Indices Indices; 00112 typedef typename internal::traits<Derived>::LeftArgType LeftArgType; 00113 typedef typename internal::traits<Derived>::RightArgType RightArgType; 00114 typedef typename internal::traits<Derived>::Device Device; 00115 00116 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; 00117 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 00118 typedef typename XprType::Index Index; 00119 typedef typename XprType::CoeffReturnType CoeffReturnType; 00120 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 00121 00122 enum { 00123 IsAligned = true, 00124 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1), 00125 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 00126 CoordAccess = false, // to be implemented 00127 RawAccess = true 00128 }; 00129 00130 // Most of the code is assuming that both input tensors are ColMajor. If the 00131 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 00132 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 00133 // will pretend B is LHS and A is RHS. 00134 typedef typename internal::conditional< 00135 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 00136 typedef typename internal::conditional< 00137 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 00138 00139 static const int LDims = 00140 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 00141 static const int RDims = 00142 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 00143 static const int ContractDims = internal::array_size<Indices>::value; 00144 static const int NumDims = LDims + RDims - 2 * ContractDims; 00145 00146 typedef array<Index, ContractDims> contract_t; 00147 typedef array<Index, LDims - ContractDims> left_nocontract_t; 00148 typedef array<Index, RDims - ContractDims> right_nocontract_t; 00149 00150 typedef DSizes<Index, NumDims> Dimensions; 00151 00152 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 00153 TensorContractionEvaluatorBase(const XprType& op, const Device& device) 00154 : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), 00155 op.lhsExpression(), op.rhsExpression()), device), 00156 m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), 00157 op.rhsExpression(), op.lhsExpression()), device), 00158 m_device(device), 00159 m_result(NULL) { 00160 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == 00161 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)), 00162 YOU_MADE_A_PROGRAMMING_MISTAKE); 00163 00164 00165 DSizes<Index, LDims> eval_left_dims; 00166 DSizes<Index, RDims> eval_right_dims; 00167 array<IndexPair<Index>, ContractDims> eval_op_indices; 00168 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 00169 // For ColMajor, we keep using the existing dimensions 00170 for (int i = 0; i < LDims; i++) { 00171 eval_left_dims[i] = m_leftImpl.dimensions()[i]; 00172 } 00173 for (int i = 0; i < RDims; i++) { 00174 eval_right_dims[i] = m_rightImpl.dimensions()[i]; 00175 } 00176 // We keep the pairs of contracting indices. 00177 for (int i = 0; i < ContractDims; i++) { 00178 eval_op_indices[i].first = op.indices()[i].first; 00179 eval_op_indices[i].second = op.indices()[i].second; 00180 } 00181 } else { 00182 // For RowMajor, we need to reverse the existing dimensions 00183 for (int i = 0; i < LDims; i++) { 00184 eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1]; 00185 } 00186 for (int i = 0; i < RDims; i++) { 00187 eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1]; 00188 } 00189 // We need to flip all the pairs of contracting indices as well as 00190 // reversing the dimensions. 00191 for (int i = 0; i < ContractDims; i++) { 00192 eval_op_indices[i].first = LDims - 1 - op.indices()[ContractDims - 1 - i].second; 00193 eval_op_indices[i].second = RDims - 1 - op.indices()[ContractDims - 1 - i].first; 00194 } 00195 } 00196 00197 // Check for duplicate axes and make sure the first index in eval_op_indices 00198 // is increasing. Using O(n^2) sorting is OK since ContractDims is small 00199 for (int i = 0; i < ContractDims; i++) { 00200 for (int j = i + 1; j < ContractDims; j++) { 00201 eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first && 00202 eval_op_indices[j].second != eval_op_indices[i].second && 00203 "contraction axes should be unique"); 00204 if (eval_op_indices[j].first < eval_op_indices[i].first) { 00205 numext::swap(eval_op_indices[j], eval_op_indices[i]); 00206 } 00207 } 00208 } 00209 00210 array<Index, LDims> lhs_strides; 00211 lhs_strides[0] = 1; 00212 for (int i = 0; i < LDims-1; ++i) { 00213 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i]; 00214 } 00215 00216 array<Index, RDims> rhs_strides; 00217 rhs_strides[0] = 1; 00218 for (int i = 0; i < RDims-1; ++i) { 00219 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i]; 00220 } 00221 00222 if (m_i_strides.size() > 0) m_i_strides[0] = 1; 00223 if (m_j_strides.size() > 0) m_j_strides[0] = 1; 00224 if (m_k_strides.size() > 0) m_k_strides[0] = 1; 00225 00226 m_i_size = 1; 00227 m_j_size = 1; 00228 m_k_size = 1; 00229 00230 // To compute the dimension, we simply concatenate the non-contracting 00231 // dimensions of the left and then the right tensor. Additionally, we also 00232 // compute the strides corresponding to the left non-contracting 00233 // dimensions and right non-contracting dimensions. 00234 m_lhs_inner_dim_contiguous = true; 00235 int dim_idx = 0; 00236 unsigned int nocontract_idx = 0; 00237 00238 for (int i = 0; i < LDims; i++) { 00239 // find if we are contracting on index i of left tensor 00240 bool contracting = false; 00241 for (int j = 0; j < ContractDims; j++) { 00242 if (eval_op_indices[j].first == i) { 00243 contracting = true; 00244 break; 00245 } 00246 } 00247 if (!contracting) { 00248 // add dimension size to output dimensions 00249 m_dimensions[dim_idx] = eval_left_dims[i]; 00250 m_left_nocontract_strides[nocontract_idx] = lhs_strides[i]; 00251 if (dim_idx != i) { 00252 m_lhs_inner_dim_contiguous = false; 00253 } 00254 if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) { 00255 m_i_strides[nocontract_idx+1] = 00256 m_i_strides[nocontract_idx] * eval_left_dims[i]; 00257 } else { 00258 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i]; 00259 } 00260 dim_idx++; 00261 nocontract_idx++; 00262 } 00263 } 00264 00265 nocontract_idx = 0; 00266 for (int i = 0; i < RDims; i++) { 00267 bool contracting = false; 00268 // find if we are contracting on index i of right tensor 00269 for (int j = 0; j < ContractDims; j++) { 00270 if (eval_op_indices[j].second == i) { 00271 contracting = true; 00272 break; 00273 } 00274 } 00275 if (!contracting) { 00276 m_dimensions[dim_idx] = eval_right_dims[i]; 00277 if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) { 00278 m_j_strides[nocontract_idx+1] = 00279 m_j_strides[nocontract_idx] * eval_right_dims[i]; 00280 } else { 00281 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i]; 00282 } 00283 m_right_nocontract_strides[nocontract_idx] = rhs_strides[i]; 00284 dim_idx++; 00285 nocontract_idx++; 00286 } 00287 } 00288 00289 // Now compute the strides corresponding to the contracting dimensions. We 00290 // assumed above that non-contracting axes are represented in the same order 00291 // in the matrix as they are in the tensor. This is not the case for 00292 // contracting axes. As the contracting axes must be of the same size in 00293 // each tensor, we'll only look at the first tensor here. 00294 m_rhs_inner_dim_contiguous = true; 00295 m_rhs_inner_dim_reordered = false; 00296 for (int i = 0; i < ContractDims; i++) { 00297 Index left = eval_op_indices[i].first; 00298 Index right = eval_op_indices[i].second; 00299 00300 Index size = eval_left_dims[left]; 00301 eigen_assert(size == eval_right_dims[right] && 00302 "Contraction axes must be same size"); 00303 00304 if (i+1 < static_cast<int>(internal::array_size<contract_t>::value)) { 00305 m_k_strides[i+1] = m_k_strides[i] * size; 00306 } else { 00307 m_k_size = m_k_strides[i] * size; 00308 } 00309 m_left_contracting_strides[i] = lhs_strides[left]; 00310 m_right_contracting_strides[i] = rhs_strides[right]; 00311 00312 if (i > 0 && right < eval_op_indices[i-1].second) { 00313 m_rhs_inner_dim_reordered = true; 00314 } 00315 if (right != i) { 00316 m_rhs_inner_dim_contiguous = false; 00317 } 00318 } 00319 00320 // If the layout is RowMajor, we need to reverse the m_dimensions 00321 if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) { 00322 for (int i = 0, j = NumDims - 1; i < j; i++, j--) { 00323 numext::swap(m_dimensions[i], m_dimensions[j]); 00324 } 00325 } 00326 } 00327 00328 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 00329 00330 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { 00331 m_leftImpl.evalSubExprsIfNeeded(NULL); 00332 m_rightImpl.evalSubExprsIfNeeded(NULL); 00333 if (data) { 00334 evalTo(data); 00335 return false; 00336 } else { 00337 m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 00338 evalTo(m_result); 00339 return true; 00340 } 00341 } 00342 00343 EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { 00344 if (this->m_lhs_inner_dim_contiguous) { 00345 if (this->m_rhs_inner_dim_contiguous) { 00346 if (this->m_rhs_inner_dim_reordered) { 00347 static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer); 00348 } 00349 else { 00350 static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer); 00351 } 00352 } 00353 else { 00354 if (this->m_rhs_inner_dim_reordered) { 00355 static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer); 00356 } 00357 else { 00358 static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer); 00359 } 00360 } 00361 } 00362 else { 00363 if (this->m_rhs_inner_dim_contiguous) { 00364 if (this->m_rhs_inner_dim_reordered) { 00365 static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer); 00366 } 00367 else { 00368 static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer); 00369 } 00370 } 00371 else { 00372 if (this->m_rhs_inner_dim_reordered) { 00373 static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer); 00374 } 00375 else { 00376 static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer); 00377 } 00378 } 00379 } 00380 } 00381 00382 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 00383 EIGEN_DEVICE_FUNC void evalGemv(Scalar* buffer) const { 00384 const Index rows = m_i_size; 00385 const Index cols = m_k_size; 00386 00387 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 00388 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 00389 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 00390 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 00391 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size; 00392 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size; 00393 const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned; 00394 const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned; 00395 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 00396 LeftEvaluator, left_nocontract_t, 00397 contract_t, lhs_packet_size, 00398 lhs_inner_dim_contiguous, 00399 false, lhs_alignment> LhsMapper; 00400 00401 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 00402 RightEvaluator, right_nocontract_t, 00403 contract_t, rhs_packet_size, 00404 rhs_inner_dim_contiguous, 00405 rhs_inner_dim_reordered, rhs_alignment> RhsMapper; 00406 00407 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides, 00408 m_left_contracting_strides, m_k_strides); 00409 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides, 00410 m_right_contracting_strides, m_k_strides); 00411 00412 const Scalar alpha(1); 00413 const Index resIncr(1); 00414 00415 // zero out the result buffer (which must be of size at least rows * sizeof(Scalar) 00416 m_device.memset(buffer, 0, rows * sizeof(Scalar)); 00417 00418 internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run( 00419 rows, cols, lhs, rhs, 00420 buffer, resIncr, alpha); 00421 } 00422 00423 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 00424 EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const { 00425 // columns in left side, rows in right side 00426 const Index k = this->m_k_size; 00427 00428 // rows in left side 00429 const Index m = this->m_i_size; 00430 00431 // columns in right side 00432 const Index n = this->m_j_size; 00433 00434 // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) 00435 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); 00436 00437 // define mr, nr, and all of my data mapper types 00438 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 00439 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 00440 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; 00441 00442 const Index nr = Traits::nr; 00443 const Index mr = Traits::mr; 00444 00445 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 00446 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 00447 00448 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size; 00449 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size; 00450 00451 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 00452 LeftEvaluator, left_nocontract_t, 00453 contract_t, lhs_packet_size, 00454 lhs_inner_dim_contiguous, 00455 false, Unaligned> LhsMapper; 00456 00457 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 00458 RightEvaluator, right_nocontract_t, 00459 contract_t, rhs_packet_size, 00460 rhs_inner_dim_contiguous, 00461 rhs_inner_dim_reordered, Unaligned> RhsMapper; 00462 00463 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; 00464 00465 // Declare GEBP packing and kernel structs 00466 internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs; 00467 internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs; 00468 00469 internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp; 00470 00471 // initialize data mappers 00472 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, 00473 this->m_left_contracting_strides, this->m_k_strides); 00474 00475 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, 00476 this->m_right_contracting_strides, this->m_k_strides); 00477 00478 OutputMapper output(buffer, m); 00479 00480 // Sizes of the blocks to load in cache. See the Goto paper for details. 00481 internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1); 00482 const Index kc = blocking.kc(); 00483 const Index mc = numext::mini(m, blocking.mc()); 00484 const Index nc = numext::mini(n, blocking.nc()); 00485 const Index sizeA = mc * kc; 00486 const Index sizeB = kc * nc; 00487 00488 LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar))); 00489 RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar))); 00490 00491 for(Index i2=0; i2<m; i2+=mc) 00492 { 00493 const Index actual_mc = numext::mini(i2+mc,m)-i2; 00494 for (Index k2 = 0; k2 < k; k2 += kc) { 00495 // make sure we don't overshoot right edge of left matrix, then pack vertical panel 00496 const Index actual_kc = numext::mini(k2 + kc, k) - k2; 00497 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0); 00498 00499 // series of horizontal blocks 00500 for (Index j2 = 0; j2 < n; j2 += nc) { 00501 // make sure we don't overshoot right edge of right matrix, then pack block 00502 const Index actual_nc = numext::mini(j2 + nc, n) - j2; 00503 pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0); 00504 00505 // call gebp (matrix kernel) 00506 // The parameters here are copied from Eigen's GEMM implementation 00507 gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0); 00508 } 00509 } 00510 } 00511 00512 this->m_device.deallocate(blockA); 00513 this->m_device.deallocate(blockB); 00514 } 00515 00516 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 00517 m_leftImpl.cleanup(); 00518 m_rightImpl.cleanup(); 00519 00520 if (m_result != NULL) { 00521 m_device.deallocate(m_result); 00522 m_result = NULL; 00523 } 00524 } 00525 00526 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 00527 return m_result[index]; 00528 } 00529 00530 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const { 00531 return TensorOpCost(sizeof(CoeffReturnType), 0, 0); 00532 } 00533 00534 template<int LoadMode> 00535 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const { 00536 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 00537 } 00538 00539 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() const { return m_result; } 00540 00541 protected: 00542 // Prevent assignment 00543 TensorContractionEvaluatorBase& operator = (const TensorContractionEvaluatorBase&); 00544 Dimensions m_dimensions; 00545 00546 contract_t m_k_strides; 00547 contract_t m_left_contracting_strides; 00548 contract_t m_right_contracting_strides; 00549 00550 bool m_lhs_inner_dim_contiguous; 00551 bool m_rhs_inner_dim_contiguous; 00552 bool m_rhs_inner_dim_reordered; 00553 00554 left_nocontract_t m_i_strides; 00555 right_nocontract_t m_j_strides; 00556 left_nocontract_t m_left_nocontract_strides; 00557 right_nocontract_t m_right_nocontract_strides; 00558 00559 Index m_i_size; 00560 Index m_j_size; 00561 Index m_k_size; 00562 00563 TensorEvaluator<EvalLeftArgType, Device> m_leftImpl; 00564 TensorEvaluator<EvalRightArgType, Device> m_rightImpl; 00565 const Device& m_device; 00566 Scalar* m_result; 00567 }; 00568 00569 00570 // evaluator for default device 00571 template<typename Indices, typename LeftArgType, typename RightArgType, typename Device> 00572 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> : 00573 public TensorContractionEvaluatorBase< 00574 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > { 00575 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; 00576 typedef TensorContractionEvaluatorBase<Self> Base; 00577 00578 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; 00579 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 00580 typedef typename XprType::Index Index; 00581 typedef typename XprType::CoeffReturnType CoeffReturnType; 00582 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 00583 00584 enum { 00585 Layout = TensorEvaluator<LeftArgType, Device>::Layout 00586 }; 00587 00588 // Most of the code is assuming that both input tensors are ColMajor. If the 00589 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 00590 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 00591 // will pretend B is LHS and A is RHS. 00592 typedef typename internal::conditional< 00593 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 00594 typedef typename internal::conditional< 00595 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 00596 00597 static const int LDims = 00598 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 00599 static const int RDims = 00600 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 00601 static const int ContractDims = internal::array_size<Indices>::value; 00602 00603 typedef array<Index, ContractDims> contract_t; 00604 typedef array<Index, LDims - ContractDims> left_nocontract_t; 00605 typedef array<Index, RDims - ContractDims> right_nocontract_t; 00606 00607 static const int NumDims = LDims + RDims - 2 * ContractDims; 00608 00609 // Could we use NumDimensions here? 00610 typedef DSizes<Index, NumDims> Dimensions; 00611 00612 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : 00613 Base(op, device) { } 00614 00615 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 00616 EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const { 00617 if (this->m_j_size == 1) { 00618 this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); 00619 return; 00620 } 00621 00622 this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); 00623 } 00624 }; 00625 00626 } // end namespace Eigen 00627 00628 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H