TensorCustomOp.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
00011 #define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
00012 
00013 namespace Eigen {
00014 
00022 namespace internal {
00023 template<typename CustomUnaryFunc, typename XprType>
00024 struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
00025 {
00026   typedef typename XprType::Scalar Scalar;
00027   typedef typename XprType::StorageKind StorageKind;
00028   typedef typename XprType::Index Index;
00029   typedef typename XprType::Nested Nested;
00030   typedef typename remove_reference<Nested>::type _Nested;
00031   static const int NumDimensions = traits<XprType>::NumDimensions;
00032   static const int Layout = traits<XprType>::Layout;
00033 };
00034 
00035 template<typename CustomUnaryFunc, typename XprType>
00036 struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
00037 {
00038   typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type;
00039 };
00040 
00041 template<typename CustomUnaryFunc, typename XprType>
00042 struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
00043 {
00044   typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
00045 };
00046 
00047 }  // end namespace internal
00048 
00049 
00050 
00051 template<typename CustomUnaryFunc, typename XprType>
00052 class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
00053 {
00054   public:
00055   typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
00056   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00057   typedef typename XprType::CoeffReturnType CoeffReturnType;
00058   typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
00059   typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
00060   typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
00061 
00062   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
00063       : m_expr(expr), m_func(func) {}
00064 
00065   EIGEN_DEVICE_FUNC
00066   const CustomUnaryFunc& func() const { return m_func; }
00067 
00068   EIGEN_DEVICE_FUNC
00069   const typename internal::remove_all<typename XprType::Nested>::type&
00070   expression() const { return m_expr; }
00071 
00072   protected:
00073     typename XprType::Nested m_expr;
00074     const CustomUnaryFunc m_func;
00075 };
00076 
00077 
00078 // Eval as rvalue
00079 template<typename CustomUnaryFunc, typename XprType, typename Device>
00080 struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
00081 {
00082   typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType;
00083   typedef typename internal::traits<ArgType>::Index Index;
00084   static const int NumDims = internal::traits<ArgType>::NumDimensions;
00085   typedef DSizes<Index, NumDims> Dimensions;
00086   typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
00087   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
00088   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00089   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00090 
00091   enum {
00092     IsAligned = false,
00093     PacketAccess = (internal::packet_traits<Scalar>::size > 1),
00094     BlockAccess = false,
00095     Layout = TensorEvaluator<XprType, Device>::Layout,
00096     CoordAccess = false,  // to be implemented
00097     RawAccess = false
00098   };
00099 
00100   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
00101       : m_op(op), m_device(device), m_result(NULL)
00102   {
00103     m_dimensions = op.func().dimensions(op.expression());
00104   }
00105 
00106   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
00107 
00108   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
00109     if (data) {
00110       evalTo(data);
00111       return false;
00112     } else {
00113       m_result = static_cast<CoeffReturnType*>(
00114           m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
00115       evalTo(m_result);
00116       return true;
00117     }
00118   }
00119 
00120   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00121     if (m_result != NULL) {
00122       m_device.deallocate(m_result);
00123       m_result = NULL;
00124     }
00125   }
00126 
00127   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
00128     return m_result[index];
00129   }
00130 
00131   template<int LoadMode>
00132   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
00133     return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
00134   }
00135 
00136   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
00137     // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
00138     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
00139   }
00140 
00141   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
00142 
00143  protected:
00144   EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
00145     TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(
00146         data, m_dimensions);
00147     m_op.func().eval(m_op.expression(), result, m_device);
00148   }
00149 
00150   Dimensions m_dimensions;
00151   const ArgType m_op;
00152   const Device& m_device;
00153   CoeffReturnType* m_result;
00154 };
00155 
00156 
00157 
00165 namespace internal {
00166 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
00167 struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
00168 {
00169   typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
00170                                                   typename RhsXprType::Scalar>::ret Scalar;
00171   typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
00172                                                   typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
00173   typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
00174                                         typename traits<RhsXprType>::StorageKind>::ret StorageKind;
00175   typedef typename promote_index_type<typename traits<LhsXprType>::Index,
00176                                       typename traits<RhsXprType>::Index>::type Index;
00177   typedef typename LhsXprType::Nested LhsNested;
00178   typedef typename RhsXprType::Nested RhsNested;
00179   typedef typename remove_reference<LhsNested>::type _LhsNested;
00180   typedef typename remove_reference<RhsNested>::type _RhsNested;
00181   static const int NumDimensions = traits<LhsXprType>::NumDimensions;
00182   static const int Layout = traits<LhsXprType>::Layout;
00183 };
00184 
00185 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
00186 struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
00187 {
00188   typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
00189 };
00190 
00191 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
00192 struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
00193 {
00194   typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
00195 };
00196 
00197 }  // end namespace internal
00198 
00199 
00200 
00201 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
00202 class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
00203 {
00204   public:
00205   typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
00206   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00207   typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
00208   typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
00209   typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
00210   typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
00211 
00212   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
00213 
00214       : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
00215 
00216   EIGEN_DEVICE_FUNC
00217   const CustomBinaryFunc& func() const { return m_func; }
00218 
00219   EIGEN_DEVICE_FUNC
00220   const typename internal::remove_all<typename LhsXprType::Nested>::type&
00221   lhsExpression() const { return m_lhs_xpr; }
00222 
00223   EIGEN_DEVICE_FUNC
00224   const typename internal::remove_all<typename RhsXprType::Nested>::type&
00225   rhsExpression() const { return m_rhs_xpr; }
00226 
00227   protected:
00228     typename LhsXprType::Nested m_lhs_xpr;
00229     typename RhsXprType::Nested m_rhs_xpr;
00230     const CustomBinaryFunc m_func;
00231 };
00232 
00233 
00234 // Eval as rvalue
00235 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
00236 struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
00237 {
00238   typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType;
00239   typedef typename internal::traits<XprType>::Index Index;
00240   static const int NumDims = internal::traits<XprType>::NumDimensions;
00241   typedef DSizes<Index, NumDims> Dimensions;
00242   typedef typename XprType::Scalar Scalar;
00243   typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
00244   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00245   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00246 
00247   enum {
00248     IsAligned = false,
00249     PacketAccess = (internal::packet_traits<Scalar>::size > 1),
00250     BlockAccess = false,
00251     Layout = TensorEvaluator<LhsXprType, Device>::Layout,
00252     CoordAccess = false,  // to be implemented
00253     RawAccess = false
00254   };
00255 
00256   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
00257       : m_op(op), m_device(device), m_result(NULL)
00258   {
00259     m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
00260   }
00261 
00262   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
00263 
00264   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
00265     if (data) {
00266       evalTo(data);
00267       return false;
00268     } else {
00269       m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
00270       evalTo(m_result);
00271       return true;
00272     }
00273   }
00274 
00275   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00276     if (m_result != NULL) {
00277       m_device.deallocate(m_result);
00278       m_result = NULL;
00279     }
00280   }
00281 
00282   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
00283     return m_result[index];
00284   }
00285 
00286   template<int LoadMode>
00287   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
00288     return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
00289   }
00290 
00291   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
00292     // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
00293     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
00294   }
00295 
00296   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
00297 
00298  protected:
00299   EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
00300     TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions);
00301     m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
00302   }
00303 
00304   Dimensions m_dimensions;
00305   const XprType m_op;
00306   const Device& m_device;
00307   CoeffReturnType* m_result;
00308 };
00309 
00310 
00311 } // end namespace Eigen
00312 
00313 #endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
 All Classes Functions Variables Typedefs Enumerator