TensorContractionMapper.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
00011 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
00012 
00013 namespace Eigen {
00014 
00015 namespace internal {
00016 
00017 enum {
00018   Rhs = 0,
00019   Lhs = 1
00020 };
00021 
00022 /*
00023  * Implementation of the Eigen blas_data_mapper class for tensors.
00024  */
00025 
00026 template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
00027   enum {
00028     DirectOffsets = false
00029   };
00030 
00031   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { }
00032 
00033   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
00034     eigen_assert(false && "unsupported");
00035   }
00036 
00037   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
00038 
00039  template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
00040  typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
00041   {
00042     return m_tensor.template packet<LoadMode>(index);
00043   }
00044 
00045 
00046  private:
00047   const Tensor m_tensor;
00048 };
00049 
00050 template <typename Tensor> struct CoeffLoader<Tensor, true> {
00051   enum {
00052     DirectOffsets = true
00053   };
00054 
00055   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
00056 
00057   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
00058     m_data += offset;
00059   }
00060 
00061   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
00062 
00063  template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
00064  typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
00065   {
00066     return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
00067   }
00068  private:
00069   typedef typename Tensor::Scalar Scalar;
00070   const Scalar* m_data;
00071 };
00072 
00073 template<typename Scalar, typename Index, int side,
00074          typename Tensor,
00075          typename nocontract_t, typename contract_t,
00076          int packet_size, bool inner_dim_contiguous, int Alignment>
00077 class SimpleTensorContractionMapper {
00078   public:
00079   EIGEN_DEVICE_FUNC
00080   SimpleTensorContractionMapper(const Tensor& tensor,
00081                                 const nocontract_t& nocontract_strides,
00082                                 const nocontract_t& ij_strides,
00083                                 const contract_t& contract_strides,
00084                                 const contract_t& k_strides) :
00085       m_tensor(tensor),
00086       m_nocontract_strides(nocontract_strides),
00087       m_ij_strides(ij_strides),
00088       m_contract_strides(contract_strides),
00089       m_k_strides(k_strides) { }
00090 
00091   enum {
00092     DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets
00093   };
00094 
00095   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
00096     m_tensor.offsetBuffer(offset);
00097   }
00098 
00099   EIGEN_DEVICE_FUNC
00100   EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
00101 
00102   EIGEN_DEVICE_FUNC
00103   EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
00104     // column major assumption
00105     return operator()(row, 0);
00106   }
00107 
00108   EIGEN_DEVICE_FUNC
00109   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
00110     return m_tensor.coeff(computeIndex(row, col));
00111   }
00112 
00113   EIGEN_DEVICE_FUNC
00114   EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
00115     const bool left = (side == Lhs);
00116     Index nocontract_val = left ? row : col;
00117     Index linidx = 0;
00118     for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
00119       const Index idx = nocontract_val / m_ij_strides[i];
00120       linidx += idx * m_nocontract_strides[i];
00121       nocontract_val -= idx * m_ij_strides[i];
00122     }
00123     if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
00124       if (side == Lhs && inner_dim_contiguous) {
00125         eigen_assert(m_nocontract_strides[0] == 1);
00126         linidx += nocontract_val;
00127       } else {
00128         linidx += nocontract_val * m_nocontract_strides[0];
00129       }
00130     }
00131 
00132     Index contract_val = left ? col : row;
00133     if(array_size<contract_t>::value > 0) {
00134       for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
00135         const Index idx = contract_val / m_k_strides[i];
00136         linidx += idx * m_contract_strides[i];
00137         contract_val -= idx * m_k_strides[i];
00138       }
00139 
00140       if (side == Rhs && inner_dim_contiguous) {
00141         eigen_assert(m_contract_strides[0] == 1);
00142         linidx += contract_val;
00143       } else {
00144         linidx += contract_val * m_contract_strides[0];
00145       }
00146     }
00147 
00148     return linidx;
00149   }
00150 
00151   EIGEN_DEVICE_FUNC
00152   EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const {
00153     const bool left = (side == Lhs);
00154     Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
00155     Index linidx[2] = {0, 0};
00156     if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
00157       for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
00158         const Index idx0 = nocontract_val[0] / m_ij_strides[i];
00159         const Index idx1 = nocontract_val[1] / m_ij_strides[i];
00160         linidx[0] += idx0 * m_nocontract_strides[i];
00161         linidx[1] += idx1 * m_nocontract_strides[i];
00162         nocontract_val[0] -= idx0 * m_ij_strides[i];
00163         nocontract_val[1] -= idx1 * m_ij_strides[i];
00164       }
00165       if (side == Lhs && inner_dim_contiguous) {
00166         eigen_assert(m_nocontract_strides[0] == 1);
00167         linidx[0] += nocontract_val[0];
00168         linidx[1] += nocontract_val[1];
00169       } else {
00170         linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
00171         linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
00172       }
00173     }
00174 
00175     Index contract_val[2] = {left ? col : row, left ? col : row + distance};
00176     if (array_size<contract_t>::value> 0) {
00177       for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
00178         const Index idx0 = contract_val[0] / m_k_strides[i];
00179         const Index idx1 = contract_val[1] / m_k_strides[i];
00180         linidx[0] += idx0 * m_contract_strides[i];
00181         linidx[1] += idx1 * m_contract_strides[i];
00182         contract_val[0] -= idx0 * m_k_strides[i];
00183         contract_val[1] -= idx1 * m_k_strides[i];
00184       }
00185 
00186       if (side == Rhs && inner_dim_contiguous) {
00187         eigen_assert(m_contract_strides[0] == 1);
00188         linidx[0] += contract_val[0];
00189         linidx[1] += contract_val[1];
00190       } else {
00191         linidx[0] += contract_val[0] * m_contract_strides[0];
00192         linidx[1] += contract_val[1] * m_contract_strides[0];
00193       }
00194     }
00195     return IndexPair<Index>(linidx[0], linidx[1]);
00196   }
00197 
00198   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const {
00199     // Only claim alignment when we can compute the actual stride (ie when we're
00200     // dealing with the lhs with inner_dim_contiguous. This is because the
00201     // matrix-vector product relies on the stride when dealing with aligned inputs.
00202     return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
00203   }
00204   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
00205     return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
00206   }
00207 
00208  protected:
00209   CoeffLoader<Tensor, Tensor::RawAccess> m_tensor;
00210   const nocontract_t m_nocontract_strides;
00211   const nocontract_t m_ij_strides;
00212   const contract_t m_contract_strides;
00213   const contract_t m_k_strides;
00214 };
00215 
00216 
00217 template<typename Scalar, typename Index, int side,
00218          typename Tensor,
00219          typename nocontract_t, typename contract_t,
00220          int packet_size, bool inner_dim_contiguous,
00221          bool inner_dim_reordered, int Alignment>
00222 class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment>
00223 {
00224  public:
00225   typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper;
00226 
00227   EIGEN_DEVICE_FUNC
00228   BaseTensorContractionMapper(const Tensor& tensor,
00229                               const nocontract_t& nocontract_strides,
00230                               const nocontract_t& ij_strides,
00231                               const contract_t& contract_strides,
00232                               const contract_t& k_strides) :
00233   ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
00234 
00235   typedef typename Tensor::PacketReturnType Packet;
00236   typedef typename unpacket_traits<Packet>::half HalfPacket;
00237 
00238   template <int AlignmentType>
00239   EIGEN_DEVICE_FUNC
00240   EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
00241     // whole method makes column major assumption
00242 
00243     // don't need to add offsets for now (because operator handles that)
00244     // current code assumes packet size must be a multiple of 2
00245     EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
00246 
00247     if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
00248       const Index index = this->computeIndex(i, j);
00249       eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
00250       return this->m_tensor.template packet<AlignmentType>(index);
00251     }
00252 
00253     const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
00254     const Index first = indexPair.first;
00255     const Index last = indexPair.second;
00256 
00257     // We can always do optimized packet reads from left hand side right now, because
00258     // the vertical matrix dimension on the left hand side is never contracting.
00259     // On the right hand side we need to check if the contracting dimensions may have
00260     // been shuffled first.
00261     if (Tensor::PacketAccess &&
00262         (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
00263         (last - first) == (packet_size - 1)) {
00264 
00265       return this->m_tensor.template packet<AlignmentType>(first);
00266     }
00267 
00268     EIGEN_ALIGN_MAX Scalar data[packet_size];
00269 
00270     data[0] = this->m_tensor.coeff(first);
00271     for (Index k = 1; k < packet_size - 1; k += 2) {
00272       const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
00273       data[k] = this->m_tensor.coeff(internal_pair.first);
00274       data[k + 1] = this->m_tensor.coeff(internal_pair.second);
00275     }
00276     data[packet_size - 1] = this->m_tensor.coeff(last);
00277 
00278     return pload<Packet>(data);
00279   }
00280 
00281   template <int AlignmentType>
00282   EIGEN_DEVICE_FUNC
00283   EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
00284     // whole method makes column major assumption
00285 
00286     // don't need to add offsets for now (because operator handles that)
00287     const Index half_packet_size = unpacket_traits<HalfPacket>::size;
00288     if (half_packet_size == packet_size) {
00289       return loadPacket<AlignmentType>(i, j);
00290     }
00291     EIGEN_ALIGN_MAX Scalar data[half_packet_size];
00292     for (Index k = 0; k < half_packet_size; k++) {
00293       data[k] = operator()(i + k, j);
00294     }
00295     return pload<HalfPacket>(data);
00296   }
00297 };
00298 
00299 
00300 template<typename Scalar, typename Index, int side,
00301          typename Tensor,
00302          typename nocontract_t, typename contract_t,
00303          bool inner_dim_contiguous,
00304          bool inner_dim_reordered, int Alignment>
00305 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment>
00306 {
00307  public:
00308   typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper;
00309 
00310   EIGEN_DEVICE_FUNC
00311   BaseTensorContractionMapper(const Tensor& tensor,
00312                               const nocontract_t& nocontract_strides,
00313                               const nocontract_t& ij_strides,
00314                               const contract_t& contract_strides,
00315                               const contract_t& k_strides) :
00316   ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
00317 
00318   typedef typename Tensor::PacketReturnType Packet;
00319   template <int> EIGEN_DEVICE_FUNC
00320   EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
00321     EIGEN_ALIGN_MAX Scalar data[1];
00322     data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
00323     return pload<typename Tensor::PacketReturnType>(data);
00324   }
00325   template <int> EIGEN_DEVICE_FUNC
00326   EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
00327     return loadPacket(i, j);
00328   }
00329 };
00330 
00331 
00332 template<typename Scalar, typename Index, int side,
00333          typename Tensor,
00334          typename nocontract_t, typename contract_t,
00335          int packet_size,
00336          bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
00337 class TensorContractionSubMapper {
00338  public:
00339   typedef typename Tensor::PacketReturnType Packet;
00340   typedef typename unpacket_traits<Packet>::half HalfPacket;
00341 
00342   typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
00343   typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
00344   typedef Self LinearMapper;
00345 
00346   enum {
00347     // We can use direct offsets iff the parent mapper supports then and we can compute the strides.
00348     // TODO: we should also enable direct offsets for the Rhs case.
00349     UseDirectOffsets = ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
00350   };
00351 
00352   EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
00353       : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
00354     // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
00355     // this offset every time we attempt to access a coefficient.
00356     if (UseDirectOffsets) {
00357       Index stride = m_base_mapper.stride();
00358       m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
00359     }
00360   }
00361 
00362   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
00363     if (UseDirectOffsets) {
00364       return m_base_mapper(i, 0);
00365     }
00366     return m_base_mapper(i + m_vert_offset, m_horiz_offset);
00367   }
00368   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
00369     if (UseDirectOffsets) {
00370       return m_base_mapper(i, j);
00371     }
00372     return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
00373   }
00374 
00375   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
00376     if (UseDirectOffsets) {
00377       return m_base_mapper.template loadPacket<Alignment>(i, 0);
00378     }
00379     return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
00380   }
00381   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
00382     if (UseDirectOffsets) {
00383       return m_base_mapper.template loadPacket<Alignment>(i, j);
00384     }
00385     return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
00386   }
00387 
00388   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
00389     if (UseDirectOffsets) {
00390       return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);
00391     }
00392     return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
00393   }
00394 
00395   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
00396     if (UseDirectOffsets) {
00397       m_base_mapper.storePacket(i, 0, p);
00398     }
00399     m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
00400   }
00401 
00402   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
00403     if (UseDirectOffsets) {
00404       return LinearMapper(m_base_mapper, i, j);
00405     }
00406     return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
00407   }
00408 
00409   template <typename PacketT, int AlignmentType>
00410   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
00411     EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
00412     const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
00413     if (UseDirectOffsets) {
00414      return m_base_mapper.template loadPacket<ActualAlignment>(i, 0);
00415     }
00416     return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset);
00417   }
00418 
00419   template <typename Packet>
00420   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
00421     return false;
00422   }
00423 
00424  private:
00425   ParentMapper m_base_mapper;
00426   const Index m_vert_offset;
00427   const Index m_horiz_offset;
00428 };
00429 
00430 
00431 template<typename Scalar_, typename Index, int side,
00432          typename Tensor,
00433          typename nocontract_t, typename contract_t,
00434          int packet_size,
00435          bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
00436 class TensorContractionInputMapper
00437   : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
00438 
00439  public:
00440   typedef Scalar_ Scalar;
00441   typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
00442   typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
00443   typedef SubMapper VectorMapper;
00444 
00445   EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
00446                                const nocontract_t& nocontract_strides,
00447                                const nocontract_t& ij_strides,
00448                                const contract_t& contract_strides,
00449                                const contract_t& k_strides)
00450       : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
00451 
00452   EIGEN_DEVICE_FUNC
00453   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
00454     return SubMapper(*this, i, j);
00455   }
00456 
00457   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
00458     return VectorMapper(*this, i, j);
00459   }
00460 };
00461 
00462 
00463 
00464 }  // end namespace internal
00465 }  // end namespace Eigen
00466 
00467 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
 All Classes Functions Variables Typedefs Enumerator