TensorArgMax.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com>
00005 //                    Benoit Steiner <benoit.steiner.goog@gmail.com>
00006 //
00007 // This Source Code Form is subject to the terms of the Mozilla
00008 // Public License v. 2.0. If a copy of the MPL was not distributed
00009 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00010 
00011 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
00012 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
00013 
00014 namespace Eigen {
00015 namespace internal {
00016 
00024 template<typename XprType>
00025 struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType>
00026 {
00027   typedef traits<XprType> XprTraits;
00028   typedef typename XprTraits::StorageKind StorageKind;
00029   typedef typename XprTraits::Index Index;
00030   typedef Tuple<Index, typename XprTraits::Scalar> Scalar;
00031   typedef typename XprType::Nested Nested;
00032   typedef typename remove_reference<Nested>::type _Nested;
00033   static const int NumDimensions = XprTraits::NumDimensions;
00034   static const int Layout = XprTraits::Layout;
00035 };
00036 
00037 template<typename XprType>
00038 struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense>
00039 {
00040   typedef const TensorIndexTupleOp<XprType>& type;
00041 };
00042 
00043 template<typename XprType>
00044 struct nested<TensorIndexTupleOp<XprType>, 1,
00045               typename eval<TensorIndexTupleOp<XprType> >::type>
00046 {
00047   typedef TensorIndexTupleOp<XprType> type;
00048 };
00049 
00050 }  // end namespace internal
00051 
00052 template<typename XprType>
00053 class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors>
00054 {
00055   public:
00056   typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar;
00057   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00058   typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested;
00059   typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind;
00060   typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index;
00061   typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType;
00062 
00063   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr)
00064       : m_xpr(expr) {}
00065 
00066   EIGEN_DEVICE_FUNC
00067   const typename internal::remove_all<typename XprType::Nested>::type&
00068   expression() const { return m_xpr; }
00069 
00070   protected:
00071     typename XprType::Nested m_xpr;
00072 };
00073 
00074 // Eval as rvalue
00075 template<typename ArgType, typename Device>
00076 struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
00077 {
00078   typedef TensorIndexTupleOp<ArgType> XprType;
00079   typedef typename XprType::Index Index;
00080   typedef typename XprType::Scalar Scalar;
00081   typedef typename XprType::CoeffReturnType CoeffReturnType;
00082 
00083   typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
00084   static const int NumDims = internal::array_size<Dimensions>::value;
00085 
00086   enum {
00087     IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
00088     PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
00089     BlockAccess = false,
00090     Layout = TensorEvaluator<ArgType, Device>::Layout,
00091     CoordAccess = false,  // to be implemented
00092     RawAccess = false
00093   };
00094 
00095   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
00096       : m_impl(op.expression(), device) { }
00097 
00098   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
00099     return m_impl.dimensions();
00100   }
00101 
00102   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
00103     m_impl.evalSubExprsIfNeeded(NULL);
00104     return true;
00105   }
00106   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00107     m_impl.cleanup();
00108   }
00109 
00110   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
00111   {
00112     return CoeffReturnType(index, m_impl.coeff(index));
00113   }
00114 
00115   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
00116   costPerCoeff(bool vectorized) const {
00117     return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
00118   }
00119 
00120   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
00121 
00122  protected:
00123   TensorEvaluator<ArgType, Device> m_impl;
00124 };
00125 
00126 namespace internal {
00127 
00134 template<typename ReduceOp, typename Dims, typename XprType>
00135 struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType>
00136 {
00137   typedef traits<XprType> XprTraits;
00138   typedef typename XprTraits::StorageKind StorageKind;
00139   typedef typename XprTraits::Index Index;
00140   typedef Index Scalar;
00141   typedef typename XprType::Nested Nested;
00142   typedef typename remove_reference<Nested>::type _Nested;
00143   static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
00144   static const int Layout = XprTraits::Layout;
00145 };
00146 
00147 template<typename ReduceOp, typename Dims, typename XprType>
00148 struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense>
00149 {
00150   typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>& type;
00151 };
00152 
00153 template<typename ReduceOp, typename Dims, typename XprType>
00154 struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1,
00155               typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type>
00156 {
00157   typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type;
00158 };
00159 
00160 }  // end namespace internal
00161 
00162 template<typename ReduceOp, typename Dims, typename XprType>
00163 class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
00164 {
00165   public:
00166   typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar;
00167   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00168   typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested;
00169   typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind;
00170   typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index;
00171   typedef Index CoeffReturnType;
00172 
00173   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr,
00174                                                           const ReduceOp& reduce_op,
00175                                                           const int return_dim,
00176                                                           const Dims& reduce_dims)
00177       : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
00178 
00179   EIGEN_DEVICE_FUNC
00180   const typename internal::remove_all<typename XprType::Nested>::type&
00181   expression() const { return m_xpr; }
00182 
00183   EIGEN_DEVICE_FUNC
00184   const ReduceOp& reduce_op() const { return m_reduce_op; }
00185 
00186   EIGEN_DEVICE_FUNC
00187   const Dims& reduce_dims() const { return m_reduce_dims; }
00188 
00189   EIGEN_DEVICE_FUNC
00190   int return_dim() const { return m_return_dim; }
00191 
00192   protected:
00193     typename XprType::Nested m_xpr;
00194     const ReduceOp m_reduce_op;
00195     const int m_return_dim;
00196     const Dims m_reduce_dims;
00197 };
00198 
00199 // Eval as rvalue
00200 template<typename ReduceOp, typename Dims, typename ArgType, typename Device>
00201 struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device>
00202 {
00203   typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType;
00204   typedef typename XprType::Index Index;
00205   typedef typename XprType::Scalar Scalar;
00206   typedef typename XprType::CoeffReturnType CoeffReturnType;
00207   typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType;
00208   typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions;
00209   typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions;
00210   static const int NumDims = internal::array_size<InputDimensions>::value;
00211   typedef array<Index, NumDims> StrideDims;
00212 
00213   enum {
00214     IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
00215     PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
00216     BlockAccess = false,
00217     Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
00218     CoordAccess = false,  // to be implemented
00219     RawAccess = false
00220   };
00221 
00222   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
00223       : m_orig_impl(op.expression(), device),
00224         m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
00225         m_return_dim(op.return_dim()) {
00226 
00227     gen_strides(m_orig_impl.dimensions(), m_strides);
00228     if (Layout == static_cast<int>(ColMajor)) {
00229       const Index total_size = internal::array_prod(m_orig_impl.dimensions());
00230       m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
00231     } else {
00232       const Index total_size = internal::array_prod(m_orig_impl.dimensions());
00233       m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
00234     }
00235     m_stride_div = m_strides[m_return_dim];
00236   }
00237 
00238   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
00239     return m_impl.dimensions();
00240   }
00241 
00242   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
00243     m_impl.evalSubExprsIfNeeded(NULL);
00244     return true;
00245   }
00246   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00247     m_impl.cleanup();
00248   }
00249 
00250   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
00251     const TupleType v = m_impl.coeff(index);
00252     return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
00253   }
00254 
00255   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
00256 
00257   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
00258   costPerCoeff(bool vectorized) const {
00259     const double compute_cost = 1.0 +
00260         (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
00261     return m_orig_impl.costPerCoeff(vectorized) +
00262            m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
00263   }
00264 
00265  private:
00266   EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) {
00267     if (m_return_dim < 0) {
00268       return;  // Won't be using the strides.
00269     }
00270     eigen_assert(m_return_dim < NumDims &&
00271                  "Asking to convert index to a dimension outside of the rank");
00272 
00273     // Calculate m_stride_div and m_stride_mod, which are used to
00274     // calculate the value of an index w.r.t. the m_return_dim.
00275     if (Layout == static_cast<int>(ColMajor)) {
00276       strides[0] = 1;
00277       for (int i = 1; i < NumDims; ++i) {
00278         strides[i] = strides[i-1] * dims[i-1];
00279       }
00280     } else {
00281       strides[NumDims-1] = 1;
00282       for (int i = NumDims - 2; i >= 0; --i) {
00283         strides[i] = strides[i+1] * dims[i+1];
00284       }
00285     }
00286   }
00287 
00288  protected:
00289   TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
00290   TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
00291   const int m_return_dim;
00292   StrideDims m_strides;
00293   Index m_stride_mod;
00294   Index m_stride_div;
00295 };
00296 
00297 } // end namespace Eigen
00298 
00299 #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
 All Classes Functions Variables Typedefs Enumerator