![]() |
Eigen-unsupported
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H 00011 #define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H 00012 00013 namespace Eigen { 00014 00022 namespace internal { 00023 template<typename CustomUnaryFunc, typename XprType> 00024 struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > 00025 { 00026 typedef typename XprType::Scalar Scalar; 00027 typedef typename XprType::StorageKind StorageKind; 00028 typedef typename XprType::Index Index; 00029 typedef typename XprType::Nested Nested; 00030 typedef typename remove_reference<Nested>::type _Nested; 00031 static const int NumDimensions = traits<XprType>::NumDimensions; 00032 static const int Layout = traits<XprType>::Layout; 00033 }; 00034 00035 template<typename CustomUnaryFunc, typename XprType> 00036 struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense> 00037 { 00038 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type; 00039 }; 00040 00041 template<typename CustomUnaryFunc, typename XprType> 00042 struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > 00043 { 00044 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type; 00045 }; 00046 00047 } // end namespace internal 00048 00049 00050 00051 template<typename CustomUnaryFunc, typename XprType> 00052 class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors> 00053 { 00054 public: 00055 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar; 00056 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 00057 typedef typename XprType::CoeffReturnType CoeffReturnType; 00058 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested; 00059 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind; 00060 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index; 00061 00062 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func) 00063 : m_expr(expr), m_func(func) {} 00064 00065 EIGEN_DEVICE_FUNC 00066 const CustomUnaryFunc& func() const { return m_func; } 00067 00068 EIGEN_DEVICE_FUNC 00069 const typename internal::remove_all<typename XprType::Nested>::type& 00070 expression() const { return m_expr; } 00071 00072 protected: 00073 typename XprType::Nested m_expr; 00074 const CustomUnaryFunc m_func; 00075 }; 00076 00077 00078 // Eval as rvalue 00079 template<typename CustomUnaryFunc, typename XprType, typename Device> 00080 struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device> 00081 { 00082 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> ArgType; 00083 typedef typename internal::traits<ArgType>::Index Index; 00084 static const int NumDims = internal::traits<ArgType>::NumDimensions; 00085 typedef DSizes<Index, NumDims> Dimensions; 00086 typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar; 00087 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; 00088 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 00089 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 00090 00091 enum { 00092 IsAligned = false, 00093 PacketAccess = (internal::packet_traits<Scalar>::size > 1), 00094 BlockAccess = false, 00095 Layout = TensorEvaluator<XprType, Device>::Layout, 00096 CoordAccess = false, // to be implemented 00097 RawAccess = false 00098 }; 00099 00100 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device) 00101 : m_op(op), m_device(device), m_result(NULL) 00102 { 00103 m_dimensions = op.func().dimensions(op.expression()); 00104 } 00105 00106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 00107 00108 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { 00109 if (data) { 00110 evalTo(data); 00111 return false; 00112 } else { 00113 m_result = static_cast<CoeffReturnType*>( 00114 m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 00115 evalTo(m_result); 00116 return true; 00117 } 00118 } 00119 00120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 00121 if (m_result != NULL) { 00122 m_device.deallocate(m_result); 00123 m_result = NULL; 00124 } 00125 } 00126 00127 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 00128 return m_result[index]; 00129 } 00130 00131 template<int LoadMode> 00132 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { 00133 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 00134 } 00135 00136 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 00137 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate. 00138 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); 00139 } 00140 00141 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } 00142 00143 protected: 00144 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { 00145 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result( 00146 data, m_dimensions); 00147 m_op.func().eval(m_op.expression(), result, m_device); 00148 } 00149 00150 Dimensions m_dimensions; 00151 const ArgType m_op; 00152 const Device& m_device; 00153 CoeffReturnType* m_result; 00154 }; 00155 00156 00157 00165 namespace internal { 00166 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 00167 struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > 00168 { 00169 typedef typename internal::promote_storage_type<typename LhsXprType::Scalar, 00170 typename RhsXprType::Scalar>::ret Scalar; 00171 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, 00172 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; 00173 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 00174 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 00175 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 00176 typename traits<RhsXprType>::Index>::type Index; 00177 typedef typename LhsXprType::Nested LhsNested; 00178 typedef typename RhsXprType::Nested RhsNested; 00179 typedef typename remove_reference<LhsNested>::type _LhsNested; 00180 typedef typename remove_reference<RhsNested>::type _RhsNested; 00181 static const int NumDimensions = traits<LhsXprType>::NumDimensions; 00182 static const int Layout = traits<LhsXprType>::Layout; 00183 }; 00184 00185 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 00186 struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense> 00187 { 00188 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type; 00189 }; 00190 00191 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 00192 struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > 00193 { 00194 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type; 00195 }; 00196 00197 } // end namespace internal 00198 00199 00200 00201 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> 00202 class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors> 00203 { 00204 public: 00205 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar; 00206 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 00207 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType; 00208 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested; 00209 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind; 00210 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index; 00211 00212 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func) 00213 00214 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {} 00215 00216 EIGEN_DEVICE_FUNC 00217 const CustomBinaryFunc& func() const { return m_func; } 00218 00219 EIGEN_DEVICE_FUNC 00220 const typename internal::remove_all<typename LhsXprType::Nested>::type& 00221 lhsExpression() const { return m_lhs_xpr; } 00222 00223 EIGEN_DEVICE_FUNC 00224 const typename internal::remove_all<typename RhsXprType::Nested>::type& 00225 rhsExpression() const { return m_rhs_xpr; } 00226 00227 protected: 00228 typename LhsXprType::Nested m_lhs_xpr; 00229 typename RhsXprType::Nested m_rhs_xpr; 00230 const CustomBinaryFunc m_func; 00231 }; 00232 00233 00234 // Eval as rvalue 00235 template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device> 00236 struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device> 00237 { 00238 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> XprType; 00239 typedef typename internal::traits<XprType>::Index Index; 00240 static const int NumDims = internal::traits<XprType>::NumDimensions; 00241 typedef DSizes<Index, NumDims> Dimensions; 00242 typedef typename XprType::Scalar Scalar; 00243 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; 00244 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 00245 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; 00246 00247 enum { 00248 IsAligned = false, 00249 PacketAccess = (internal::packet_traits<Scalar>::size > 1), 00250 BlockAccess = false, 00251 Layout = TensorEvaluator<LhsXprType, Device>::Layout, 00252 CoordAccess = false, // to be implemented 00253 RawAccess = false 00254 }; 00255 00256 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) 00257 : m_op(op), m_device(device), m_result(NULL) 00258 { 00259 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression()); 00260 } 00261 00262 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 00263 00264 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { 00265 if (data) { 00266 evalTo(data); 00267 return false; 00268 } else { 00269 m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 00270 evalTo(m_result); 00271 return true; 00272 } 00273 } 00274 00275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { 00276 if (m_result != NULL) { 00277 m_device.deallocate(m_result); 00278 m_result = NULL; 00279 } 00280 } 00281 00282 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 00283 return m_result[index]; 00284 } 00285 00286 template<int LoadMode> 00287 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const { 00288 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 00289 } 00290 00291 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { 00292 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate. 00293 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); 00294 } 00295 00296 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; } 00297 00298 protected: 00299 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { 00300 TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions); 00301 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device); 00302 } 00303 00304 Dimensions m_dimensions; 00305 const XprType m_op; 00306 const Device& m_device; 00307 CoeffReturnType* m_result; 00308 }; 00309 00310 00311 } // end namespace Eigen 00312 00313 #endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H