TensorBroadcasting.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
00011 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
00012 
00013 namespace Eigen {
00014 
00022 namespace internal {
00023 template<typename Broadcast, typename XprType>
00024 struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
00025 {
00026   typedef typename XprType::Scalar Scalar;
00027   typedef traits<XprType> XprTraits;
00028   typedef typename XprTraits::StorageKind StorageKind;
00029   typedef typename XprTraits::Index Index;
00030   typedef typename XprType::Nested Nested;
00031   typedef typename remove_reference<Nested>::type _Nested;
00032   static const int NumDimensions = XprTraits::NumDimensions;
00033   static const int Layout = XprTraits::Layout;
00034 };
00035 
00036 template<typename Broadcast, typename XprType>
00037 struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
00038 {
00039   typedef const TensorBroadcastingOp<Broadcast, XprType>& type;
00040 };
00041 
00042 template<typename Broadcast, typename XprType>
00043 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
00044 {
00045   typedef TensorBroadcastingOp<Broadcast, XprType> type;
00046 };
00047 
00048 template <typename Dims>
00049 struct is_input_scalar {
00050   static const bool value = false;
00051 };
00052 template <>
00053 struct is_input_scalar<Sizes<> > {
00054   static const bool value = true;
00055 };
00056 #ifndef EIGEN_EMULATE_CXX11_META_H
00057 template <typename std::size_t... Indices>
00058 struct is_input_scalar<Sizes<Indices...> > {
00059   static const bool value = (Sizes<Indices...>::total_size == 1);
00060 };
00061 #endif
00062 
00063 }  // end namespace internal
00064 
00065 
00066 
00067 template<typename Broadcast, typename XprType>
00068 class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
00069 {
00070   public:
00071   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
00072   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00073   typedef typename XprType::CoeffReturnType CoeffReturnType;
00074   typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
00075   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
00076   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
00077 
00078   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(const XprType& expr, const Broadcast& broadcast)
00079       : m_xpr(expr), m_broadcast(broadcast) {}
00080 
00081     EIGEN_DEVICE_FUNC
00082     const Broadcast& broadcast() const { return m_broadcast; }
00083 
00084     EIGEN_DEVICE_FUNC
00085     const typename internal::remove_all<typename XprType::Nested>::type&
00086     expression() const { return m_xpr; }
00087 
00088   protected:
00089     typename XprType::Nested m_xpr;
00090     const Broadcast m_broadcast;
00091 };
00092 
00093 
00094 // Eval as rvalue
00095 template<typename Broadcast, typename ArgType, typename Device>
00096 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
00097 {
00098   typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
00099   typedef typename XprType::Index Index;
00100   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
00101   typedef DSizes<Index, NumDims> Dimensions;
00102   typedef typename XprType::Scalar Scalar;
00103   typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
00104   typedef typename XprType::CoeffReturnType CoeffReturnType;
00105   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00106   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00107 
00108   enum {
00109     IsAligned = true,
00110     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
00111     Layout = TensorEvaluator<ArgType, Device>::Layout,
00112     RawAccess = false
00113   };
00114 
00115   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
00116     : m_broadcast(op.broadcast()),m_impl(op.expression(), device)
00117   {
00118     // The broadcasting op doesn't change the rank of the tensor. One can't broadcast a scalar
00119     // and store the result in a scalar. Instead one should reshape the scalar into a a N-D
00120     // tensor with N >= 1 of 1 element first and then broadcast.
00121     EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
00122     const InputDimensions& input_dims = m_impl.dimensions();
00123     const Broadcast& broadcast = op.broadcast();
00124     for (int i = 0; i < NumDims; ++i) {
00125       eigen_assert(input_dims[i] > 0);
00126       m_dimensions[i] = input_dims[i] * broadcast[i];
00127     }
00128 
00129     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
00130       m_inputStrides[0] = 1;
00131       m_outputStrides[0] = 1;
00132       for (int i = 1; i < NumDims; ++i) {
00133         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
00134         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
00135       }
00136     } else {
00137       m_inputStrides[NumDims-1] = 1;
00138       m_outputStrides[NumDims-1] = 1;
00139       for (int i = NumDims-2; i >= 0; --i) {
00140         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
00141         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
00142       }
00143     }
00144   }
00145 
00146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
00147 
00148   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
00149     m_impl.evalSubExprsIfNeeded(NULL);
00150     return true;
00151   }
00152 
00153   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00154     m_impl.cleanup();
00155   }
00156 
00157   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
00158   {
00159     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
00160       return m_impl.coeff(0);
00161     }
00162 
00163     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
00164       return coeffColMajor(index);
00165     } else {
00166       return coeffRowMajor(index);
00167     }
00168   }
00169 
00170   // TODO: attempt to speed this up. The integer divisions and modulo are slow
00171   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index) const
00172   {
00173     Index inputIndex = 0;
00174     for (int i = NumDims - 1; i > 0; --i) {
00175       const Index idx = index / m_outputStrides[i];
00176       if (internal::index_statically_eq<Broadcast>(i, 1)) {
00177         eigen_assert(idx < m_impl.dimensions()[i]);
00178         inputIndex += idx * m_inputStrides[i];
00179       } else {
00180         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
00181           eigen_assert(idx % m_impl.dimensions()[i] == 0);
00182         } else {
00183           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
00184         }
00185       }
00186       index -= idx * m_outputStrides[i];
00187     }
00188     if (internal::index_statically_eq<Broadcast>(0, 1)) {
00189       eigen_assert(index < m_impl.dimensions()[0]);
00190       inputIndex += index;
00191     } else {
00192       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
00193         eigen_assert(index % m_impl.dimensions()[0] == 0);
00194       } else {
00195         inputIndex += (index % m_impl.dimensions()[0]);
00196       }
00197     }
00198     return m_impl.coeff(inputIndex);
00199   }
00200 
00201   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index) const
00202   {
00203     Index inputIndex = 0;
00204     for (int i = 0; i < NumDims - 1; ++i) {
00205       const Index idx = index / m_outputStrides[i];
00206       if (internal::index_statically_eq<Broadcast>(i, 1)) {
00207         eigen_assert(idx < m_impl.dimensions()[i]);
00208         inputIndex += idx * m_inputStrides[i];
00209       } else {
00210         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
00211           eigen_assert(idx % m_impl.dimensions()[i] == 0);
00212         } else {
00213           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
00214         }
00215       }
00216       index -= idx * m_outputStrides[i];
00217     }
00218     if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
00219       eigen_assert(index < m_impl.dimensions()[NumDims-1]);
00220       inputIndex += index;
00221     } else {
00222       if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
00223         eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
00224       } else {
00225         inputIndex += (index % m_impl.dimensions()[NumDims-1]);
00226       }
00227     }
00228     return m_impl.coeff(inputIndex);
00229   }
00230 
00231   template<int LoadMode>
00232   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
00233   {
00234     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
00235       return internal::pset1<PacketReturnType>(m_impl.coeff(0));
00236     }
00237 
00238     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
00239       return packetColMajor<LoadMode>(index);
00240     } else {
00241       return packetRowMajor<LoadMode>(index);
00242     }
00243   }
00244 
00245   // Ignore the LoadMode and always use unaligned loads since we can't guarantee
00246   // the alignment at compile time.
00247   template<int LoadMode>
00248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const
00249   {
00250     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
00251     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
00252 
00253     const Index originalIndex = index;
00254 
00255     Index inputIndex = 0;
00256     for (int i = NumDims - 1; i > 0; --i) {
00257       const Index idx = index / m_outputStrides[i];
00258       if (internal::index_statically_eq<Broadcast>(i, 1)) {
00259         eigen_assert(idx < m_impl.dimensions()[i]);
00260         inputIndex += idx * m_inputStrides[i];
00261       } else {
00262         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
00263           eigen_assert(idx % m_impl.dimensions()[i] == 0);
00264         } else {
00265           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
00266         }
00267       }
00268       index -= idx * m_outputStrides[i];
00269     }
00270     Index innermostLoc;
00271     if (internal::index_statically_eq<Broadcast>(0, 1)) {
00272       eigen_assert(index < m_impl.dimensions()[0]);
00273       innermostLoc = index;
00274     } else {
00275       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
00276         eigen_assert(index % m_impl.dimensions()[0] == 0);
00277         innermostLoc = 0;
00278       } else {
00279         innermostLoc = index % m_impl.dimensions()[0];
00280       }
00281     }
00282     inputIndex += innermostLoc;
00283 
00284     // Todo: this could be extended to the second dimension if we're not
00285     // broadcasting alongside the first dimension, and so on.
00286     if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
00287       return m_impl.template packet<Unaligned>(inputIndex);
00288     } else {
00289       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
00290       values[0] = m_impl.coeff(inputIndex);
00291       for (int i = 1; i < PacketSize; ++i) {
00292         values[i] = coeffColMajor(originalIndex+i);
00293       }
00294       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
00295       return rslt;
00296     }
00297   }
00298 
00299   template<int LoadMode>
00300   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const
00301   {
00302     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
00303     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
00304 
00305     const Index originalIndex = index;
00306 
00307     Index inputIndex = 0;
00308     for (int i = 0; i < NumDims - 1; ++i) {
00309       const Index idx = index / m_outputStrides[i];
00310       if (internal::index_statically_eq<Broadcast>(i, 1)) {
00311         eigen_assert(idx < m_impl.dimensions()[i]);
00312         inputIndex += idx * m_inputStrides[i];
00313       } else {
00314         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
00315           eigen_assert(idx % m_impl.dimensions()[i] == 0);
00316         } else {
00317           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
00318         }
00319       }
00320       index -= idx * m_outputStrides[i];
00321     }
00322     Index innermostLoc;
00323     if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
00324       eigen_assert(index < m_impl.dimensions()[NumDims-1]);
00325       innermostLoc = index;
00326     } else {
00327       if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
00328         eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
00329         innermostLoc = 0;
00330       } else {
00331         innermostLoc = index % m_impl.dimensions()[NumDims-1];
00332       }
00333     }
00334     inputIndex += innermostLoc;
00335 
00336     // Todo: this could be extended to the second dimension if we're not
00337     // broadcasting alongside the first dimension, and so on.
00338     if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
00339       return m_impl.template packet<Unaligned>(inputIndex);
00340     } else {
00341       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
00342       values[0] = m_impl.coeff(inputIndex);
00343       for (int i = 1; i < PacketSize; ++i) {
00344         values[i] = coeffRowMajor(originalIndex+i);
00345       }
00346       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
00347       return rslt;
00348     }
00349   }
00350 
00351   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
00352   costPerCoeff(bool vectorized) const {
00353     double compute_cost = TensorOpCost::AddCost<Index>();
00354     if (NumDims > 0) {
00355       for (int i = NumDims - 1; i > 0; --i) {
00356         compute_cost += TensorOpCost::DivCost<Index>();
00357         if (internal::index_statically_eq<Broadcast>(i, 1)) {
00358           compute_cost +=
00359               TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
00360         } else {
00361           if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
00362             compute_cost += TensorOpCost::MulCost<Index>() +
00363                             TensorOpCost::ModCost<Index>() +
00364                             TensorOpCost::AddCost<Index>();
00365           }
00366         }
00367         compute_cost +=
00368             TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
00369       }
00370     }
00371     return m_impl.costPerCoeff(vectorized) +
00372            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
00373   }
00374 
00375   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
00376 
00377   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
00378 
00379   Broadcast functor() const { return m_broadcast; }
00380 
00381  protected:
00382   const Broadcast m_broadcast;
00383   Dimensions m_dimensions;
00384   array<Index, NumDims> m_outputStrides;
00385   array<Index, NumDims> m_inputStrides;
00386   TensorEvaluator<ArgType, Device> m_impl;
00387 };
00388 
00389 
00390 } // end namespace Eigen
00391 
00392 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
 All Classes Functions Variables Typedefs Enumerator