TensorEvaluator.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
00011 #define EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
00012 
00013 namespace Eigen {
00014 
00026 // Generic evaluator
00027 template<typename Derived, typename Device>
00028 struct TensorEvaluator
00029 {
00030   typedef typename Derived::Index Index;
00031   typedef typename Derived::Scalar Scalar;
00032   typedef typename Derived::Scalar CoeffReturnType;
00033   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00034   typedef typename Derived::Dimensions Dimensions;
00035 
00036   // NumDimensions is -1 for variable dim tensors
00037   static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
00038                                internal::traits<Derived>::NumDimensions : 0;
00039 
00040   enum {
00041     IsAligned = Derived::IsAligned,
00042     PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
00043     Layout = Derived::Layout,
00044     CoordAccess = NumCoords > 0,
00045     RawAccess = true
00046   };
00047 
00048   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
00049       : m_data(const_cast<typename internal::traits<Derived>::template MakePointer<Scalar>::Type>(m.data())), m_dims(m.dimensions()), m_device(device), m_impl(m)
00050   { }
00051 
00052   // Used for accessor extraction in SYCL Managed TensorMap:
00053   const Derived& derived() const { return m_impl; }
00054   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
00055 
00056   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* dest) {
00057     if (dest) {
00058       m_device.memcpy((void*)dest, m_data, sizeof(Scalar) * m_dims.TotalSize());
00059       return false;
00060     }
00061     return true;
00062   }
00063 
00064   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
00065 
00066   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
00067     eigen_assert(m_data);
00068     return m_data[index];
00069   }
00070 
00071   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
00072     eigen_assert(m_data);
00073     return m_data[index];
00074   }
00075 
00076   template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
00077   PacketReturnType packet(Index index) const
00078   {
00079     return internal::ploadt<PacketReturnType, LoadMode>(m_data + index);
00080   }
00081 
00082   template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
00083   void writePacket(Index index, const PacketReturnType& x)
00084   {
00085     return internal::pstoret<Scalar, PacketReturnType, StoreMode>(m_data + index, x);
00086   }
00087 
00088   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
00089     eigen_assert(m_data);
00090     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
00091       return m_data[m_dims.IndexOfColMajor(coords)];
00092     } else {
00093       return m_data[m_dims.IndexOfRowMajor(coords)];
00094     }
00095   }
00096 
00097   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) {
00098     eigen_assert(m_data);
00099     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
00100       return m_data[m_dims.IndexOfColMajor(coords)];
00101     } else {
00102       return m_data[m_dims.IndexOfRowMajor(coords)];
00103     }
00104   }
00105 
00106   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
00107     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
00108                         internal::unpacket_traits<PacketReturnType>::size);
00109   }
00110 
00111   EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<Scalar>::Type data() const { return m_data; }
00112 
00114   const Device& device() const{return m_device;}
00115 
00116  protected:
00117   typename internal::traits<Derived>::template MakePointer<Scalar>::Type m_data;
00118   Dimensions m_dims;
00119   const Device& m_device;
00120   const Derived& m_impl;
00121 };
00122 
00123 namespace {
00124 template <typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
00125 T loadConstant(const T* address) {
00126   return *address;
00127 }
00128 // Use the texture cache on CUDA devices whenever possible
00129 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
00130 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
00131 float loadConstant(const float* address) {
00132   return __ldg(address);
00133 }
00134 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
00135 double loadConstant(const double* address) {
00136   return __ldg(address);
00137 }
00138 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
00139 Eigen::half loadConstant(const Eigen::half* address) {
00140   return Eigen::half(half_impl::raw_uint16_to_half(__ldg(&address->x)));
00141 }
00142 #endif
00143 }
00144 
00145 
00146 // Default evaluator for rvalues
00147 template<typename Derived, typename Device>
00148 struct TensorEvaluator<const Derived, Device>
00149 {
00150   typedef typename Derived::Index Index;
00151   typedef typename Derived::Scalar Scalar;
00152   typedef typename Derived::Scalar CoeffReturnType;
00153   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00154   typedef typename Derived::Dimensions Dimensions;
00155 
00156   // NumDimensions is -1 for variable dim tensors
00157   static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
00158                                internal::traits<Derived>::NumDimensions : 0;
00159 
00160   enum {
00161     IsAligned = Derived::IsAligned,
00162     PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
00163     Layout = Derived::Layout,
00164     CoordAccess = NumCoords > 0,
00165     RawAccess = true
00166   };
00167 
00168   // Used for accessor extraction in SYCL Managed TensorMap:
00169   const Derived& derived() const { return m_impl; }
00170 
00171   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
00172       : m_data(m.data()), m_dims(m.dimensions()), m_device(device), m_impl(m)
00173   { }
00174 
00175   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
00176 
00177   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
00178     if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization && data) {
00179       m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar));
00180       return false;
00181     }
00182     return true;
00183   }
00184 
00185   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
00186 
00187   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
00188     eigen_assert(m_data);
00189     return loadConstant(m_data+index);
00190   }
00191 
00192   template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
00193   PacketReturnType packet(Index index) const
00194   {
00195     return internal::ploadt_ro<PacketReturnType, LoadMode>(m_data + index);
00196   }
00197 
00198   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
00199     eigen_assert(m_data);
00200     const Index index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_dims.IndexOfColMajor(coords)
00201                         : m_dims.IndexOfRowMajor(coords);
00202     return loadConstant(m_data+index);
00203   }
00204 
00205   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
00206     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
00207                         internal::unpacket_traits<PacketReturnType>::size);
00208   }
00209 
00210   EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<const Scalar>::Type data() const { return m_data; }
00211 
00213   const Device& device() const{return m_device;}
00214 
00215  protected:
00216   typename internal::traits<Derived>::template MakePointer<const Scalar>::Type m_data;
00217   Dimensions m_dims;
00218   const Device& m_device;
00219   const Derived& m_impl;
00220 };
00221 
00222 
00223 
00224 
00225 // -------------------- CwiseNullaryOp --------------------
00226 
00227 template<typename NullaryOp, typename ArgType, typename Device>
00228 struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
00229 {
00230   typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
00231 
00232   enum {
00233     IsAligned = true,
00234     PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
00235     Layout = TensorEvaluator<ArgType, Device>::Layout,
00236     CoordAccess = false,  // to be implemented
00237     RawAccess = false
00238   };
00239 
00240   EIGEN_DEVICE_FUNC
00241   TensorEvaluator(const XprType& op, const Device& device)
00242       : m_functor(op.functor()), m_argImpl(op.nestedExpression(), device), m_wrapper()
00243   { }
00244 
00245   typedef typename XprType::Index Index;
00246   typedef typename XprType::Scalar Scalar;
00247   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
00248   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00249   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00250   typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
00251 
00252   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
00253 
00254   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; }
00255   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
00256 
00257   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
00258   {
00259     return m_wrapper(m_functor, index);
00260   }
00261 
00262   template<int LoadMode>
00263   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
00264   {
00265     return m_wrapper.template packetOp<PacketReturnType, Index>(m_functor, index);
00266   }
00267 
00268   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
00269   costPerCoeff(bool vectorized) const {
00270     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
00271                         internal::unpacket_traits<PacketReturnType>::size);
00272   }
00273 
00274   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
00275 
00277   const TensorEvaluator<ArgType, Device>& impl() const { return m_argImpl; }
00279   NullaryOp functor() const { return m_functor; }
00280 
00281 
00282  private:
00283   const NullaryOp m_functor;
00284   TensorEvaluator<ArgType, Device> m_argImpl;
00285   const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper;
00286 };
00287 
00288 
00289 
00290 // -------------------- CwiseUnaryOp --------------------
00291 
00292 template<typename UnaryOp, typename ArgType, typename Device>
00293 struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
00294 {
00295   typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType;
00296 
00297   enum {
00298     IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
00299     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
00300     Layout = TensorEvaluator<ArgType, Device>::Layout,
00301     CoordAccess = false,  // to be implemented
00302     RawAccess = false
00303   };
00304 
00305   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
00306     : m_functor(op.functor()),
00307       m_argImpl(op.nestedExpression(), device)
00308   { }
00309 
00310   typedef typename XprType::Index Index;
00311   typedef typename XprType::Scalar Scalar;
00312   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
00313   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00314   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00315   typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
00316 
00317   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
00318 
00319   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
00320     m_argImpl.evalSubExprsIfNeeded(NULL);
00321     return true;
00322   }
00323   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00324     m_argImpl.cleanup();
00325   }
00326 
00327   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
00328   {
00329     return m_functor(m_argImpl.coeff(index));
00330   }
00331 
00332   template<int LoadMode>
00333   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
00334   {
00335     return m_functor.packetOp(m_argImpl.template packet<LoadMode>(index));
00336   }
00337 
00338   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
00339     const double functor_cost = internal::functor_traits<UnaryOp>::Cost;
00340     return m_argImpl.costPerCoeff(vectorized) +
00341         TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
00342   }
00343 
00344   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
00345 
00347   const TensorEvaluator<ArgType, Device> & impl() const { return m_argImpl; }
00349   UnaryOp functor() const { return m_functor; }
00350 
00351 
00352  private:
00353   const UnaryOp m_functor;
00354   TensorEvaluator<ArgType, Device> m_argImpl;
00355 };
00356 
00357 
00358 // -------------------- CwiseBinaryOp --------------------
00359 
00360 template<typename BinaryOp, typename LeftArgType, typename RightArgType, typename Device>
00361 struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType>, Device>
00362 {
00363   typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType;
00364 
00365   enum {
00366     IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
00367     PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess &
00368                    internal::functor_traits<BinaryOp>::PacketAccess,
00369     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
00370     CoordAccess = false,  // to be implemented
00371     RawAccess = false
00372   };
00373 
00374   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
00375     : m_functor(op.functor()),
00376       m_leftImpl(op.lhsExpression(), device),
00377       m_rightImpl(op.rhsExpression(), device)
00378   {
00379     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
00380     eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
00381   }
00382 
00383   typedef typename XprType::Index Index;
00384   typedef typename XprType::Scalar Scalar;
00385   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
00386   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00387   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00388   typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
00389 
00390   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
00391   {
00392     // TODO: use right impl instead if right impl dimensions are known at compile time.
00393     return m_leftImpl.dimensions();
00394   }
00395 
00396   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
00397     m_leftImpl.evalSubExprsIfNeeded(NULL);
00398     m_rightImpl.evalSubExprsIfNeeded(NULL);
00399     return true;
00400   }
00401   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00402     m_leftImpl.cleanup();
00403     m_rightImpl.cleanup();
00404   }
00405 
00406   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
00407   {
00408     return m_functor(m_leftImpl.coeff(index), m_rightImpl.coeff(index));
00409   }
00410   template<int LoadMode>
00411   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
00412   {
00413     return m_functor.packetOp(m_leftImpl.template packet<LoadMode>(index), m_rightImpl.template packet<LoadMode>(index));
00414   }
00415 
00416   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
00417   costPerCoeff(bool vectorized) const {
00418     const double functor_cost = internal::functor_traits<BinaryOp>::Cost;
00419     return m_leftImpl.costPerCoeff(vectorized) +
00420            m_rightImpl.costPerCoeff(vectorized) +
00421            TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
00422   }
00423 
00424   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
00426   const TensorEvaluator<LeftArgType, Device>& left_impl() const { return m_leftImpl; }
00428   const TensorEvaluator<RightArgType, Device>& right_impl() const { return m_rightImpl; }
00430   BinaryOp functor() const { return m_functor; }
00431 
00432  private:
00433   const BinaryOp m_functor;
00434   TensorEvaluator<LeftArgType, Device> m_leftImpl;
00435   TensorEvaluator<RightArgType, Device> m_rightImpl;
00436 };
00437 
00438 // -------------------- CwiseTernaryOp --------------------
00439 
00440 template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
00441 struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
00442 {
00443   typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
00444 
00445   enum {
00446     IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
00447     PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
00448                    internal::functor_traits<TernaryOp>::PacketAccess,
00449     Layout = TensorEvaluator<Arg1Type, Device>::Layout,
00450     CoordAccess = false,  // to be implemented
00451     RawAccess = false
00452   };
00453 
00454   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
00455     : m_functor(op.functor()),
00456       m_arg1Impl(op.arg1Expression(), device),
00457       m_arg2Impl(op.arg2Expression(), device),
00458       m_arg3Impl(op.arg3Expression(), device)
00459   {
00460     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
00461 
00462     EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind,
00463                          typename internal::traits<Arg2Type>::StorageKind>::value),
00464                         STORAGE_KIND_MUST_MATCH)
00465     EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind,
00466                          typename internal::traits<Arg3Type>::StorageKind>::value),
00467                         STORAGE_KIND_MUST_MATCH)
00468     EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index,
00469                          typename internal::traits<Arg2Type>::Index>::value),
00470                         STORAGE_INDEX_MUST_MATCH)
00471     EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index,
00472                          typename internal::traits<Arg3Type>::Index>::value),
00473                         STORAGE_INDEX_MUST_MATCH)
00474 
00475     eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
00476   }
00477 
00478   typedef typename XprType::Index Index;
00479   typedef typename XprType::Scalar Scalar;
00480   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
00481   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00482   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00483   typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
00484 
00485   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
00486   {
00487     // TODO: use arg2 or arg3 dimensions if they are known at compile time.
00488     return m_arg1Impl.dimensions();
00489   }
00490 
00491   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
00492     m_arg1Impl.evalSubExprsIfNeeded(NULL);
00493     m_arg2Impl.evalSubExprsIfNeeded(NULL);
00494     m_arg3Impl.evalSubExprsIfNeeded(NULL);
00495     return true;
00496   }
00497   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00498     m_arg1Impl.cleanup();
00499     m_arg2Impl.cleanup();
00500     m_arg3Impl.cleanup();
00501   }
00502 
00503   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
00504   {
00505     return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
00506   }
00507   template<int LoadMode>
00508   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
00509   {
00510     return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
00511                               m_arg2Impl.template packet<LoadMode>(index),
00512                               m_arg3Impl.template packet<LoadMode>(index));
00513   }
00514 
00515   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
00516   costPerCoeff(bool vectorized) const {
00517     const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
00518     return m_arg1Impl.costPerCoeff(vectorized) +
00519            m_arg2Impl.costPerCoeff(vectorized) +
00520            m_arg3Impl.costPerCoeff(vectorized) +
00521            TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
00522   }
00523 
00524   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
00525 
00527   const TensorEvaluator<Arg1Type, Device> & arg1Impl() const { return m_arg1Impl; }
00529   const TensorEvaluator<Arg2Type, Device>& arg2Impl() const { return m_arg2Impl; }
00531   const TensorEvaluator<Arg3Type, Device>& arg3Impl() const { return m_arg3Impl; }
00532 
00533  private:
00534   const TernaryOp m_functor;
00535   TensorEvaluator<Arg1Type, Device> m_arg1Impl;
00536   TensorEvaluator<Arg2Type, Device> m_arg2Impl;
00537   TensorEvaluator<Arg3Type, Device> m_arg3Impl;
00538 };
00539 
00540 
00541 // -------------------- SelectOp --------------------
00542 
00543 template<typename IfArgType, typename ThenArgType, typename ElseArgType, typename Device>
00544 struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device>
00545 {
00546   typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
00547   typedef typename XprType::Scalar Scalar;
00548 
00549   enum {
00550     IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
00551     PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess &
00552                    internal::packet_traits<Scalar>::HasBlend,
00553     Layout = TensorEvaluator<IfArgType, Device>::Layout,
00554     CoordAccess = false,  // to be implemented
00555     RawAccess = false
00556   };
00557 
00558   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
00559     : m_condImpl(op.ifExpression(), device),
00560       m_thenImpl(op.thenExpression(), device),
00561       m_elseImpl(op.elseExpression(), device)
00562   {
00563     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ThenArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
00564     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ElseArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
00565     eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
00566     eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
00567   }
00568 
00569   typedef typename XprType::Index Index;
00570   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
00571   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00572   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00573   typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
00574 
00575   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
00576   {
00577     // TODO: use then or else impl instead if they happen to be known at compile time.
00578     return m_condImpl.dimensions();
00579   }
00580 
00581   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
00582     m_condImpl.evalSubExprsIfNeeded(NULL);
00583     m_thenImpl.evalSubExprsIfNeeded(NULL);
00584     m_elseImpl.evalSubExprsIfNeeded(NULL);
00585     return true;
00586   }
00587   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00588     m_condImpl.cleanup();
00589     m_thenImpl.cleanup();
00590     m_elseImpl.cleanup();
00591   }
00592 
00593   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
00594   {
00595     return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
00596   }
00597   template<int LoadMode>
00598   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
00599   {
00600     internal::Selector<PacketSize> select;
00601     for (Index i = 0; i < PacketSize; ++i) {
00602       select.select[i] = m_condImpl.coeff(index+i);
00603     }
00604     return internal::pblend(select,
00605                             m_thenImpl.template packet<LoadMode>(index),
00606                             m_elseImpl.template packet<LoadMode>(index));
00607   }
00608 
00609   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
00610   costPerCoeff(bool vectorized) const {
00611     return m_condImpl.costPerCoeff(vectorized) +
00612            m_thenImpl.costPerCoeff(vectorized)
00613         .cwiseMax(m_elseImpl.costPerCoeff(vectorized));
00614   }
00615 
00616   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType* data() const { return NULL; }
00618   const TensorEvaluator<IfArgType, Device> & cond_impl() const { return m_condImpl; }
00620   const TensorEvaluator<ThenArgType, Device>& then_impl() const { return m_thenImpl; }
00622   const TensorEvaluator<ElseArgType, Device>& else_impl() const { return m_elseImpl; }
00623 
00624  private:
00625   TensorEvaluator<IfArgType, Device> m_condImpl;
00626   TensorEvaluator<ThenArgType, Device> m_thenImpl;
00627   TensorEvaluator<ElseArgType, Device> m_elseImpl;
00628 };
00629 
00630 
00631 } // end namespace Eigen
00632 
00633 #endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
 All Classes Functions Variables Typedefs Enumerator