TensorExpr.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
00011 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
00012 
00013 namespace Eigen {
00014 
00030 namespace internal {
00031 template<typename NullaryOp, typename XprType>
00032 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
00033     : traits<XprType>
00034 {
00035   typedef traits<XprType> XprTraits;
00036   typedef typename XprType::Scalar Scalar;
00037   typedef typename XprType::Nested XprTypeNested;
00038   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
00039   static const int NumDimensions = XprTraits::NumDimensions;
00040   static const int Layout = XprTraits::Layout;
00041 
00042   enum {
00043     Flags = 0
00044   };
00045 };
00046 
00047 }  // end namespace internal
00048 
00049 
00050 
00051 template<typename NullaryOp, typename XprType>
00052 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
00053 {
00054   public:
00055     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
00056     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00057     typedef typename XprType::CoeffReturnType CoeffReturnType;
00058     typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
00059     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
00060     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
00061 
00062     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
00063         : m_xpr(xpr), m_functor(func) {}
00064 
00065     EIGEN_DEVICE_FUNC
00066     const typename internal::remove_all<typename XprType::Nested>::type&
00067     nestedExpression() const { return m_xpr; }
00068 
00069     EIGEN_DEVICE_FUNC
00070     const NullaryOp& functor() const { return m_functor; }
00071 
00072   protected:
00073     typename XprType::Nested m_xpr;
00074     const NullaryOp m_functor;
00075 };
00076 
00077 
00078 
00079 namespace internal {
00080 template<typename UnaryOp, typename XprType>
00081 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
00082     : traits<XprType>
00083 {
00084   // TODO(phli): Add InputScalar, InputPacket.  Check references to
00085   // current Scalar/Packet to see if the intent is Input or Output.
00086   typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
00087   typedef traits<XprType> XprTraits;
00088   typedef typename XprType::Nested XprTypeNested;
00089   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
00090   static const int NumDimensions = XprTraits::NumDimensions;
00091   static const int Layout = XprTraits::Layout;
00092 };
00093 
00094 template<typename UnaryOp, typename XprType>
00095 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
00096 {
00097   typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
00098 };
00099 
00100 template<typename UnaryOp, typename XprType>
00101 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
00102 {
00103   typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
00104 };
00105 
00106 }  // end namespace internal
00107 
00108 
00109 
00110 template<typename UnaryOp, typename XprType>
00111 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
00112 {
00113   public:
00114     // TODO(phli): Add InputScalar, InputPacket.  Check references to
00115     // current Scalar/Packet to see if the intent is Input or Output.
00116     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
00117     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00118     typedef Scalar CoeffReturnType;
00119     typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
00120     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
00121     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
00122 
00123     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
00124       : m_xpr(xpr), m_functor(func) {}
00125 
00126     EIGEN_DEVICE_FUNC
00127     const UnaryOp& functor() const { return m_functor; }
00128 
00130     EIGEN_DEVICE_FUNC
00131     const typename internal::remove_all<typename XprType::Nested>::type&
00132     nestedExpression() const { return m_xpr; }
00133 
00134   protected:
00135     typename XprType::Nested m_xpr;
00136     const UnaryOp m_functor;
00137 };
00138 
00139 
00140 namespace internal {
00141 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
00142 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
00143 {
00144   // Type promotion to handle the case where the types of the lhs and the rhs
00145   // are different.
00146   // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
00147   // current Scalar/Packet to see if the intent is Inputs or Output.
00148   typedef typename result_of<
00149       BinaryOp(typename LhsXprType::Scalar,
00150                typename RhsXprType::Scalar)>::type Scalar;
00151   typedef traits<LhsXprType> XprTraits;
00152   typedef typename promote_storage_type<
00153       typename traits<LhsXprType>::StorageKind,
00154       typename traits<RhsXprType>::StorageKind>::ret StorageKind;
00155   typedef typename promote_index_type<
00156       typename traits<LhsXprType>::Index,
00157       typename traits<RhsXprType>::Index>::type Index;
00158   typedef typename LhsXprType::Nested LhsNested;
00159   typedef typename RhsXprType::Nested RhsNested;
00160   typedef typename remove_reference<LhsNested>::type _LhsNested;
00161   typedef typename remove_reference<RhsNested>::type _RhsNested;
00162   static const int NumDimensions = XprTraits::NumDimensions;
00163   static const int Layout = XprTraits::Layout;
00164 
00165   enum {
00166     Flags = 0
00167   };
00168 };
00169 
00170 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
00171 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
00172 {
00173   typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
00174 };
00175 
00176 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
00177 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
00178 {
00179   typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
00180 };
00181 
00182 }  // end namespace internal
00183 
00184 
00185 
00186 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
00187 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
00188 {
00189   public:
00190     // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
00191     // current Scalar/Packet to see if the intent is Inputs or Output.
00192     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
00193     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00194     typedef Scalar CoeffReturnType;
00195     typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
00196     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
00197     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
00198 
00199     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
00200         : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
00201 
00202     EIGEN_DEVICE_FUNC
00203     const BinaryOp& functor() const { return m_functor; }
00204 
00206     EIGEN_DEVICE_FUNC
00207     const typename internal::remove_all<typename LhsXprType::Nested>::type&
00208     lhsExpression() const { return m_lhs_xpr; }
00209 
00210     EIGEN_DEVICE_FUNC
00211     const typename internal::remove_all<typename RhsXprType::Nested>::type&
00212     rhsExpression() const { return m_rhs_xpr; }
00213 
00214   protected:
00215     typename LhsXprType::Nested m_lhs_xpr;
00216     typename RhsXprType::Nested m_rhs_xpr;
00217     const BinaryOp m_functor;
00218 };
00219 
00220 
00221 namespace internal {
00222 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
00223 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
00224 {
00225   // Type promotion to handle the case where the types of the args are different.
00226   typedef typename result_of<
00227       TernaryOp(typename Arg1XprType::Scalar,
00228                 typename Arg2XprType::Scalar,
00229                 typename Arg3XprType::Scalar)>::type Scalar;
00230   typedef traits<Arg1XprType> XprTraits;
00231   typedef typename traits<Arg1XprType>::StorageKind StorageKind;
00232   typedef typename traits<Arg1XprType>::Index Index;
00233   typedef typename Arg1XprType::Nested Arg1Nested;
00234   typedef typename Arg2XprType::Nested Arg2Nested;
00235   typedef typename Arg3XprType::Nested Arg3Nested;
00236   typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
00237   typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
00238   typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
00239   static const int NumDimensions = XprTraits::NumDimensions;
00240   static const int Layout = XprTraits::Layout;
00241 
00242   enum {
00243     Flags = 0
00244   };
00245 };
00246 
00247 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
00248 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
00249 {
00250   typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
00251 };
00252 
00253 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
00254 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
00255 {
00256   typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
00257 };
00258 
00259 }  // end namespace internal
00260 
00261 
00262 
00263 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
00264 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
00265 {
00266   public:
00267     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
00268     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00269     typedef Scalar CoeffReturnType;
00270     typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
00271     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
00272     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
00273 
00274     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
00275         : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
00276 
00277     EIGEN_DEVICE_FUNC
00278     const TernaryOp& functor() const { return m_functor; }
00279 
00281     EIGEN_DEVICE_FUNC
00282     const typename internal::remove_all<typename Arg1XprType::Nested>::type&
00283     arg1Expression() const { return m_arg1_xpr; }
00284 
00285     EIGEN_DEVICE_FUNC
00286     const typename internal::remove_all<typename Arg2XprType::Nested>::type&
00287     arg2Expression() const { return m_arg2_xpr; }
00288 
00289     EIGEN_DEVICE_FUNC
00290     const typename internal::remove_all<typename Arg3XprType::Nested>::type&
00291     arg3Expression() const { return m_arg3_xpr; }
00292 
00293   protected:
00294     typename Arg1XprType::Nested m_arg1_xpr;
00295     typename Arg2XprType::Nested m_arg2_xpr;
00296     typename Arg3XprType::Nested m_arg3_xpr;
00297     const TernaryOp m_functor;
00298 };
00299 
00300 
00301 namespace internal {
00302 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
00303 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
00304     : traits<ThenXprType>
00305 {
00306   typedef typename traits<ThenXprType>::Scalar Scalar;
00307   typedef traits<ThenXprType> XprTraits;
00308   typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
00309                                         typename traits<ElseXprType>::StorageKind>::ret StorageKind;
00310   typedef typename promote_index_type<typename traits<ElseXprType>::Index,
00311                                       typename traits<ThenXprType>::Index>::type Index;
00312   typedef typename IfXprType::Nested IfNested;
00313   typedef typename ThenXprType::Nested ThenNested;
00314   typedef typename ElseXprType::Nested ElseNested;
00315   static const int NumDimensions = XprTraits::NumDimensions;
00316   static const int Layout = XprTraits::Layout;
00317 };
00318 
00319 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
00320 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
00321 {
00322   typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
00323 };
00324 
00325 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
00326 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
00327 {
00328   typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
00329 };
00330 
00331 }  // end namespace internal
00332 
00333 
00334 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
00335 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
00336 {
00337   public:
00338     typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
00339     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00340     typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
00341                                                     typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
00342     typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
00343     typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
00344     typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
00345 
00346     EIGEN_DEVICE_FUNC
00347     TensorSelectOp(const IfXprType& a_condition,
00348                    const ThenXprType& a_then,
00349                    const ElseXprType& a_else)
00350       : m_condition(a_condition), m_then(a_then), m_else(a_else)
00351     { }
00352 
00353     EIGEN_DEVICE_FUNC
00354     const IfXprType& ifExpression() const { return m_condition; }
00355 
00356     EIGEN_DEVICE_FUNC
00357     const ThenXprType& thenExpression() const { return m_then; }
00358 
00359     EIGEN_DEVICE_FUNC
00360     const ElseXprType& elseExpression() const { return m_else; }
00361 
00362   protected:
00363     typename IfXprType::Nested m_condition;
00364     typename ThenXprType::Nested m_then;
00365     typename ElseXprType::Nested m_else;
00366 };
00367 
00368 
00369 } // end namespace Eigen
00370 
00371 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
 All Classes Functions Variables Typedefs Enumerator