TensorPatch.h
00001 // This file is part of Eigen, a lightweight C++ template library
00002 // for linear algebra.
00003 //
00004 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
00005 //
00006 // This Source Code Form is subject to the terms of the Mozilla
00007 // Public License v. 2.0. If a copy of the MPL was not distributed
00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
00009 
00010 #ifndef EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
00011 #define EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
00012 
00013 namespace Eigen {
00014 
00022 namespace internal {
00023 template<typename PatchDim, typename XprType>
00024 struct traits<TensorPatchOp<PatchDim, XprType> > : public traits<XprType>
00025 {
00026   typedef typename XprType::Scalar Scalar;
00027   typedef traits<XprType> XprTraits;
00028   typedef typename XprTraits::StorageKind StorageKind;
00029   typedef typename XprTraits::Index Index;
00030   typedef typename XprType::Nested Nested;
00031   typedef typename remove_reference<Nested>::type _Nested;
00032   static const int NumDimensions = XprTraits::NumDimensions + 1;
00033   static const int Layout = XprTraits::Layout;
00034 };
00035 
00036 template<typename PatchDim, typename XprType>
00037 struct eval<TensorPatchOp<PatchDim, XprType>, Eigen::Dense>
00038 {
00039   typedef const TensorPatchOp<PatchDim, XprType>& type;
00040 };
00041 
00042 template<typename PatchDim, typename XprType>
00043 struct nested<TensorPatchOp<PatchDim, XprType>, 1, typename eval<TensorPatchOp<PatchDim, XprType> >::type>
00044 {
00045   typedef TensorPatchOp<PatchDim, XprType> type;
00046 };
00047 
00048 }  // end namespace internal
00049 
00050 
00051 
00052 template<typename PatchDim, typename XprType>
00053 class TensorPatchOp : public TensorBase<TensorPatchOp<PatchDim, XprType>, ReadOnlyAccessors>
00054 {
00055   public:
00056   typedef typename Eigen::internal::traits<TensorPatchOp>::Scalar Scalar;
00057   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
00058   typedef typename XprType::CoeffReturnType CoeffReturnType;
00059   typedef typename Eigen::internal::nested<TensorPatchOp>::type Nested;
00060   typedef typename Eigen::internal::traits<TensorPatchOp>::StorageKind StorageKind;
00061   typedef typename Eigen::internal::traits<TensorPatchOp>::Index Index;
00062 
00063   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPatchOp(const XprType& expr, const PatchDim& patch_dims)
00064       : m_xpr(expr), m_patch_dims(patch_dims) {}
00065 
00066     EIGEN_DEVICE_FUNC
00067     const PatchDim& patch_dims() const { return m_patch_dims; }
00068 
00069     EIGEN_DEVICE_FUNC
00070     const typename internal::remove_all<typename XprType::Nested>::type&
00071     expression() const { return m_xpr; }
00072 
00073   protected:
00074     typename XprType::Nested m_xpr;
00075     const PatchDim m_patch_dims;
00076 };
00077 
00078 
00079 // Eval as rvalue
00080 template<typename PatchDim, typename ArgType, typename Device>
00081 struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device>
00082 {
00083   typedef TensorPatchOp<PatchDim, ArgType> XprType;
00084   typedef typename XprType::Index Index;
00085   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value + 1;
00086   typedef DSizes<Index, NumDims> Dimensions;
00087   typedef typename XprType::Scalar Scalar;
00088   typedef typename XprType::CoeffReturnType CoeffReturnType;
00089   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
00090   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
00091 
00092 
00093   enum {
00094     IsAligned = false,
00095     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
00096     Layout = TensorEvaluator<ArgType, Device>::Layout,
00097     CoordAccess = false,
00098     RawAccess = false
00099  };
00100 
00101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
00102       : m_impl(op.expression(), device)
00103   {
00104     Index num_patches = 1;
00105     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
00106     const PatchDim& patch_dims = op.patch_dims();
00107     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
00108       for (int i = 0; i < NumDims-1; ++i) {
00109         m_dimensions[i] = patch_dims[i];
00110         num_patches *= (input_dims[i] - patch_dims[i] + 1);
00111       }
00112       m_dimensions[NumDims-1] = num_patches;
00113 
00114       m_inputStrides[0] = 1;
00115       m_patchStrides[0] = 1;
00116       for (int i = 1; i < NumDims-1; ++i) {
00117         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
00118         m_patchStrides[i] = m_patchStrides[i-1] * (input_dims[i-1] - patch_dims[i-1] + 1);
00119       }
00120       m_outputStrides[0] = 1;
00121       for (int i = 1; i < NumDims; ++i) {
00122         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
00123       }
00124     } else {
00125       for (int i = 0; i < NumDims-1; ++i) {
00126         m_dimensions[i+1] = patch_dims[i];
00127         num_patches *= (input_dims[i] - patch_dims[i] + 1);
00128       }
00129       m_dimensions[0] = num_patches;
00130 
00131       m_inputStrides[NumDims-2] = 1;
00132       m_patchStrides[NumDims-2] = 1;
00133       for (int i = NumDims-3; i >= 0; --i) {
00134         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
00135         m_patchStrides[i] = m_patchStrides[i+1] * (input_dims[i+1] - patch_dims[i+1] + 1);
00136       }
00137       m_outputStrides[NumDims-1] = 1;
00138       for (int i = NumDims-2; i >= 0; --i) {
00139         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
00140       }
00141     }
00142   }
00143 
00144   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
00145 
00146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
00147     m_impl.evalSubExprsIfNeeded(NULL);
00148     return true;
00149   }
00150 
00151   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
00152     m_impl.cleanup();
00153   }
00154 
00155   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
00156   {
00157     Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
00158     // Find the location of the first element of the patch.
00159     Index patchIndex = index / m_outputStrides[output_stride_index];
00160     // Find the offset of the element wrt the location of the first element.
00161     Index patchOffset = index - patchIndex * m_outputStrides[output_stride_index];
00162     Index inputIndex = 0;
00163     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
00164       for (int i = NumDims - 2; i > 0; --i) {
00165         const Index patchIdx = patchIndex / m_patchStrides[i];
00166         patchIndex -= patchIdx * m_patchStrides[i];
00167         const Index offsetIdx = patchOffset / m_outputStrides[i];
00168         patchOffset -= offsetIdx * m_outputStrides[i];
00169         inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
00170       }
00171     } else {
00172       for (int i = 0; i < NumDims - 2; ++i) {
00173         const Index patchIdx = patchIndex / m_patchStrides[i];
00174         patchIndex -= patchIdx * m_patchStrides[i];
00175         const Index offsetIdx = patchOffset / m_outputStrides[i+1];
00176         patchOffset -= offsetIdx * m_outputStrides[i+1];
00177         inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
00178       }
00179     }
00180     inputIndex += (patchIndex + patchOffset);
00181     return m_impl.coeff(inputIndex);
00182   }
00183 
00184   template<int LoadMode>
00185   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
00186   {
00187     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
00188     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
00189 
00190     Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
00191     Index indices[2] = {index, index + PacketSize - 1};
00192     Index patchIndices[2] = {indices[0] / m_outputStrides[output_stride_index],
00193                              indices[1] / m_outputStrides[output_stride_index]};
00194     Index patchOffsets[2] = {indices[0] - patchIndices[0] * m_outputStrides[output_stride_index],
00195                              indices[1] - patchIndices[1] * m_outputStrides[output_stride_index]};
00196 
00197     Index inputIndices[2] = {0, 0};
00198     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
00199       for (int i = NumDims - 2; i > 0; --i) {
00200         const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
00201                                    patchIndices[1] / m_patchStrides[i]};
00202         patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
00203         patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
00204 
00205         const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i],
00206                                     patchOffsets[1] / m_outputStrides[i]};
00207         patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i];
00208         patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i];
00209 
00210         inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
00211         inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
00212       }
00213     } else {
00214       for (int i = 0; i < NumDims - 2; ++i) {
00215         const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
00216                                    patchIndices[1] / m_patchStrides[i]};
00217         patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
00218         patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
00219 
00220         const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i+1],
00221                                     patchOffsets[1] / m_outputStrides[i+1]};
00222         patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i+1];
00223         patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i+1];
00224 
00225         inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
00226         inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
00227       }
00228     }
00229     inputIndices[0] += (patchIndices[0] + patchOffsets[0]);
00230     inputIndices[1] += (patchIndices[1] + patchOffsets[1]);
00231 
00232     if (inputIndices[1] - inputIndices[0] == PacketSize - 1) {
00233       PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]);
00234       return rslt;
00235     }
00236     else {
00237       EIGEN_ALIGN_MAX CoeffReturnType values[PacketSize];
00238       values[0] = m_impl.coeff(inputIndices[0]);
00239       values[PacketSize-1] = m_impl.coeff(inputIndices[1]);
00240       for (int i = 1; i < PacketSize-1; ++i) {
00241         values[i] = coeff(index+i);
00242       }
00243       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
00244       return rslt;
00245     }
00246   }
00247 
00248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
00249     const double compute_cost = NumDims * (TensorOpCost::DivCost<Index>() +
00250                                            TensorOpCost::MulCost<Index>() +
00251                                            2 * TensorOpCost::AddCost<Index>());
00252     return m_impl.costPerCoeff(vectorized) +
00253            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
00254   }
00255 
00256   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
00257 
00258  protected:
00259   Dimensions m_dimensions;
00260   array<Index, NumDims> m_outputStrides;
00261   array<Index, NumDims-1> m_inputStrides;
00262   array<Index, NumDims-1> m_patchStrides;
00263 
00264   TensorEvaluator<ArgType, Device> m_impl;
00265 };
00266 
00267 } // end namespace Eigen
00268 
00269 #endif // EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
 All Classes Functions Variables Typedefs Enumerator