![]() |
Eigen-unsupported
3.3.3
|
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 00011 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H 00012 00013 namespace Eigen { 00014 00030 namespace internal { 00031 template<typename NullaryOp, typename XprType> 00032 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > 00033 : traits<XprType> 00034 { 00035 typedef traits<XprType> XprTraits; 00036 typedef typename XprType::Scalar Scalar; 00037 typedef typename XprType::Nested XprTypeNested; 00038 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; 00039 static const int NumDimensions = XprTraits::NumDimensions; 00040 static const int Layout = XprTraits::Layout; 00041 00042 enum { 00043 Flags = 0 00044 }; 00045 }; 00046 00047 } // end namespace internal 00048 00049 00050 00051 template<typename NullaryOp, typename XprType> 00052 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors> 00053 { 00054 public: 00055 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; 00056 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 00057 typedef typename XprType::CoeffReturnType CoeffReturnType; 00058 typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested; 00059 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind; 00060 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index; 00061 00062 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp()) 00063 : m_xpr(xpr), m_functor(func) {} 00064 00065 EIGEN_DEVICE_FUNC 00066 const typename internal::remove_all<typename XprType::Nested>::type& 00067 nestedExpression() const { return m_xpr; } 00068 00069 EIGEN_DEVICE_FUNC 00070 const NullaryOp& functor() const { return m_functor; } 00071 00072 protected: 00073 typename XprType::Nested m_xpr; 00074 const NullaryOp m_functor; 00075 }; 00076 00077 00078 00079 namespace internal { 00080 template<typename UnaryOp, typename XprType> 00081 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > 00082 : traits<XprType> 00083 { 00084 // TODO(phli): Add InputScalar, InputPacket. Check references to 00085 // current Scalar/Packet to see if the intent is Input or Output. 00086 typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar; 00087 typedef traits<XprType> XprTraits; 00088 typedef typename XprType::Nested XprTypeNested; 00089 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; 00090 static const int NumDimensions = XprTraits::NumDimensions; 00091 static const int Layout = XprTraits::Layout; 00092 }; 00093 00094 template<typename UnaryOp, typename XprType> 00095 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense> 00096 { 00097 typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type; 00098 }; 00099 00100 template<typename UnaryOp, typename XprType> 00101 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type> 00102 { 00103 typedef TensorCwiseUnaryOp<UnaryOp, XprType> type; 00104 }; 00105 00106 } // end namespace internal 00107 00108 00109 00110 template<typename UnaryOp, typename XprType> 00111 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors> 00112 { 00113 public: 00114 // TODO(phli): Add InputScalar, InputPacket. Check references to 00115 // current Scalar/Packet to see if the intent is Input or Output. 00116 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar; 00117 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 00118 typedef Scalar CoeffReturnType; 00119 typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested; 00120 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind; 00121 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index; 00122 00123 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp()) 00124 : m_xpr(xpr), m_functor(func) {} 00125 00126 EIGEN_DEVICE_FUNC 00127 const UnaryOp& functor() const { return m_functor; } 00128 00130 EIGEN_DEVICE_FUNC 00131 const typename internal::remove_all<typename XprType::Nested>::type& 00132 nestedExpression() const { return m_xpr; } 00133 00134 protected: 00135 typename XprType::Nested m_xpr; 00136 const UnaryOp m_functor; 00137 }; 00138 00139 00140 namespace internal { 00141 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 00142 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > 00143 { 00144 // Type promotion to handle the case where the types of the lhs and the rhs 00145 // are different. 00146 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to 00147 // current Scalar/Packet to see if the intent is Inputs or Output. 00148 typedef typename result_of< 00149 BinaryOp(typename LhsXprType::Scalar, 00150 typename RhsXprType::Scalar)>::type Scalar; 00151 typedef traits<LhsXprType> XprTraits; 00152 typedef typename promote_storage_type< 00153 typename traits<LhsXprType>::StorageKind, 00154 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 00155 typedef typename promote_index_type< 00156 typename traits<LhsXprType>::Index, 00157 typename traits<RhsXprType>::Index>::type Index; 00158 typedef typename LhsXprType::Nested LhsNested; 00159 typedef typename RhsXprType::Nested RhsNested; 00160 typedef typename remove_reference<LhsNested>::type _LhsNested; 00161 typedef typename remove_reference<RhsNested>::type _RhsNested; 00162 static const int NumDimensions = XprTraits::NumDimensions; 00163 static const int Layout = XprTraits::Layout; 00164 00165 enum { 00166 Flags = 0 00167 }; 00168 }; 00169 00170 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 00171 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense> 00172 { 00173 typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type; 00174 }; 00175 00176 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 00177 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type> 00178 { 00179 typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type; 00180 }; 00181 00182 } // end namespace internal 00183 00184 00185 00186 template<typename BinaryOp, typename LhsXprType, typename RhsXprType> 00187 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors> 00188 { 00189 public: 00190 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to 00191 // current Scalar/Packet to see if the intent is Inputs or Output. 00192 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar; 00193 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 00194 typedef Scalar CoeffReturnType; 00195 typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested; 00196 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind; 00197 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index; 00198 00199 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) 00200 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} 00201 00202 EIGEN_DEVICE_FUNC 00203 const BinaryOp& functor() const { return m_functor; } 00204 00206 EIGEN_DEVICE_FUNC 00207 const typename internal::remove_all<typename LhsXprType::Nested>::type& 00208 lhsExpression() const { return m_lhs_xpr; } 00209 00210 EIGEN_DEVICE_FUNC 00211 const typename internal::remove_all<typename RhsXprType::Nested>::type& 00212 rhsExpression() const { return m_rhs_xpr; } 00213 00214 protected: 00215 typename LhsXprType::Nested m_lhs_xpr; 00216 typename RhsXprType::Nested m_rhs_xpr; 00217 const BinaryOp m_functor; 00218 }; 00219 00220 00221 namespace internal { 00222 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 00223 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> > 00224 { 00225 // Type promotion to handle the case where the types of the args are different. 00226 typedef typename result_of< 00227 TernaryOp(typename Arg1XprType::Scalar, 00228 typename Arg2XprType::Scalar, 00229 typename Arg3XprType::Scalar)>::type Scalar; 00230 typedef traits<Arg1XprType> XprTraits; 00231 typedef typename traits<Arg1XprType>::StorageKind StorageKind; 00232 typedef typename traits<Arg1XprType>::Index Index; 00233 typedef typename Arg1XprType::Nested Arg1Nested; 00234 typedef typename Arg2XprType::Nested Arg2Nested; 00235 typedef typename Arg3XprType::Nested Arg3Nested; 00236 typedef typename remove_reference<Arg1Nested>::type _Arg1Nested; 00237 typedef typename remove_reference<Arg2Nested>::type _Arg2Nested; 00238 typedef typename remove_reference<Arg3Nested>::type _Arg3Nested; 00239 static const int NumDimensions = XprTraits::NumDimensions; 00240 static const int Layout = XprTraits::Layout; 00241 00242 enum { 00243 Flags = 0 00244 }; 00245 }; 00246 00247 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 00248 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense> 00249 { 00250 typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type; 00251 }; 00252 00253 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 00254 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type> 00255 { 00256 typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type; 00257 }; 00258 00259 } // end namespace internal 00260 00261 00262 00263 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> 00264 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors> 00265 { 00266 public: 00267 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar; 00268 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 00269 typedef Scalar CoeffReturnType; 00270 typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested; 00271 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind; 00272 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index; 00273 00274 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp()) 00275 : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {} 00276 00277 EIGEN_DEVICE_FUNC 00278 const TernaryOp& functor() const { return m_functor; } 00279 00281 EIGEN_DEVICE_FUNC 00282 const typename internal::remove_all<typename Arg1XprType::Nested>::type& 00283 arg1Expression() const { return m_arg1_xpr; } 00284 00285 EIGEN_DEVICE_FUNC 00286 const typename internal::remove_all<typename Arg2XprType::Nested>::type& 00287 arg2Expression() const { return m_arg2_xpr; } 00288 00289 EIGEN_DEVICE_FUNC 00290 const typename internal::remove_all<typename Arg3XprType::Nested>::type& 00291 arg3Expression() const { return m_arg3_xpr; } 00292 00293 protected: 00294 typename Arg1XprType::Nested m_arg1_xpr; 00295 typename Arg2XprType::Nested m_arg2_xpr; 00296 typename Arg3XprType::Nested m_arg3_xpr; 00297 const TernaryOp m_functor; 00298 }; 00299 00300 00301 namespace internal { 00302 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 00303 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > 00304 : traits<ThenXprType> 00305 { 00306 typedef typename traits<ThenXprType>::Scalar Scalar; 00307 typedef traits<ThenXprType> XprTraits; 00308 typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind, 00309 typename traits<ElseXprType>::StorageKind>::ret StorageKind; 00310 typedef typename promote_index_type<typename traits<ElseXprType>::Index, 00311 typename traits<ThenXprType>::Index>::type Index; 00312 typedef typename IfXprType::Nested IfNested; 00313 typedef typename ThenXprType::Nested ThenNested; 00314 typedef typename ElseXprType::Nested ElseNested; 00315 static const int NumDimensions = XprTraits::NumDimensions; 00316 static const int Layout = XprTraits::Layout; 00317 }; 00318 00319 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 00320 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense> 00321 { 00322 typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type; 00323 }; 00324 00325 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 00326 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type> 00327 { 00328 typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type; 00329 }; 00330 00331 } // end namespace internal 00332 00333 00334 template<typename IfXprType, typename ThenXprType, typename ElseXprType> 00335 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors> 00336 { 00337 public: 00338 typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar; 00339 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 00340 typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType, 00341 typename ElseXprType::CoeffReturnType>::ret CoeffReturnType; 00342 typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested; 00343 typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind; 00344 typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index; 00345 00346 EIGEN_DEVICE_FUNC 00347 TensorSelectOp(const IfXprType& a_condition, 00348 const ThenXprType& a_then, 00349 const ElseXprType& a_else) 00350 : m_condition(a_condition), m_then(a_then), m_else(a_else) 00351 { } 00352 00353 EIGEN_DEVICE_FUNC 00354 const IfXprType& ifExpression() const { return m_condition; } 00355 00356 EIGEN_DEVICE_FUNC 00357 const ThenXprType& thenExpression() const { return m_then; } 00358 00359 EIGEN_DEVICE_FUNC 00360 const ElseXprType& elseExpression() const { return m_else; } 00361 00362 protected: 00363 typename IfXprType::Nested m_condition; 00364 typename ThenXprType::Nested m_then; 00365 typename ElseXprType::Nested m_else; 00366 }; 00367 00368 00369 } // end namespace Eigen 00370 00371 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H